# Transforms & DataLoaders

## üéØ Concept Primer
Image preprocessing: resize, normalize, augment (train), deterministic (val/test).

**Expected:** DataLoaders with shape [B, 3, H, W]

## üìã Objectives
1. Define transforms for train/val/test
2. Create custom Dataset class
3. Setup DataLoaders
4. Verify batch shapes

## üîß Setup

In [2]:
# TODO 1: Import libraries
# import torch
# from torch.utils.data import Dataset, DataLoader
# from torchvision import transforms
# from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os

## üîÑ Define Transforms

### TODO 2: Create transform pipelines

**Train:** Resize, augment (flip, rotate), normalize  
**Val/Test:** Resize, normalize only

In [3]:
# TODO 2: Transforms
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

## üì¶ Custom Dataset

### TODO 3: Create Dataset class

**Expected:** __getitem__ returns (image, label)

In [4]:
# TODO 3: Dataset class
class RetinalDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['Image name']
        img_path = os.path.join(self.img_dir, f"{img_name}.jpg")
        
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = row['Retinopathy grade']
        
        return image, label

## üîÑ DataLoaders

### TODO 4: Create DataLoaders

**Expected:** Batch size=32, shape [B, 3, 224, 224]

In [5]:
# TODO 4: DataLoaders
train_labels_df = pd.read_csv('../../../datasets/diabetic_retinopathy_images/groundtruths/training_labels.csv')
train_images_folder = '../../../datasets/diabetic_retinopathy_images/images/training_images_small'

train_dataset = RetinalDataset(train_labels_df, train_images_folder, train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

for images, labels in train_loader:
    print(f"Batch shape: {images.shape}")
    

Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([32, 3, 224, 224])
Batch shape: torch.Size([29, 3, 224, 224])


## ü§î Reflection
1. Batch shape correct?
2. Augmentation choices?

**Reflection**

1. **Batch shape correct?**
   - Yes! Each batch comes out as `[batch_size, channels, height, width] = [32, 3, 224, 224]`.
   - The final batch is `[29, ...]` because the dataset (413 images) isn‚Äôt divisible by 32 ‚Äî normal behavior.

2. **Augmentation choices?**
   - ‚úÖ `Resize(224,224)`: matches pre-trained models (ResNet/EfficientNet expect 224√ó224).
   - ‚úÖ `RandomHorizontalFlip`: duplicates of a retina across left/right orientation ‚Äî clinically valid, boosts data diversity.
   - üîÅ Later (only on train split) we can add `RandomRotation`, `ColorJitter`, etc., to make the model robust to imaging conditions.
   - ‚úÖ `ToTensor()` and `Normalize(mean,std)`: Converts to tensor and uses ImageNet stats for transfer learning.
   - ‚ö†Ô∏è Validation/test should use **identical transformations minus augmentation** to avoid data leakage. (Already prepared with `val_transform` ‚Äî we‚Äôll use it once splits are in place.)

**Your reflection:**

*Write here*

## üìå Summary
‚úÖ Transforms defined  
‚úÖ DataLoaders ready

**Next:** `03_simple_cnn_scaffold.ipynb`