# Model Experiments

Now that we have our dev set and model working (see `dev.ipynb`), we want to test how modifications affect accuracy on a larger portion of ISIC data (but still not the full set so that we can still train in a reasonable amount of time).

The modifications to `BaselineModel` (EfficientNet-B3 backbone, defined in `baseline_model.py`) that we tested:
1. (ISIC) Metadata fusion: `FusionModel` (defined in `fusion_model.py`)
2. Classifier head changes to yield `MLPFusionModel`: replace final `Linear(in_features, num_classes)` with a small MLP: `Dropout → Linear → ReLU → Dropout → Linear` (class defined in `mlphead_model.py`)
3. MLP Model without metadata fusion: `BaselineMLPHead` (defined in `baseline_w_mlp.py`)
4. ViT model Swin: `SwinClassifier` (defined in `swin.py`)
5. Ensembling pairs of the above models

A common subset size to test is 5-20% of data but this depends on budget and overfitting risk:
- 1%: sanity checks, dev loop testing
- 5%: quick turnarounds, still meaningful eval
- 20%: mid-size experiments, tune/fine-tune



In [None]:
from google.colab import drive
drive.mount('/content/drive')

import sys
project_root = "/content/drive/My Drive/midas"
if project_root not in sys.path:
    sys.path.append(project_root)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# load pacakges
import os
import sys
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, recall_score
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
from tqdm import tqdm
from models.fusion_model import FusionModel
import traceback
import matplotlib.pyplot as plt
from pathlib import Path

# load my scripts
project_root = "/content/drive/My Drive/midas"
sys.path.append(project_root)
from utils.dataloader import ISICDataset
from utils.preprocess import get_effnet_image_transform, get_swin_image_transform, load_and_split
from src.trainer import train_model

In [None]:
# required global params
N_SAMPLES = 1000
# USE_METADATA = True

### Load & Preprocess Data

ISIC images and metadata

TODO: make sure `use_metadata` toggle works.

Methodological Notes:
1. A subset of lesions have lesion IDs which indicate that there are other images of the same lesion in the dataset. Lesion-level leakage is a known concern in ISIC evaluations: (1) if one image of lesion is in train and the other in validation -> validation accuracy becomes inflated (2) if you train multiple images on the same lesion, risk memorizing that lesion's specific pattern.
2. To mitigate lesion-level leakage, we will instead split data by lesion, not by image. We will group our data by lesion_id and assign entire lesion groups to train or val. We will treat samples without lesion_id as a unique group.
3. We drop lesion_id in pre-process, because it is not used in training (we have already selected unique lesions for train-val split). We don't yet drop image or label, because we need image ot find image file and label in dataset to return target.


In [None]:
# Paths
base_dir = "/content/drive/My Drive/midas/data"
labels_path = os.path.join(base_dir, "ISIC_2019_Training_GroundTruth.csv")
meta_path = os.path.join(base_dir, "ISIC_2019_Training_Metadata.csv")
img_dir = os.path.join(base_dir, f"sample_{N_SAMPLES}")


# Load merged metadata + labels
train_df, val_df, index_to_label = load_and_split(labels_path, meta_path, n_samples=N_SAMPLES)
print('class distribution for train: ', train_df['label'].value_counts(normalize=True))

# Dataset + loaders
effnet_transform = get_effnet_image_transform()
# ---- baseline model -----
train_dataset_baseline = ISICDataset(img_dir, train_df, effnet_transform, index_to_label, use_metadata=False)
val_dataset_baseline = ISICDataset(img_dir, val_df, effnet_transform, index_to_label, use_metadata=False)
train_loader_baseline = DataLoader(train_dataset_baseline, batch_size=16, shuffle=True)
val_loader_baseline = DataLoader(val_dataset_baseline, batch_size=16, shuffle=False)

# ---- metadata fusion -----
train_dataset_fusion = ISICDataset(img_dir, train_df, effnet_transform, index_to_label, use_metadata=True)
val_dataset_fusion = ISICDataset(img_dir, val_df, effnet_transform, index_to_label, use_metadata=True)
train_loader_fusion = DataLoader(train_dataset_fusion, batch_size=16, shuffle=True)
val_loader_fusion = DataLoader(val_dataset_fusion, batch_size=16, shuffle=False)

# ----- swin -----
swin_transform = get_swin_image_transform()
train_swin_dataset = ISICDataset(img_dir, train_df, swin_transform, index_to_label, use_metadata=False)
val_swin_dataset = ISICDataset(img_dir, val_df, swin_transform, index_to_label, use_metadata=False)
train_loader_swin = DataLoader(train_swin_dataset, batch_size=16, shuffle=True)
val_loader_swin = DataLoader(val_swin_dataset, batch_size=16, shuffle=False)


Train split: 800 images, 764 unique lesions
Val split:   200 images, 192 unique lesions
class distribution for train:  label
1    0.48250
0    0.16750
2    0.15875
4    0.11375
3    0.04250
7    0.02000
6    0.00875
5    0.00625
Name: proportion, dtype: float64


### NOTE

Not worth trying to class balance right now, because we are trying to compare fusion model to baseline. We will instead use class-weighted loss `class_weight=balanced` in our loss fucntion to focus on rare classes with changing images distirbution. We have decided not to class balance here, because this is more representative of the real-world imabalance, and we want to see if fusion helps under real-world conditions.

In [None]:
# Confirm all 8 classes are present in train and val
train_classes = set(train_df['label'].unique())
val_classes = set(val_df['label'].unique())
missing_train = set(range(8)) - train_classes
missing_val = set(range(8)) - val_classes

print('Train class distribution:\n', train_df['label'].value_counts(normalize=True).sort_index())
print('Val class distribution:\n', val_df['label'].value_counts(normalize=True).sort_index())

if missing_train:
    print(f"⚠️ Train set is missing classes: {sorted(missing_train)}")
