In [105]:
import torch
from torch.utils.data import Dataset, DataLoader, Sampler, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist

from typing import Optional, Tuple, List
try:
    from collections.abc import Iterable
except ImportError:
    from collections import Iterable
    
import numpy as np
import random

from transformers.file_utils import cached_property

from fairseq.data.data_utils import batch_by_size

'''
batch_by_size source is from 
https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/data_utils.py

it is quite annoying to install cython extention, 
so it looks better to install fairseq

git clone https://github.com/facebookresearch/fairseq &&\
cd fairseq &&\
pip install -e . ## i recommend you to implement editable installation
'''


def _reset_seeds(seed=1234):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

class RandomDataset(Dataset):
    def __init__(self, num_data=100000, seed=2023):
        _reset_seeds(seed)
        self.src_lens = list(np.random.randint(60,2048,num_data))
#         print('seq_lens list from numpy random: {}'.format(self.src_lens))
        self.samples = [np.random.rand(src_len) for src_len in self.src_lens]
        self.max_target_length = 4096

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx], self.src_lens[idx]

    def make_sortish_sampler(self, batch_size, distributed=False, shuffle=False, **kwargs):
        if distributed:
            return DistributedSortishSampler(self, batch_size, shuffle=shuffle, **kwargs)
        else:
            return SortishSampler(self.src_lens, batch_size, shuffle=shuffle)

    def make_dynamic_sampler(self, max_tokens_per_batch=1024, **kwargs):
        sorted_indices = list(self.make_sortish_sampler(max_tokens_per_batch, shuffle=False))

        def num_tokens_fn(i):
            return min(self.src_lens[i], self.max_target_length)

        # call fairseq cython function
        batch_sampler: List[List[int]] = batch_by_size(
            sorted_indices,
            num_tokens_fn=num_tokens_fn,
            num_tokens_vec=None,
            max_tokens=max_tokens_per_batch,
            max_sentences=None,
#             required_batch_size_multiple=64,
            required_batch_size_multiple=1,
            fixed_shapes=None,
        )
        """
        Yield mini-batches of indices bucketed by size. Batches may contain
        sequences of different lengths.

        Args:
            indices (List[int]): ordered list of dataset indices
            num_tokens_fn (callable): function that returns the number of tokens at
                a given index
            num_tokens_vec (List[int], optional): precomputed vector of the number
                of tokens for each index in indices (to enable faster batch generation)
            max_tokens (int, optional): max number of tokens in each batch
                (default: None).
            max_sentences (int, optional): max number of sentences in each
                batch (default: None).
            required_batch_size_multiple (int, optional): require batch size to
                be less than N or a multiple of N (default: 1).
            fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
                only be created with the given shapes. *max_sentences* and
                *required_batch_size_multiple* will be ignored (default: None).
        """
        shuffled_batches = [batch_sampler[i] for i in np.random.permutation(range(len(batch_sampler)))]
        
        # move the largest batch to the front to OOM quickly (uses an approximation for padding)
        approximate_toks_per_batch = [max(self.src_lens[i] for i in batch) * len(batch) for batch in shuffled_batches]
        largest_batch_idx = np.argmax(approximate_toks_per_batch)
        shuffled_batches[0], shuffled_batches[largest_batch_idx] = (
            shuffled_batches[largest_batch_idx],
            shuffled_batches[0],
        )
        
        return shuffled_batches

'''
fairseq batch_by_size (bbs) -> to huggingface style porting is from 
https://github.com/huggingface/transformers/blob/main/examples/research_projects/seq2seq-distillation/finetune.py#L267C37-L267C57
https://github.com/huggingface/transformers/blob/main/examples/research_projects/seq2seq-distillation/utils.py#L160

someone can implement more precisely and wrap huggingface dataset
'''

def sortish_sampler_indices(
    data: List, 
    bs: int, 
    shuffle=True
) -> np.array:
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."
    if not shuffle:
        return np.argsort(np.array(data) * -1)

    def key_fn(i):
        return data[i]

    idxs = np.random.permutation(len(data))
    sz = bs * 50
    ck_idx = [idxs[i : i + sz] for i in range(0, len(idxs), sz)]
    sort_idx = np.concatenate([sorted(s, key=key_fn, reverse=True) for s in ck_idx])
    sz = bs
    ck_idx = [sort_idx[i : i + sz] for i in range(0, len(sort_idx), sz)]
    max_ck = np.argmax([key_fn(ck[0]) for ck in ck_idx])  # find the chunk with the largest key,
    ck_idx[0], ck_idx[max_ck] = ck_idx[max_ck], ck_idx[0]  # then make sure it goes first.
    sort_idx = np.concatenate(np.random.permutation(ck_idx[1:])) if len(ck_idx) > 1 else np.array([], dtype=int)
    sort_idx = np.concatenate((ck_idx[0], sort_idx))
    return sort_idx

