Basically a notebook to test code interactively before writing it in full detail in `src/`.

In [52]:
import numpy as np
import torch
import torchvision
import random
from torch.utils.data import Dataset, DataLoader, Sampler, BatchSampler


random.seed(226)

### Classwise Uniform Sampler

The goal is to write the sampling strategy of randomly pick among the K classes, and randomly pick P examples from these K classes in Pytorch.

In [53]:
class MyDataset(Dataset):
    def __init__(self, xs,ys):
        self.xs = xs
        self.ys = ys
        
        assert len(self.xs) == len(self.ys)
        
    def __len__(self):
        return len(self.xs)
    
    def __getitem__(self, i):
        return self.xs[i], self.ys[i]

In [54]:
mydataset = MyDataset(list(range(101,126)), [0,0,0,0,0, 1,1,1,1,1,2,2,2,2,2,3,3,3, 3,3, 4,4, 4, 4, 4])

In [62]:
start_indices = {x : 5 * x for x in range(5)} #index of first entry of that class
num_examples = {x : 5 for x in range(5)} #number of examples for each label

num_classes = 5

We want to handle a K * P batch. What about test time? Train/validation ? Does validation have the sample sample strategy? It should.

In [56]:
K = 3
P = 2

In [57]:
random.choices(list(range(5)), k=6)

[2, 4, 2, 2, 0, 1]

In [58]:
from collections import defaultdict
from loguru import logger

log = logger.info

In [69]:
class ClassUniformBatchSampler(Sampler):
    """Returns the classes in order, but with the entries of each class shuffled."""
    def __init__(self, dataset, P, K, num_classes=num_classes, 
                 start_indices=start_indices, num_examples=num_examples):
        """Returns batches of size P x K where P is the number of classes per batch and K is the number of samples per class."""
        self.num_classes = num_classes
        self.start_indices = start_indices
        self.num_examples = num_examples 
        self.dataset = dataset
        self.P = P
        self.K = K
        #log(f"P = {P}, K = {K}, len(dataset) = {len(dataset)}")
        
        
    def __iter__(self):
        """Provides a valid permutation of the indices for the dataset."""
        
        #Shuffle indices of each class in place.
        classwise_shuffled_indices = []
        k_at_time = defaultdict(list) #for each class we want the value to be [[k samples], [k samples], ...]
        #log(f"Before for loop in __iter__, num_classes = {num_classes}")
        for i in range(num_classes):
            class_batches = []
            start_idx = self.start_indices[i]
            stop_idx = start_idx + self.num_examples[i]
            #log(f"For class {i}, start_idx = {start_idx} and stop_idx = {stop_idx}")
            #want a random shuffle of the dataset from start_idx -> start_idx + num_examples
            shuffled_examples = random.sample(range(start_idx, stop_idx), stop_idx - start_idx)
            #log(f"For class {i}, shuffled_examples = {shuffled_examples}")
            split_batches = list(torch.split(torch.tensor(shuffled_examples), self.K))
            #log(f"For class {i}, split_batches start = {split_batches}")
            #last_batch = n % self.K
            
            #we want last batch to also have size K using random.choices
            split_batches[-1] = torch.tensor(random.choices(shuffled_examples[-(len(shuffled_examples) % self.K) :], k=self.K))
             #we want to pop() elements so the last batch should be at index 0
            #log(f"For class {i}, split_batches after resizing last element = {split_batches}")
            split_batches = split_batches[::-1]
            k_at_time[i] = [x.tolist() for x in split_batches] #each is a list of k indices.
            #log(f"For class {i}, k_at_time[i] = {k_at_time[i]}")
            classwise_shuffled_indices += shuffled_examples
            
        log(f"k_at_a_time = {k_at_time}")
        
        #We want to extract P classes at a time randomly and pop the k_at_a_time[idx] for each
        #If any class becomes empty, we want to remove it from our set of alive classes.
        #When we have fewer than K classes remaining, we just return what remains.
        
        alive_classes = set(range(num_classes))
        
        while len(alive_classes) >= self.P:
            #sample P classes
            #log(f"In while loop, alive_classes = {alive_classes}")
            selected_indices = random.sample(list(alive_classes), k=self.P)    
            x_batch  = []
            
            
            for idx in selected_indices:
                class_batch = k_at_time[idx].pop()
                if not k_at_time[idx]:
                    alive_classes.remove(idx)
                x_batch += class_batch
                #y_batch += [idx] * len(class_batch)
                
            
            random.shuffle(x_batch)
            #log(f"In while loop, x_batch = {x_batch}")
            yield x_batch
        
        #if there are fewer than P classes remaining, we will just cycle through each class and add a set of K elements in order
        #until we hit PxK elements. If we are done before we hit PxK elements, we ignore the batch. 
        log(f"alive_classes = {alive_classes}")
        curr_batch = []
        num_items_in_batch = 0
        while len(alive_classes) >= 2: #if there is only one alive class, no point having a triplet loss
            for idx in list(alive_classes):
                class_batch = k_at_time[idx].pop()
                if not k_at_time[idx]:
                    alive_classes.remove(idx)
                
                curr_batch += class_batch
                num_items_in_batch += self.K
            
                if num_items_in_batch == self.P * self.K:
                    random.shuffle(curr_batch)
                    yield curr_batch
                    curr_batch = []
                    num_items_in_batch = 0
                    
        #there are fewer than PxK elements remaining    
        
        return
    
    def __len__(self):
        return len(self.dataset)
    

