# Importing libs

In [None]:
# --- Setup & imports ---
import pickle
import numpy as np
import random
from pathlib import Path

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, models
from PIL import Image

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Load Data

In [None]:
# Repro
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# --- Load data ---
data_dir = Path("../data")

with open(data_dir / "train_data.pkl", "rb") as f:
    train_data = pickle.load(f)

with open(data_dir / "test_data.pkl", "rb") as f:
    test_data = pickle.load(f)

train_images = train_data["images"]          # (1080, 28, 28, 3)
train_labels = train_data["labels"].reshape(-1).astype(int)  # (1080,)
test_images  = test_data["images"]          # (400, 28, 28, 3)

print("Train images:", train_images.shape)
print("Train labels:", train_labels.shape, "unique:", np.unique(train_labels, return_counts=True))
print("Test images:", test_images.shape)


# Stratified Split Train/Validation

In [None]:
def stratified_train_val_split(y, val_ratio=0.2, seed=42):
    np.random.seed(seed)
    classes = np.unique(y)
    train_idxs = []
    val_idxs = []

    for c in classes:
        idxs = np.where(y == c)[0]
        np.random.shuffle(idxs)
        n_val = int(len(idxs) * val_ratio)
        val_idxs.append(idxs[:n_val])
        train_idxs.append(idxs[n_val:])

    train_idxs = np.concatenate(train_idxs)
    val_idxs = np.concatenate(val_idxs)
    np.random.shuffle(train_idxs)
    np.random.shuffle(val_idxs)

    return train_idxs, val_idxs

train_idx, val_idx = stratified_train_val_split(train_labels, val_ratio=0.2)
print("Train idx:", train_idx.shape, "Val idx:", val_idx.shape)


# Data preprocessing

In [None]:
# ImageNet stats for ResNet18
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

class RetinaDataset(Dataset):
    def __init__(self, images, labels=None, transform=None):
        """
        images: np.array (N, H, W, C)
        labels: np.array (N,) or None for test
        """
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]              # (H, W, C), uint8
        img = Image.fromarray(img)          # -> PIL

        if self.transform is not None:
            img = self.transform(img)

        if self.labels is None:
            return img

        label = int(self.labels[idx])
        return img, label

# Build datasets using the same split as before
full_train_dataset = RetinaDataset(train_images, train_labels, transform=train_transform)

val_dataset_images = train_images[val_idx]
val_dataset_labels = train_labels[val_idx]

train_dataset = Subset(full_train_dataset, train_idx)
val_dataset = RetinaDataset(val_dataset_images, val_dataset_labels, transform=val_test_transform)
test_dataset = RetinaDataset(test_images, labels=None, transform=val_test_transform)

batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=batch_size,
                          shuffle=False, num_workers=0, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=batch_size,
                           shuffle=False, num_workers=0, pin_memory=True)

print(len(train_loader), len(val_loader), len(test_loader))

# Resnet18 Model Building

In [None]:
from torch import nn
import torch

def create_resnet18(num_classes=5, pretrained=True, device=device):
    from torchvision import models
    try:
        from torchvision.models import ResNet18_Weights
        weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
        model = models.resnet18(weights=weights)
    except Exception:
        model = models.resnet18(pretrained=pretrained)

    in_features = model.fc.in_features
    # Add dropout before final linear layer
    model.fc = nn.Sequential(
        nn.Dropout(p=0.5),
        nn.Linear(in_features, num_classes)
    )

    # Freeze everything
    for p in model.parameters():
        p.requires_grad = False

    # Unfreeze deeper part + head
    for p in model.layer3.parameters():
        p.requires_grad = True
    for p in model.layer4.parameters():
        p.requires_grad = True
    for p in model.fc.parameters():
        p.requires_grad = True

    return model.to(device)

