# import

In [4]:
import os
import os.path
import random

import numpy as np
import pandas as pd
from typing import Dict, List
import torch
from torch.utils.data import Dataset
import tensorflow as tf

In [2]:
def load_embedding(embedding_path):
    raw_dataset = tf.data.TFRecordDataset([embedding_path])
    for raw_record in raw_dataset.take(1):
        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
        embedding_feature = example.features.feature['embedding']
        embedding_values = embedding_feature.float_list.value
    return torch.tensor(embedding_values)

In [5]:
class MIMIC_Embed_Dataset(Dataset):

    pathologies = [
        "Enlarged Cardiomediastinum",
        "Cardiomegaly",
        "Lung Opacity",
        "Lung Lesion",
        "Edema",
        "Consolidation",
        "Pneumonia",
        "Atelectasis",
        "Pneumothorax",
        "Pleural Effusion",
        "Pleural Other",
        "Fracture",
        "Support Devices",
    ]

    split_ratio = [0.8, 0.1, 0.1]

    def __init__(
        self,
        embedpath,
        csvpath,
        metacsvpath,
        views=["PA"],
        data_aug=None,
        seed=0,
        unique_patients=True,
        mode=["train", "valid", "test"][0],
    ):

        super().__init__()
        np.random.seed(seed) 

        self.pathologies = sorted(self.pathologies)

        self.mode = mode
        self.embedpath = embedpath
        self.data_aug = data_aug
        self.csvpath = csvpath
        self.csv = pd.read_csv(self.csvpath)
        self.metacsvpath = metacsvpath
        self.metacsv = pd.read_csv(self.metacsvpath)

        self.csv = self.csv.set_index(["subject_id", "study_id"])
        self.metacsv = self.metacsv.set_index(["subject_id", "study_id"])

        self.csv = self.csv.join(self.metacsv).reset_index()

        # Keep only the desired view
        self.csv["view"] = self.csv["ViewPosition"]
        self.limit_to_selected_views(views)

        if unique_patients:
            self.csv = self.csv.groupby("subject_id").first().reset_index()

        n_row = self.csv.shape[0]

        # spit data to one of train valid test
        if self.mode == "train":
            self.csv = self.csv[: int(n_row * self.split_ratio[0])]
        elif self.mode == "valid":
            self.csv = self.csv[
                int(n_row * self.split_ratio[0]) : int(
                    n_row * (self.split_ratio[0] + self.split_ratio[1])
                )
            ]
        elif self.mode == "test":
            self.csv = self.csv[-int(n_row * self.split_ratio[-1]) :]
        else:
            raise ValueError(
                f"attr:mode has to be one of [train, valid, test] but your input is {self.mode}"
            )

        # Get our classes.
        healthy = self.csv["No Finding"] == 1
        labels = []
        for pathology in self.pathologies:
            if pathology in self.csv.columns:
                self.csv.loc[healthy, pathology] = 0
                mask = self.csv[pathology]

            labels.append(mask.values)
        self.labels = np.asarray(labels).T
        self.labels = self.labels.astype(np.float32)

        # Make all the -1 values into nans to keep things simple
        self.labels[self.labels == -1] = np.nan

        # Rename pathologies
        self.pathologies = list(
            np.char.replace(self.pathologies, "Pleural Effusion", "Effusion")
        )

        # add consistent csv values

        # offset_day_int
        self.csv["offset_day_int"] = self.csv["StudyDate"]

        # patientid
        self.csv["patientid"] = self.csv["subject_id"].astype(str)

    def string(self):
        return self.__class__.__name__ + " num_samples={} views={}".format(
            len(self), self.views,
        )

    def limit_to_selected_views(self, views):
        """This function is called by subclasses to filter the
        images by view based on the values in .csv['view']
        """
        if type(views) is not list:
            views = [views]
        if '*' in views:
            # if you have the wildcard, the rest are irrelevant
            views = ["*"]
        self.views = views

        # missing data is unknown
        self.csv.view.fillna("UNKNOWN", inplace=True)

        if "*" not in views:
            self.csv = self.csv[self.csv["view"].isin(self.views)]  # Select the view

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        sample = {}
        sample["idx"] = idx
        sample["lab"] = self.labels[idx]

        subjectid = str(self.csv.iloc[idx]["subject_id"])
        studyid = str(self.csv.iloc[idx]["study_id"])
        dicom_id = str(self.csv.iloc[idx]["dicom_id"])


        #data_aug
        embed_file = os.path.join(
            self.embedpath,
            "p" + subjectid[:2],
            "p" + subjectid,
            "s" + studyid,
            dicom_id + ".tfrecord",
        )
        sample["embedding"] = load_embedding(embed_file)
        #sample["embedding"] = embed_file

        return sample

