Pytorch create buckets with same lengths
---------------------------------------
Let us use a class create by chat gpt to generate batches of sequences of same lengths.

#### Load the data with a custom dataloader

The following custom dataset is very similar to that we created in [tokenizing_sentences](creating_tokenizer_for_all_sentences_3.ipynb).

In [57]:
%%writefile wolof-translate/wolof_translate/utils/bucket_iterator.py

import torch
import numpy as np
from torch.utils.data import Sampler
from torch.nn.utils.rnn import pad_sequence
from math import ceil

class SequenceLengthBatchSampler(Sampler):
    def __init__(self, dataset, boundaries, batch_sizes):
        self.dataset = dataset
        self.boundaries = boundaries
        self.batch_sizes = batch_sizes

    def __iter__(self):
        indices = list(range(len(self.dataset)))  # Get indices of the dataset
        sorted_indices = sorted(indices, key=lambda i: max(len(self.dataset[i][0]), len(self.dataset[i][1])))  # Sort indices based on sequence length
        self.batches = []

        # Group indices into batches of sequences with the same length
        for boundary in self.boundaries:
            batch = [i for i in sorted_indices if len(self.dataset[i][0]) <= boundary]  # Filter indices based on length boundary
            self.batches.append(batch)
            sorted_indices = [i for i in sorted_indices if i not in batch]  # Remove processed indices

        # Add remaining indices to the last batch
        self.batches.append(sorted_indices)

        # Yield batches with the corresponding batch sizes
        for batch_indices, batch_size in zip(self.batches, self.batch_sizes):
            num_batches = len(batch_indices) // batch_size
            for i in range(num_batches):
                yield batch_indices[i * batch_size: (i + 1) * batch_size]

            remaining_indices = len(batch_indices) % batch_size
            if remaining_indices > 0:
                yield batch_indices[-remaining_indices:]

    def __len__(self):
        return sum(len(batch) // batch_size + 1 for batch, batch_size in zip(self.batches, self.batch_sizes))

class BucketSampler(Sampler):
    def __init__(self, dataset, batch_size, sort_key=lambda x: max(len(x[0]), len(x[1]))):
        self.dataset = dataset
        self.batch_size = batch_size
        self.sort_key = sort_key

    def __iter__(self):
        indices = np.argsort([self.sort_key(self.dataset[i]) for i in range(len(self.dataset))])
        batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
        if self.batch_size > 1:
            np.random.shuffle(batches)
        for batch in batches:
            yield batch.tolist()

    def __len__(self):
        return ceil(len(self.dataset) / self.batch_size)


def collate_fn(batch):
    # Separate the input sequences, target sequences, and attention masks
    input_seqs, input_masks, target_seqs, target_masks = zip(*batch)

    # Pad the input sequences to have the same length
    padded_input_seqs = pad_sequence(input_seqs, batch_first=True)

    # Pad the target sequences to have the same length
    padded_target_seqs = pad_sequence(target_seqs, batch_first=True)

    # Pad the input masks to have the same length
    padded_input_masks = pad_sequence(input_masks, batch_first=True)

    # Pad the labels masks to have the same length
    padded_target_masks = pad_sequence(target_masks, batch_first=True)

    return padded_input_seqs, padded_input_masks, padded_target_seqs, padded_target_masks


Overwriting wolof-translate/wolof_translate/utils/bucket_iterator.py


In [58]:
%run wolof-translate/wolof_translate/data/dataset_v4.py
%run wolof-translate/wolof_translate/utils/bucket_iterator.py

-----------------

Let us create two datasets. One for the training and another for the validation. We need to upload and split the sentences before.

In [59]:
from wolof_translate.utils.split_with_valid import split_data
from wolof_translate.data.dataset_v4 import SentenceDataset
# from wolof_translate.utils.bucket_iterator import SameLengthBatchSampler, collate_fn
from transformers import T5TokenizerFast

# split the data
split_data(random_state=0, csv_file='corpora_v6.csv')

# tokenizer
tokenizer = T5TokenizerFast('wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v5.model')

# load the train data
valid_dataset = SentenceDataset('data/extractions/new_data/valid_set.csv', tokenizer)

sampler = SequenceLengthBatchSampler(valid_dataset, [2, 31, 59, 87, 115, 143, 171], [256, 128, 64, 32, 16, 8, 4, 2])
dataloader = torch.utils.data.DataLoader(valid_dataset, batch_sampler=sampler, collate_fn=collate_fn)


In [60]:
i = 0
for input_, mask_, labels, _ in dataloader:
    i+=1
    print(input_.shape)

torch.Size([1, 2])
torch.Size([128, 11])
torch.Size([85, 31])
torch.Size([42, 59])
torch.Size([16, 83])
torch.Size([6, 115])
torch.Size([4, 139])
torch.Size([1, 153])
torch.Size([2, 224])
torch.Size([1, 248])


In [61]:
from tqdm import tqdm

In [62]:
progress = tqdm(dataloader)

i = 0

for batch in progress:
    
    i += 1
    
    progress.set_description(f"Batch {i}")

Batch 10: 100%|██████████| 10/10 [00:01<00:00,  8.64it/s]


In [63]:
len(dataloader)

10

In [64]:
i

10

--------------------------

Let us create two datasets. One for the training and another for the validation. We need to upload and split the sentences before.

In [71]:
from wolof_translate.utils.split_with_valid import split_data
from wolof_translate.data.dataset_v4 import SentenceDataset
# from wolof_translate.utils.bucket_iterator import SameLengthBatchSampler, collate_fn
from transformers import T5TokenizerFast

# split the data
split_data(random_state=0, csv_file='corpora_v6.csv')

# tokenizer
tokenizer = T5TokenizerFast('wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v5.model')

# load the train data
valid_dataset = SentenceDataset('data/extractions/new_data/valid_set.csv', tokenizer)

sampler = BucketSampler(valid_dataset, 16)
dataloader = torch.utils.data.DataLoader(valid_dataset, batch_sampler=sampler, collate_fn=collate_fn)


In [72]:
i = 0
for input_, mask_, labels, _ in dataloader:
    i+=1
    print(input_.shape)

torch.Size([16, 8])
torch.Size([16, 10])
torch.Size([16, 14])
torch.Size([16, 59])
torch.Size([14, 248])
torch.Size([16, 39])
torch.Size([16, 23])
torch.Size([16, 6])
torch.Size([16, 7])
torch.Size([16, 12])
torch.Size([16, 16])
torch.Size([16, 8])
torch.Size([16, 6])
torch.Size([16, 10])
torch.Size([16, 47])
torch.Size([16, 83])
torch.Size([16, 9])
torch.Size([16, 30])


In [73]:
from tqdm import tqdm

In [74]:
progress = tqdm(dataloader)

i = 0

for batch in progress:
    
    i += 1
    
    progress.set_description(f"Batch {i}")

Batch 18: 100%|██████████| 18/18 [00:00<00:00, 58.15it/s]


In [75]:
len(dataloader)

18

In [76]:
i

18