In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b5, EfficientNet_B5_Weights
import pandas as pd
import numpy as np
import os
from PIL import Image
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from pathlib import Path

# Ensure reproducibility
d = 19
torch.manual_seed(d)
np.random.seed(d)


In [None]:
df = pd.DataFrame(columns=["tta_augmentation", "model_name", "auc", "auc_12", "auc_34", "auc_56", 
                           "accuracy", "accuracy_12", "accuracy_34", "accuracy_56"])


In [None]:
################### Model Definition ###################
class SkinCancerModel(nn.Module):
    def __init__(self, num_classes=1):
        super(SkinCancerModel, self).__init__()
        self.backbone = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=0)
        feature_dim = self.backbone.num_features

        # Freeze backbone except for the last layer
        for param in self.backbone.parameters():
            param.requires_grad = False
        for param in self.backbone.layers[-1].parameters():
            param.requires_grad = True

        self.image_fc = nn.Linear(feature_dim, 512)
        self.metadata_fc = nn.Linear(3, 512)

        self.classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)  
        )

    def forward(self, image, metadata):
        image_features = self.image_fc(self.backbone(image))
        metadata_features = self.metadata_fc(metadata)
        fused_features = torch.cat((image_features, metadata_features), dim=1)
        output = self.classifier(fused_features)
        return output.squeeze(1) 

################### Load Model ###################
def load_model(model: torch.nn.Module, model_path: str, device: torch.device):
    model_path = Path(model_path)
    if not model_path.exists():
        raise FileNotFoundError(f"Model file not found: {model_path}")
    print(f"Loading model from: {model_path}")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model


In [None]:
################### Transformations ###################
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

tta_transforms = [
    A.HorizontalFlip(p=1),
    A.RandomBrightnessContrast(p=1, brightness_limit=0.5, contrast_limit=0.4),
    A.ElasticTransform(p=1),
    A.RandomScale(p=1, scale_limit=0.2),
    A.Rotate(p=1, limit=45)
]

def apply_tta(image, transform_fn):
    return transform_fn(image=image)['image']

################### Prediction with TTA ###################
def tta_predict(model, image_path, augmentations, skin_features, device):
    original_image = Image.open(image_path).convert("RGB")
    tta_preds = []
    model.eval()
    with torch.no_grad():
        transformed_image = transform(original_image).unsqueeze(0).to(device)
        original_pred = model(transformed_image, skin_features.to(device))
        tta_preds.append(torch.sigmoid(original_pred))

        for aug in augmentations:
            aug_image_np = apply_tta(np.array(original_image), aug)
            aug_image = Image.fromarray(aug_image_np)
            transformed_aug_image = transform(aug_image).unsqueeze(0).to(device)
            aug_pred = model(transformed_aug_image, skin_features.to(device))
            tta_preds.append(torch.sigmoid(aug_pred))

    return torch.mean(torch.stack(tta_preds), dim=0)

################### Model Evaluation Function ###################
def evaluate_model(model, test_metadata, device, skin_tone_column=None):
    if skin_tone_column:
        skin_tone_df = test_metadata[test_metadata[skin_tone_column] == 1].reset_index(drop=True)
    else:
        skin_tone_df = test_metadata.reset_index(drop=True)

    y_true, y_pred_classes, y_pred_probs = [], [], []

    for _, row in skin_tone_df.iterrows():
        image_path = row["DDI_path"]
        true_label = row["malignant"]
        skin_features = torch.tensor([
            row['skin_tone_12'],
            row['skin_tone_34'],
            row['skin_tone_56']
        ], dtype=torch.float32).unsqueeze(0).to(device)

        final_pred = tta_predict(model, image_path, tta_transforms, skin_features, device)
        predicted_prob = final_pred.item()
        predicted_class = 1 if predicted_prob >= 0.5 else 0

        y_true.append(true_label)
        y_pred_probs.append(predicted_prob)
        y_pred_classes.append(predicted_class)

    # Compute metrics
    accuracy = accuracy_score(y_true, y_pred_classes)
    auc = roc_auc_score(y_true, y_pred_probs) if len(set(y_true)) > 1 else None

    return accuracy, auc

################### Load Test Data and Model ###################
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SkinCancerModel().to(DEVICE)
model = load_model(model, model_path="models/improved_skin_cancer_v4.pth", device=DEVICE)
test_metadata = pd.read_csv("data/test_metadata.csv")

################### Global Evaluation ###################
global_acc, global_auc = evaluate_model(model, test_metadata, DEVICE)
print("\nGlobal Evaluation:")
print(f"Accuracy: {global_acc:.4f}")
print(f"AUC: {global_auc:.4f}" if global_auc is not None else "AUC: Not computed")

################### Per Skin Tone Evaluation ###################
skin_tone_columns = ["skin_tone_12", "skin_tone_34", "skin_tone_56"]
accuracies, aucs = {}, {}

for skin_tone_col in skin_tone_columns:
    acc, auc = evaluate_model(model, test_metadata, DEVICE, skin_tone_column=skin_tone_col)
    accuracies[skin_tone_col] = acc
    aucs[skin_tone_col] = auc

    print(f"\nEvaluation for {skin_tone_col}:")
    print(f"Accuracy: {acc:.4f}")
    print(f"AUC: {auc:.4f}" if auc is not None else "AUC: Not computed")

# Store results in DataFrame
df.loc[len(df)] = {
    "tta_augmentation": 1,
    "model_name": "Swin Transformer",
    "auc": global_auc,
    "auc_12": aucs["skin_tone_12"],
    "auc_34": aucs["skin_tone_34"],
    "auc_56": aucs["skin_tone_56"],
    "accuracy": global_acc,
    "accuracy_12": accuracies["skin_tone_12"],
    "accuracy_34": accuracies["skin_tone_34"],
    "accuracy_56": accuracies["skin_tone_56"]
}



In [None]:
df


Unnamed: 0,tta_augmentation,model_name,auc,auc_12,auc_34,auc_56,accuracy,accuracy_12,accuracy_34,accuracy_56
0,1,Swin Transformer,0.917333,0.882353,0.854167,0.985294,0.830769,0.909091,0.772727,0.857143