In [4]:
embedpath = "/d/hd04/armstrong/MIMIC/data/generalized-image-embeddings-for-the-mimic-chest-x-ray-dataset-1.0/files"
csvpath = "/d/hd04/armstrong/MIMIC/data/mimic-cxr-2.0.0-chexpert.csv"
metacsvpath = "/d/hd04/armstrong/MIMIC/data/mimic-cxr-2.0.0-metadata.csv"

dataset = MIMIC_Embed_Dataset(embedpath,csvpath,metacsvpath,mode = "train")

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  self.csv.view.fillna("UNKNOWN", inplace=True)


In [5]:
sample = dataset[1000]
sample

2025-04-16 11:16:48.742988: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2256] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


{'idx': 1000,
 'lab': array([nan,  1., nan, nan, nan,  1., nan,  1., nan, nan, nan,  1.,  0.],
       dtype=float32),
 'embedding': tensor([-0.6009, -2.6448,  1.1589,  ...,  0.3574,  1.4271, -2.1083])}

# baseline

In [None]:
from torch.utils.data import DataLoader

N = 36000
subset = [dataset[i] for i in range(N)]
train_len = int(N * 0.8)
train_subset = subset[:train_len]
val_subset = subset[train_len:]

def collate_fn(batch):
    x = torch.stack([sample['embedding'] for sample in batch])
    y = torch.stack([torch.tensor(sample['lab']) for sample in batch])
    return {"x": x, "y": y}

train_loader = DataLoader(train_subset, batch_size=128, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_subset, batch_size=128, shuffle=False, collate_fn=collate_fn)

test_subset = [dataset[i] for i in range(N, N + 500)]
test_loader = DataLoader(test_subset, batch_size=128, shuffle=False, collate_fn=collate_fn)


In [9]:
from torch.utils.data import Subset, DataLoader

N = 36000
train_indices = list(range(0, int(N * 0.8)))
val_indices = list(range(int(N * 0.8), N))
test_indices = list(range(N, N + 500))

train_subset = Subset(dataset, train_indices)
val_subset = Subset(dataset, val_indices)
test_subset = Subset(dataset, test_indices)

train_loader = DataLoader(train_subset, batch_size=128, shuffle=True,
                          collate_fn=collate_fn, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_subset, batch_size=128, shuffle=False,
                        collate_fn=collate_fn, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_subset, batch_size=128, shuffle=False,
                         collate_fn=collate_fn, num_workers=4, pin_memory=True)


In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, accuracy_score
import pandas as pd
from torch.utils.tensorboard import SummaryWriter
import os
from collections import OrderedDict

# MLP
class FlexibleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, norm='batch', dropout=0.3):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for h_dim in hidden_dims:
            layers.append(nn.Linear(prev_dim, h_dim))
            if norm == 'batch':
                layers.append(nn.BatchNorm1d(h_dim))
            elif norm == 'layer':
                layers.append(nn.LayerNorm(h_dim))
            layers.append(nn.ReLU())
            if dropout > 0:
                layers.append(nn.Dropout(dropout))
            prev_dim = h_dim
        layers.append(nn.Linear(prev_dim, output_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

# Transformer
class TransformerMultiModalClassifier(nn.Module):
    def __init__(self, image_input_dim=1376, patch_size=8, embed_dim=256,
                 num_heads=4, num_layers=2, output_dim=13, dropout=0.1,
                 meta_dim=0, use_positional_encoding=True,
                 norm='batch'):
        super().__init__()
        assert image_input_dim % patch_size == 0, "input_dim must be divisible by patch_size"
        self.seq_len = image_input_dim // patch_size
        self.use_positional_encoding = use_positional_encoding

        self.patch_embed = nn.Linear(patch_size, embed_dim)
        if use_positional_encoding:
            self.pos_embedding = nn.Parameter(torch.randn(1, self.seq_len, embed_dim))

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads,
            dim_feedforward=embed_dim * 2, dropout=dropout, batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        total_input_dim = embed_dim * self.seq_len + meta_dim
        self.head = nn.Sequential(
            nn.Linear(total_input_dim, 128),
            nn.BatchNorm1d(128) if norm == 'batch' else nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, output_dim)
        )

    def forward(self, x_img, x_meta=None):
        B = x_img.size(0)
        x = x_img.view(B, self.seq_len, -1)
        x = self.patch_embed(x)
        if self.use_positional_encoding:
            x = x + self.pos_embedding
        x = self.transformer(x)
        x_flat = x.flatten(1)
        if x_meta is not None:
            x = torch.cat([x_flat, x_meta], dim=1)
        else:
            x = x_flat
        return self.head(x)

# ResMLP
class ResidualMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=4, dropout=0.2):
        super().__init__()
        self.fc_in = nn.Linear(input_dim, hidden_dim)
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.LayerNorm(hidden_dim),
                nn.ReLU(),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, hidden_dim)
            ) for _ in range(num_layers)
        ])
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.fc_in(x)
        for layer in self.layers:
            x = x + layer(x)
        return self.fc_out(x)

