Explore some ideas how to write samplers more concise

In [2]:
import random
import numpy as np

In [53]:
from torch.utils.data import Sampler
import torch.distributed as dist

In [4]:
# Define a test case

num_elements = [17, 21, 11]
seq_length = 4
batch_size = 3


data_arrs = [np.zeros(n) for n in num_elements]

In [5]:
# Samplers for single-shot dataset

class BatchedSampler(Sampler):
    r"""Sample linear sequences, allows for batching and shuffling.

    Similar to SequentialSampler, but returns a batch of sequences in each iteration.
    """
    def __init__(self, num_elements, seq_length, batch_size, shuffle=False, seed=0):
        self.num_elements = num_elements  # Length of the dataset
        self.seq_length = seq_length      # Length of the sequences to sample
        self.batch_size = batch_size      # Batch size
        self.shuffle = shuffle            # Shuffle the start of the sequences?
        self.seed = seed                  # Seed for shuffling
        self.epoch = 0                    # Increase this after each epoch to get different shuffling in next iteration
    
    def set_epoch(self, epoch):
        """Update epoch to adjust random seed."""
        self.sepoch = epoch
        
    def __iter__(self):
        """Returns fixed-length, ordered sequences that cover the dataset."""
        idx_permuted = [(ix) for ix in range(self.num_elements - self.seq_length )]
        if self.shuffle:
            random.seed(self.seed + self.epoch)
            random.shuffle(idx_permuted)
                      
        # Slicing the list like this takes care of partial batches
        for start in range(0, len(idx_permuted), self.batch_size):
            yield [range(ix, ix + self.seq_length + 1) for ix in idx_permuted[start:start+self.batch_size]]
            

In [60]:
s0 = BatchedSampler(num_elements[0], seq_length, batch_size, shuffle=False, seed=0)

In [61]:
for s in s0:
    print(s)
    for ix in s:
        data_arrs[0][ix]

[range(0, 5), range(1, 6), range(2, 7)]
[range(3, 8), range(4, 9), range(5, 10)]
[range(6, 11), range(7, 12), range(8, 13)]
[range(9, 14), range(10, 15), range(11, 16)]
[range(12, 17)]


In [8]:
l = list(range(10))
random.shuffle(l)

print(l)
for s in range(0, 10, 3):
    print(l[s:s+3])




[6, 5, 3, 9, 7, 0, 1, 4, 2, 8]
[6, 5, 3]
[9, 7, 0]
[1, 4, 2]
[8]


In [9]:
# This is the basic sequential batched sampler for multi-shot dataset.
# It iterates over each shot and returns sequences of a fixed length, starting at 0, until a shot is exhausted.
# Then it proceeds to the next shot


class SequentialBatchedSampler_multi(Sampler):
    r"""Sample batched, linear sequences from multishot dataset.
    
    Args:
        num_elements (List[Int]): Elements per dataset.
        seq_length (Int) : Length of sequences to sample
        batch_size (Int) : Number of sequences to return per iteration.
    """
    def __init__(self, num_elements, seq_length, batch_size):
        self.num_elements = num_elements
        self.num_shots = len(num_elements)
        self.seq_length = seq_length
        self.batch_size = batch_size

    def __iter__(self):
        """Return a batch of linear sequences.
        
        * Always exhaust one dataset, even if it means that the batch will be smaller than 
          requested batch_size, before continuing on the next shot.
        """
        for s in range(0, self.num_shots):
            for start in range(0, self.num_elements[s] - self.seq_length - 1, self.batch_size):
                yield [(s, range(start + b, start + b + self.seq_length + 1)) for b in range(self.batch_size) if start + b + self.seq_length + 1 <= self.num_elements[s]]
    

In [10]:
s1 = SequentialBatchedSampler_multi(num_elements, seq_length, batch_size)

In [11]:
ctr = 0
for s in s1:
    if ctr < 50:
        print(s)
    ctr += 1
    
print(f"{ctr} samples sampled")