class SortishSampler(Sampler):
    "Go through the text data by order of src length with a bit of randomness. From fastai repo."

    def __init__(
        self, 
        data, 
        batch_size, 
        shuffle=True
    ):
        self.data, self.bs, self.shuffle = data, batch_size, shuffle

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
        return iter(sortish_sampler_indices(self.data, self.bs, shuffle=self.shuffle))

class DistributedSortishSampler(Sampler):
    """Copied from torch DistributedSampler"""
    """
    https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
    
    Args: 
        dataset – Dataset used for sampling.

        num_replicas (int, optional) – Number of processes participating in distributed training. 
        By default, world_size is retrieved from the current distributed group.

        rank (int, optional) – Rank of the current process within num_replicas. 
        By default, rank is retrieved from the current distributed group.

        shuffle (bool, optional) – If True (default), sampler will shuffle the indices.

        seed (int, optional) – random seed used to shuffle the sampler if shuffle=True. 
        This number should be identical across all processes in the distributed group. 
        Default: 0.

        drop_last (bool, optional) – if True, then the sampler will drop the tail of the data 
        to make it evenly divisible across the number of replicas. 
        If False, the sampler will add extra indices to make the data evenly divisible across the replicas. 
        Default: False.
    """
    def __init__(
        self, 
        dataset, 
        batch_size, 
        num_replicas=None, 
        rank=None, 
        add_extra_examples=True, 
        shuffle=True
    ):
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        if add_extra_examples:
            self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
            self.total_size = self.num_samples * self.num_replicas
        else:
            self.total_size = len(dataset)
            self.num_samples = len(self.available_indices)
        self.batch_size = batch_size
        self.add_extra_examples = add_extra_examples
        self.shuffle = shuffle

    def __iter__(self) -> Iterable:
        g = torch.Generator()
        g.manual_seed(self.epoch)

        sortish_data = [self.dataset.src_lens[i] for i in self.available_indices]
        sortish_indices = sortish_sampler_indices(sortish_data, self.batch_size, shuffle=self.shuffle)
        indices = [self.available_indices[i] for i in sortish_indices]
        assert len(indices) == self.num_samples
        return iter(indices)

    @cached_property
    def available_indices(self) -> np.array:
        indices = list(range(len(self.dataset)))
        # add extra samples to make it evenly divisible
        indices += indices[: (self.total_size - len(indices))]
        assert len(indices) == self.total_size
        # subsample
        available_indices = indices[self.rank : self.total_size : self.num_replicas]
        return available_indices

    def __len__(self):
        return self.num_samples

    def set_epoch(self, epoch):
        self.epoch = epoch

In [106]:
def collator(batch):
    sample_list = []
    seq_len_list = []
    for (sample, seq_len) in batch:
        sample_list.append(torch.FloatTensor(sample))
        seq_len_list.append(seq_len)
    max_len = max(seq_len_list)
    for i, sample in enumerate(sample_list):
        diff = max_len-seq_len_list[i]
        if diff!=0:
            sample_list[i] = torch.cat((sample, torch.zeros(diff)), dim=0)
        
    return torch.stack(sample_list, dim=0), seq_len_list

## random sampler

In [110]:
random_dataset = RandomDataset()
data_loader = DataLoader(
    random_dataset, 
    batch_size=128, 
    shuffle=False, 
    sampler=RandomSampler(random_dataset),
    batch_sampler=None, 
    num_workers=0, 
    collate_fn=collator,
    pin_memory=False, 
    drop_last=False, 
)

print("total number of iteration: {}".format(len(data_loader)))
for x in iter(data_loader):
    print("bsz : {}, total tokens: {}, total tokens w/o padding: {}, min/max seq_lens in batch: {}, padding %: {:.2f}%".format(
            x[0].size(), 
            x[0].size(0)*x[0].size(1), 
            sum(x[1]),
            (max(x[1]),min(x[1])),
            (x[0].size(0)*x[0].size(1) - sum(x[1]))/(x[0].size(0)*x[0].size(1))*100,
        )
    )

total number of iteration: 782
bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 132097, min/max seq_lens in batch: (2046, 87), padding %: 49.56%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 126377, min/max seq_lens in batch: (2047, 73), padding %: 51.77%
bsz : torch.Size([128, 2030]), total tokens: 259840, total tokens w/o padding: 141837, min/max seq_lens in batch: (2030, 87), padding %: 45.41%
bsz : torch.Size([128, 2042]), total tokens: 261376, total tokens w/o padding: 131386, min/max seq_lens in batch: (2042, 67), padding %: 49.73%
bsz : torch.Size([128, 2042]), total tokens: 261376, total tokens w/o padding: 138499, min/max seq_lens in batch: (2042, 116), padding %: 47.01%
bsz : torch.Size([128, 2042]), total tokens: 261376, total tokens w/o padding: 135383, min/max seq_lens in batch: (2042, 63), padding %: 48.20%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 131189, min/max seq_lens 