# TransformerMLP 
class TransformerMLPClassifier(nn.Module):
    def __init__(self, input_dim=1376, output_dim=13, hidden_dim=512, dropout=0.1, d_token=64, n_layers=2, n_heads=4):
        super().__init__()
        self.linear_in = nn.Linear(input_dim, d_token)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_token, nhead=n_heads, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.classifier = nn.Sequential(
            nn.LayerNorm(d_token),
            nn.Linear(d_token, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.linear_in(x).unsqueeze(1)
        x = self.transformer(x).squeeze(1)
        return self.classifier(x)


# ViT-like
class ViTLikeClassifier(nn.Module):
    def __init__(self, input_dim=1376, output_dim=13, patch_size=32, hidden_dim=512, dropout=0.1):
        super().__init__()
        assert input_dim % patch_size == 0, "input_dim must be divisible by patch_size"
        self.n_patches = input_dim // patch_size
        self.patch_embed = nn.Linear(patch_size, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=4, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        B = x.shape[0]
        x = x.view(B, self.n_patches, -1)  # (B, num_patches, patch_size)
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = self.transformer(x)
        return self.classifier(x[:, 0])


# FTTransformer
class FTTransformerClassifier(nn.Module):
    def __init__(self, input_dim=1376, output_dim=13, hidden_dim=512, dropout=0.1):
        super().__init__()
        self.embed = nn.Linear(input_dim, hidden_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_dim))
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        B = x.shape[0]
        x = self.embed(x).unsqueeze(1)  # (B, 1, hidden_dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)
        x = self.transformer(x)
        return self.classifier(x[:, 0])

# Swin-MLP
class SwinMLPClassifier(nn.Module):
    def __init__(self, input_dim=1376, output_dim=13, hidden_dim=512, dropout=0.1):
        super().__init__()
        self.proj = nn.Linear(input_dim, hidden_dim)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.norm = nn.LayerNorm(hidden_dim)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.proj(x)
        x = x + self.mlp(self.norm(x))  # Swin-style MLP block with skip connection
        return self.classifier(x)


# SAINT
class SAINTClassifier(nn.Module):
    def __init__(self, input_dim=1376, output_dim=13, hidden_dim=512, dropout=0.1, n_heads=4):
        super().__init__()
        self.embed = nn.Linear(input_dim, hidden_dim)
        self.row_attn = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, dropout=dropout, batch_first=True), num_layers=1)
        self.col_attn = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=n_heads, dropout=dropout, batch_first=True), num_layers=1)
        self.classifier = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, output_dim)
        )

    def forward(self, x):
        x = self.embed(x).unsqueeze(1)  # (B, 1, D)
        x = self.row_attn(x)
        x = self.col_attn(x)
        return self.classifier(x.squeeze(1))