[(0, range(0, 5)), (0, range(1, 6)), (0, range(2, 7))]
[(0, range(3, 8)), (0, range(4, 9)), (0, range(5, 10))]
[(0, range(6, 11)), (0, range(7, 12)), (0, range(8, 13))]
[(0, range(9, 14)), (0, range(10, 15)), (0, range(11, 16))]
[(1, range(0, 5)), (1, range(1, 6)), (1, range(2, 7))]
[(1, range(3, 8)), (1, range(4, 9)), (1, range(5, 10))]
[(1, range(6, 11)), (1, range(7, 12)), (1, range(8, 13))]
[(1, range(9, 14)), (1, range(10, 15)), (1, range(11, 16))]
[(1, range(12, 17)), (1, range(13, 18)), (1, range(14, 19))]
[(1, range(15, 20)), (1, range(16, 21))]
[(2, range(0, 5)), (2, range(1, 6)), (2, range(2, 7))]
[(2, range(3, 8)), (2, range(4, 9)), (2, range(5, 10))]
12 samples sampled


In [12]:
# Now we modify a bit. The goal is to have it shuffle in a deterministic way.
# Take inspiration from https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler

class sampler_v2(Sampler):
    def __init__(self, num_elements, seq_length, batch_size, shuffle=False, seed=0):
        self.num_elements = num_elements
        self.num_shots = len(num_elements)
        self.seq_length = seq_length
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0
        print(f"self.shuffle = {self.shuffle}, self.seed = {self.seed}")
        
    def set_epoch(self, epoch):
        """Sets epoch for this sampler.
        When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.
        
        Args:
            epoch (int) : Epoch number
        """
        self.epoch = epoch
        
    def __iter__(self):
        idx_permuted = [(s, i)  for s in range(self.num_shots) for i in range(self.num_elements[s] - self.seq_length)]
        if self.shuffle:
            print("Shuffling")
            random.seed(self.seed + self.epoch)
            random.shuffle(idx_permuted)

        
        full_batches = len(idx_permuted) // self.batch_size # Number of batches we can fill with the specified batch size
        # Check if the last batch is full or partial
        # We iterate up to num_batches. If in the loop the batch_counter == full_batches, we will have a partial patch
        if len(idx_permuted) != full_batches * self.batch_size:
            remaining_samples = len(idx_permuted) - full_batches * self.batch_size
            partial_batch = True
            num_batches = full_batches + 1
        else: 
            partial_batch = False
            num_batches = full_batches
     
        # Number of batches to draw. Round up.
        #num_batches = self.num_shots * (self.num_elements - self.seq_length) // self.batch_size
        for ix_b in range(0, num_batches):
            # If ix_x is full_batches (remember 0-based indexing and num_batches is excludede in range)
            # we have need to fill a partial batch with the remaining samples.
            if ix_b == full_batches:  
                selected = idx_permuted[-remaining_samples:]
            else:
                # Fill a full batch
                # Select starting points for sequences
                selected = idx_permuted[(ix_b * self.batch_size):((ix_b + 1) * self.batch_size)]
            # Remember to return a list. PyTorch dataloader passes each item in the
            # returned list to dataset.__getidx__. If we only return a single list,
            # each scalar index in that list would be passed to __getidx__.
            # If we return a list of lists, that inner list will be passed to __getidx__.
            # Then this list will be used for indexing.
            # Long story short: pass list of lists, not a single list.
            yield [(s[0], range(s[1], s[1] + self.seq_length + 1)) for s in selected]

In [13]:
s2 = sampler_v2(num_elements, seq_length, batch_size, shuffle=False, seed=1337)

self.shuffle = False, self.seed = 1337


In [14]:
ctr = 0
for s in s2:
    if ctr < 50:
        print(s)
    ctr += 1
    
print(f"{ctr} samples sampled")

[(0, range(0, 5)), (0, range(1, 6)), (0, range(2, 7))]
[(0, range(3, 8)), (0, range(4, 9)), (0, range(5, 10))]
[(0, range(6, 11)), (0, range(7, 12)), (0, range(8, 13))]
[(0, range(9, 14)), (0, range(10, 15)), (0, range(11, 16))]
[(0, range(12, 17)), (1, range(0, 5)), (1, range(1, 6))]
[(1, range(2, 7)), (1, range(3, 8)), (1, range(4, 9))]
[(1, range(5, 10)), (1, range(6, 11)), (1, range(7, 12))]
[(1, range(8, 13)), (1, range(9, 14)), (1, range(10, 15))]
[(1, range(11, 16)), (1, range(12, 17)), (1, range(13, 18))]
[(1, range(14, 19)), (1, range(15, 20)), (1, range(16, 21))]
[(2, range(0, 5)), (2, range(1, 6)), (2, range(2, 7))]
[(2, range(3, 8)), (2, range(4, 9)), (2, range(5, 10))]
[(2, range(6, 11))]
13 samples sampled