else:
    print("✅ Train set has all 8 classes.")

if missing_val:
    print(f"⚠️ Val set is missing classes: {sorted(missing_val)}")
else:
    print("✅ Val set has all 8 classes.")

Train class distribution:
 label
0    0.16750
1    0.48250
2    0.15875
3    0.04250
4    0.11375
5    0.00625
6    0.00875
7    0.02000
Name: proportion, dtype: float64
Val class distribution:
 label
0    0.195
1    0.495
2    0.115
3    0.050
4    0.095
5    0.010
6    0.020
7    0.020
Name: proportion, dtype: float64
✅ Train set has all 8 classes.
✅ Val set has all 8 classes.


#### Create random subset for dev

In [None]:
# import shutil

# image_ids = set(train_df['image']).union(val_df['image'])
# src_dir = "/content/drive/My Drive/midas/data/ISIC_2019_Training_Input"  # full ISIC image dir
# dst_dir = f"/content/drive/My Drive/midas/data/sample_{n_samples}"

# os.makedirs(dst_dir, exist_ok=True)

# for image_id in image_ids:
#     src_path = os.path.join(src_dir, f"{image_id}.jpg")
#     dst_path = os.path.join(dst_dir, f"{image_id}.jpg")
#     shutil.copyfile(src_path, dst_path)

## Metadata Fusion
Best at Epoch 4/5: Train Loss = 1.1936, Val Loss = 1.6319, Val Acc = 0.6450

See logs for more details.

In [None]:
from models.fusion_model import FusionModel

results_root = Path("/content/drive/MyDrive/midas/results")

# Get metadata dim from a sample
_, meta_sample, _ = next(iter(train_loader))
meta_dim = meta_sample.shape[1]

# Initialize model
model = FusionModel(num_metadata_features=meta_dim, num_classes=8)

# Define forward function for FusionModel
def fusion_forward(model, batch):
    imgs, metas, labels = batch
    return model(imgs, metas)

# Setup optimizer and scheduler
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', patience=2, factor=0.5
)

# Train the model
train_model(
    model, train_loader, val_loader, train_df, index_to_label,
    forward_fn=fusion_forward,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs = 5,
    experiment_name="fusion_lr-1e-4_lr-scheduler",
    results_root=results_root
)


Downloading: "https://download.pytorch.org/models/efficientnet_b3_rwightman-b3899882.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b3_rwightman-b3899882.pth
100%|██████████| 47.2M/47.2M [00:00<00:00, 229MB/s]


CUDA available: True
Current device: 0


Train: 100%|██████████| 50/50 [10:01<00:00, 12.03s/it]
Val: 100%|██████████| 13/13 [02:27<00:00, 11.38s/it]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 1: Train Loss = 2.0352
Epoch 1: Val Loss = 2.0117, Val Acc = 0.5450
Classification Report:
              precision    recall  f1-score   support

         MEL       0.60      0.23      0.33        39
          NV       0.63      0.90      0.74        99
         BCC       0.33      0.30      0.32        23
          AK       0.20      0.30      0.24        10
         BKL       0.20      0.05      0.08        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.55       200
   macro avg       0.25      0.22      0.21       200
weighted avg       0.49      0.55      0.49       200