bsz : torch.Size([128, 2040]), total tokens: 261120, total tokens w/o padding: 131865, min/max seq_lens in batch: (2040, 64), padding %: 49.50%
bsz : torch.Size([128, 2000]), total tokens: 256000, total tokens w/o padding: 125576, min/max seq_lens in batch: (2000, 68), padding %: 50.95%
bsz : torch.Size([128, 2015]), total tokens: 257920, total tokens w/o padding: 132901, min/max seq_lens in batch: (2015, 63), padding %: 48.47%
bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 127742, min/max seq_lens in batch: (2046, 76), padding %: 51.22%
bsz : torch.Size([128, 2041]), total tokens: 261248, total tokens w/o padding: 131988, min/max seq_lens in batch: (2041, 67), padding %: 49.48%
bsz : torch.Size([128, 2017]), total tokens: 258176, total tokens w/o padding: 124695, min/max seq_lens in batch: (2017, 69), padding %: 51.70%
bsz : torch.Size([128, 2024]), total tokens: 259072, total tokens w/o padding: 132848, min/max seq_lens in batch: (2024, 62), padding %:

bsz : torch.Size([128, 2044]), total tokens: 261632, total tokens w/o padding: 122580, min/max seq_lens in batch: (2044, 66), padding %: 53.15%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 135414, min/max seq_lens in batch: (2038, 85), padding %: 48.09%
bsz : torch.Size([128, 2005]), total tokens: 256640, total tokens w/o padding: 130490, min/max seq_lens in batch: (2005, 81), padding %: 49.15%
bsz : torch.Size([128, 2044]), total tokens: 261632, total tokens w/o padding: 143304, min/max seq_lens in batch: (2044, 136), padding %: 45.23%
bsz : torch.Size([128, 2039]), total tokens: 260992, total tokens w/o padding: 147139, min/max seq_lens in batch: (2039, 107), padding %: 43.62%
bsz : torch.Size([128, 1999]), total tokens: 255872, total tokens w/o padding: 131033, min/max seq_lens in batch: (1999, 63), padding %: 48.79%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 141810, min/max seq_lens in batch: (2038, 94), padding 

bsz : torch.Size([128, 2034]), total tokens: 260352, total tokens w/o padding: 141075, min/max seq_lens in batch: (2034, 60), padding %: 45.81%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 147155, min/max seq_lens in batch: (2038, 62), padding %: 43.59%
bsz : torch.Size([128, 2031]), total tokens: 259968, total tokens w/o padding: 132271, min/max seq_lens in batch: (2031, 67), padding %: 49.12%
bsz : torch.Size([128, 2044]), total tokens: 261632, total tokens w/o padding: 133598, min/max seq_lens in batch: (2044, 94), padding %: 48.94%
bsz : torch.Size([128, 2043]), total tokens: 261504, total tokens w/o padding: 139205, min/max seq_lens in batch: (2043, 60), padding %: 46.77%
bsz : torch.Size([128, 2009]), total tokens: 257152, total tokens w/o padding: 129317, min/max seq_lens in batch: (2009, 81), padding %: 49.71%
bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 146789, min/max seq_lens in batch: (2045, 66), padding %:

bsz : torch.Size([128, 2043]), total tokens: 261504, total tokens w/o padding: 136607, min/max seq_lens in batch: (2043, 77), padding %: 47.76%
bsz : torch.Size([128, 2036]), total tokens: 260608, total tokens w/o padding: 127784, min/max seq_lens in batch: (2036, 69), padding %: 50.97%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 128841, min/max seq_lens in batch: (2047, 70), padding %: 50.83%
bsz : torch.Size([128, 2041]), total tokens: 261248, total tokens w/o padding: 131331, min/max seq_lens in batch: (2041, 108), padding %: 49.73%
bsz : torch.Size([128, 2015]), total tokens: 257920, total tokens w/o padding: 127249, min/max seq_lens in batch: (2015, 65), padding %: 50.66%
bsz : torch.Size([128, 2042]), total tokens: 261376, total tokens w/o padding: 138334, min/max seq_lens in batch: (2042, 77), padding %: 47.07%
bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 136237, min/max seq_lens in batch: (2045, 104), padding 

bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 141910, min/max seq_lens in batch: (2046, 61), padding %: 45.81%
bsz : torch.Size([128, 2035]), total tokens: 260480, total tokens w/o padding: 134304, min/max seq_lens in batch: (2035, 61), padding %: 48.44%
bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 137978, min/max seq_lens in batch: (2046, 72), padding %: 47.31%
bsz : torch.Size([128, 2030]), total tokens: 259840, total tokens w/o padding: 134073, min/max seq_lens in batch: (2030, 81), padding %: 48.40%
bsz : torch.Size([128, 2030]), total tokens: 259840, total tokens w/o padding: 140765, min/max seq_lens in batch: (2030, 65), padding %: 45.83%
bsz : torch.Size([128, 2016]), total tokens: 258048, total tokens w/o padding: 145157, min/max seq_lens in batch: (2016, 63), padding %: 43.75%
bsz : torch.Size([128, 2027]), total tokens: 259456, total tokens w/o padding: 128346, min/max seq_lens in batch: (2027, 60), padding %:

bsz : torch.Size([128, 2003]), total tokens: 256384, total tokens w/o padding: 138786, min/max seq_lens in batch: (2003, 112), padding %: 45.87%
bsz : torch.Size([128, 2043]), total tokens: 261504, total tokens w/o padding: 127094, min/max seq_lens in batch: (2043, 66), padding %: 51.40%
bsz : torch.Size([128, 2041]), total tokens: 261248, total tokens w/o padding: 135658, min/max seq_lens in batch: (2041, 69), padding %: 48.07%
bsz : torch.Size([128, 2031]), total tokens: 259968, total tokens w/o padding: 135383, min/max seq_lens in batch: (2031, 70), padding %: 47.92%
bsz : torch.Size([128, 2042]), total tokens: 261376, total tokens w/o padding: 147457, min/max seq_lens in batch: (2042, 77), padding %: 43.58%
bsz : torch.Size([128, 2019]), total tokens: 258432, total tokens w/o padding: 133698, min/max seq_lens in batch: (2019, 79), padding %: 48.27%
bsz : torch.Size([128, 2032]), total tokens: 260096, total tokens w/o padding: 132269, min/max seq_lens in batch: (2032, 64), padding %

bsz : torch.Size([128, 2036]), total tokens: 260608, total tokens w/o padding: 135012, min/max seq_lens in batch: (2036, 62), padding %: 48.19%
bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 135946, min/max seq_lens in batch: (2045, 62), padding %: 48.06%
bsz : torch.Size([128, 2033]), total tokens: 260224, total tokens w/o padding: 142105, min/max seq_lens in batch: (2033, 79), padding %: 45.39%
bsz : torch.Size([128, 2031]), total tokens: 259968, total tokens w/o padding: 133971, min/max seq_lens in batch: (2031, 64), padding %: 48.47%
bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 146578, min/max seq_lens in batch: (2045, 94), padding %: 44.00%
bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 136486, min/max seq_lens in batch: (2046, 72), padding %: 47.88%
bsz : torch.Size([128, 2002]), total tokens: 256256, total tokens w/o padding: 132018, min/max seq_lens in batch: (2002, 62), padding %:

bsz : torch.Size([128, 2042]), total tokens: 261376, total tokens w/o padding: 127812, min/max seq_lens in batch: (2042, 61), padding %: 51.10%
bsz : torch.Size([128, 2011]), total tokens: 257408, total tokens w/o padding: 121982, min/max seq_lens in batch: (2011, 62), padding %: 52.61%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 135511, min/max seq_lens in batch: (2047, 77), padding %: 48.28%
bsz : torch.Size([128, 2040]), total tokens: 261120, total tokens w/o padding: 138427, min/max seq_lens in batch: (2040, 66), padding %: 46.99%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 137085, min/max seq_lens in batch: (2047, 70), padding %: 47.68%
bsz : torch.Size([128, 2023]), total tokens: 258944, total tokens w/o padding: 129616, min/max seq_lens in batch: (2023, 60), padding %: 49.94%
bsz : torch.Size([128, 2023]), total tokens: 258944, total tokens w/o padding: 141958, min/max seq_lens in batch: (2023, 128), padding %

bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 136959, min/max seq_lens in batch: (2046, 90), padding %: 47.70%
bsz : torch.Size([128, 2027]), total tokens: 259456, total tokens w/o padding: 144642, min/max seq_lens in batch: (2027, 66), padding %: 44.25%
bsz : torch.Size([128, 2035]), total tokens: 260480, total tokens w/o padding: 144105, min/max seq_lens in batch: (2035, 62), padding %: 44.68%
bsz : torch.Size([128, 2002]), total tokens: 256256, total tokens w/o padding: 121320, min/max seq_lens in batch: (2002, 92), padding %: 52.66%
bsz : torch.Size([128, 2040]), total tokens: 261120, total tokens w/o padding: 131820, min/max seq_lens in batch: (2040, 61), padding %: 49.52%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 133945, min/max seq_lens in batch: (2038, 75), padding %: 48.65%
bsz : torch.Size([128, 2041]), total tokens: 261248, total tokens w/o padding: 132142, min/max seq_lens in batch: (2041, 72), padding %:

bsz : torch.Size([128, 2030]), total tokens: 259840, total tokens w/o padding: 139035, min/max seq_lens in batch: (2030, 60), padding %: 46.49%
bsz : torch.Size([128, 2027]), total tokens: 259456, total tokens w/o padding: 133603, min/max seq_lens in batch: (2027, 65), padding %: 48.51%
bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 137191, min/max seq_lens in batch: (2046, 60), padding %: 47.61%
bsz : torch.Size([128, 2043]), total tokens: 261504, total tokens w/o padding: 131528, min/max seq_lens in batch: (2043, 61), padding %: 49.70%
bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 150726, min/max seq_lens in batch: (2046, 74), padding %: 42.45%
bsz : torch.Size([128, 2027]), total tokens: 259456, total tokens w/o padding: 112678, min/max seq_lens in batch: (2027, 63), padding %: 56.57%
bsz : torch.Size([128, 1998]), total tokens: 255744, total tokens w/o padding: 138167, min/max seq_lens in batch: (1998, 67), padding %:

## sequential sampler

In [111]:
random_dataset = RandomDataset()
data_loader = DataLoader(
    random_dataset, 
    batch_size=128, 
    shuffle=False, 
    sampler=SequentialSampler(random_dataset),
    batch_sampler=None, 
    num_workers=0, 
    collate_fn=collator,
    pin_memory=False, 
    drop_last=False, 
)

print("total number of iteration: {}".format(len(data_loader)))
for x in iter(data_loader):
    print("bsz : {}, total tokens: {}, total tokens w/o padding: {}, min/max seq_lens in batch: {}, padding %: {:.2f}%".format(
            x[0].size(), 
            x[0].size(0)*x[0].size(1), 
            sum(x[1]),
            (max(x[1]),min(x[1])),
            (x[0].size(0)*x[0].size(1) - sum(x[1]))/(x[0].size(0)*x[0].size(1))*100,
        )
    )

total number of iteration: 782
bsz : torch.Size([128, 2033]), total tokens: 260224, total tokens w/o padding: 131676, min/max seq_lens in batch: (2033, 99), padding %: 49.40%
bsz : torch.Size([128, 2036]), total tokens: 260608, total tokens w/o padding: 131425, min/max seq_lens in batch: (2036, 89), padding %: 49.57%
bsz : torch.Size([128, 2040]), total tokens: 261120, total tokens w/o padding: 145690, min/max seq_lens in batch: (2040, 61), padding %: 44.21%
bsz : torch.Size([128, 2031]), total tokens: 259968, total tokens w/o padding: 142966, min/max seq_lens in batch: (2031, 81), padding %: 45.01%
bsz : torch.Size([128, 2034]), total tokens: 260352, total tokens w/o padding: 139517, min/max seq_lens in batch: (2034, 63), padding %: 46.41%
bsz : torch.Size([128, 2043]), total tokens: 261504, total tokens w/o padding: 145594, min/max seq_lens in batch: (2043, 104), padding %: 44.32%
bsz : torch.Size([128, 2044]), total tokens: 261632, total tokens w/o padding: 125540, min/max seq_lens 

bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 141901, min/max seq_lens in batch: (2047, 97), padding %: 45.84%
bsz : torch.Size([128, 2039]), total tokens: 260992, total tokens w/o padding: 132017, min/max seq_lens in batch: (2039, 67), padding %: 49.42%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 130612, min/max seq_lens in batch: (2047, 86), padding %: 50.15%
bsz : torch.Size([128, 2042]), total tokens: 261376, total tokens w/o padding: 147808, min/max seq_lens in batch: (2042, 142), padding %: 43.45%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 133317, min/max seq_lens in batch: (2038, 65), padding %: 48.89%
bsz : torch.Size([128, 2014]), total tokens: 257792, total tokens w/o padding: 134290, min/max seq_lens in batch: (2014, 69), padding %: 47.91%
bsz : torch.Size([128, 2035]), total tokens: 260480, total tokens w/o padding: 136267, min/max seq_lens in batch: (2035, 105), padding 

bsz : torch.Size([128, 2007]), total tokens: 256896, total tokens w/o padding: 133350, min/max seq_lens in batch: (2007, 70), padding %: 48.09%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 137063, min/max seq_lens in batch: (2038, 80), padding %: 47.46%
bsz : torch.Size([128, 2044]), total tokens: 261632, total tokens w/o padding: 133763, min/max seq_lens in batch: (2044, 90), padding %: 48.87%
bsz : torch.Size([128, 2024]), total tokens: 259072, total tokens w/o padding: 125080, min/max seq_lens in batch: (2024, 62), padding %: 51.72%
bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 130741, min/max seq_lens in batch: (2046, 61), padding %: 50.08%
bsz : torch.Size([128, 2026]), total tokens: 259328, total tokens w/o padding: 140336, min/max seq_lens in batch: (2026, 83), padding %: 45.88%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 144679, min/max seq_lens in batch: (2047, 122), padding %

bsz : torch.Size([128, 2010]), total tokens: 257280, total tokens w/o padding: 132040, min/max seq_lens in batch: (2010, 62), padding %: 48.68%
bsz : torch.Size([128, 1983]), total tokens: 253824, total tokens w/o padding: 135084, min/max seq_lens in batch: (1983, 61), padding %: 46.78%
bsz : torch.Size([128, 2005]), total tokens: 256640, total tokens w/o padding: 135643, min/max seq_lens in batch: (2005, 72), padding %: 47.15%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 138784, min/max seq_lens in batch: (2047, 62), padding %: 47.03%
bsz : torch.Size([128, 2028]), total tokens: 259584, total tokens w/o padding: 138488, min/max seq_lens in batch: (2028, 62), padding %: 46.65%
bsz : torch.Size([128, 1993]), total tokens: 255104, total tokens w/o padding: 129963, min/max seq_lens in batch: (1993, 61), padding %: 49.05%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 143684, min/max seq_lens in batch: (2047, 72), padding %:

bsz : torch.Size([128, 2034]), total tokens: 260352, total tokens w/o padding: 125228, min/max seq_lens in batch: (2034, 62), padding %: 51.90%
bsz : torch.Size([128, 2042]), total tokens: 261376, total tokens w/o padding: 125353, min/max seq_lens in batch: (2042, 70), padding %: 52.04%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 124242, min/max seq_lens in batch: (2047, 65), padding %: 52.58%
bsz : torch.Size([128, 2039]), total tokens: 260992, total tokens w/o padding: 150146, min/max seq_lens in batch: (2039, 63), padding %: 42.47%
bsz : torch.Size([128, 2028]), total tokens: 259584, total tokens w/o padding: 130923, min/max seq_lens in batch: (2028, 71), padding %: 49.56%
bsz : torch.Size([128, 2007]), total tokens: 256896, total tokens w/o padding: 128634, min/max seq_lens in batch: (2007, 78), padding %: 49.93%
bsz : torch.Size([128, 2040]), total tokens: 261120, total tokens w/o padding: 138124, min/max seq_lens in batch: (2040, 61), padding %:

bsz : torch.Size([128, 2044]), total tokens: 261632, total tokens w/o padding: 135128, min/max seq_lens in batch: (2044, 78), padding %: 48.35%
bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 144183, min/max seq_lens in batch: (2045, 65), padding %: 44.92%
bsz : torch.Size([128, 2044]), total tokens: 261632, total tokens w/o padding: 138225, min/max seq_lens in batch: (2044, 74), padding %: 47.17%
bsz : torch.Size([128, 2018]), total tokens: 258304, total tokens w/o padding: 135298, min/max seq_lens in batch: (2018, 66), padding %: 47.62%
bsz : torch.Size([128, 2020]), total tokens: 258560, total tokens w/o padding: 135561, min/max seq_lens in batch: (2020, 110), padding %: 47.57%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 133071, min/max seq_lens in batch: (2038, 72), padding %: 48.99%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 149295, min/max seq_lens in batch: (2038, 66), padding %

bsz : torch.Size([128, 2036]), total tokens: 260608, total tokens w/o padding: 130562, min/max seq_lens in batch: (2036, 62), padding %: 49.90%
bsz : torch.Size([128, 2044]), total tokens: 261632, total tokens w/o padding: 134316, min/max seq_lens in batch: (2044, 69), padding %: 48.66%
bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 136073, min/max seq_lens in batch: (2045, 64), padding %: 48.02%
bsz : torch.Size([128, 1999]), total tokens: 255872, total tokens w/o padding: 135581, min/max seq_lens in batch: (1999, 72), padding %: 47.01%
bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 143698, min/max seq_lens in batch: (2045, 67), padding %: 45.10%
bsz : torch.Size([128, 2031]), total tokens: 259968, total tokens w/o padding: 142199, min/max seq_lens in batch: (2031, 60), padding %: 45.30%
bsz : torch.Size([128, 1983]), total tokens: 253824, total tokens w/o padding: 131806, min/max seq_lens in batch: (1983, 61), padding %:

bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 132182, min/max seq_lens in batch: (2045, 106), padding %: 49.50%
bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 144083, min/max seq_lens in batch: (2045, 63), padding %: 44.96%
bsz : torch.Size([128, 2039]), total tokens: 260992, total tokens w/o padding: 133980, min/max seq_lens in batch: (2039, 77), padding %: 48.67%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 140757, min/max seq_lens in batch: (2047, 63), padding %: 46.28%
bsz : torch.Size([128, 1991]), total tokens: 254848, total tokens w/o padding: 133066, min/max seq_lens in batch: (1991, 91), padding %: 47.79%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 127512, min/max seq_lens in batch: (2047, 78), padding %: 51.33%
bsz : torch.Size([128, 2016]), total tokens: 258048, total tokens w/o padding: 138263, min/max seq_lens in batch: (2016, 87), padding %

bsz : torch.Size([128, 2035]), total tokens: 260480, total tokens w/o padding: 146041, min/max seq_lens in batch: (2035, 65), padding %: 43.93%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 126888, min/max seq_lens in batch: (2038, 60), padding %: 51.36%
bsz : torch.Size([128, 2023]), total tokens: 258944, total tokens w/o padding: 137269, min/max seq_lens in batch: (2023, 70), padding %: 46.99%
bsz : torch.Size([128, 2040]), total tokens: 261120, total tokens w/o padding: 128640, min/max seq_lens in batch: (2040, 77), padding %: 50.74%
bsz : torch.Size([128, 2008]), total tokens: 257024, total tokens w/o padding: 130453, min/max seq_lens in batch: (2008, 64), padding %: 49.24%
bsz : torch.Size([128, 2046]), total tokens: 261888, total tokens w/o padding: 140703, min/max seq_lens in batch: (2046, 105), padding %: 46.27%
bsz : torch.Size([128, 2032]), total tokens: 260096, total tokens w/o padding: 142076, min/max seq_lens in batch: (2032, 64), padding %