def train_model(model, train_loader, val_loader, pathologies, test_loader=None,
                epochs=10, lr=1e-3, device=None, early_stop_patience=3,
                return_metrics=False, model_name="Model"):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss(reduction='none')
    writer = SummaryWriter(log_dir=f'runs/{model_name}')

    train_losses, val_aucs = [], []
    best_auc, patience, best_model = 0, 0, None
    best_scores = {}

    print(f"\nTraining [{model_name}]...")
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0
        for batch in train_loader:
            x = batch['x'].to(device)
            y = batch['y'].to(device)
            mask = ~torch.isnan(y)
            y_clean = torch.nan_to_num(y, nan=0.0)

            logits = model(x)
            loss_all = loss_fn(logits, y_clean)
            loss = (loss_all * mask).sum() / mask.sum()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        train_losses.append(avg_loss)
        writer.add_scalar("Loss/train", avg_loss, epoch)
        print(f"Epoch {epoch}/{epochs} | Train Loss: {avg_loss:.4f}")

        model.eval()
        y_true, y_prob = [], []
        with torch.no_grad():
            for batch in val_loader:
                x = batch['x'].to(device)
                y = batch['y'].cpu().numpy()
                logits = model(x).cpu()
                probs = torch.sigmoid(logits).numpy()
                y_true.append(y)
                y_prob.append(probs)

        y_true = np.concatenate(y_true, axis=0)
        y_prob = np.concatenate(y_prob, axis=0)

        aucs, f1s, maps, accs = [], [], [], []
        for i, name in enumerate(pathologies):
            mask = ~np.isnan(y_true[:, i])
            if mask.sum() < 10:
                continue
            y_true_i, y_prob_i = y_true[mask, i], y_prob[mask, i]
            y_pred_i = (y_prob_i > 0.5).astype(int)

            aucs.append(roc_auc_score(y_true_i, y_prob_i))
            f1s.append(f1_score(y_true_i, y_pred_i))
            maps.append(average_precision_score(y_true_i, y_prob_i))
            accs.append(accuracy_score(y_true_i, y_pred_i))

        val_auc = np.mean(aucs)
        val_aucs.append(val_auc)
        writer.add_scalar("AUC/val", val_auc, epoch)
        print(f"Validation AUC: {val_auc:.4f}")

        if val_auc > best_auc:
            best_auc = val_auc
            patience = 0
            best_model = model.state_dict()
            best_scores = {
                'AUC': aucs,
                'F1': f1s,
                'mAP': maps,
                'Accuracy': accs
            }
            torch.save(best_model, f"{model_name}_best.pt")
        else:
            patience += 1
            if patience >= early_stop_patience:
                print("Early stopping triggered!")
                break

    writer.close()

    if best_model:
        model.load_state_dict(best_model)

    metrics_df = pd.DataFrame(best_scores, index=pathologies[:len(best_scores['AUC'])]).round(4)

    if test_loader:
        print(f"\nEvaluating [{model_name}] on Test Set...")
        model.eval()
        y_true, y_prob = [], []
        with torch.no_grad():
            for batch in test_loader:
                x = batch['x'].to(device)
                y = batch['y'].cpu().numpy()
                logits = model(x).cpu()
                probs = torch.sigmoid(logits).numpy()
                y_true.append(y)
                y_prob.append(probs)

        y_true = np.concatenate(y_true, axis=0)
        y_prob = np.concatenate(y_prob, axis=0)

        test_aucs, test_f1s, test_maps, test_accs = [], [], [], []
        for i, name in enumerate(pathologies):
            mask = ~np.isnan(y_true[:, i])
            if mask.sum() < 10:
                continue
            y_true_i, y_prob_i = y_true[mask, i], y_prob[mask, i]
            y_pred_i = (y_prob_i > 0.5).astype(int)
            test_aucs.append(roc_auc_score(y_true_i, y_prob_i))
            test_f1s.append(f1_score(y_true_i, y_pred_i))
            test_maps.append(average_precision_score(y_true_i, y_prob_i))
            test_accs.append(accuracy_score(y_true_i, y_pred_i))

        test_df = pd.DataFrame({
            'AUC': test_aucs,
            'F1': test_f1s,
            'MAP': test_maps,
            'Accuracy': test_accs
        }, index=pathologies[:len(test_aucs)]).round(4)
        print(test_df)

    if return_metrics:
        return train_losses, val_aucs, metrics_df
    else:
        display(metrics_df)

In [None]:
def plot_model_comparison(loss_dict, auc_dict):
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    for name, losses in loss_dict.items():
        plt.plot(losses, label=name)
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 2, 2)
    for name, aucs in auc_dict.items():
        plt.plot(aucs, label=name)
    plt.title("Validation AUC")
    plt.xlabel("Epoch")
    plt.ylabel("AUC")
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()

