In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path

In [2]:
IMG_SIZE = 224

train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(brightness=0.2,
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

val_test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

In [None]:
root_dir   = Path("dataset")
batch_size = 64
num_workers = 4

# Datasets
train_ds = datasets.ImageFolder(root_dir / "train_extracted",
                                transform=train_transforms)
val_ds   = datasets.ImageFolder(root_dir / "val_extracted",
                                transform=val_test_transforms)
# test_ds  = datasets.ImageFolder(root_dir / "test_extracted",
#                                 transform=val_test_transforms)

# DataLoaders
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                      num_workers=num_workers, pin_memory=True, drop_last=True)
val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                      num_workers=num_workers, pin_memory=True)
# test_dl  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
#                       num_workers=num_workers, pin_memory=True)

# Quick sanity check
idx_to_class = {v: k for k, v in train_ds.class_to_idx.items()}
print(f"{len(idx_to_class)} classes detected:", idx_to_class)

imgs, labels = next(iter(train_dl))
print("Batch tensor shape:", imgs.shape)
print("Labels shape:", labels.shape)

📚  20 classes detected: {0: '00175_Animalia_Arthropoda_Insecta_Blattodea_Blaberidae_Aptera_fusca', 1: '00176_Animalia_Arthropoda_Insecta_Blattodea_Blaberidae_Panchlora_nivea', 2: '00177_Animalia_Arthropoda_Insecta_Blattodea_Blaberidae_Pycnoscelus_surinamensis', 3: '00178_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Blatta_orientalis', 4: '00179_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Periplaneta_americana', 5: '00180_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Periplaneta_australasiae', 6: '00181_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Periplaneta_fuliginosa', 7: '00182_Animalia_Arthropoda_Insecta_Blattodea_Ectobiidae_Pseudomops_septentrionalis', 8: '00443_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Aedes_aegypti', 9: '00444_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Aedes_albopictus', 10: '00445_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Aedes_vexans', 11: '00446_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Culex_quinquefasciatus', 12: '00447_Animalia_A