In [15]:
s2.set_epoch(1)
ctr = 0
for s in s2:
    if ctr < 5:
        print(s)
    ctr += 1
    
print(f"{ctr} samples sampled")

[(0, range(0, 5)), (0, range(1, 6)), (0, range(2, 7))]
[(0, range(3, 8)), (0, range(4, 9)), (0, range(5, 10))]
[(0, range(6, 11)), (0, range(7, 12)), (0, range(8, 13))]
[(0, range(9, 14)), (0, range(10, 15)), (0, range(11, 16))]
[(0, range(12, 17)), (1, range(0, 5)), (1, range(1, 6))]
13 samples sampled


## Explore index splitting for distributed samplers

In [16]:
# We are working with rank and num_replicas. Rank is the MPI rank. num_replicas is the MPI_WORLD_SIZE
# Try how we can split a list between these

total_size = 10

idx_permuted = list(range(total_size))
random.seed(1337)
random.shuffle(idx_permuted)

print("Entire list: ", idx_permuted)

num_replicas = 3
for rank in range(num_replicas):
    this_rank = idx_permuted[rank:total_size:num_replicas]
    print(f"rank={rank}, {len(this_rank)} elements: ", this_rank)



Entire list:  [0, 3, 7, 2, 1, 6, 4, 5, 8, 9]
rank=0, 4 elements:  [0, 2, 4, 9]
rank=1, 3 elements:  [3, 1, 5]
rank=2, 3 elements:  [7, 6, 8]


In [235]:
idx_permuted[None:total_size:None] == idx_permuted

True

In [57]:
# Samplers for single-shot dataset, using multiple replicas

class sampler_v3(Sampler):
    r"""Sample linear sequences, allows for batching and shuffling.

    Similar to SequentialSampler, but returns a batch of sequences in each iteration.
    """
    def __init__(self, num_elements, seq_length, batch_size, num_replicas=None, rank=None, shuffle=False, seed=0):
        self.num_elements = num_elements  # Length of the dataset
        self.seq_length = seq_length      # Length of the sequences to sample
        self.batch_size = batch_size      # Batch size
        self.shuffle = shuffle            # Shuffle the start of the sequences?
        self.seed = seed                  # Seed for shuffling
        self.epoch = 0                    # Increase this after each epoch to get different shuffling in next iteration
        
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval"
                f" [0, {num_replicas - 1}]")
        self.num_replicas = num_replicas
        self.rank = rank
        
        print(f"Sampler, rank={self.rank}")

            
    def set_epoch(self, epoch):
        """Update epoch to adjust random seed."""
        self.sepoch = epoch
        
    def __iter__(self):
        """Returns fixed-length, ordered sequences that cover the dataset."""
        idx_permuted = [(ix) for ix in range(self.num_elements - self.seq_length )]
        if self.shuffle:
            random.seed(self.seed + self.epoch)
            random.shuffle(idx_permuted)
            
        idx_permuted = idx_permuted[self.rank:self.num_elements:self.num_replicas]
        print(f"{self.rank}, {self.num_elements}, {self.num_replicas}: idx_permuted = {idx_permuted}")
                      
        # Slicing the list like this takes care of partial batches
        for start in range(0, len(idx_permuted), self.batch_size):
            yield [range(ix, ix + self.seq_length + 1) for ix in idx_permuted[start:start+self.batch_size]]

In [66]:
ix = 2
s3 = sampler_v3(num_elements[ix], seq_length, batch_size, num_replicas=1, rank=0, shuffle=True, seed=1337)
s3_0 = sampler_v3(num_elements[ix], seq_length, batch_size, num_replicas=2, rank=0, shuffle=True, seed=1337)
s3_1 = sampler_v3(num_elements[ix], seq_length, batch_size, num_replicas=2, rank=1, shuffle=True, seed=1337)