bsz : torch.Size([128, 2037]), total tokens: 260736, total tokens w/o padding: 127955, min/max seq_lens in batch: (2037, 86), padding %: 50.93%
bsz : torch.Size([128, 2043]), total tokens: 261504, total tokens w/o padding: 130358, min/max seq_lens in batch: (2043, 62), padding %: 50.15%
bsz : torch.Size([128, 2044]), total tokens: 261632, total tokens w/o padding: 140425, min/max seq_lens in batch: (2044, 70), padding %: 46.33%
bsz : torch.Size([128, 2001]), total tokens: 256128, total tokens w/o padding: 133104, min/max seq_lens in batch: (2001, 73), padding %: 48.03%
bsz : torch.Size([128, 2043]), total tokens: 261504, total tokens w/o padding: 126404, min/max seq_lens in batch: (2043, 85), padding %: 51.66%
bsz : torch.Size([128, 2047]), total tokens: 262016, total tokens w/o padding: 141863, min/max seq_lens in batch: (2047, 81), padding %: 45.86%
bsz : torch.Size([128, 2030]), total tokens: 259840, total tokens w/o padding: 132507, min/max seq_lens in batch: (2030, 78), padding %:

bsz : torch.Size([128, 2045]), total tokens: 261760, total tokens w/o padding: 140474, min/max seq_lens in batch: (2045, 77), padding %: 46.33%
bsz : torch.Size([128, 2041]), total tokens: 261248, total tokens w/o padding: 129521, min/max seq_lens in batch: (2041, 81), padding %: 50.42%
bsz : torch.Size([128, 2038]), total tokens: 260864, total tokens w/o padding: 137965, min/max seq_lens in batch: (2038, 71), padding %: 47.11%
bsz : torch.Size([128, 2037]), total tokens: 260736, total tokens w/o padding: 130096, min/max seq_lens in batch: (2037, 82), padding %: 50.10%
bsz : torch.Size([128, 2039]), total tokens: 260992, total tokens w/o padding: 141951, min/max seq_lens in batch: (2039, 110), padding %: 45.61%
bsz : torch.Size([128, 2037]), total tokens: 260736, total tokens w/o padding: 144876, min/max seq_lens in batch: (2037, 62), padding %: 44.44%
bsz : torch.Size([128, 2035]), total tokens: 260480, total tokens w/o padding: 129769, min/max seq_lens in batch: (2035, 73), padding %

## dynamic batching

In [112]:
random_dataset = RandomDataset()
dynamic_batch_sampler = random_dataset.make_dynamic_sampler(max_tokens_per_batch=250000)
data_loader = DataLoader(
    random_dataset, 
    batch_size=1, # be careful of setting batch size larger than 1, it is not compatible with batch sampler
    shuffle=False, # be careful of setting shuffle True, it is not compatible with batch sampler
    sampler=None,
    batch_sampler=dynamic_batch_sampler, 
    num_workers=0, 
    collate_fn=collator,
    pin_memory=False, 
    drop_last=False, # be careful of setting drop_last True, it is not compatible with batch sampler
)


print("total number of iteration: {}".format(len(data_loader)))
for x in iter(data_loader):
    print("bsz : {}, total tokens: {}, total tokens w/o padding: {}, min/max seq_lens in batch: {}, padding %: {:.2f}%".format(
            x[0].size(), 
            x[0].size(0)*x[0].size(1), 
            sum(x[1]),
            (max(x[1]),min(x[1])),
            (x[0].size(0)*x[0].size(1) - sum(x[1]))/(x[0].size(0)*x[0].size(1))*100,
        )
    )

total number of iteration: 424
bsz : torch.Size([500, 500]), total tokens: 250000, total tokens w/o padding: 247473, min/max seq_lens in batch: (500, 490), padding %: 1.01%
bsz : torch.Size([185, 1350]), total tokens: 249750, total tokens w/o padding: 249448, min/max seq_lens in batch: (1350, 1346), padding %: 0.12%
bsz : torch.Size([213, 1171]), total tokens: 249423, total tokens w/o padding: 248969, min/max seq_lens in batch: (1171, 1166), padding %: 0.18%
bsz : torch.Size([165, 1508]), total tokens: 248820, total tokens w/o padding: 248596, min/max seq_lens in batch: (1508, 1505), padding %: 0.09%
bsz : torch.Size([397, 629]), total tokens: 249713, total tokens w/o padding: 247981, min/max seq_lens in batch: (629, 621), padding %: 0.69%
bsz : torch.Size([219, 1140]), total tokens: 249660, total tokens w/o padding: 249289, min/max seq_lens in batch: (1140, 1136), padding %: 0.15%
bsz : torch.Size([146, 1701]), total tokens: 248346, total tokens w/o padding: 248202, min/max seq_lens i

