In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score


import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import ReduceLROnPlateau


# -----------------------------
# DEVICE
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# -----------------------------
# DATA TRANSFORMS
# -----------------------------
train_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomAffine(0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


val_tfms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])


# -----------------------------
# DATA LOADERS
# -----------------------------
DATA_DIR = "/content/data"
train_ds = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=train_tfms)
val_ds   = datasets.ImageFolder(os.path.join(DATA_DIR, "val"), transform=val_tfms)


train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2, pin_memory=True)


class_names = train_ds.classes
print("Classes:", class_names)


# -----------------------------
# STAGE 1: OA vs NO OA
# -----------------------------
binary_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)


# Fine-tune last two layers
for p in binary_model.parameters():
    p.requires_grad = False
for p in list(binary_model.layer3.parameters()) + list(binary_model.layer4.parameters()):
    p.requires_grad = True


binary_model.fc = nn.Linear(binary_model.fc.in_features, 2)
binary_model = binary_model.to(device)


binary_loss = nn.CrossEntropyLoss()
binary_opt = optim.Adam(filter(lambda p: p.requires_grad, binary_model.parameters()), lr=1e-4)
binary_scheduler = ReduceLROnPlateau(binary_opt, mode='min', factor=0.5, patience=5)


# -----------------------------
# STAGE 2: SEVERITY (KL 1-4)
# -----------------------------
severity_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)


for p in severity_model.parameters():
    p.requires_grad = False
for p in list(severity_model.layer3.parameters()) + list(severity_model.layer4.parameters()):
    p.requires_grad = True


severity_model.fc = nn.Linear(severity_model.fc.in_features, 4)
severity_model = severity_model.to(device)


severity_loss = nn.CrossEntropyLoss()
severity_opt = optim.Adam(filter(lambda p: p.requires_grad, severity_model.parameters()), lr=1e-4)
severity_scheduler = ReduceLROnPlateau(severity_opt, mode='min', factor=0.5, patience=5)


# -----------------------------
# TRAINING FUNCTION
# -----------------------------
def train_model(model, loader, criterion, optimizer, scheduler=None, epochs=75, is_binary=True):
    all_acc = []
    for epoch in range(epochs):
        model.train()
        correct, total, running_loss = 0, 0, 0
        for imgs, labels in loader:
            imgs = imgs.to(device)
            if is_binary:
                labels = (labels > 0).long().to(device)
            else:
                mask = labels > 0
                if mask.sum() == 0:
                    continue
                imgs = imgs[mask].to(device)
                labels = (labels[mask] - 1).to(device)


            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()


            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += len(labels)
            running_loss += loss.item() * len(labels)


        epoch_acc = correct / total
        epoch_loss = running_loss / total
        all_acc.append(epoch_acc)
        print(f"Epoch {epoch+1}/{epochs} | Loss: {epoch_loss:.4f} | Accuracy: {epoch_acc:.4f}")


        # Scheduler step (ReduceLROnPlateau expects a metric)
        if scheduler:
            scheduler.step(epoch_loss)


    return all_acc


# -----------------------------
# TRAIN BOTH MODELS
# -----------------------------
print("=== Training Binary Model (OA vs NO OA) ===")
bin_train_acc = train_model(binary_model, train_loader, binary_loss, binary_opt, binary_scheduler, epochs=75, is_binary=True)


print("\n=== Training Severity Model (KL 1-4) ===")
sev_train_acc = train_model(severity_model, train_loader, severity_loss, severity_opt, severity_scheduler, epochs=75, is_binary=False)


# -----------------------------
# EVALUATION
# -----------------------------
binary_model.eval()
severity_model.eval()


y_true, y_pred = [], []


with torch.no_grad():
    for imgs, labels in val_loader:
        imgs = imgs.to(device)
        bin_out = binary_model(imgs)
        bin_pred = bin_out.argmax(1)


        for i in range(len(labels)):
            if bin_pred[i] == 0:
                y_pred.append(0)
            else:
                sev_out = severity_model(imgs[i:i+1])
                sev_pred = sev_out.argmax(1).item() + 1
                y_pred.append(sev_pred)


        y_true.extend(labels.numpy())


acc = accuracy_score(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred)
print("Final Validation Accuracy:", acc)


plt.figure(figsize=(6,6))
plt.imshow(cm, cmap='Blues')
plt.title("Confusion Matrix â€“ Two-Stage OA Model")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.colorbar()
for i in range(5):
    for j in range(5):
        plt.text(j, i, cm[i, j], ha="center", va="center")
plt.show()