model_constructors = OrderedDict({
    "MLP": lambda hidden_dim, dropout, **kwargs: FlexibleMLP(input_dim=1376, hidden_dims=[hidden_dim]*2, output_dim=13, norm='batch', dropout=dropout),
    "Transformer": lambda hidden_dim, dropout, **kwargs: TransformerMultiModalClassifier(norm='batch', dropout=dropout),
    "ViTLike": lambda hidden_dim, dropout, **kwargs: ViTLikeClassifier(input_dim=1376, output_dim=13, hidden_dim=hidden_dim, dropout=dropout),
    "FTTransformer": lambda hidden_dim, dropout, **kwargs: FTTransformerClassifier(input_dim=1376, output_dim=13, hidden_dim=hidden_dim, dropout=dropout),
    "ResMLP": lambda hidden_dim, dropout: ResidualMLP(input_dim=1376, hidden_dim=hidden_dim, output_dim=13, dropout=dropout),
    "TransformerMLP": lambda hidden_dim, dropout, d_token, n_layers, n_heads: TransformerMLPClassifier(input_dim=1376, output_dim=13, hidden_dim=hidden_dim, dropout=dropout, d_token=d_token, n_layers=n_layers, n_heads=n_heads),
    "SwinMLP": lambda hidden_dim, dropout, **kwargs: SwinMLPClassifier(input_dim=1376, output_dim=13, hidden_dim=hidden_dim, dropout=dropout),
    "SAINT": lambda hidden_dim, dropout, n_heads, **kwargs: SAINTClassifier(input_dim=1376, output_dim=13, hidden_dim=hidden_dim, dropout=dropout, n_heads=n_heads),

})

In [None]:
import optuna
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, accuracy_score
from collections import defaultdict

from optuna.pruners import MedianPruner


def objective(trial, name, constructor, train_loader, val_loader, test_loader, pathologies, epochs=10):
    hidden_dim = trial.suggest_categorical("hidden_dim", [256, 512, 1024])
    dropout = trial.suggest_float("dropout", 0.1, 0.5)

    config = {"hidden_dim": hidden_dim, "dropout": dropout}
    if "TransformerMLP" in name:
        config["d_token"] = trial.suggest_categorical("d_token", [64, 128])
        config["n_layers"] = trial.suggest_int("n_layers", 2, 4)
        config["n_heads"] = trial.suggest_categorical("n_heads", [2, 4, 8])
    elif "SAINT" in name:
        config["n_heads"] = trial.suggest_categorical("n_heads", [2, 4, 8])

    model = constructor(**config)
    model_name = f"{name}_optuna_trial{trial.number}"

    _, auc, _ = train_model(
        model, train_loader, val_loader,
        test_loader=test_loader,
        pathologies=pathologies,
        epochs=epochs,
        return_metrics=True,
        model_name=model_name
    )

    return np.mean(auc)


def run_model_sweep_optuna(model_constructors, train_loader, val_loader, test_loader, pathologies, epochs=10, n_trials=20):
    all_results = []
    loss_dict, auc_dict = defaultdict(list), defaultdict(list)
    best_configs = {}

    for idx, (name, constructor) in enumerate(model_constructors.items(), 1):
        print(f"\n🔍 Optimizing {name}...  (Model {idx}/{len(model_constructors)})")

        study = optuna.create_study(
            direction="maximize",
            pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=5)
        )
        study.optimize(
            lambda trial: objective(trial, name, constructor, train_loader, val_loader, test_loader, pathologies, epochs),
            n_trials=n_trials
        )

        best_config = study.best_params
        best_configs[name] = best_config

        model = constructor(**best_config)
        model_name = f"{name}_optuna"
        loss, auc, df = train_model(
            model,
            train_loader,
            val_loader,
            test_loader=test_loader,
            pathologies=pathologies,
            epochs=epochs,
            return_metrics=True,
            model_name=model_name
        )
        loss_dict[model_name] = loss
        auc_dict[model_name] = auc
        df['model'] = model_name

        all_results.append(df.reset_index())

    summary_df = pd.concat(all_results, axis=0).rename(columns={"index": "class"}).reset_index(drop=True)

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    for name, losses in loss_dict.items():
        plt.plot(losses, label=name)
    plt.title("Training Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.legend()

    plt.subplot(1, 2, 2)
    for name, aucs in auc_dict.items():
        plt.plot(aucs, label=name)
    plt.title("Validation AUC")
    plt.xlabel("Epoch")
    plt.ylabel("AUC")
    plt.grid(True)
    plt.legend()

    plt.tight_layout()
    plt.show()

    avg_df = summary_df.groupby("model")[["AUC", "F1", "mAP", "Accuracy"]].mean().round(4)
    best_model = avg_df["AUC"].idxmax()

    print("\n📊 Mean Metrics Per Model:")
    display(avg_df.style.highlight_max(color='lightgreen', axis=0))
    print(f"\n🏆 Best Model by AUC: {best_model}")

    for name, cfg in best_configs.items():
        print(f"🔧 Best config for {name}: {cfg}")

    return summary_df, loss_dict, auc_dict, best_model


summary_df, loss_dict, auc_dict, best_model = run_model_sweep_optuna(
    model_constructors, train_loader, val_loader, test_loader, dataset.pathologies
)

print("\n✅ Sweep Summary:")
display(summary_df)


In [None]:
! pip install optuna