bsz : torch.Size([965, 259]), total tokens: 249935, total tokens w/o padding: 240998, min/max seq_lens in batch: (259, 241), padding %: 3.58%
bsz : torch.Size([197, 1264]), total tokens: 249008, total tokens w/o padding: 248721, min/max seq_lens in batch: (1264, 1261), padding %: 0.12%
bsz : torch.Size([204, 1224]), total tokens: 249696, total tokens w/o padding: 249332, min/max seq_lens in batch: (1224, 1220), padding %: 0.15%
bsz : torch.Size([185, 1346]), total tokens: 249010, total tokens w/o padding: 248703, min/max seq_lens in batch: (1346, 1343), padding %: 0.12%
bsz : torch.Size([156, 1594]), total tokens: 248664, total tokens w/o padding: 248422, min/max seq_lens in batch: (1594, 1591), padding %: 0.10%
bsz : torch.Size([134, 1860]), total tokens: 249240, total tokens w/o padding: 249073, min/max seq_lens in batch: (1860, 1858), padding %: 0.07%
bsz : torch.Size([142, 1756]), total tokens: 249352, total tokens w/o padding: 249197, min/max seq_lens in batch: (1756, 1754), paddi

bsz : torch.Size([129, 1934]), total tokens: 249486, total tokens w/o padding: 249376, min/max seq_lens in batch: (1934, 1932), padding %: 0.04%
bsz : torch.Size([236, 1059]), total tokens: 249924, total tokens w/o padding: 249475, min/max seq_lens in batch: (1059, 1055), padding %: 0.18%
bsz : torch.Size([147, 1693]), total tokens: 248871, total tokens w/o padding: 248583, min/max seq_lens in batch: (1693, 1690), padding %: 0.12%
bsz : torch.Size([224, 1113]), total tokens: 249312, total tokens w/o padding: 248835, min/max seq_lens in batch: (1113, 1109), padding %: 0.19%
bsz : torch.Size([170, 1468]), total tokens: 249560, total tokens w/o padding: 249325, min/max seq_lens in batch: (1468, 1465), padding %: 0.09%
bsz : torch.Size([200, 1245]), total tokens: 249000, total tokens w/o padding: 248621, min/max seq_lens in batch: (1245, 1241), padding %: 0.15%
bsz : torch.Size([133, 1874]), total tokens: 249242, total tokens w/o padding: 249012, min/max seq_lens in batch: (1874, 1871), pa

bsz : torch.Size([170, 1465]), total tokens: 249050, total tokens w/o padding: 248751, min/max seq_lens in batch: (1465, 1462), padding %: 0.12%
bsz : torch.Size([136, 1833]), total tokens: 249288, total tokens w/o padding: 249176, min/max seq_lens in batch: (1833, 1831), padding %: 0.04%
bsz : torch.Size([230, 1083]), total tokens: 249090, total tokens w/o padding: 248570, min/max seq_lens in batch: (1083, 1079), padding %: 0.21%
bsz : torch.Size([136, 1836]), total tokens: 249696, total tokens w/o padding: 249535, min/max seq_lens in batch: (1836, 1833), padding %: 0.06%
bsz : torch.Size([143, 1748]), total tokens: 249964, total tokens w/o padding: 249748, min/max seq_lens in batch: (1748, 1745), padding %: 0.09%
bsz : torch.Size([137, 1818]), total tokens: 249066, total tokens w/o padding: 248813, min/max seq_lens in batch: (1818, 1815), padding %: 0.10%
bsz : torch.Size([282, 886]), total tokens: 249852, total tokens w/o padding: 249098, min/max seq_lens in batch: (886, 880), paddi

bsz : torch.Size([136, 1831]), total tokens: 249016, total tokens w/o padding: 248827, min/max seq_lens in batch: (1831, 1828), padding %: 0.08%
bsz : torch.Size([206, 1211]), total tokens: 249466, total tokens w/o padding: 249139, min/max seq_lens in batch: (1211, 1207), padding %: 0.13%
bsz : torch.Size([132, 1881]), total tokens: 248292, total tokens w/o padding: 248151, min/max seq_lens in batch: (1881, 1879), padding %: 0.06%
bsz : torch.Size([233, 1069]), total tokens: 249077, total tokens w/o padding: 248586, min/max seq_lens in batch: (1069, 1065), padding %: 0.20%
bsz : torch.Size([145, 1722]), total tokens: 249690, total tokens w/o padding: 249421, min/max seq_lens in batch: (1722, 1719), padding %: 0.11%
bsz : torch.Size([141, 1768]), total tokens: 249288, total tokens w/o padding: 249020, min/max seq_lens in batch: (1768, 1765), padding %: 0.11%
bsz : torch.Size([141, 1773]), total tokens: 249993, total tokens w/o padding: 249822, min/max seq_lens in batch: (1773, 1770), pa