def train_model(model, train_loader, val_loader,
                num_epochs=30,
                lr_head=1e-3,
                lr_backbone=1e-4,
                weight_decay=1e-4,
                patience=6,
                device=device):

    criterion = nn.CrossEntropyLoss()

    # Split parameters: backbone vs head (fc)
    backbone_params, head_params = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if name.startswith("fc."):
            head_params.append(p)
        else:
            backbone_params.append(p)

    optimizer = torch.optim.Adam(
        [
            {"params": backbone_params, "lr": lr_backbone},
            {"params": head_params, "lr": lr_head},
        ],
        weight_decay=weight_decay,
    )

    history = {
        "epoch": [],
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
    }

    best_val_acc = 0.0
    best_state = None
    best_epoch = 0
    patience_counter = 0

    for epoch in range(num_epochs):
        # ----- Train -----
        model.train()
        running_loss = 0.0
        running_correct = 0
        running_total = 0

        for imgs, labels in train_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            running_correct += (preds == labels).sum().item()
            running_total += labels.size(0)

        train_loss = running_loss / running_total
        train_acc = running_correct / running_total

        # ----- Validation -----
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs = imgs.to(device)
                labels = labels.to(device)

                outputs = model(imgs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * labels.size(0)
                preds = outputs.argmax(dim=1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_loss = val_loss / val_total
        val_acc = val_correct / val_total

        history["epoch"].append(epoch)
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)

        print(
            f"Epoch {epoch:02d} | "
            f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} | "
            f"train_acc={train_acc:.4f} val_acc={val_acc:.4f}"
        )

        # Early stopping on val_acc
        if val_acc > best_val_acc + 1e-4:
            best_val_acc = val_acc
            best_epoch = epoch
            patience_counter = 0
            # store on CPU
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(
                    f"Early stopping at epoch {epoch}, "
                    f"best val_acc={best_val_acc:.4f} (epoch {best_epoch})"
                )
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    print(f"Best val_acc={best_val_acc:.4f} at epoch {best_epoch}")
    return model, history


# ---- Train once with a solid default config ----
model = create_resnet18(num_classes=5, pretrained=True)

model, history = train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=30,
    lr_head=5e-4,
    lr_backbone=5e-5,
    weight_decay=5e-4,
    patience=6,
    device=device,
)

## Plot Curves

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np # Needed for np.arange

hist_df = pd.DataFrame(history)
sns.set_theme(style="whitegrid", rc={"grid.linestyle": "--", "grid.alpha": 0.6})

# --- Loss curves ---
plt.figure(figsize=(10, 4)) # Increased figure size slightly to accommodate more ticks
sns.lineplot(data=hist_df, x="epoch", y="train_loss", label="train", linewidth=1.5)
sns.lineplot(data=hist_df, x="epoch", y="val_loss", label="val", linewidth=1.5)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("ResNet18 – Loss Curves")
plt.legend()

# --- Control the X-axis Ticks ---
plt.xticks(np.arange(0, len(hist_df) + 2, 2)) # Ticks for every epoch

# Rotate ticks for readability if they overlap
plt.xticks(rotation=45) 

plt.tight_layout()
plt.show()

# --- Accuracy curves ---
plt.figure(figsize=(10, 4)) # Increased figure size slightly
sns.lineplot(data=hist_df, x="epoch", y="train_acc", label="train", linewidth=2.5)
sns.lineplot(data=hist_df, x="epoch", y="val_acc", label="val", linewidth=2.5)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("ResNet18 – Accuracy Curves")
plt.legend()

# --- Control the X-axis Ticks ---
plt.xticks(np.arange(0, len(hist_df) + 2, 2)) # Ticks for every epoch
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

# Submission CSV

In [None]:
model.eval()

all_preds = []

with torch.no_grad():
    for imgs in test_loader:
        imgs = imgs.to(device)
        outputs = model(imgs)
        preds = outputs.argmax(dim=1).cpu().numpy()
        all_preds.append(preds)

y_test_pred = np.concatenate(all_preds)
print("Test preds shape:", y_test_pred.shape,
      "unique:", np.unique(y_test_pred, return_counts=True))

ids = np.arange(1, len(y_test_pred) + 1)
sub = pd.DataFrame({"ID": ids, "Label": y_test_pred})
sub.to_csv("m2_resnet18_regdrop_submission.csv", index=False)
sub.head()