In [1]:
from datasets import load_dataset
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms


ds = load_dataset("slegroux/tiny-imagenet-200-clean")


train_dataset = ds["train"]
val_dataset = ds["validation"]

# Define transforms for training (with data augmentation)
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Define transforms for validation (without data augmentation)
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def apply_train_transforms(examples):
    # Convert PIL image to tensor after applying transforms
    examples["image"] = [train_transforms(image) for image in examples["image"]]
    return examples

def apply_val_transforms(examples):
    examples["image"] = [val_transforms(image) for image in examples["image"]]
    return examples

# The transform is applied when an item is indexed
train_dataset.set_transform(apply_train_transforms)
val_dataset.set_transform(apply_val_transforms)


In [28]:
BATCH_SIZE = 64

# Create DataLoaders
# The collate_fn is needed because the 'set_transform' returns a list of tensors for the 'image' field
def collate_fn(batch):
    # Stack images and labels into single tensors
    images = torch.stack([item['image'] for item in batch])
    labels = torch.tensor([item['label'] for item in batch])
    return {images, labels}

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    pin_memory=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    pin_memory=True
)

# Example usage: Iterate through the DataLoader
for images, labels in train_dataloader:
    print(f"Batch images shape: {images.shape}")
    print(f"Batch labels shape: {labels.shape}")
    break


Batch images shape: torch.Size([64, 3, 224, 224])
Batch labels shape: torch.Size([64])


In [None]:
train_dataloader.dataset['label']

98179

In [29]:
from torch.utils.data import random_split, TensorDataset, DataLoader, Subset
fim_size = 8000
num_classes = 200
seed=42
assert fim_size % num_classes == 0, \
    f"fim_size ({fim_size}) must be divisible by num_classes ({num_classes})"

per_class = fim_size // num_classes

targets = torch.tensor(train_dataloader.dataset['label'])  # shape: [50000]
g = torch.Generator().manual_seed(seed)
    
indices_per_class = []
for c in range(num_classes):
    class_idx = torch.nonzero(targets == c).view(-1)  # indices of samples of class c
        # shuffle indices for this class
    perm = class_idx[torch.randperm(len(class_idx), generator=g)]
        # take per_class samples
    indices_per_class.append(perm[:per_class])

    # concatenate all class indices and shuffle globally
balanced_indices = torch.cat(indices_per_class)
balanced_indices = balanced_indices[torch.randperm(len(balanced_indices), generator=g)]

fim_subset = Subset(train_dataloader.dataset, balanced_indices.tolist())

fim_loader = DataLoader(
fim_subset,
batch_size=1,
shuffle=True
)

KeyboardInterrupt: 

In [30]:
import torch
from torch.utils.data import Subset, DataLoader

def balanced_subset_from_labels(dataset, labels, num_classes, fim_size, seed=42):
    assert fim_size % num_classes == 0
    per_class = fim_size // num_classes

    targets = torch.as_tensor(labels, dtype=torch.long)
    g = torch.Generator().manual_seed(seed)

    # 1) group indices by class with ONE sort
    order = torch.argsort(targets)          # indices that would sort by label
    sorted_targets = targets[order]

    # 2) find class slice boundaries
    counts = torch.bincount(targets, minlength=num_classes)
    if (counts < per_class).any():
        bad = torch.nonzero(counts < per_class).view(-1).tolist()
        raise ValueError(f"Not enough samples for classes: {bad}")

    starts = torch.cumsum(counts, dim=0) - counts  # start offset for each class in `order`

    # 3) sample per_class indices from each class slice (fast: slices are small)
    picked = []
    for c in range(num_classes):
        start = starts[c].item()
        cnt = counts[c].item()
        class_indices = order[start:start + cnt]   # indices in original dataset for class c

        perm = torch.randperm(cnt, generator=g)[:per_class]
        picked.append(class_indices[perm])

    balanced_indices = torch.cat(picked)
    balanced_indices = balanced_indices[torch.randperm(len(balanced_indices), generator=g)]

    return Subset(dataset, balanced_indices.tolist())


# --- usage (matching your variables) ---
fim_size = 8000
num_classes = 200
seed = 42

labels = train_dataloader.dataset['label']  # or dataset["label"]
fim_subset = balanced_subset_from_labels(train_dataloader.dataset, labels, num_classes, fim_size, seed)

fim_loader = DataLoader(fim_subset, batch_size=1, shuffle=True, collate_fn=collate_fn)


KeyboardInterrupt: 

In [21]:
fim_subset[0]

{'image': tensor([[[-2.1179, -2.1179, -2.1179,  ..., -1.5357, -1.4843, -1.4843],
          [-2.1179, -2.1179, -2.1179,  ..., -1.5357, -1.4843, -1.4843],
          [-2.1179, -2.1179, -2.1179,  ..., -1.4843, -1.4672, -1.4672],
          ...,
          [-2.1008, -2.1008, -2.1008,  ..., -1.5185, -1.5185, -1.5185],
          [-2.1008, -2.1008, -2.1008,  ..., -1.2445, -1.2445, -1.2445],
          [-2.1008, -2.1008, -2.1008,  ..., -1.2445, -1.2445, -1.2445]],
 
         [[-2.0357, -2.0357, -2.0357,  ..., -0.8978, -0.8803, -0.8803],
          [-2.0357, -2.0357, -2.0357,  ..., -0.8978, -0.8803, -0.8803],
          [-2.0357, -2.0357, -2.0357,  ..., -0.8277, -0.8277, -0.8277],
          ...,
          [-1.7731, -1.7731, -1.7731,  ..., -1.1779, -1.1779, -1.1779],
          [-1.8256, -1.8256, -1.8256,  ..., -0.8978, -0.8978, -0.8978],
          [-1.8256, -1.8256, -1.8256,  ..., -0.8978, -0.8978, -0.8978]],
 
         [[-1.8044, -1.8044, -1.8044,  ..., -1.3687, -1.3513, -1.3513],
          [-1.8044,

In [None]:
for x in train_dataloader:
    print(x['pixel_values'].shape, x['labels'].shape)
    break

NameError: name 'train_loader' is not defined

In [26]:
for x, y in train_dataloader:
    print(x.shape, y.shape)
    break

AttributeError: 'str' object has no attribute 'shape'

In [12]:
y

'label'