Sampler, rank=0
Sampler, rank=0
Sampler, rank=1


In [67]:
# Sampler for single rank
s3.set_epoch(1)
ctr = 0
for s in s3:
    if ctr < 5:
        print(s)
    ctr += 1

print(f"{ctr} samples sampled")


for ss in [s3_0, s3_1]:
    print("=================================================")
    ss.set_epoch(1)
    ctr = 0
    for s in ss:
        if ctr < 5:
            print(s)
        ctr += 1

    print(f"{ctr} samples sampled")

0, 11, 1: idx_permuted = [0, 5, 3, 1, 2, 6, 4]
[range(0, 5), range(5, 10), range(3, 8)]
[range(1, 6), range(2, 7), range(6, 11)]
[range(4, 9)]
3 samples sampled
0, 11, 2: idx_permuted = [0, 3, 2, 4]
[range(0, 5), range(3, 8), range(2, 7)]
[range(4, 9)]
2 samples sampled
1, 11, 2: idx_permuted = [5, 1, 6]
[range(5, 10), range(1, 6), range(6, 11)]
1 samples sampled


In [75]:
# Now we modify a bit. The goal is to have it shuffle in a deterministic way.
# Take inspiration from https://pytorch.org/docs/stable/_modules/torch/utils/data/distributed.html#DistributedSampler

class sampler_v4(Sampler):
    def __init__(self, num_elements, seq_length, batch_size, num_replicas=None, rank=None, shuffle=False, seed=0):
        self.num_elements = num_elements
        self.num_shots = len(num_elements)
        self.seq_length = seq_length
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.seed = seed
        self.epoch = 0
        
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval"
                f" [0, {num_replicas - 1}]")
        self.num_replicas = num_replicas
        self.rank = rank
        
        print(f"Sampler, rank={self.rank}")
        
    def set_epoch(self, epoch):
        """Sets epoch for this sampler.
        When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.
        
        Args:
            epoch (int) : Epoch number
        """
        self.epoch = epoch
        
    def __iter__(self):
        idx_permuted = [(s, i)  for s in range(self.num_shots) for i in range(self.num_elements[s] - self.seq_length)]
        if self.shuffle:
            print("Shuffling")
            random.seed(self.seed + self.epoch)
            random.shuffle(idx_permuted)
        
        # Sub-sample for replicas.
        ll = len(idx_permuted)
        idx_permuted = idx_permuted[self.rank:ll:self.num_replicas]
        print(f"{self.rank}, {len(idx_permuted)}, {self.num_replicas}: idx_permuted = {idx_permuted}")
    
        
        full_batches = len(idx_permuted) // self.batch_size # Number of batches we can fill with the specified batch size
        # Check if the last batch is full or partial
        # We iterate up to num_batches. If in the loop the batch_counter == full_batches, we will have a partial patch
        if len(idx_permuted) != full_batches * self.batch_size:
            remaining_samples = len(idx_permuted) - full_batches * self.batch_size
            partial_batch = True
            num_batches = full_batches + 1
        else: 
            partial_batch = False
            num_batches = full_batches
     
        # Number of batches to draw. Round up.
        #num_batches = self.num_shots * (self.num_elements - self.seq_length) // self.batch_size
        for ix_b in range(0, num_batches):
            # If ix_x is full_batches (remember 0-based indexing and num_batches is excludede in range)
            # we have need to fill a partial batch with the remaining samples.
            if ix_b == full_batches:  
                selected = idx_permuted[-remaining_samples:]
            else:
                # Fill a full batch
                # Select starting points for sequences
                selected = idx_permuted[(ix_b * self.batch_size):((ix_b + 1) * self.batch_size)]
            # Remember to return a list. PyTorch dataloader passes each item in the
            # returned list to dataset.__getidx__. If we only return a single list,
            # each scalar index in that list would be passed to __getidx__.
            # If we return a list of lists, that inner list will be passed to __getidx__.
            # Then this list will be used for indexing.
            # Long story short: pass list of lists, not a single list.
            yield [(s[0], range(s[1], s[1] + self.seq_length + 1)) for s in selected]

Sampler, rank=0


In [77]:
ctr = 0
for s in s4:
    if ctr < 50:
        print(s)
    ctr += 1
    
