In [1]:
# ============================================================
# 05_Final_CNN_Model.ipynb  (exported as .py style)
# Final multiclass CNN for Derm7pt — ResNet50 + Aug-D + Focal Loss
# ============================================================

# ============================
# Block 1 — Imports & Config
# ============================
import os
import random
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    classification_report,
    confusion_matrix,
    balanced_accuracy_score
)

import matplotlib.pyplot as plt
import seaborn as sns

import torchvision.transforms as T
import torchvision.models as models

# -------- Paths (adapt if needed) --------
ROOT_DIR    = r"C:\Users\anama\Documents\Group_8"
DATASET_DIR = os.path.join(ROOT_DIR, "Dataset", "DERM7PT")
META_CSV    = os.path.join(DATASET_DIR, "meta", "meta.csv")
IMAGES_DIR  = os.path.join(DATASET_DIR, "images")

print("META_CSV:", META_CSV)
print("IMAGES_DIR:", IMAGES_DIR)

# -------- Device + Seeds --------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# ============================
# Block 2 — Load & Clean Metadata
# ============================
df = pd.read_csv(META_CSV)
print("Loaded meta.csv — shape:", df.shape)

# Drop clearly useless columns if present
df = df.drop(columns=["case_num", "case_id", "notes"], errors="ignore")

# Build full image path (assuming 'derm' column holds filenames)
if "derm" not in df.columns:
    raise ValueError("Expected a 'derm' column in meta.csv for image filenames.")

df["derm_fullpath"] = df["derm"].apply(
    lambda x: os.path.join(IMAGES_DIR, str(x))
)

df["derm_exists"] = df["derm_fullpath"].apply(os.path.exists)
missing = df[~df["derm_exists"]]

print("Images missing on disk:", len(missing))
if len(missing) > 0:
    print("Sample missing rows:")
    print(missing[["derm", "diagnosis"]].head())

# Keep only rows with existing images and non-null diagnosis
df = df[df["derm_exists"]].copy()
df = df[~df["diagnosis"].isna()].copy()
df.reset_index(drop=True, inplace=True)

print("After cleaning — shape:", df.shape)
print("Sample rows:")
print(df[["diagnosis", "derm"]].head())

# ============================
# Block 3 — Encode Labels
# ============================
le = LabelEncoder()
df["label"] = le.fit_transform(df["diagnosis"])
class_names = list(le.classes_)
NUM_CLASSES = len(class_names)

print("\nLabel encoding:")
for i, cls in enumerate(class_names):
    print(f"{i:2d} -> {cls}")
print("NUM_CLASSES:", NUM_CLASSES)

# ============================
# Block 4 — Stratified Train/Val/Test Split (70/15/15)
# ============================
train_df, temp_df = train_test_split(
    df,
    test_size=0.30,
    stratify=df["label"],
    random_state=SEED
)

val_df, test_df = train_test_split(
    temp_df,
    test_size=0.50,
    stratify=temp_df["label"],
    random_state=SEED
)

print("\n===== SPLIT SIZES =====")
print(f"Train: {len(train_df)}")
print(f"Val:   {len(val_df)}")
print(f"Test:  {len(test_df)}")

print("\nClass distribution (train):")
print(train_df["diagnosis"].value_counts())

# ============================
# Block 5 — Augmentation (Aug-D) & Datasets
# ============================
IMG_SIZE = 256  # slightly larger than 224 to capture more detail

# Strong but realistic dermoscopy augmentation
train_transform = T.Compose([
    T.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0), ratio=(0.75, 1.33)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomVerticalFlip(p=0.5),
    T.RandomRotation(degrees=15),
    T.ColorJitter(
        brightness=0.2,
        contrast=0.2,
        saturation=0.2,
        hue=0.02
    ),
    T.GaussianBlur(kernel_size=3, sigma=(0.1, 1.0)),
    T.ToTensor(),
    T.RandomErasing(p=0.3),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

valid_transform = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

class DermDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row["derm_fullpath"]
        label = int(row["label"])

        # Load image
        img = Image.open(img_path).convert("RGB")

        if self.transform is not None:
            img = self.transform(img)

        return img, label

train_dataset = DermDataset(train_df, transform=train_transform)
val_dataset   = DermDataset(val_df,   transform=valid_transform)
test_dataset  = DermDataset(test_df,  transform=valid_transform)

print("\nDataset sizes:")
print(len(train_dataset), len(val_dataset), len(test_dataset))

# ============================
# Block 6 — DataLoaders
# ============================
BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# ============================
# Block 7 — Class Weights (for Focal Loss)
# ============================
class_counts = train_df["label"].value_counts().sort_index()
print("\nClass counts (train):")
print(class_counts)

# Simple inverse-frequency weights
inv_freq = 1.0 / class_counts
weights = inv_freq / inv_freq.sum() * len(inv_freq)

class_weights_tensor = torch.tensor(
    weights.values,
    dtype=torch.float32,
    device=device
)

print("\nClass weights (normalized):")
print(class_weights_tensor)

# ============================
# Block 8 — Focal Loss
# ============================
class FocalLoss(nn.Module):
    """
    Multi-class Focal Loss with per-class alpha weights.
    """
    def __init__(self, alpha=None, gamma=2.0, reduction="mean"):
        super().__init__()
        self.alpha = alpha  # tensor [num_classes]
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # logits: [B, C], targets: [B]
        log_probs = F.log_softmax(logits, dim=1)        # [B, C]
        probs = log_probs.exp()                         # [B, C]

        # Gather log-prob and prob of the true class
        targets = targets.long()
        log_p = log_probs.gather(1, targets.unsqueeze(1)).squeeze(1)  # [B]
        p = probs.gather(1, targets.unsqueeze(1)).squeeze(1)          # [B]

        if self.alpha is not None:
            alpha_t = self.alpha[targets]  # [B]
        else:
            alpha_t = torch.ones_like(p)

        loss = -alpha_t * (1 - p) ** self.gamma * log_p

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss

criterion = FocalLoss(alpha=class_weights_tensor, gamma=2.0)

# ============================
# Block 9 — Build ResNet50 Model
# ============================
# Use ImageNet-pretrained ResNet50
try:
    weights_enum = models.ResNet50_Weights.IMAGENET1K_V2
    base_model = models.resnet50(weights=weights_enum)
except AttributeError:
    # For older torchvision versions
    base_model = models.resnet50(pretrained=True)

in_features = base_model.fc.in_features
base_model.fc = nn.Linear(in_features, NUM_CLASSES)

model = base_model.to(device)

# Optimizer & scheduler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

LR = 1e-4
EPOCHS = 60

optimizer = AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)

