In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import efficientnet_b5, EfficientNet_B5_Weights
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np
import os
from PIL import Image
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2

import random
d = 55
torch.manual_seed(d)
np.random.seed(d)
random.seed(d)


In [None]:
df = pd.DataFrame(columns=["tta_augmentation", "model_name", "auc", "auc_12", "auc_34", "auc_56", 
                           "accuracy", "accuracy_12", "accuracy_34", "accuracy_56"])

################# Data setup ###########################
class SkinDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = row["DDI_path"].replace("\\", "/")
        image = np.array(Image.open(image_path).convert("RGB"))

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented["image"]

        # Metadata (skin tone + disease group)
        metadata = torch.tensor([
            row['skin_tone_12'],
            row['skin_tone_34'],
            row['skin_tone_56'],
            row['Disease_Group_Non_melanoma'],
            row['Disease_Group_melanoma']
        ], dtype=torch.float)

        # Label
        label = torch.tensor(row['malignant'], dtype=torch.float)
        return image, metadata, label


In [None]:
################### Feature-wise Linear Modulation (FiLM) ###################
class FiLM(nn.Module):
    def __init__(self, metadata_dim, feature_dim):
        super(FiLM, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(metadata_dim, 32),
            nn.ReLU(),
            nn.Linear(32, feature_dim * 2)  
        )

    def forward(self, features, metadata):
        gamma_beta = self.fc(metadata)  
        gamma, beta = gamma_beta.chunk(2, dim=1)  
        return features * gamma + beta

############# Build our model ##################
class SkinCancerModel(nn.Module):
    def __init__(self, num_classes=1):
        super(SkinCancerModel, self).__init__()
        self.backbone = efficientnet_b5(weights=EfficientNet_B5_Weights.DEFAULT)
        
        for param in self.backbone.features.parameters():
            param.requires_grad = False  

        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Identity()  

        self.film = FiLM(metadata_dim=5, feature_dim=in_features)
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, image, metadata):
        image_features = self.backbone(image)  
        modulated_features = self.film(image_features, metadata)  
        output = self.classifier(modulated_features)
        return output.squeeze(1) 


################# Load Model Function ###################
def load_model(model: torch.nn.Module, model_path: str, device: torch.device):
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model


In [None]:
transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

tta_transforms = [
    A.HorizontalFlip(p=1),
    A.RandomBrightnessContrast(p=1, brightness_limit=0.5, contrast_limit=0.4),
    A.ElasticTransform(p=1),
    A.RandomScale(p=1, scale_limit=0.2),
    A.Rotate(p=1, limit=45)
]

def apply_tta(image, transform_fn):
    return transform_fn(image=image)['image']

################# Prediction Function with TTA ###################
def tta_predict(model, image_path, augmentations, metadata, device):
    original_image = Image.open(image_path).convert("RGB")
    tta_preds = []
    model.eval()
    with torch.no_grad():
        transformed_image = transform(image=np.array(original_image))["image"].unsqueeze(0).to(device)
        original_pred = model(transformed_image, metadata.to(device))
        tta_preds.append(torch.sigmoid(original_pred))
        for aug in augmentations:
            aug_image_np = apply_tta(np.array(original_image), aug)
            transformed_aug_image = transform(image=aug_image_np)["image"].unsqueeze(0).to(device)
            aug_pred = model(transformed_aug_image, metadata.to(device))
            tta_preds.append(torch.sigmoid(aug_pred))
    return torch.mean(torch.stack(tta_preds), dim=0)

################# Evaluation Function ###################
def evaluate_model(model, test_metadata, device, skin_tone_column=None):
    if skin_tone_column:
        skin_tone_df = test_metadata[test_metadata[skin_tone_column] == 1].reset_index(drop=True)
    else:
        skin_tone_df = test_metadata.reset_index(drop=True)

    y_true, y_pred_classes, y_pred_probs = [], [], []

    for _, row in skin_tone_df.iterrows():
        image_path = row["DDI_path"]
        true_label = row["malignant"]
        metadata = torch.tensor([
            row['skin_tone_12'],
            row['skin_tone_34'],
            row['skin_tone_56'],
            row['Disease_Group_Non_melanoma'],
            row['Disease_Group_melanoma']
        ], dtype=torch.float32).unsqueeze(0).to(device)

        final_pred = tta_predict(model, image_path, tta_transforms, metadata, device)
        predicted_prob = final_pred.item()
        predicted_class = 1 if predicted_prob >= 0.5 else 0

        y_true.append(true_label)
        y_pred_probs.append(predicted_prob)
        y_pred_classes.append(predicted_class)

    # Compute metrics
    accuracy = accuracy_score(y_true, y_pred_classes)
    auc = roc_auc_score(y_true, y_pred_probs) if len(set(y_true)) > 1 else None

    return accuracy, auc


In [None]:
################## Load Test Data ##################
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SkinCancerModel().to(DEVICE)
model = load_model(model, "models/skin_cancer_model_melanoma_epochs={num_epochs}.pth", DEVICE)
test_metadata = pd.read_csv("data/test_metadata.csv")

################## Global Evaluation ##################
global_acc, global_auc = evaluate_model(model, test_metadata, DEVICE)
print("\nGlobal Evaluation:")
print(f"Accuracy: {global_acc:.4f}")
print(f"AUC: {global_auc:.4f}" if global_auc is not None else "AUC: Not computed")

################## Per Skin Tone Evaluation ##################
skin_tone_columns = ["skin_tone_12", "skin_tone_34", "skin_tone_56"]
accuracies, aucs = {}, {}

for skin_tone_col in skin_tone_columns:
    acc, auc = evaluate_model(model, test_metadata, DEVICE, skin_tone_column=skin_tone_col)
    accuracies[skin_tone_col] = acc
    aucs[skin_tone_col] = auc

    print(f"\nEvaluation for {skin_tone_col}:")
    print(f"Accuracy: {acc:.4f}")
    print(f"AUC: {auc:.4f}" if auc is not None else "AUC: Not computed")

# Store results in DataFrame
df.loc[len(df)] = {
    "tta_augmentation": 1,
    "model_name": "EfficientNet_B5 + FiLM",
    "auc": global_auc,
    "auc_12": aucs["skin_tone_12"],
    "auc_34": aucs["skin_tone_34"],
    "auc_56": aucs["skin_tone_56"],
    "accuracy": global_acc,
    "accuracy_12": accuracies["skin_tone_12"],
    "accuracy_34": accuracies["skin_tone_34"],
    "accuracy_56": accuracies["skin_tone_56"]
}



In [None]:
df


Unnamed: 0,tta_augmentation,model_name,auc,auc_12,auc_34,auc_56,accuracy,accuracy_12,accuracy_34,accuracy_56
0,1,EfficientNet_B5 + FiLM,0.925333,0.811765,0.895833,0.955882,0.815385,0.772727,0.727273,0.904762
