In [5]:
import torchvision

from torch.utils.data import DataLoader

# Import of our CategoriesSampler module
from src.datasets.sampler import CategoriesSampler
from src.utils import warp_tqdm

In [6]:
# Load test set
test_set = torchvision.datasets.FashionMNIST(
    root="exemple_sampler_data",
    train=False,
    download=True,
    transform=torchvision.transforms.ToTensor()
)

# Exemple of initializing our custom CategorySampler object and loading batches using Pytorch's DataLoader

In [7]:
### Get CategorySampler object ###
"""
        CategorySampler
        inputs:
            label : All labels of dataset
            n_batch : Number of batches to load
            n_cls : Number of classification ways (n_ways)
            s_shot : Support shot 
            q_shot : Query shot (number of query shot per class we would have in the standard balanced setting (eg. 15))
            balanced : 'balanced': Balanced query class distribution: Standard class balanced Few-Shot setting
                       'dirichlet': Dirichlet's distribution over query data: Realisatic class imbalanced Few-Shot setting
            alpha : Dirichlet's distribution concentration parameter

        returns :
            sampler : CategoriesSampler object that will yield batch when iterated
"""
sampler = CategoriesSampler(label=test_set.targets, n_batch=1, n_cls=5, s_shot=5, q_shot=15, balanced='dirichlet', alpha=2)

# Get test loader
test_loader = DataLoader(test_set, batch_sampler=sampler, num_workers=8, pin_memory=True)

# Iterate over loader to get batch
for i, (data, labels) in enumerate(warp_tqdm(test_loader, False)):
  x = data      # x : torch.tensor [n_support + n_query, channel, H, W]
                #     [support_data, query_data]    
    
  y = labels    # y: torch.tensor [n_support + n_query]
                #    [support_labels, query_labels]
                # Where : 
                #       Support data and labels class sequence is :
                #           [a b c d e a b c d e a b c d e ...]
                #       
                #       Query data and labels class sequence is :
                #           [a a a a a a a a b b b b b b b c c c c c d d d d d e e e e e ...]





  0%|                                                     | 0/1 [00:00<?, ?it/s][A[A[A[A



100%|█████████████████████████████████████████████| 1/1 [00:00<00:00,  5.44it/s][A[A[A[A