print("\nModel built: ResNet50 with", NUM_CLASSES, "classes")

# ============================
# Block 10 — Training & Validation Loops
# ============================
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

def eval_one_epoch(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    all_labels = []
    all_preds = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    bal_acc = balanced_accuracy_score(all_labels, all_preds)

    return epoch_loss, epoch_acc, bal_acc, np.array(all_labels), np.array(all_preds)

# ============================
# Block 11 — Train Model with Best-Checkpoint Saving
# ============================
history = {
    "train_loss": [],
    "val_loss": [],
    "train_acc": [],
    "val_acc": [],
    "val_bal_acc": []
}

best_bal_acc = 0.0
best_state_dict = None

print("\n==========================")
print(" TRAINING STARTED")
print("==========================")

for epoch in range(1, EPOCHS + 1):
    print(f"\nEpoch {epoch}/{EPOCHS}")

    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer, criterion, device
    )
    val_loss, val_acc, val_bal_acc, _, _ = eval_one_epoch(
        model, val_loader, criterion, device
    )

    scheduler.step()

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)
    history["val_bal_acc"].append(val_bal_acc)

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val   Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val BalAcc: {val_bal_acc:.4f}")

    # Save best model based on validation balanced accuracy
    if val_bal_acc > best_bal_acc:
        best_bal_acc = val_bal_acc
        best_state_dict = model.state_dict().copy()
        torch.save(best_state_dict, os.path.join(ROOT_DIR, "resnet50_augD_best.pth"))
        print(f"  -> New best model saved (Val BalAcc = {best_bal_acc:.4f})")

print("\n==========================")
print(" TRAINING FINISHED")
print("==========================")
print("Best Val Balanced Accuracy:", best_bal_acc)

# Load best weights before testing
if best_state_dict is not None:
    model.load_state_dict(best_state_dict)
    print("Best checkpoint loaded for testing.")

# ============================
# Block 12 — Test Evaluation
# ============================
test_loss, test_acc, test_bal_acc, y_true, y_pred = eval_one_epoch(
    model, test_loader, criterion, device
)

print("\n===== TEST RESULTS =====")
print(f"Test Loss:       {test_loss:.4f}")
print(f"Test Accuracy:   {test_acc:.4f}")
print(f"Test BalAcc:     {test_bal_acc:.4f}")

print("\nClassification Report (Test):")
print(classification_report(y_true, y_pred, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(12, 10))
sns.heatmap(
    cm,
    annot=False,
    cmap="Blues",
    xticklabels=class_names,
    yticklabels=class_names
)
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.title("Confusion Matrix — ResNet50 Aug-D (Test)")
plt.tight_layout()
plt.show()

# ============================
# Block 13 — Plot Learning Curves
# ============================
epochs_range = range(1, EPOCHS + 1)

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, history["train_loss"], label="Train Loss")
plt.plot(epochs_range, history["val_loss"],   label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss curves")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs_range, history["train_acc"],      label="Train Acc")
plt.plot(epochs_range, history["val_acc"],        label="Val Acc")
plt.plot(epochs_range, history["val_bal_acc"],    label="Val BalAcc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy curves")
plt.legend()

plt.tight_layout()
plt.show()


META_CSV: C:\Users\anama\Documents\Group_8\Dataset\DERM7PT\meta\meta.csv
IMAGES_DIR: C:\Users\anama\Documents\Group_8\Dataset\DERM7PT\images
Using device: cuda
Loaded meta.csv — shape: (1011, 19)
Images missing on disk: 0
After cleaning — shape: (1011, 18)
Sample rows:
              diagnosis            derm
0  basal cell carcinoma  NEL/Nel026.jpg
1  basal cell carcinoma  NEL/Nel028.jpg
2  basal cell carcinoma  NEL/Nel033.jpg
3  basal cell carcinoma  NEL/Nel035.jpg
4  basal cell carcinoma  NEL/Nel037.jpg

Label encoding:
 0 -> basal cell carcinoma
 1 -> blue nevus
 2 -> clark nevus
 3 -> combined nevus
 4 -> congenital nevus
 5 -> dermal nevus
 6 -> dermatofibroma
 7 -> lentigo
 8 -> melanoma
 9 -> melanoma (0.76 to 1.5 mm)
10 -> melanoma (in situ)
11 -> melanoma (less than 0.76 mm)
12 -> melanoma (more than 1.5 mm)
13 -> melanoma metastasis
14 -> melanosis
15 -> miscellaneous
16 -> recurrent nevus
17 -> reed or spitz nevus
18 -> seborrheic keratosis
19 -> vascular lesion
NUM_CLASSES: 

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.