Confusion Matrix:
[[ 9 23  2  3  1  0  1  0]
 [ 3 89  3  3  1  0  0  0]
 [ 0 11  7  4  1  0  0  0]
 [ 2  1  3  3  1  0  0  0]
 [ 1 11  4  1  1  1  0  0]
 [ 0  2  0  0  0  0  0  0]
 [ 0  4  0  0  0  0  0  0]
 [ 0  1  2  1  0  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:28<00:00,  1.73it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  3.05it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 2: Train Loss = 1.8074
Epoch 2: Val Loss = 1.8879, Val Acc = 0.6150
Classification Report:
              precision    recall  f1-score   support

         MEL       0.57      0.41      0.48        39
          NV       0.73      0.85      0.79        99
         BCC       0.43      0.78      0.55        23
          AK       0.50      0.40      0.44        10
         BKL       0.20      0.05      0.08        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.61       200
   macro avg       0.30      0.31      0.29       200
weighted avg       0.57      0.61      0.58       200

Confusion Matrix:
[[16 16  4  2  1  0  0  0]
 [ 5 84  8  0  2  0  0  0]
 [ 1  3 18  1  0  0  0  0]
 [ 0  0  4  4  1  0  0  1]
 [ 5  7  4  1  1  0  1  0]
 [ 1  1  0  0  0  0  0  0]
 [ 0  3  1  0  0  0  0  0]
 [ 0  1  3  0  0  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:30<00:00,  1.65it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  3.06it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 3: Train Loss = 1.5365
Epoch 3: Val Loss = 1.7103, Val Acc = 0.6350
Classification Report:
              precision    recall  f1-score   support

         MEL       0.59      0.51      0.55        39
          NV       0.82      0.79      0.80        99
         BCC       0.45      0.87      0.60        23
          AK       1.00      0.20      0.33        10
         BKL       0.30      0.37      0.33        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.64       200
   macro avg       0.40      0.34      0.33       200
weighted avg       0.65      0.64      0.62       200

Confusion Matrix:
[[20 10  3  0  6  0  0  0]
 [ 8 78  8  0  5  0  0  0]
 [ 1  0 20  0  2  0  0  0]
 [ 0  0  5  2  2  0  0  1]
 [ 4  3  4  0  7  0  1  0]
 [ 0  1  0  0  1  0  0  0]
 [ 0  3  1  0  0  0  0  0]
 [ 1  0  3  0  0  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:29<00:00,  1.69it/s]
Val: 100%|██████████| 13/13 [00:05<00:00,  2.50it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 4: Train Loss = 1.1936
Epoch 4: Val Loss = 1.6319, Val Acc = 0.6450
Classification Report:
              precision    recall  f1-score   support

         MEL       0.55      0.54      0.55        39
          NV       0.82      0.84      0.83        99
         BCC       0.64      0.70      0.67        23
          AK       0.40      0.40      0.40        10
         BKL       0.19      0.21      0.20        19
          DF       0.00      0.00      0.00         2
        VASC       0.33      0.25      0.29         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.65       200
   macro avg       0.37      0.37      0.37       200
weighted avg       0.63      0.65      0.64       200

Confusion Matrix:
[[21 11  0  1  6  0  0  0]
 [ 6 83  3  0  6  0  1  0]
 [ 1  0 16  3  3  0  0  0]
 [ 1  0  2  4  1  0  0  2]
 [ 8  4  1  1  4  0  1  0]
 [ 0  1  0  0  1  0  0  0]
 [ 0  2  1  0  0  0  1  0]
 [ 1  0  2  1  0  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:29<00:00,  1.70it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  2.98it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 5: Train Loss = 0.9107
Epoch 5: Val Loss = 1.7892, Val Acc = 0.6650
Classification Report:
              precision    recall  f1-score   support

         MEL       0.65      0.51      0.57        39
          NV       0.80      0.86      0.83        99
         BCC       0.55      0.78      0.64        23
          AK       0.21      0.30      0.25        10
         BKL       0.50      0.26      0.34        19
          DF       0.00      0.00      0.00         2
        VASC       0.50      0.50      0.50         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.67       200
   macro avg       0.40      0.40      0.39       200
weighted avg       0.65      0.67      0.65       200

Confusion Matrix:
[[20 12  1  5  1  0  0  0]
 [ 5 85  5  1  2  0  1  0]
 [ 1  1 18  2  1  0  0  0]
 [ 0  0  5  3  0  0  0  2]
 [ 4  5  2  2  5  0  1  0]
 [ 0  1  0  0  1  0  0  0]
 [ 0  2  0  0  0  0  2  0]
 [ 1  0  2  1  0  0  0  0]]

EarlyStopping: 

([2.0352233266830444,
  1.8073658347129822,
  1.5364780282974244,
  1.1935576689243317,
  0.9106791859865189],
 [2.0116628503799436,
  1.8878853845596313,
  1.710341911315918,
  1.6318917846679688,
  1.7891948318481445],
 'MyDrive/midas/results/fusion_lr-1e-4_lr-scheduler_20250604_042053/best_model.pt')

## MLPHead
MLP: Replace final Linear(in_features, num_classes) with a small multi-layer perceptron (MLP) to add non-linearity and regularization: Dropout → Linear → ReLU → Dropout → Linear

Other specs: predict lesion class of the eight possible, still use data fusion (from stage 2)

Hypothesis: MLP model will provide more significant gains
that Fusion model did over baseline, because MLP models have been shown to provide robust gains for biomedical image models (Gessert et al. 2020)

Result was best at Epoch 5/5: Train Loss = 1.6384, Val Loss = 1.7877, Val Acc = 0.5950

In [None]:
from models.mlphead_model import MLPFusionModel

results_root = Path("/content/drive/MyDrive/midas/results")

# Get metadata dim from a sample
_, meta_sample, _ = next(iter(train_loader))
meta_dim = meta_sample.shape[1]

# Initialize MLPFusionModel
model = MLPFusionModel(
    num_classes=8,
    metadata_dim=meta_dim,
    hidden_dim=256,
    p1=0.5,
    p2=0.5
)

# Define forward function for MLP model
def mlp_forward(model, batch):
    imgs, metas, labels = batch
    return model(imgs, metas)

# Setup optimizer and scheduler
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', patience=2, factor=0.5
)

# Train the model
train_model(
    model, train_loader, val_loader, train_df, index_to_label,
    forward_fn=mlp_forward,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=5,
    experiment_name="mlp_fusion_lr-1e-4_hid-256",
    results_root=results_root
)

CUDA available: True
Current device: 0


Train: 100%|██████████| 50/50 [00:31<00:00,  1.59it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  3.03it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 1: Train Loss = 2.0661
Epoch 1: Val Loss = 2.0635, Val Acc = 0.4900
Classification Report:
              precision    recall  f1-score   support

         MEL       0.08      0.03      0.04        39
          NV       0.56      0.95      0.70        99
         BCC       0.25      0.04      0.07        23
          AK       0.00      0.00      0.00        10
         BKL       0.12      0.11      0.11        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.49       200
   macro avg       0.13      0.14      0.12       200
weighted avg       0.33      0.49      0.38       200

Confusion Matrix:
[[ 1 33  0  0  5  0  0  0]
 [ 3 94  1  0  1  0  0  0]
 [ 2 14  1  0  6  0  0  0]
 [ 3  5  1  0  1  0  0  0]
 [ 2 14  1  0  2  0  0  0]
 [ 0  2  0  0  0  0  0  0]
 [ 0  4  0  0  0  0  0  0]
 [ 1  2  0  0  1  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:29<00:00,  1.68it/s]
Val: 100%|██████████| 13/13 [00:05<00:00,  2.55it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 2: Train Loss = 2.0118
Epoch 2: Val Loss = 2.0268, Val Acc = 0.5150
Classification Report:
              precision    recall  f1-score   support

         MEL       0.21      0.08      0.11        39
          NV       0.59      0.92      0.72        99
         BCC       0.35      0.30      0.33        23
          AK       0.00      0.00      0.00        10
         BKL       0.18      0.11      0.13        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.52       200
   macro avg       0.17      0.18      0.16       200
weighted avg       0.39      0.52      0.43       200

Confusion Matrix:
[[ 3 31  3  0  2  0  0  0]
 [ 3 91  3  0  2  0  0  0]
 [ 2 11  7  0  3  0  0  0]
 [ 4  2  3  0  1  0  0  0]
 [ 1 13  3  0  2  0  0  0]
 [ 0  2  0  0  0  0  0  0]
 [ 0  4  0  0  0  0  0  0]
 [ 1  1  1  0  1  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:29<00:00,  1.70it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  3.10it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 3: Train Loss = 1.8907
Epoch 3: Val Loss = 1.9889, Val Acc = 0.5600
Classification Report:
              precision    recall  f1-score   support

         MEL       0.28      0.18      0.22        39
          NV       0.67      0.93      0.78        99
         BCC       0.44      0.52      0.48        23
          AK       0.00      0.00      0.00        10
         BKL       0.10      0.05      0.07        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.56       200
   macro avg       0.19      0.21      0.19       200
weighted avg       0.45      0.56      0.49       200

Confusion Matrix:
[[ 7 27  1  0  4  0  0  0]
 [ 3 92  3  0  1  0  0  0]
 [ 5  5 12  0  1  0  0  0]
 [ 1  1  6  0  2  0  0  0]
 [ 7  9  2  0  1  0  0  0]
 [ 1  1  0  0  0  0  0  0]
 [ 1  2  1  0  0  0  0  0]
 [ 0  1  2  0  1  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:29<00:00,  1.69it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  2.72it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 4: Train Loss = 1.7527
Epoch 4: Val Loss = 1.8900, Val Acc = 0.5750
Classification Report:
              precision    recall  f1-score   support

         MEL       0.38      0.49      0.43        39
          NV       0.79      0.82      0.81        99
         BCC       0.40      0.52      0.45        23
          AK       0.00      0.00      0.00        10
         BKL       0.17      0.16      0.16        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.57       200
   macro avg       0.22      0.25      0.23       200
weighted avg       0.53      0.57      0.55       200

Confusion Matrix:
[[19 14  3  0  3  0  0  0]
 [12 81  4  0  2  0  0  0]
 [ 6  0 12  0  5  0  0  0]
 [ 0  0  6  0  4  0  0  0]
 [ 9  5  2  0  3  0  0  0]
 [ 1  1  0  0  0  0  0  0]
 [ 2  1  1  0  0  0  0  0]
 [ 1  0  2  0  1  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:29<00:00,  1.72it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  2.87it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 5: Train Loss = 1.6384
Epoch 5: Val Loss = 1.7877, Val Acc = 0.5950
Classification Report:
              precision    recall  f1-score   support

         MEL       0.42      0.51      0.46        39
          NV       0.79      0.83      0.81        99
         BCC       0.42      0.65      0.51        23
          AK       0.00      0.00      0.00        10
         BKL       0.17      0.11      0.13        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.59       200
   macro avg       0.22      0.26      0.24       200
weighted avg       0.54      0.59      0.56       200

Confusion Matrix:
[[20 14  3  0  2  0  0  0]
 [10 82  5  0  2  0  0  0]
 [ 5  1 15  0  2  0  0  0]
 [ 0  0  6  0  4  0  0  0]
 [12  4  1  0  2  0  0  0]
 [ 0  1  1  0  0  0  0  0]
 [ 0  2  2  0  0  0  0  0]
 [ 1  0  3  0  0  0  0  0]]

Training comple

([2.066127824783325,
  2.0117931389808654,
  1.890748345851898,
  1.7526945281028747,
  1.6384464144706725],
 [2.0634827041625976,
  2.0267913866043092,
  1.988890323638916,
  1.8899679136276246,
  1.7876703643798828],
 'MyDrive/midas/results/mlp_fusion_lr-1e-4_hid-256_20250604_044430/best_model.pt')

## MLP Head, No Fusion

Result is best at Epoch 4/5: Train Loss = 1.5203
Epoch 4: Val Loss = 1.9416, Val Acc = 0.6350


In [None]:
! touch '/content/drive/MyDrive/midas/models/baseline_w_mlp.py'

In [None]:
from models.baseline_w_mlp import BaselineMLPHead

results_root = Path("/content/drive/MyDrive/midas/results")

# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BaselineMLPHead(num_classes=8).to(device)

def baseline_forward(model, batch):
    images, labels = batch
    images, labels = images.to(device), labels.long().to(device)
    return model(images)

# Setup optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', patience=2, factor=0.5
)

# Run training
train_model(
    model, train_loader_baseline, val_loader_baseline, train_df, index_to_label,
    forward_fn=baseline_forward,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=5,
    experiment_name="baseline_mlp_head_lr-1e-4_scheduler",
    results_root=results_root
)

CUDA available: True
Current device: 0


Train: 100%|██████████| 50/50 [00:30<00:00,  1.66it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  2.66it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 1: Train Loss = 2.0410
Epoch 1: Val Loss = 2.0492, Val Acc = 0.5750
Classification Report:
              precision    recall  f1-score   support

         MEL       0.64      0.23      0.34        39
          NV       0.64      0.93      0.76        99
         BCC       0.32      0.52      0.40        23
          AK       0.33      0.10      0.15        10
         BKL       0.50      0.05      0.10        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.57       200
   macro avg       0.30      0.23      0.22       200
weighted avg       0.54      0.57      0.50       200

Confusion Matrix:
[[ 9 25  5  0  0  0  0  0]
 [ 0 92  6  0  1  0  0  0]
 [ 1  9 12  1  0  0  0  0]
 [ 0  2  7  1  0  0  0  0]
 [ 4 10  3  1  1  0  0  0]
 [ 0  2  0  0  0  0  0  0]
 [ 0  3  1  0  0  0  0  0]
 [ 0  1  3  0  0  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:29<00:00,  1.68it/s]
Val: 100%|██████████| 13/13 [00:05<00:00,  2.58it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 2: Train Loss = 1.8954
Epoch 2: Val Loss = 1.9194, Val Acc = 0.5800
Classification Report:
              precision    recall  f1-score   support

         MEL       0.53      0.26      0.34        39
          NV       0.68      0.91      0.78        99
         BCC       0.34      0.65      0.45        23
          AK       0.33      0.10      0.15        10
         BKL       0.00      0.00      0.00        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.58       200
   macro avg       0.23      0.24      0.22       200
weighted avg       0.49      0.58      0.51       200

Confusion Matrix:
[[10 23  4  1  1  0  0  0]
 [ 1 90  8  0  0  0  0  0]
 [ 1  6 15  1  0  0  0  0]
 [ 1  1  7  1  0  0  0  0]
 [ 6  7  6  0  0  0  0  0]
 [ 0  2  0  0  0  0  0  0]
 [ 0  3  1  0  0  0  0  0]
 [ 0  1  3  0  0  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:29<00:00,  1.67it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  3.03it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 3: Train Loss = 1.6813
Epoch 3: Val Loss = 1.9145, Val Acc = 0.6100
Classification Report:
              precision    recall  f1-score   support

         MEL       0.50      0.56      0.53        39
          NV       0.79      0.83      0.81        99
         BCC       0.37      0.78      0.50        23
          AK       0.00      0.00      0.00        10
         BKL       0.00      0.00      0.00        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.61       200
   macro avg       0.21      0.27      0.23       200
weighted avg       0.53      0.61      0.56       200

Confusion Matrix:
[[22 11  3  1  2  0  0  0]
 [ 9 82  8  0  0  0  0  0]
 [ 2  3 18  0  0  0  0  0]
 [ 0  0 10  0  0  0  0  0]
 [ 9  4  6  0  0  0  0  0]
 [ 1  1  0  0  0  0  0  0]
 [ 0  3  1  0  0  0  0  0]
 [ 1  0  3  0  0  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:29<00:00,  1.69it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  2.65it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 4: Train Loss = 1.5203
Epoch 4: Val Loss = 1.9416, Val Acc = 0.6350
Classification Report:
              precision    recall  f1-score   support

         MEL       0.56      0.59      0.57        39
          NV       0.82      0.82      0.82        99
         BCC       0.39      0.78      0.52        23
          AK       0.50      0.10      0.17        10
         BKL       0.33      0.21      0.26        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.64       200
   macro avg       0.33      0.31      0.29       200
weighted avg       0.62      0.64      0.61       200

Confusion Matrix:
[[23  9  3  1  3  0  0  0]
 [ 9 81  6  0  3  0  0  0]
 [ 2  2 18  0  1  0  0  0]
 [ 0  0  8  1  1  0  0  0]
 [ 5  4  6  0  4  0  0  0]
 [ 0  1  1  0  0  0  0  0]
 [ 1  2  1  0  0  0  0  0]
 [ 1  0  3  0  0  0  0  0]]

EarlyStopping: 

Train: 100%|██████████| 50/50 [00:29<00:00,  1.70it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  2.90it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 5: Train Loss = 1.2570
Epoch 5: Val Loss = 1.8536, Val Acc = 0.6300
Classification Report:
              precision    recall  f1-score   support

         MEL       0.61      0.56      0.59        39
          NV       0.82      0.81      0.82        99
         BCC       0.52      0.65      0.58        23
          AK       0.20      0.30      0.24        10
         BKL       0.27      0.32      0.29        19
          DF       0.00      0.00      0.00         2
        VASC       0.00      0.00      0.00         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.63       200
   macro avg       0.30      0.33      0.31       200
weighted avg       0.62      0.63      0.62       200

Confusion Matrix:
[[22 10  1  3  3  0  0  0]
 [ 8 80  4  2  4  0  1  0]
 [ 1  1 15  4  2  0  0  0]
 [ 0  0  3  3  4  0  0  0]
 [ 5  4  3  1  6  0  0  0]
 [ 0  1  1  0  0  0  0  0]
 [ 0  1  1  0  2  0  0  0]
 [ 0  0  1  2  1  0  0  0]]

Training comple

([2.0410257506370546,
  1.8953864908218383,
  1.6813093543052673,
  1.5203377771377564,
  1.2569697499275208],
 [2.0491921854019166,
  1.9193555974960328,
  1.9144776964187622,
  1.9415509986877442,
  1.853567886352539],
 'MyDrive/midas/results/baseline_mlp_head_lr-1e-4_scheduler_20250604_072658/best_model.pt')

## Swin (ViT)

Result was best at epoch 4: Train Loss = 1.2938, Val Loss = 1.2147, Val Acc = 0.5950

In [None]:
! touch '/content/drive/MyDrive/midas/models/swin.py'

Swin requires different image transform, so we reload our images accordingly.

In [None]:
from utils.preprocess import get_swin_image_transform

# Paths
base_dir = "/content/drive/My Drive/midas/data"
labels_path = os.path.join(base_dir, "ISIC_2019_Training_GroundTruth.csv")
meta_path = os.path.join(base_dir, "ISIC_2019_Training_Metadata.csv")
img_dir = os.path.join(base_dir, f"sample_{N_SAMPLES}")

# Transform
swin_transform = get_swin_image_transform()

# Load merged metadata + labels
train_df, val_df, index_to_label = load_and_split(labels_path, meta_path, n_samples=N_SAMPLES)
print('class distribution for train: ', train_df['label'].value_counts(normalize=True))

# Dataset + loaders
train_swin_dataset = ISICDataset(img_dir, train_df, swin_transform, index_to_label, use_metadata=False)
val_swin_dataset = ISICDataset(img_dir, val_df, swin_transform, index_to_label, use_metadata=False)
train_swin_loader = DataLoader(train_swin_dataset, batch_size=16, shuffle=True)
val_swin_loader = DataLoader(val_swin_dataset, batch_size=16)



Train split: 800 images, 764 unique lesions
Val split:   200 images, 192 unique lesions
class distribution for train:  label
1    0.48250
0    0.16750
2    0.15875
4    0.11375
3    0.04250
7    0.02000
6    0.00875
5    0.00625
Name: proportion, dtype: float64


In [None]:
from models.swin import SwinClassifier

results_root = Path("/content/drive/MyDrive/midas/results")

# Initialize Swin
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SwinClassifier(
    num_classes=8
)

# Define forward function for Swin model
def swin_forward(model, batch):
    images, labels = batch
    return model(images)

# Setup optimizer and scheduler
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', patience=2, factor=0.5
)

# Train the model
train_model(
    model, train_swin_loader, val_swin_loader, train_df, index_to_label,
    forward_fn=swin_forward,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=5,
    experiment_name="swin_tiny_lr-1e-4_patch4_window7",
    results_root=results_root
)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

CUDA available: True
Current device: 0


Train: 100%|██████████| 50/50 [00:25<00:00,  1.98it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  3.15it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 1: Train Loss = 1.9948
Epoch 1: Val Loss = 1.7667, Val Acc = 0.5250
Classification Report:
              precision    recall  f1-score   support

         MEL       0.42      0.69      0.52        39
          NV       0.78      0.70      0.74        99
         BCC       0.50      0.04      0.08        23
          AK       0.20      0.10      0.13        10
         BKL       0.17      0.11      0.13        19
          DF       0.00      0.00      0.00         2
        VASC       0.17      1.00      0.29         4
         SCC       0.25      0.25      0.25         4

    accuracy                           0.53       200
   macro avg       0.31      0.36      0.27       200
weighted avg       0.56      0.53      0.51       200

Confusion Matrix:
[[27 10  0  1  1  0  0  0]
 [21 69  0  0  4  0  5  0]
 [ 1  4  1  2  4  0 10  1]
 [ 4  0  1  1  1  0  1  2]
 [11  5  0  0  2  0  1  0]
 [ 0  0  0  0  0  0  2  0]
 [ 0  0  0  0  0  0  4  0]
 [ 1  0  0  1  0  0  1  1]]

CUDA available:

Train: 100%|██████████| 50/50 [00:25<00:00,  1.99it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  3.17it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 2: Train Loss = 1.6610
Epoch 2: Val Loss = 1.9208, Val Acc = 0.5100
Classification Report:
              precision    recall  f1-score   support

         MEL       0.33      0.72      0.45        39
          NV       0.81      0.56      0.66        99
         BCC       0.48      0.70      0.57        23
          AK       0.00      0.00      0.00        10
         BKL       0.09      0.05      0.07        19
          DF       0.00      0.00      0.00         2
        VASC       1.00      0.50      0.67         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.51       200
   macro avg       0.34      0.32      0.30       200
weighted avg       0.55      0.51      0.50       200

Confusion Matrix:
[[28  7  3  0  1  0  0  0]
 [40 55  1  0  3  0  0  0]
 [ 2  0 16  0  5  0  0  0]
 [ 1  0  7  0  1  0  0  1]
 [13  4  1  0  1  0  0  0]
 [ 0  1  1  0  0  0  0  0]
 [ 0  1  1  0  0  0  2  0]
 [ 1  0  3  0  0  0  0  0]]

EarlyStopping: 

Train: 100%|██████████| 50/50 [00:23<00:00,  2.09it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  2.67it/s]



Epoch 3: Train Loss = 1.4598
Epoch 3: Val Loss = 1.5921, Val Acc = 0.4300
Classification Report:
              precision    recall  f1-score   support

         MEL       0.50      0.46      0.48        39
          NV       0.89      0.49      0.64        99
         BCC       0.25      0.26      0.26        23
          AK       0.30      0.80      0.43        10
         BKL       1.00      0.05      0.10        19
          DF       0.05      0.50      0.08         2
        VASC       0.67      0.50      0.57         4
         SCC       0.03      0.25      0.06         4

    accuracy                           0.43       200
   macro avg       0.46      0.41      0.33       200
weighted avg       0.69      0.43      0.48       200

Confusion Matrix:
[[18  4  3  3  0  4  0  7]
 [15 49 10  6  0  9  1  9]
 [ 0  0  6  6  0  4  0  7]
 [ 0  0  0  8  0  1  0  1]
 [ 3  2  2  2  1  3  0  6]
 [ 0  0  0  0  0  1  0  1]
 [ 0  0  2  0  0  0  2  0]
 [ 0  0  1  2  0  0  0  1]]

CUDA available:

Train: 100%|██████████| 50/50 [00:24<00:00,  2.02it/s]
Val: 100%|██████████| 13/13 [00:04<00:00,  3.01it/s]



Epoch 4: Train Loss = 1.2938
Epoch 4: Val Loss = 1.2147, Val Acc = 0.5950
Classification Report:
              precision    recall  f1-score   support

         MEL       0.41      0.77      0.54        39
          NV       0.89      0.63      0.73        99
         BCC       0.72      0.57      0.63        23
          AK       0.45      0.50      0.48        10
         BKL       0.24      0.21      0.22        19
          DF       0.50      0.50      0.50         2
        VASC       0.57      1.00      0.73         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.59       200
   macro avg       0.47      0.52      0.48       200
weighted avg       0.66      0.59      0.61       200

Confusion Matrix:
[[30  4  0  2  3  0  0  0]
 [31 62  2  0  2  1  1  0]
 [ 1  0 13  1  5  0  2  1]
 [ 0  0  3  5  2  0  0  0]
 [ 9  4  0  1  4  0  0  1]
 [ 1  0  0  0  0  1  0  0]
 [ 0  0  0  0  0  0  4  0]
 [ 1  0  0  2  1  0  0  0]]

CUDA available:

Train: 100%|██████████| 50/50 [00:24<00:00,  2.08it/s]
Val: 100%|██████████| 13/13 [00:03<00:00,  3.31it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



Epoch 5: Train Loss = 0.9222
Epoch 5: Val Loss = 1.4529, Val Acc = 0.6400
Classification Report:
              precision    recall  f1-score   support

         MEL       0.61      0.44      0.51        39
          NV       0.77      0.80      0.79        99
         BCC       0.50      0.70      0.58        23
          AK       0.36      0.50      0.42        10
         BKL       0.38      0.42      0.40        19
          DF       0.00      0.00      0.00         2
        VASC       1.00      0.75      0.86         4
         SCC       0.00      0.00      0.00         4

    accuracy                           0.64       200
   macro avg       0.45      0.45      0.44       200
weighted avg       0.63      0.64      0.63       200

Confusion Matrix:
[[17 13  1  2  6  0  0  0]
 [ 8 79  7  0  5  0  0  0]
 [ 0  2 16  5  0  0  0  0]
 [ 0  0  4  5  1  0  0  0]
 [ 3  6  1  1  8  0  0  0]
 [ 0  1  1  0  0  0  0  0]
 [ 0  0  1  0  0  0  3  0]
 [ 0  1  1  1  1  0  0  0]]

EarlyStopping: 

([1.994837338924408,
  1.6609626889228821,
  1.459804869890213,
  1.293847280740738,
  0.9221996504068375],
 [1.7667136907577514,
  1.920771417617798,
  1.592076382637024,
  1.214667751789093,
  1.4528960180282593],
 'MyDrive/midas/results/swin_tiny_lr-1e-4_patch4_window7_20250604_061219/best_model.pt')

## Ensemble Swin and our MLPFusion

In [None]:
! touch '/content/drive/MyDrive/midas/utils/eval.py'

In [None]:
from models.mlphead_model import MLPFusionModel
from models.swin import SwinClassifier
from utils.eval import ensemble_predict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load MLPFusionModel ---
meta_dim = 12  # required to match the saved checkpoint
mlp_model = MLPFusionModel(num_classes=8, metadata_dim=meta_dim, hidden_dim=256)
mlp_model.load_state_dict(torch.load(
    "/content/drive/MyDrive/midas/results/mlp_fusion_lr-1e-4_hid-256_20250604_044430/best_model.pt",
    map_location=device
))
mlp_model = mlp_model.to(device).eval()

# --- Load Swin Model ---
swin_model = SwinClassifier(num_classes=8)
swin_model.load_state_dict(torch.load(
    "/content/drive/MyDrive/midas/results/swin_tiny_lr-1e-4_patch4_window7_20250604_061219/best_model.pt",
    map_location=device
))
swin_model = swin_model.to(device).eval()

# --- Ensemble Predict (Requires two separate loaders) ---
# create loaders
transform_swin = get_swin_image_transform() # 224×224, no metadata
transform_mlp = get_effnet_image_transform() # 300×300, with metadata
val_dataset_swin = ISICDataset(img_dir, val_df, transform_swin, index_to_label, use_metadata=False)
val_dataset_mlp = ISICDataset(img_dir, val_df, transform_mlp, index_to_label, use_metadata=True)
val_loader_swin = DataLoader(val_dataset_swin, batch_size=16, shuffle=False)
val_loader_mlp = DataLoader(val_dataset_mlp, batch_size=16, shuffle=False)
# enseble models
models = [mlp_model, swin_model]
dataloaders = [val_loader_mlp, val_loader_swin]
final_preds, true_labels, acc, bma = ensemble_predict(
    models, dataloaders, device, weights=[0.3, 0.7]
)
print(f"✅ Ensemble Accuracy: {acc:.4f}")
print(f"📊 Balanced Multiclass Accuracy (BMA): {bma:.4f}")


✅ Ensemble Accuracy: 0.6250
📊 Balanced Multiclass Accuracy (BMA): 0.5119


Try ensembling MLPBaseline with Swin

In [None]:
from models.baseline_w_mlp import BaselineMLPHead
from models.swin import SwinClassifier
from utils.eval import ensemble_predict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load MLPFusionModel ---
meta_dim = 12  # required to match the saved checkpoint
mlp_model = BaselineMLPHead (num_classes=8, hidden_dim=256)
mlp_model.load_state_dict(torch.load(
    "/content/drive/MyDrive/midas/results/baseline_mlp_head_lr-1e-4_scheduler_20250604_072658/best_model.pt",
    map_location=device
))
mlp_model = mlp_model.to(device).eval()

# --- Load Swin Model ---
swin_model = SwinClassifier(num_classes=8)
swin_model.load_state_dict(torch.load(
    "/content/drive/MyDrive/midas/results/swin_tiny_lr-1e-4_patch4_window7_20250604_061219/best_model.pt",
    map_location=device
))
swin_model = swin_model.to(device).eval()

# --- Ensemble Predict (Requires two separate loaders) ---
# create loaders
transform_swin = get_swin_image_transform() # 224×224, no metadata
transform_mlp = get_effnet_image_transform() # 300×300, with metadata
val_dataset_swin = ISICDataset(img_dir, val_df, transform_swin, index_to_label, use_metadata=False)
val_dataset_mlp = ISICDataset(img_dir, val_df, transform_mlp, index_to_label, use_metadata=False)
val_loader_swin = DataLoader(val_dataset_swin, batch_size=16, shuffle=False)
val_loader_mlp = DataLoader(val_dataset_mlp, batch_size=16, shuffle=False)
# enseble models
models = [mlp_model, swin_model]
dataloaders = [val_loader_mlp, val_loader_swin]
final_preds, true_labels, acc, bma = ensemble_predict(
    models, dataloaders, device, weights=[0.3, 0.7]
)
print(f"✅ Ensemble Accuracy: {acc:.4f}")
print(f"📊 Balanced Multiclass Accuracy (BMA): {bma:.4f}")


✅ Ensemble Accuracy: 0.6300
📊 Balanced Multiclass Accuracy (BMA): 0.5479


## BMA accuracy
Retrospectively obtaining BMA accuracy, because this was the metric by which ISIC 2019 model were ranked.

Results:
1. Baseline: accuracy = 0.645, BMA = 0.464
2. Fusion: accuracy = 0.655, BMA = 0.437
3. MLP-Fusion: accuracy = 0.595, BMA = 0.262
4. Swin: accuracy = 0.595, BMA = 0.5214046378234021
5. Ensemble (MLP-Fusion, Swin): accuracy = 0.6250, BMA = 0.5119
6. MLP-Baseline: accuracy = 0.63, BMA = 0.330

This indicates that Swin better attends to poorly represented classes compared to other models (evidenced by BMA).

In [None]:
from utils.eval import get_model_predictions
from models.baseline_model import BaselineModel
from models.fusion_model import FusionModel
from models.mlphead_model import MLPFusionModel
from models.swin import SwinClassifier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
baseline_model = BaselineModel(num_classes=8)
baseline_model.load_state_dict(torch.load("/content/drive/MyDrive/midas/results/baseline_lr-1e-4_lr-scheduler_20250604_040009/best_model.pt"))
baseline_model = baseline_model.to(device).eval()

preds_base, labels_base = get_model_predictions(baseline_model, val_loader_baseline, device, is_fusion=False)
print("Baseline Accuracy:", accuracy_score(labels_base, preds_base))
print("Baseline BMA:", recall_score(labels_base, preds_base, average='macro'))


Baseline Accuracy: 0.645
Baseline BMA: 0.46439346154735167


In [None]:
meta_dim = 12
fusion_model = FusionModel(num_metadata_features=meta_dim, num_classes=8)
fusion_model.load_state_dict(torch.load("/content/drive/MyDrive/midas/results/fusion_lr-1e-4_lr-scheduler_20250604_005551/best_model.pt"))
fusion_model = fusion_model.to(device).eval()

preds_fusion, labels_fusion = get_model_predictions(fusion_model, val_loader_fusion, device, is_fusion=True)
print("Fusion Accuracy:", accuracy_score(labels_fusion, preds_fusion))
print("Fusion BMA:", recall_score(labels_fusion, preds_fusion, average='macro'))


Fusion Accuracy: 0.655
Fusion BMA: 0.4378181235875744


In [None]:
mlp_model = MLPFusionModel(num_classes=8, metadata_dim=12, hidden_dim=256)
mlp_model.load_state_dict(torch.load("/content/drive/MyDrive/midas/results/mlp_fusion_lr-1e-4_hid-256_20250604_044430/best_model.pt"))
mlp_model = mlp_model.to(device).eval()

preds_mlp, labels_mlp = get_model_predictions(mlp_model, val_loader_fusion, device, is_fusion=True)
print("MLP-Fusion Accuracy:", accuracy_score(labels_mlp, preds_mlp))
print("MLP-Fusion BMA:", recall_score(labels_mlp, preds_mlp, average='macro'))

MLP-Fusion Accuracy: 0.595
MLP-Fusion BMA: 0.26231755150519454


In [None]:
swin_model = SwinClassifier(num_classes=8)
swin_model.load_state_dict(torch.load("/content/drive/MyDrive/midas/results/swin_tiny_lr-1e-4_patch4_window7_20250604_061219/best_model.pt"))
swin_model = swin_model.to(device).eval()

preds_swin, labels_swin = get_model_predictions(swin_model, val_loader_swin, device, is_fusion=False)
print("Swin Accuracy:", accuracy_score(labels_swin, preds_swin))
print("Swin BMA:", recall_score(labels_swin, preds_swin, average='macro'))

Swin Accuracy: 0.595
Swin BMA: 0.5214046378234021


In [None]:
base_mlp_model = BaselineMLPHead(num_classes=8, hidden_dim=256)
base_mlp_model.load_state_dict(torch.load("/content/drive/MyDrive/midas/results/baseline_mlp_head_lr-1e-4_scheduler_20250604_072658/best_model.pt"))
base_mlp_model = base_mlp_model.to(device).eval()

preds_base_mlp, labels_base_mlp = get_model_predictions(base_mlp_model, val_loader_baseline, device, is_fusion=False)
print("MLP-Fusion Accuracy:", accuracy_score(labels_base_mlp, preds_base_mlp))
print("MLP-Fusion BMA:", recall_score(labels_base_mlp, preds_base_mlp, average='macro'))

MLP-Fusion Accuracy: 0.63
MLP-Fusion BMA: 0.33001834486388265


## Conclusion

Based on all the explorations in this notebook of modifying baseline model, we will move forward with Ensemble (`Swin + MLP-Baseline`). High accuracy implies that model is accurate for highly represented classes. High BMA implies that model is performing well on less represented classes. `Swin + MLP-Baseline` has the best balance of high accuracy and high BMA. High BMA is the more important metric in clinical setting.


#### 🔬 Model Performance Summary (on ISIC dev set, n = 1000)

| Model                          | Accuracy | BMA     | Notes                                                  |
|-------------------------------|----------|---------|--------------------------------------------------------|
| Baseline (EffNet-B3)          | 0.645    | 0.464   | Standard classifier, no metadata                       |
| Fusion                        | 0.655    | 0.437   | Baseline + metadata (age, sex, site)                   |
| MLP-Fusion                    | 0.595    | 0.262   | Fusion + MLP head (non-linear layers)                  |
| MLP-Baseline                  | 0.630    | 0.330   | Baseline + MLP head (no metadata)                      |
| Swin Transformer              | 0.595    | 0.521   | Vision Transformer, no metadata                        |
| Ensemble (Swin + MLP-Fusion) | 0.625    | 0.512   | Weighted ensemble, models likely too correlated        |
| **Ensemble (Swin + MLP-Baseline)** | **0.630**  | **0.548** | ✅ Best overall tradeoff between accuracy and BMA       |