print(f"{ctr} samples sampled")

0, 37, 1: idx_permuted = [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10), (0, 11), (0, 12), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (1, 12), (1, 13), (1, 14), (1, 15), (1, 16), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6)]
[(0, range(0, 5)), (0, range(1, 6)), (0, range(2, 7))]
[(0, range(3, 8)), (0, range(4, 9)), (0, range(5, 10))]
[(0, range(6, 11)), (0, range(7, 12)), (0, range(8, 13))]
[(0, range(9, 14)), (0, range(10, 15)), (0, range(11, 16))]
[(0, range(12, 17)), (1, range(0, 5)), (1, range(1, 6))]
[(1, range(2, 7)), (1, range(3, 8)), (1, range(4, 9))]
[(1, range(5, 10)), (1, range(6, 11)), (1, range(7, 12))]
[(1, range(8, 13)), (1, range(9, 14)), (1, range(10, 15))]
[(1, range(11, 16)), (1, range(12, 17)), (1, range(13, 18))]
[(1, range(14, 19)), (1, range(15, 20)), (1, range(16, 21))]
[(2, range(0, 5)), (2, range(1, 6)), (2, range(2, 7))]
[(2, range(3, 8)), (2, range(4, 9

In [84]:
# In single tasking (num_replicas=1), the dist iterator should be the same as the single-process iterator
s4 = sampler_v4(num_elements, seq_length, batch_size, num_replicas=1, rank=0, shuffle=False, seed=1337)
s2 = sampler_v2(num_elements, seq_length, batch_size, shuffle=False, seed=1337)
for sa, sb in zip(s2, s4):
    assert(sa == sb)

Sampler, rank=0
self.shuffle = False, self.seed = 1337
0, 37, 1: idx_permuted = [(0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10), (0, 11), (0, 12), (1, 0), (1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (1, 7), (1, 8), (1, 9), (1, 10), (1, 11), (1, 12), (1, 13), (1, 14), (1, 15), (1, 16), (2, 0), (2, 1), (2, 2), (2, 3), (2, 4), (2, 5), (2, 6)]


In [87]:
s4_0 = sampler_v4(num_elements, seq_length, batch_size, num_replicas=2, rank=0, shuffle=False, seed=1337)
s4_1 = sampler_v4(num_elements, seq_length, batch_size, num_replicas=2, rank=1, shuffle=False, seed=1337)

Sampler, rank=0
Sampler, rank=1


In [88]:
for ss in [s4_0, s4_1]:
    print("=================================================")
    ss.set_epoch(1)
    ctr = 0
    for s in ss:
        if ctr < 5:
            print(s)
        ctr += 1

    print(f"{ctr} samples sampled")

0, 19, 2: idx_permuted = [(0, 0), (0, 2), (0, 4), (0, 6), (0, 8), (0, 10), (0, 12), (1, 1), (1, 3), (1, 5), (1, 7), (1, 9), (1, 11), (1, 13), (1, 15), (2, 0), (2, 2), (2, 4), (2, 6)]
[(0, range(0, 5)), (0, range(2, 7)), (0, range(4, 9))]
[(0, range(6, 11)), (0, range(8, 13)), (0, range(10, 15))]
[(0, range(12, 17)), (1, range(1, 6)), (1, range(3, 8))]
[(1, range(5, 10)), (1, range(7, 12)), (1, range(9, 14))]
[(1, range(11, 16)), (1, range(13, 18)), (1, range(15, 20))]
7 samples sampled
1, 18, 2: idx_permuted = [(0, 1), (0, 3), (0, 5), (0, 7), (0, 9), (0, 11), (1, 0), (1, 2), (1, 4), (1, 6), (1, 8), (1, 10), (1, 12), (1, 14), (1, 16), (2, 1), (2, 3), (2, 5)]
[(0, range(1, 6)), (0, range(3, 8)), (0, range(5, 10))]
[(0, range(7, 12)), (0, range(9, 14)), (0, range(11, 16))]
[(1, range(0, 5)), (1, range(2, 7)), (1, range(4, 9))]
[(1, range(6, 11)), (1, range(8, 13)), (1, range(10, 15))]
[(1, range(12, 17)), (1, range(14, 19)), (1, range(16, 21))]
6 samples sampled
