In [1]:
import torch
from torch.utils.data import Dataset, DataLoader

In [2]:
# load les données

fichier = open('villes.txt')
donnees = fichier.read()
villes = donnees.replace('\n', ',').split(',')
villes = [ville for ville in villes if len(ville) > 2]
villes = sorted(villes, key=len)

In [3]:
# création du vocabulaire

vocabulaire = sorted(list(set(''.join(villes))))
vocabulaire = ["<SOS>", "<EOS>"] + vocabulaire

# pour convertir char <-> int
char_to_int = {}
int_to_char = {}

for (c, i) in zip(vocabulaire, range(len(vocabulaire))):
    char_to_int[c] = i
    int_to_char[i] = c

In [4]:
num_sequences = len(villes)
max_len = max([len(ville) for ville in villes]) + 2 # account for <SOS> and <EOS>

X = torch.zeros((num_sequences, max_len))

for i in range(num_sequences):
    X[i] = torch.tensor([char_to_int['<SOS>']] + [char_to_int[c] for c in villes[i]] + [char_to_int['<EOS>']] + [-1] * (max_len - len(villes[i]) - 2))

"""
n_split = int(0.9*X.shape[0])

X_train = X[:n_split]
X_val = X[n_split:]
"""

In [20]:
def get_batch(split, batch_size):
    # returns a batch, according to the data pipeline written in the W&B report
    data = X_train if split == 'train' else X_val

    idx_seed = torch.randint(high=data.shape[0], size=(1,)).item() #sample la ligne seed autour de laquelle on va piocher les exemples

    idx = torch.randint(low = max(0, idx_seed - 4 * batch_size), high = min(data.shape[0], idx_seed + 4 * batch_size), size=(batch_size,)) #samples les indices du batch à produire
    #pq 4 ? bon compromis entre assez large pour pas bcp de répétitions, assez petit pour pas bcp de padding (cf data.ipynb)

    idx_sorted, _ = torch.sort(idx) #on les ordonne pour recuperer facilement la longueur de la plus grande seq. du batch

    X_batch = data[idx_sorted] #on extrait la matrice qui va produire Xb et Yb

    max_len_batch = torch.sum(torch.ne(X_batch[-1], -1)) #longueur de la plus grande seq. du batch : torch.ne(X_batch[-1], -1) crée une matrice masque, avec True si diff de -1, False si egal a -1

    Xb = X_batch[:, :max_len_batch-1] #on selectionne que jusqu'a la len max - 1 (<EOS> du plus long inutile) (le reste n'est que padding)
    Yb = X_batch[:, 1:max_len_batch] #meme que Xb, mais décalé de 1 (avec le <EOS> mais sans le <SOS>)

    #Xb[Xb == 1] = -1 #on remplace le <EOS> par du padding (totalement optionnel)

    Xb = Xb.pin_memory().to('cuda', non_blocking=True)
    Yb = Yb.pin_memory().to('cuda', non_blocking=True)

    return Xb, Yb

#todo : report max-min pour voir cb on perd de compute

In [34]:
Xb, Yb = get_batch('val', 1024)

In [35]:
for i in range(Xb.shape[0]):
    nom_X = ""
    for id in Xb[i]:
        if int(id.item()) == -1:
            nom_X += "<pad>"
        else:
            nom_X += int_to_char[int(id.item())]

    nom_Y = ""
    for id in Yb[i]:
        if int(id.item()) == -1:
            nom_Y += "<pad>"
        else:
            nom_Y += int_to_char[int(id.item())]
            
    print(nom_X)
    print(nom_Y)
    print("**************************************************")

<SOS>la chapelle-sur-coise<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
la chapelle-sur-coise<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
**************************************************
<SOS>saint-germain-laprade<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
saint-germain-laprade<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
**************************************************
<SOS>saint-martin-labouval<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
saint-martin-labouval<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
**************************************************
<SOS>chevry-sous-le-bignon<EOS><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
chevry-sous-le-bignon<EOS><pad><pad><pad><pad><pad><pad><pad><p

### comment déterminer le 4 dans le sample des idx ?

In [432]:
batch_size = 1024

for i in range(50):
    idx_seed = torch.randint(high=X_train.shape[0], size=(1,)).item()
    
    idx = torch.randint(low = max(0, idx_seed - 4 * batch_size), high = min(X_train.shape[0], idx_seed + 4 * batch_size), size=(batch_size,))
    idx_sorted, _ = torch.sort(idx)

    X_batch = X_train[idx_sorted]

    min_len_batch = torch.sum(torch.ne(X_batch[0], -1))
    max_len_batch = torch.sum(torch.ne(X_batch[-1], -1))

    if(min_len_batch.item() > max_len_batch.item()):
        print("ouille")
    
    print(str(min_len_batch.item()) + " | " + str(max_len_batch.item()))
    print("***************************************************************")

8 | 10
***************************************************************
5 | 8
***************************************************************
13 | 19
***************************************************************
12 | 17
***************************************************************
5 | 9
***************************************************************
17 | 23
***************************************************************
5 | 9
***************************************************************
14 | 21
***************************************************************
12 | 18
***************************************************************
8 | 10
***************************************************************
19 | 23
***************************************************************
10 | 12
***************************************************************
9 | 11
***************************************************************
9 | 12
***************************************************************
13

In [228]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader, BatchSampler, Sampler

In [230]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]

class BucketBatchSampler(BatchSampler):
    def __iter__(self):
        sorted_indices = sorted(list(range(len(self.sampler.data_source))), key=lambda i: len(self.sampler.data_source[i]))
        return iter([sorted_indices[i:i+self.batch_size] for i in range(0, len(sorted_indices), self.batch_size)])

def collate_fn(batch):
    return pad_sequence([torch.tensor(seq) for seq in batch], batch_first=True)

data = [  # example data: list of sequences
    [1, 2, 3, 4],
    [5, 6, 7],
    [8, 9],
    [10]
]
dataset = MyDataset(data)
bucket_sampler = BucketBatchSampler(Sampler(dataset), batch_size=2, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=bucket_sampler, collate_fn=collate_fn)

for batch in dataloader:
    print(batch)

AttributeError: 'Sampler' object has no attribute 'data_source'