In [None]:
# data_loader.ipynb

# --- 1. Import Packages ---
import pandas as pd
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

from dataset_class import MessidorOpenCVDataset
from preprocess_class import OpenCV_DR_Preprocessor
from transforms import light_transform, heavy_transform, test_transform

# --- 2. Initialize Preprocessor ---
preprocessor = OpenCV_DR_Preprocessor(apply_clahe=True)

# --- 3. Load Full Dataset (to get dataframe) ---
full_dataset = MessidorOpenCVDataset(
    root_dir='/Users/abohane/Desktop/THEIA Datasets/MESSIDOR',
    preprocessor=preprocessor,
    light_transform=None,    # No transform yet
    heavy_transform=None,
    minority_classes=[3]
)

# --- 4. Train/Test Split ---
df = full_dataset.data  # Get full annotations

train_df, test_df = train_test_split(
    df,
    test_size=0.2,
    stratify=df['Retinopathy grade'],
    random_state=42
)

print(f"Training set size: {len(train_df)}")
print(f"Test set size: {len(test_df)}")

# --- 5. Create Train and Test Datasets Separately ---
train_dataset = MessidorOpenCVDataset(
    root_dir='/Users/abohane/Desktop/THEIA Datasets/MESSIDOR',
    preprocessor=preprocessor,
    light_transform=light_transform,
    heavy_transform=heavy_transform,
    minority_classes=[3]
)
train_dataset.data = train_df.reset_index(drop=True)  # Assign only train data

test_dataset = MessidorOpenCVDataset(
    root_dir='/Users/abohane/Desktop/THEIA Datasets/MESSIDOR',
    preprocessor=preprocessor,
    light_transform=test_transform,   # <- **THIS IS WHERE TEST_TRANSFORM IS USED**
    heavy_transform=test_transform,   # <- no heavy aug, just deterministic resizing
    minority_classes=[3]
)
test_dataset.data = test_df.reset_index(drop=True)  # Assign only test data

# --- 6. Create DataLoaders ---
train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [None]:
import torch

def unnormalize(img, mean, std):
    """Undo normalization to visualize an image."""
    mean = torch.tensor(mean).view(3, 1, 1)
    std = torch.tensor(std).view(3, 1, 1)
    img = img * std + mean
    return img

def plot_batch_with_labels(imgs, labels, mean, std, class_names=None):
    """Show batch of images with color-coded labels."""
    unnormalized_imgs = torch.stack([unnormalize(img, mean, std) for img in imgs])

    n_imgs = unnormalized_imgs.shape[0]
    n_cols = 4
    n_rows = (n_imgs + n_cols - 1) // n_cols

    plt.figure(figsize=(4 * n_cols, 4 * n_rows))

    for idx in range(n_imgs):
        img = unnormalized_imgs[idx]
        label = labels[idx].item()

        plt.subplot(n_rows, n_cols, idx + 1)
        plt.imshow(img.permute(1, 2, 0))
        plt.axis('off')

        # Color coding: healthy = green, diseased = red
        if label == 0:
            color = 'green'
        else:
            color = 'red'

        label_name = f"Class {label}" if class_names is None else class_names[label]

        plt.title(label_name, color=color, fontsize=12)

    plt.tight_layout()
    plt.show()

imgs, labels = next(iter(train_loader))
# Unnormalize before showing
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

class_names = {
    0: "No DR",
    1: "Mild DR",
    2: "Moderate DR",
    3: "Severe DR"
}


plot_batch_with_labels(imgs, labels, mean, std, class_names=class_names)

# unnormalized_imgs = torch.stack([unnormalize(img, mean, std) for img in imgs])

# grid_img = torchvision.utils.make_grid(unnormalized_imgs, nrow=4)
# plt.figure(figsize=(10, 10))
# plt.imshow(grid_img.permute(1, 2, 0))
# plt.title(f"Train Batch Labels: {labels.tolist()}")
# plt.axis('off')
# plt.show()