In [None]:
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

In [None]:
# 1. Install required packages (if needed)
!pip install torchinfo

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import models, transforms
import timm
from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchinfo import summary
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import os
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import json
import random
import math
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,confusion_matrix

In [None]:
def set_seed(seed: int):
    """Seed everything for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # enforce deterministic algorithms (may slow things down)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # torch 2.x: fully deterministic
    if hasattr(torch, "use_deterministic_algorithms"):
        torch.use_deterministic_algorithms(True)
# os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
# choose your seed
SEED = 2506
set_seed(SEED)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
train_dir = "path_to_train_directory"
test_dir =  "path_to_test_directory"

In [None]:
def load_image(path):
    img = cv2.imread(path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (224, 224))
    img = torch.tensor(img).permute(2, 0, 1).float() / 255.0
    return img

In [None]:
def get_data(root_directory):
    """
    Collects image paths from a fixed list of labels.
    Any label that is not 'normal' is considered 'abnormal'.
    Returns a list of (image_path, label) tuples.
    """
    image_label_dict = {}
    for label in ["lesion", "normal", "variation in normal", "red lesion"]:
    # for label in ["lesion", "normal", "red lesion"]:
        label_dir = os.path.join(root_directory, label)
        print('Loading Images from:',label_dir)
        if os.path.isdir(label_dir):
            for image_file in tqdm(os.listdir(label_dir)):
                image_path = os.path.join(label_dir, image_file)
                if os.path.isfile(image_path):
                    if label != 'normal':
                        image_label_dict[image_path] = {
                            'img':load_image(image_path),
                            'label':1,
                            'label_name':'abnormal'
                            }
                    else:
                        image_label_dict[image_path] = {
                            'img':load_image(image_path),
                            'label':0,
                            'label_name':'normal'
                            }
    return image_label_dict

In [None]:
train_samples_all = get_data(train_dir)
test_samples  = get_data(test_dir)

In [None]:
len(train_samples_all.keys())

In [None]:
def split_data(samples_dict, val_ratio=0.2, seed = SEED):
    """
    Splits the samples_dict into training and validation dictionaries.
    """
    random.seed(seed)
    sample_keys = list(samples_dict.keys())
    random.shuffle(sample_keys)

    val_size = int(len(sample_keys) * val_ratio)
    val_keys = sample_keys[:val_size]
    train_keys = sample_keys[val_size:]

    train_split = {k: samples_dict[k] for k in train_keys}
    val_split = {k: samples_dict[k] for k in val_keys}

    return train_split, val_split

# Apply the split
train_samples, val_samples = split_data(train_samples_all)

In [None]:
print(len(train_samples.keys()),len(val_samples.keys()),len(test_samples.keys()))
print('Total samples:',len(train_samples.keys())+len(val_samples.keys())+len(test_samples.keys()))

In [None]:
class BatchGenerator:
    def __init__(self, image_label_dict, batch_size):
        """
        Wraps a list of pairs into an iterable batch‐generator with length.
        """
        self.image_label_dict = image_label_dict
        self.batch_size = batch_size

    def __len__(self):
        # number of batches (ceil so last partial batch counts)
        return math.ceil(len(self.image_label_dict.keys()) / self.batch_size)

    def __iter__(self):
        batch_img, batch_labels,batch_img_path,batch_label_name = [], [], [], []

        for img_path in self.image_label_dict.keys():
            batch_img_path.append(img_path)
            batch_img.append(self.image_label_dict[img_path]['img'])
            batch_labels.append(self.image_label_dict[img_path]['label'])
            batch_label_name.append(self.image_label_dict[img_path]['label_name'])

            if len(batch_img) == self.batch_size:
                yield (
                    batch_img_path,
                    batch_label_name,
                    torch.stack(batch_img),
                    torch.tensor(batch_labels, dtype=torch.float),

                )
                batch_img, batch_labels,batch_img_path,batch_label_name = [], [], [], []

        # last partial batch
        if batch_img:
            yield(
                    batch_img_path,
                    batch_label_name,
                    torch.stack(batch_img),
                    torch.tensor(batch_labels, dtype=torch.float),
                )

In [None]:
train_gen = BatchGenerator(train_samples, batch_size=32)
val_gen = BatchGenerator(val_samples, batch_size=32)
test_gen = BatchGenerator(test_samples, batch_size=32)

In [None]:
len(train_gen),len(val_gen),len(test_gen)

In [None]:
for pth,label_name,imgs, labels in tqdm(train_gen, desc='Train', leave=False):
        imgs, labels = imgs.to(device), labels.to(device)

In [None]:
pth[0],label_name[0],imgs[0].shape,labels[0]

# Model

In [None]:
# model = timm.create_model('inception_v4', pretrained=True)

In [None]:
# parameter_list = [param for param in model.features[-1].parameters()]
# parameter_list

In [None]:
# for feature_layer in model.features:
#     print(feature_layer)
# for p in model.named_parameters():
#     print(p[0])

In [None]:
def get_model():
    # 1. Create a pretrained Inception-V4
    model = timm.create_model('inception_v4', pretrained=True)

    # 2. Swap out its head for a 2-class linear layer
    #    timm models provide reset_classifier():
    model.reset_classifier(num_classes=2, global_pool='avg')

    # 3. Move to device
    model = model.to(device)

    # 4. Freeze all parameters
    for p in model.parameters():
        p.requires_grad = False

    # 5. Unfreeze the last few feature‐blocks + the classifier
    #    Here we take the last 3 “children” modules (usually high‐level Inception blocks + head)
    # children = list(model.children())
    # for block in children[-3:]:
    #     for p in block.parameters():
    #         p.requires_grad = True
    # for feature_layer in model.features[-1]:
    #     for param in feature_layer.parameters()[:]:
    #         param.requires_grad = True
    # parameter_list = [param for param in model.features[-1].parameters()]
    # for param in parameter_list[-4:]:
    #     param.requires_grad = True
    # for param in model.features[-1].parameters():
    #     param.requires_grad = True

    # 6. Also ensure the final linear layer is trainable
    for p in model.get_classifier().parameters():
        p.requires_grad = True
    return model

In [None]:
# model = get_model()
# summary(model, input_size=(32, 3, 224, 224))

In [None]:
# for i in model.named_parameters():
#   print(i[0])

In [None]:
# --- Training for one epoch -----------------------------------
def train_epoch(model, loader, optimizer,criterion, device):
    model.train()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    for pth,label_names,images, labels in tqdm(loader, desc='Train', leave=False):
        images, labels = images.to(device), labels.to(device).long()

        optimizer.zero_grad()
        outputs = model(images)                  # [batch, 2] raw logits
        loss = criterion(outputs, labels)        # CrossEntropyLoss
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    avg_acc  = total_correct / total_samples
    return avg_loss, avg_acc

# --- Validation (no threshold sweep) --------------------------
def validate_epoch(model, loader,criterion, device):
    model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0

    with torch.no_grad():
        for pth,label_names,images, labels in tqdm(loader, desc='Val', leave=False):
            images, labels = images.to(device), labels.to(device).long()
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * labels.size(0)
            preds = outputs.argmax(dim=1)
            total_correct += (preds == labels).sum().item()
            total_samples += labels.size(0)

    avg_loss = total_loss / total_samples
    avg_acc  = total_correct / total_samples
    return avg_loss, avg_acc

In [None]:
model = get_model()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer= optimizer, mode = 'min', factor=0.1 ,patience=5)
criterion = nn.CrossEntropyLoss()
# Early stopping config
best_val_loss = float('inf')

In [None]:
dir = "CNN Results/Inception-v4/Best Model"
MODEL_NAME = f"Inception-v4_{SEED}"
model_path      = os.path.join(dir, f"{MODEL_NAME}.pth")
best_model_path = os.path.join(dir, f"{MODEL_NAME}_best.pth")
metrics_path    = os.path.join(dir, f"{MODEL_NAME}_training_metrics.json")

best_loss = float('inf')
patience, epochs_no_improve = 20, 0
train_loss_history, train_acc_history = [], []
val_loss_history, val_acc_history = [], []
lr_history = []
num_epochs = 100
for epoch in range(1,num_epochs+1):
    train_loss, train_acc = train_epoch(model, train_gen, optimizer, criterion, device)
    val_loss, val_acc = validate_epoch(model, val_gen, criterion, device)
    scheduler.step(val_loss)

    # Record metrics
    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)
    val_loss_history.append(val_loss)
    val_acc_history.append(val_acc)
    # threshold_history.append(None)  # no threshold logic in current validate
    lr_history.append(optimizer.param_groups[0]['lr'])
    print(f"Epoch {epoch:02d}: "
          f"Train Loss={train_loss:.4f}, Acc={train_acc*100:.2f}% | "
          f"Val Loss={val_loss:.4f}, Acc={val_acc*100:.2f}% | "
          f"LR={lr_history[-1]:.6f}")

    # Early stopping and best model save
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        epochs_no_improve = 0
        print("  → New best model saved")
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch}")
            break

# Save final model and metrics
torch.save(model.state_dict(), model_path)
metrics = {
    "train_loss": [float(x) for x in train_loss_history],
    "train_accuracy": [float(x) for x in train_acc_history],
    "val_loss": [float(x) for x in val_loss_history],
    "val_accuracy": [float(x) for x in val_acc_history],
    "learning_rate": [float(x) for x in lr_history]
}
with open(metrics_path, 'w') as f:
    json.dump(metrics, f, indent=4)

print(f"Training complete. Model saved to:\n  best → {best_model_path}\n  final → {model_path}\nMetrics written to {metrics_path}")

In [None]:
import matplotlib.pyplot as plt
import json

# Load metrics
with open(metrics_path, 'r') as f:
    metrics = json.load(f)

train_loss = metrics['train_loss']
val_loss = metrics['val_loss']
train_acc = metrics['train_accuracy']
val_acc = metrics['val_accuracy']
lr = metrics['learning_rate']
epochs = range(1, len(train_loss) + 1)

# Plot Loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_loss, label='Train Loss', marker='o')
plt.plot(epochs, val_loss, label='Validation Loss', marker='o')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss over Epochs')
plt.legend()
plt.grid(True)

# Plot Accuracy
plt.subplot(1, 2, 2)
plt.plot(epochs, train_acc, label='Train Accuracy', marker='o')
plt.plot(epochs, val_acc, label='Validation Accuracy', marker='o')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracy over Epochs')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Plot Learning Rate
plt.figure(figsize=(6, 4))
plt.plot(epochs, lr, label='Learning Rate', color='purple', marker='o')
plt.xlabel('Epochs')
plt.ylabel('LR')
plt.title('Learning Rate over Epochs')
plt.grid(True)
plt.show()

# Testing

In [None]:
def get_model():
    # 1. Create a pretrained Inception-V4
    model = timm.create_model('inception_v4', pretrained=True)

    # 2. Swap out its head for a 2-class linear layer
    #    timm models provide reset_classifier():
    model.reset_classifier(num_classes=2, global_pool='avg')

    # 3. Move to device
    model = model.to(device)

    # # 4. Freeze all parameters
    # for p in model.parameters():
    #     p.requires_grad = False

    # # 5. Unfreeze the last few feature‐blocks + the classifier
    # #    Here we take the last 3 “children” modules (usually high‐level Inception blocks + head)
    # # children = list(model.children())
    # # for block in children[-3:]:
    # #     for p in block.parameters():
    # #         p.requires_grad = True
    # # for feature_layer in model.features[20]:
    # #     for param in feature_layer.parameters():
    # #         param.requires_grad = True
    # for param in model.features[21:].parameters():
    #     param.requires_grad = True

    # # 6. Also ensure the final linear layer is trainable
    # for p in model.get_classifier().parameters():
    #     p.requires_grad = True

    return model

In [None]:
model = get_model()
seed = SEED
model_path = best_model_path
model.load_state_dict(torch.load(model_path, map_location=device))

In [None]:
model.eval()
y_true = []
y_pred = []
total_loss, total_correct, total_samples = 0.0, 0, 0
criterion = nn.CrossEntropyLoss()
with torch.no_grad():
    for pth,label_names,images, labels in tqdm(test_gen, desc='Test', leave=False):
        images, labels = images.to(device), labels.to(device).long()
        outputs = model(images)
        loss = criterion(outputs, labels)

        total_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)

avg_loss = total_loss / total_samples
avg_acc  = total_correct / total_samples
print(f"Test Loss={avg_loss:.4f}, Acc={avg_acc*100:.2f}%")

In [None]:
# Basic metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)

# --- Confusion matrix ---
cm = confusion_matrix(y_true, y_pred)
tn, fp, fn, tp = cm.ravel()

# Specificity = TN / (TN + FP)
specificity = tn / (tn + fp) if (tn + fp) != 0 else 0.0

print(f"\nEvaluation Metrics:")
print(f"Accuracy : {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Sensitivity : {recall:.4f}")
print(f"Specificity : {specificity:.4f}")
print(f"F1 Score : {f1:.4f}")

labels = ["normal", "lesion"]
plt.figure(figsize=(5, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.show()