In [70]:
mysampler = ClassUniformBatchSampler(dataset=mydataset, P=2, K=2) #each batch has 4 elements from 2 classes

In [71]:
i = 0
for i, indices in enumerate(mysampler):
    print(f"In batch {i}: \n indices = {indices} \n elements = {[mydataset[i] for i in indices]} \n\n")
    
print(f" i = {i}")

2022-11-18 13:46:35.394 | INFO     | __main__:__iter__:43 - k_at_a_time = defaultdict(<class 'list'>, {0: [[0, 0], [3, 2], [1, 4]], 1: [[8, 8], [7, 9], [6, 5]], 2: [[13, 13], [11, 10], [14, 12]], 3: [[17, 17], [18, 15], [16, 19]], 4: [[21, 21], [22, 24], [23, 20]]})
2022-11-18 13:46:35.394 | INFO     | __main__:__iter__:72 - alive_classes = {1}


In batch 0: 
 indices = [4, 1, 12, 14] 
 elements = [(105, 0), (102, 0), (113, 2), (115, 2)] 


In batch 1: 
 indices = [11, 23, 20, 10] 
 elements = [(112, 2), (124, 4), (121, 4), (111, 2)] 


In batch 2: 
 indices = [19, 16, 6, 5] 
 elements = [(120, 3), (117, 3), (107, 1), (106, 1)] 


In batch 3: 
 indices = [13, 3, 13, 2] 
 elements = [(114, 2), (104, 0), (114, 2), (103, 0)] 


In batch 4: 
 indices = [22, 0, 0, 24] 
 elements = [(123, 4), (101, 0), (101, 0), (125, 4)] 


In batch 5: 
 indices = [15, 21, 21, 18] 
 elements = [(116, 3), (122, 4), (122, 4), (119, 3)] 


In batch 6: 
 indices = [9, 7, 17, 17] 
 elements = [(110, 1), (108, 1), (118, 3), (118, 3)] 


 i = 6


Testing the sampling strategy for imbalanced classes.

In [72]:
newdataset = MyDataset(list(range(101,126)), [0,0,0,0,0,0, 1,1,1,1,1,1,2,2,2,2,3,3,3, 3,3,3, 4,4, 4, ])
new_start_indices = {0 : 0, 1: 6, 2: 12, 3:16, 4:22} #index of first entry of that class
new_num_examples = {0:6, 1:6, 2:4, 3:6, 4:3} #number of examples for each label

num_classes = 5
newsampler = ClassUniformBatchSampler(dataset=newdataset, P=2, K=2,
                                      num_examples=new_num_examples,
                                      start_indices=new_start_indices) #each batch has 4 elements from 2 classes

In [73]:
i = 0
for i, indices in enumerate(newsampler):
    print(f"In batch {i}: \n indices = {indices} \n elements = {[newdataset[i] for i in indices]} \n\n")
    
print(f" i = {i}")

2022-11-18 13:46:35.462 | INFO     | __main__:__iter__:43 - k_at_a_time = defaultdict(<class 'list'>, {0: [[4, 3], [4, 2], [1, 5]], 1: [[8, 6], [9, 7], [11, 6]], 2: [[12, 15], [14, 12]], 3: [[21, 16], [21, 20], [18, 17]], 4: [[22, 22], [24, 23]]})
2022-11-18 13:46:35.462 | INFO     | __main__:__iter__:72 - alive_classes = {0}


In batch 0: 
 indices = [6, 17, 18, 11] 
 elements = [(107, 1), (118, 3), (119, 3), (112, 1)] 


In batch 1: 
 indices = [21, 12, 14, 20] 
 elements = [(122, 3), (113, 2), (115, 2), (121, 3)] 


In batch 2: 
 indices = [21, 7, 9, 16] 
 elements = [(122, 3), (108, 1), (110, 1), (117, 3)] 


In batch 3: 
 indices = [15, 24, 23, 12] 
 elements = [(116, 2), (125, 4), (124, 4), (113, 2)] 


In batch 4: 
 indices = [6, 1, 8, 5] 
 elements = [(107, 1), (102, 0), (109, 1), (106, 0)] 


In batch 5: 
 indices = [22, 22, 2, 4] 
 elements = [(123, 4), (123, 4), (103, 0), (105, 0)] 


 i = 5
