Pytorch create buckets with same lengths
---------------------------------------
Let us use a class create by chat gpt to generate batches of sequences of same lengths.

#### Load the data with a custom dataloader

The following custom dataset is very similar to that we created in [tokenizing_sentences](creating_tokenizer_for_all_sentences_3.ipynb).

In [1]:
%%writefile wolof-translate/wolof_translate/utils/bucket_iterator.py

import torch
from torch.utils.data import Sampler
from torch.nn.utils.rnn import pad_sequence

class SequenceLengthBatchSampler(Sampler):
    def __init__(self, dataset, boundaries, batch_sizes):
        self.dataset = dataset
        self.boundaries = boundaries
        self.batch_sizes = batch_sizes

    def __iter__(self):
        indices = list(range(len(self.dataset)))  # Get indices of the dataset
        sorted_indices = sorted(indices, key=lambda i: max(len(self.dataset[i][0]), len(self.dataset[i][1])))  # Sort indices based on sequence length
        self.batches = []

        # Group indices into batches of sequences with the same length
        for boundary in self.boundaries:
            batch = [i for i in sorted_indices if len(self.dataset[i][0]) <= boundary]  # Filter indices based on length boundary
            self.batches.append(batch)
            sorted_indices = [i for i in sorted_indices if i not in batch]  # Remove processed indices

        # Add remaining indices to the last batch
        self.batches.append(sorted_indices)

        # Yield batches with the corresponding batch sizes
        for batch_indices, batch_size in zip(self.batches, self.batch_sizes):
            for i in range(0, len(batch_indices), batch_size):
                yield batch_indices[i:i + batch_size]

    def __len__(self):
        return sum(len(batch) // batch_size + 1 for batch, batch_size in zip(self.batches, self.batch_sizes))

def collate_fn(batch):
    # Separate the input sequences, target sequences, and attention masks
    input_seqs, input_masks, target_seqs, target_masks = zip(*batch)

    # Pad the input sequences to have the same length
    padded_input_seqs = pad_sequence(input_seqs, batch_first=True)

    # Pad the target sequences to have the same length
    padded_target_seqs = pad_sequence(target_seqs, batch_first=True)

    # Pad the input masks to have the same length
    padded_input_masks = pad_sequence(input_masks, batch_first=True)

    # Pad the labels masks to have the same length
    padded_target_masks = pad_sequence(target_masks, batch_first=True)

    return padded_input_seqs, padded_input_masks, padded_target_seqs, padded_target_masks


Overwriting wolof-translate/wolof_translate/utils/bucket_iterator.py


In [2]:
%run wolof-translate/wolof_translate/data/dataset_v4.py
%run wolof-translate/wolof_translate/utils/bucket_iterator.py

  from .autonotebook import tqdm as notebook_tqdm


Let us create two datasets. One for the training and another for the validation. We need to upload and split the sentences before.

In [3]:
from wolof_translate.utils.split_with_valid import split_data
from wolof_translate.data.dataset_v4 import SentenceDataset
# from wolof_translate.utils.bucket_iterator import SameLengthBatchSampler, collate_fn
from transformers import T5TokenizerFast

# split the data
split_data(random_state=0, csv_file='ad_sentences.csv')

# tokenizer
tokenizer = T5TokenizerFast('wolof-translate/wolof_translate/tokenizers/t5_tokenizers/tokenizer_v4.model')

# load the train data
train_dataset = SentenceDataset('data/extractions/new_data/train_set.csv', tokenizer)

sampler = SequenceLengthBatchSampler(train_dataset, [2, 23, 43, 64, 84, 104], [256, 128, 64, 32, 16, 8, 4])
dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=sampler, collate_fn=collate_fn)


In [4]:
i = 0
for input_, mask_, labels, _ in dataloader:
    i+=1
    print(input_.shape)

torch.Size([4, 2])
torch.Size([128, 6])
torch.Size([128, 7])
torch.Size([128, 8])
torch.Size([128, 8])
torch.Size([128, 9])
torch.Size([128, 10])
torch.Size([128, 11])
torch.Size([128, 12])
torch.Size([128, 13])
torch.Size([128, 15])
torch.Size([128, 20])
torch.Size([27, 23])
torch.Size([62, 43])
torch.Size([32, 63])
torch.Size([2, 64])
torch.Size([16, 72])
torch.Size([11, 84])
torch.Size([8, 93])
torch.Size([3, 99])
torch.Size([4, 151])
torch.Size([4, 206])


In [5]:
len(dataloader)

23

In [6]:
i

22