In [1]:
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)


In [2]:
import logging
import os
import random
from datetime import datetime

import numpy as np
import torch
from torch import optim
from torch.cuda.amp import GradScaler

from open_clip import get_tokenizer
from open_clip.transform import image_transform_v2, PreprocessCfg
import webdataset as wds
from torch.utils.data import DataLoader
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.datasets import Imagenette, ImageFolder, DTD
import torch.nn as nn

from open_clip import create_model_from_pretrained, get_tokenizer, create_model_and_transforms
from src.clipn import CLIPNAdapter, AltCLIPNAdapter

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image, ImageFile
import pandas as pd
from tqdm import tqdm
from sklearn import metrics
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from scipy import interpolate

import medmnist
from medmnist import INFO

# Allows loading of truncated images
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
model, preprocess = create_model_from_pretrained('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
tokenizer = get_tokenizer('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
checkpoint = torch.load("/repo/ALTMEDCLIPN/checkpoints/PMC_model_cosine/latest_checkpoint.pth", weights_only=False)["model_state_dict"]

model = AltCLIPNAdapter(
    model,
    tokenizer,
    "text.transformer.encoder"
)

model.load_state_dict(checkpoint)

<All keys matched successfully>

In [4]:
MEDMNIST_2D_DATASETS = {
    'pathmnist': medmnist.PathMNIST,
    'chestmnist': medmnist.ChestMNIST,
    'dermamnist': medmnist.DermaMNIST,
    'octmnist': medmnist.OCTMNIST,
    'pneumoniamnist': medmnist.PneumoniaMNIST,
    'breastmnist': medmnist.BreastMNIST,
    'bloodmnist': medmnist.BloodMNIST,
    'tissuemnist': medmnist.TissueMNIST,
    'organamnist': medmnist.OrganAMNIST,
    'organcmnist': medmnist.OrganCMNIST,
    'organsmnist': medmnist.OrganSMNIST
}

print("Available MedMNIST 2D datasets:")
for name in MEDMNIST_2D_DATASETS.keys():
    ds_info = INFO[name]
    print(ds_info["python_class"])
    print(ds_info["task"])
    print(ds_info["label"])
    print()
    # print(ds_info)

Available MedMNIST 2D datasets:
PathMNIST
multi-class
{'0': 'adipose', '1': 'background', '2': 'debris', '3': 'lymphocytes', '4': 'mucus', '5': 'smooth muscle', '6': 'normal colon mucosa', '7': 'cancer-associated stroma', '8': 'colorectal adenocarcinoma epithelium'}

ChestMNIST
multi-label, binary-class
{'0': 'atelectasis', '1': 'cardiomegaly', '2': 'effusion', '3': 'infiltration', '4': 'mass', '5': 'nodule', '6': 'pneumonia', '7': 'pneumothorax', '8': 'consolidation', '9': 'edema', '10': 'emphysema', '11': 'fibrosis', '12': 'pleural', '13': 'hernia'}

DermaMNIST
multi-class
{'0': 'actinic keratoses and intraepithelial carcinoma', '1': 'basal cell carcinoma', '2': 'benign keratosis-like lesions', '3': 'dermatofibroma', '4': 'melanoma', '5': 'melanocytic nevi', '6': 'vascular lesions'}

OCTMNIST
multi-class
{'0': 'choroidal neovascularization', '1': 'diabetic macular edema', '2': 'drusen', '3': 'normal'}

PneumoniaMNIST
binary-class
{'0': 'normal', '1': 'pneumonia'}

BreastMNIST
binary-

In [5]:
class MedMNISTDataset(Dataset):
    def __init__(self, dataset_name, split='test', transform=None, class_subset=None):
        """
        Args:
            dataset_name: Name of the MedMNIST dataset
            split: 'train', 'val', or 'test'
            transform: Transform to apply to images
            class_subset: List of class indices to include (for novel class detection)
        """
        self.transform = transform
        self.dataset_name = dataset_name
        
        # Load the dataset
        dataset_class = MEDMNIST_2D_DATASETS[dataset_name]
        dataset = dataset_class(split=split, size=224)
        
        self.images = dataset.imgs
        self.labels = dataset.labels.squeeze()
        
        # Get dataset info
        self.info = INFO[dataset_name]
        self.task_type = self.info['task'] if 'multi-label' not in self.info['task'] else "multi-label"
        
        # Get class names from dataset info
        self.class_to_idx = {name: int(idx) for idx, name in self.info['label'].items()}
        self.idx_to_class = {int(idx): name for idx, name in self.info['label'].items()}
        self.class_names = [self.idx_to_class[idx] for idx in range(len(self.idx_to_class))]

        
        # Handle different label formats based on task type
        if self.task_type == 'multi-label':
            # For multi-label (like ChestMNIST), labels are binary vectors
            # Convert to proper format and handle class filtering differently
            if class_subset is not None:
                # For multi-label OOD, we need to select samples that have 
                # at least one positive label in the subset
                mask = np.any(self.labels[:, class_subset] == 1, axis=1)
                self.images = self.images[mask]
                self.labels = self.labels[mask]
                # Keep only the selected classes in labels
                self.labels = self.labels[:, class_subset]
                self.class_names = [self.class_names[i] for i in class_subset]
                self.n_classes = len(class_subset)
            else:
                self.n_classes = len(self.labels)
        else:
            # For single-label tasks (multi-class, binary, ordinal)
            if class_subset is not None:
                mask = np.isin(self.labels, class_subset)
                self.images = self.images[mask]
                self.labels = self.labels[mask]
                self.class_names = [self.class_names[i] for i in class_subset]
                self.n_classes = len(class_subset)
            else:
                self.n_classes = len(self.labels)
        
        # Create classes attribute with actual class names for zero-shot
        self.classes = self.class_names
        
        print(f"Loaded {dataset_name} {split} split:")
        print(f"  Task type: {self.task_type}")
        print(f"  Samples: {len(self.images)}")
        print(f"  Classes: {self.n_classes}")
        print(f"  Class names: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Convert numpy array to PIL Image
        img_array = self.images[idx]
        # Handle different channel configurations
        if img_array.shape[-1] == 1 or len(img_array.shape) == 2:  # Grayscale
            img_array = np.repeat(img_array, 3, axis=-1)  # Convert to RGB
        elif img_array.shape[-1] == 3:  # Already RGB
            pass
        else:
            raise ValueError(f"Unexpected number of channels: {img_array.shape[-1]}")
            
        # Convert to PIL Image
        image = Image.fromarray(img_array.astype(np.uint8))
        
        if self.transform:
            image = self.transform(image)

        label = self.labels[idx]
        
        return image, label
    
    # Helper methods for zero-shot evaluation
    def get_class_name(self, idx):
        """Get class name from index"""
        if self.task_type == 'multi-label':
            # For multi-label, return list of active class names
            if isinstance(idx, (list, np.ndarray)):
                return [self.class_names[i] for i, active in enumerate(idx) if active == 1]
            else:
                return [self.class_names[i] for i in range(len(idx)) if idx[i] == 1]
        else:
            return self.class_names[idx] if idx < len(self.class_names) else f"class_{idx}"
    
    def is_multi_label(self):
        """Check if this is a multi-label dataset"""
        return self.task_type == 'multi-label'

In [6]:
def merge_yes_no_feature(dataset, model, tokenizer, device):
    txt = []
    N = len(dataset.classes)
    model.to(device)
    model.eval()
    if N:
        with open("/repo/MEDCLIPN/src/prompt/empty_prompt.txt") as f:
            prompt_lis = f.readlines()
        num_prom = len(prompt_lis)
    for idx in range(num_prom):
        for name in dataset.classes:
            txt.append(tokenizer(prompt_lis[idx].replace("\n", "").format(name)).unsqueeze(0))
    txt = torch.cat(txt, dim=0)
    txt = txt.reshape(num_prom, len(dataset.classes), -1)
    text_inputs = txt.to(device)
    
    text_yes_ttl = torch.zeros(len(dataset.classes), 512).to(device)
    text_no_ttl = torch.zeros(len(dataset.classes), 512).to(device)
    
    with torch.no_grad():
        for i in range(num_prom):
            text_yes_i = model.encode_text(text_inputs[i], normalize=True)
            text_no_i = model.encode_text_no(text_inputs[i])
            text_no_i = F.normalize(text_no_i, dim=-1)
            
            text_yes_ttl += text_yes_i
            text_no_ttl += text_no_i
            
    return F.normalize(text_yes_ttl, dim=-1), F.normalize(text_no_ttl, dim=-1)

class ViT_Classifier(torch.nn.Module):
    def __init__(self, image_encoder, classification_head_yes, classification_head_no):
        super().__init__()
        self.image_encoder = image_encoder
        flag = True
        self.fc_yes = nn.Parameter(classification_head_yes, requires_grad=flag)    # num_classes  num_feat_dimension
        self.fc_no = nn.Parameter(classification_head_no, requires_grad=flag)      # num_classes  num_feat_dimension
        self.scale = 100. # this is from the parameter of logit scale in CLIPN
        
    def set_frozen(self, module):
        for module_name in module.named_parameters():
            module_name[1].requires_grad = False
    def set_learnable(self, module):
        for module_name in module.named_parameters():
            module_name[1].requires_grad = True
            
    def forward(self, x):
        inputs = self.image_encoder(x)
        inputs_norm = F.normalize(inputs, dim=-1)
        fc_yes = F.normalize(self.fc_yes, dim=-1)
        fc_no = F.normalize(self.fc_no, dim=-1)
        
        logits_yes = self.scale * inputs_norm @ fc_yes.T 
        logits_no = self.scale * inputs_norm @ fc_no.T
        return logits_yes, logits_no, inputs

In [7]:
to_np = lambda x: x.detach().cpu().numpy()

def max_logit_score(logits):
    """
    MaxLogit: Uses the maximum logit value as confidence score
    Higher values indicate higher confidence (in-distribution)
    """
    return to_np(torch.max(logits, -1)[0])

def msp_score(logits):
    """
    Maximum Softmax Probability (MSP): Uses maximum predicted probability
    Higher values indicate higher confidence (in-distribution)
    """
    prob = torch.softmax(logits, -1)
    return to_np(torch.max(prob, -1)[0])

def energy_score(logits):
    """
    Energy Score: Uses log-sum-exp of logits as energy
    Higher values indicate higher confidence (in-distribution)
    """
    return to_np(torch.logsumexp(logits, -1))

def ctw_score(logits, logits_no):
    """
    Class-wise Temperature Weighting (CTW) - CLIPN specific method
    Uses the yes/no probability for the predicted class
    Higher values indicate higher confidence (in-distribution)
    """
    idex = torch.argmax(logits, -1).unsqueeze(-1)
    yesno = torch.softmax(torch.cat([logits.unsqueeze(-1), logits_no.unsqueeze(-1)], -1), dim=-1)[:, :, 0]
    yesno_s = torch.gather(yesno, dim=1, index=idex)
    return to_np(yesno_s)

def atd_score(logits, logits_no):
    """
    Attention-based Temperature Distribution (ATD) - CLIPN specific method
    Weighted average of yes/no probabilities using softmax weights
    Higher values indicate higher confidence (in-distribution)
    """
    yesno = torch.softmax(torch.cat([logits.unsqueeze(-1), logits_no.unsqueeze(-1)], -1), dim=-1)[:, :, 0]
    return to_np((yesno * torch.softmax(logits, -1)).sum(1))

In [8]:
def maybe_dictionarize(batch):
    if isinstance(batch, dict):
        return batch

    if len(batch) == 2:
        batch = {'images': batch[0], 'labels': batch[1]}
    elif len(batch) == 3:
        batch = {'images': batch[0], 'labels': batch[1], 'metadata': batch[2]}
    else:
        raise ValueError(f'Unexpected number of elements: {len(batch)}')

    return batch

def cal_auc_fpr(ind_conf, ood_conf):
    conf = np.concatenate((ind_conf, ood_conf))
    ind_indicator = np.concatenate((np.ones_like(ind_conf), np.zeros_like(ood_conf)))
    auroc = metrics.roc_auc_score(ind_indicator, conf)
    fpr,tpr,thresh = roc_curve(ind_indicator, conf, pos_label=1)
    fpr = float(interpolate.interp1d(tpr, fpr)(0.95))
    return auroc, fpr

def evaluate_classification_accuracy(dataset_loader, model, task_type):
    """
    Evaluate classification accuracy on original dataset
    """
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    all_logits = []
    all_labels = []
    
    with torch.no_grad():
        for batch in tqdm(dataset_loader, desc="Evaluating classification"):
            batch = maybe_dictionarize(batch)
            inputs = batch["images"].to(device)
            labels = batch['labels'].to(device)
            
            logits, _, _ = model(inputs)
            
            all_logits.append(logits.detach().cpu())
            all_labels.append(labels.detach().cpu())
    
    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)
    
    if task_type == 'multi-label':
        # Multi-label: use micro-averaged AUC
        probs = torch.sigmoid(all_logits)
        try:
            auc_micro = roc_auc_score(to_np(all_labels), to_np(probs), average='micro')
            auc_macro = roc_auc_score(to_np(all_labels), to_np(probs), average='macro')
            return {'micro_auc': auc_micro, 'macro_auc': auc_macro}
        except ValueError:
            # Handle case where some classes are missing
            return {'micro_auc': 0.0, 'macro_auc': 0.0}
    else:
        # Single-label: use accuracy
        preds = torch.argmax(all_logits, dim=1)
        accuracy = accuracy_score(to_np(all_labels), to_np(preds))
        return {'accuracy': accuracy}

def evaluate_ood_detection(id_loader, ood_loaders, model, methods=['MSP', 'MaxLogit', 'Energy', 'CTW', 'ATD']):
    """
    Evaluate OOD detection performance
    """
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Collect ID scores
    id_scores = {method: [] for method in methods}
    
    with torch.no_grad():
        for batch in tqdm(id_loader, desc="Collecting ID scores"):
            batch = maybe_dictionarize(batch)
            inputs = batch["images"].to(device)
            
            logits, logits_no, _ = model(inputs)
            
            if 'MaxLogit' in methods:
                id_scores['MaxLogit'].extend(max_logit_score(logits))
            if 'MSP' in methods:
                id_scores['MSP'].extend(msp_score(logits))
            if 'Energy' in methods:
                id_scores['Energy'].extend(energy_score(logits))
            if 'CTW' in methods:
                id_scores['CTW'].extend(ctw_score(logits, logits_no))
            if 'ATD' in methods:
                id_scores['ATD'].extend(atd_score(logits, logits_no))
    
    # Collect OOD scores and calculate metrics
    results = []
    
    for ood_name, ood_loader in ood_loaders.items():
        ood_scores = {method: [] for method in methods}
        
        with torch.no_grad():
            for batch in tqdm(ood_loader, desc=f"Collecting OOD scores: {ood_name}"):
                batch = maybe_dictionarize(batch)
                inputs = batch["images"].to(device)
                
                logits, logits_no, _ = model(inputs)
                
                if 'MaxLogit' in methods:
                    ood_scores['MaxLogit'].extend(max_logit_score(logits))
                if 'MSP' in methods:
                    ood_scores['MSP'].extend(msp_score(logits))
                if 'Energy' in methods:
                    ood_scores['Energy'].extend(energy_score(logits))
                if 'CTW' in methods:
                    ood_scores['CTW'].extend(ctw_score(logits, logits_no))
                if 'ATD' in methods:
                    ood_scores['ATD'].extend(atd_score(logits, logits_no))
        
        # Calculate metrics for each method
        for method in methods:
            try:
                auc, fpr95 = cal_auc_fpr(id_scores[method], ood_scores[method])
                results.append({
                    'method': method,
                    'ood_dataset': ood_name,
                    'auroc': auc,
                    'fpr95': fpr95
                })
            except Exception as e:
                print(f"Error calculating metrics for {method} on {ood_name}: {e}")
                results.append({
                    'method': method,
                    'ood_dataset': ood_name,
                    'auroc': 0.0,
                    'fpr95': 1.0
                })
    
    return results

def create_semantic_split(dataset_name, held_out_ratio=0.3):
    """
    Create ID/OOD splits for semantic shift detection
    """
    info = INFO[dataset_name]
    total_classes = len(info["label"])
    
    if total_classes < 3:
        return None, None
    
    all_classes = list(range(total_classes))
    np.random.seed(42)  # For reproducibility
    
    n_held_out = max(1, int(total_classes * held_out_ratio))
    held_out_classes = np.random.choice(all_classes, n_held_out, replace=False)
    id_classes = [c for c in all_classes if c not in held_out_classes]
    
    print(f"\n{dataset_name} Semantic Split:")
    print(f"  ID classes: {id_classes} ({len(id_classes)} classes)")
    print(f"  OOD classes: {held_out_classes.tolist()} ({len(held_out_classes)} classes)")
    
    return id_classes, held_out_classes.tolist()

def run_comprehensive_evaluation():
    """
    Run all three evaluation scenarios
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Remove retinamnist as requested
    eval_datasets = [name for name in MEDMNIST_2D_DATASETS.keys() if name != 'retinamnist']
    
    print("="*80)
    print("COMPREHENSIVE MEDMNIST CLIPN EVALUATION")
    print("="*80)
    
    # ===== SCENARIO 1: ORIGINAL DATASET CLASSIFICATION =====
    print("\n" + "="*60)
    print("SCENARIO 1: ORIGINAL DATASET CLASSIFICATION ACCURACY")
    print("="*60)
    
    classification_results = []
    
    for dataset_name in eval_datasets:
        print(f"\n--- Evaluating {dataset_name.upper()} ---")
        
        try:
            # Load dataset
            dataset = MedMNISTDataset(dataset_name, split='test', transform=preprocess)
            
            # Create classifier with dataset's classes
            text_yes, text_no = merge_yes_no_feature(dataset, model, tokenizer, device)
            classifier = ViT_Classifier(model.visual, text_yes, text_no)
            
            # Create data loader
            loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=1, pin_memory=True)
            
            # Evaluate classification
            results = evaluate_classification_accuracy(loader, classifier, dataset.task_type)
            results['dataset'] = dataset_name
            results['task_type'] = dataset.task_type
            results['n_classes'] = len(dataset.classes)
            results['n_samples'] = len(dataset)
            
            classification_results.append(results)
            
            print(f"Results for {dataset_name}:")
            for metric, value in results.items():
                if metric not in ['dataset', 'task_type', 'n_classes', 'n_samples']:
                    print(f"  {metric}: {value:.4f}")
                    
        except Exception as e:
            print(f"Error evaluating {dataset_name}: {e}")
            continue
    
    # ===== SCENARIO 2: SEMANTIC SHIFT DETECTION =====
    print("\n" + "="*60)
    print("SCENARIO 2: SEMANTIC SHIFT DETECTION")
    print("="*60)
    
    semantic_results = []
    
    for dataset_name in eval_datasets:
        print(f"\n--- Semantic Shift: {dataset_name.upper()} ---")
        
        try:
            # Create semantic split
            id_classes, ood_classes = create_semantic_split(dataset_name)
            
            if id_classes is None:
                print(f"Skipping {dataset_name}: insufficient classes for split")
                continue
            
            # Create ID and OOD datasets
            id_dataset = MedMNISTDataset(dataset_name, split='test', transform=preprocess, class_subset=id_classes)
            ood_dataset = MedMNISTDataset(dataset_name, split='test', transform=preprocess, class_subset=ood_classes)
            
            # Create classifier with only ID classes
            text_yes, text_no = merge_yes_no_feature(id_dataset, model, tokenizer, device)
            classifier = ViT_Classifier(model.visual, text_yes, text_no)
            
            # Create data loaders
            id_loader = DataLoader(id_dataset, batch_size=32, shuffle=False, num_workers=1, pin_memory=True)
            ood_loaders = {f"{dataset_name}_semantic_ood": DataLoader(ood_dataset, batch_size=32, shuffle=False, num_workers=1, pin_memory=True)}
            
            # Evaluate OOD detection
            ood_results = evaluate_ood_detection(id_loader, ood_loaders, classifier)
            
            for result in ood_results:
                result['scenario'] = 'semantic_shift'
                result['id_dataset'] = dataset_name
                semantic_results.append(result)
                
            print("Semantic Shift Results:")
            for result in ood_results:
                print(f"  {result['method']:<10}: AUROC={result['auroc']:.4f}, FPR95={result['fpr95']:.4f}")
                
        except Exception as e:
            print(f"Error in semantic shift evaluation for {dataset_name}: {e}")
            continue
    
    # ===== SCENARIO 3: MODALITY SHIFT DETECTION =====
    print("\n" + "="*60)
    print("SCENARIO 3: MODALITY SHIFT DETECTION")
    print("="*60)
    
    modality_results = []
    
    for id_dataset_name in eval_datasets:
        print(f"\n--- Modality Shift: {id_dataset_name.upper()} vs Others ---")
        
        try:
            # Load ID dataset
            id_dataset = MedMNISTDataset(id_dataset_name, split='test', transform=preprocess)
            
            # Create classifier with ID dataset classes
            text_yes, text_no = merge_yes_no_feature(id_dataset, model, tokenizer, device)
            classifier = ViT_Classifier(model.visual, text_yes, text_no)
            
            # Create ID loader
            id_loader = DataLoader(id_dataset, batch_size=32, shuffle=False, num_workers=1, pin_memory=True)
            
            # Create OOD loaders (all other datasets)
            ood_loaders = {}
            for ood_dataset_name in eval_datasets:
                if ood_dataset_name != id_dataset_name:
                    try:
                        ood_dataset = MedMNISTDataset(ood_dataset_name, split='test', transform=preprocess)
                        ood_loaders[ood_dataset_name] = DataLoader(ood_dataset, batch_size=32, shuffle=False, num_workers=1, pin_memory=True)
                    except Exception as e:
                        print(f"Warning: Could not load {ood_dataset_name}: {e}")
                        continue
            
            if not ood_loaders:
                print(f"No OOD datasets available for {id_dataset_name}")
                continue
            
            # Evaluate OOD detection
            ood_results = evaluate_ood_detection(id_loader, ood_loaders, classifier)
            
            for result in ood_results:
                result['scenario'] = 'modality_shift'
                result['id_dataset'] = id_dataset_name
                modality_results.append(result)
            
            print(f"Modality Shift Results (ID: {id_dataset_name}):")
            for result in ood_results:
                print(f"  {result['method']:<10} vs {result['ood_dataset']:<15}: AUROC={result['auroc']:.4f}, FPR95={result['fpr95']:.4f}")
                
        except Exception as e:
            print(f"Error in modality shift evaluation for {id_dataset_name}: {e}")
            continue
    
    # ===== FINAL SUMMARY =====
    print("\n" + "="*80)
    print("EVALUATION SUMMARY")
    print("="*80)
    
    # Classification Summary
    if classification_results:
        print("\n1. Classification Accuracy Summary:")
        for result in classification_results:
            print(f"  {result['dataset']:<15} ({result['task_type']:<12}): ", end="")
            if 'accuracy' in result:
                print(f"Accuracy = {result['accuracy']:.4f}")
            else:
                print(f"Micro-AUC = {result['micro_auc']:.4f}, Macro-AUC = {result['macro_auc']:.4f}")
    
    # Semantic Shift Summary
    if semantic_results:
        print("\n2. Semantic Shift Detection Summary (Average AUROC):")
        semantic_df = pd.DataFrame(semantic_results)
        avg_results = semantic_df.groupby(['method'])['auroc'].mean().sort_values(ascending=False)
        for method, avg_auroc in avg_results.items():
            print(f"  {method:<10}: {avg_auroc:.4f}")
    
    # Modality Shift Summary
    if modality_results:
        print("\n3. Modality Shift Detection Summary (Average AUROC):")
        modality_df = pd.DataFrame(modality_results)
        avg_results = modality_df.groupby(['method'])['auroc'].mean().sort_values(ascending=False)
        for method, avg_auroc in avg_results.items():
            print(f"  {method:<10}: {avg_auroc:.4f}")
    
    return classification_results, semantic_results, modality_results

# ===== EXECUTE EVALUATION =====
print("Starting comprehensive evaluation...")
print("This will evaluate:")
print("1. Classification accuracy on each dataset")
print("2. Semantic shift detection (novel classes within datasets)")
print("3. Modality shift detection (cross-dataset OOD)")
print("\nRunning evaluation...")

classification_results, semantic_results, modality_results = run_comprehensive_evaluation()


Starting comprehensive evaluation...
This will evaluate:
1. Classification accuracy on each dataset
2. Semantic shift detection (novel classes within datasets)
3. Modality shift detection (cross-dataset OOD)

Running evaluation...
COMPREHENSIVE MEDMNIST CLIPN EVALUATION

SCENARIO 1: ORIGINAL DATASET CLASSIFICATION ACCURACY

--- Evaluating PATHMNIST ---
Loaded pathmnist test split:
  Task type: multi-class
  Samples: 7180
  Classes: 7180
  Class names: ['adipose', 'background', 'debris', 'lymphocytes', 'mucus', 'smooth muscle', 'normal colon mucosa', 'cancer-associated stroma', 'colorectal adenocarcinoma epithelium']


Evaluating classification: 100%|██████████| 225/225 [00:22<00:00,  9.97it/s]


Results for pathmnist:
  accuracy: 0.3462

--- Evaluating CHESTMNIST ---
Loaded chestmnist test split:
  Task type: multi-label
  Samples: 22433
  Classes: 22433
  Class names: ['atelectasis', 'cardiomegaly', 'effusion', 'infiltration', 'mass', 'nodule', 'pneumonia', 'pneumothorax', 'consolidation', 'edema', 'emphysema', 'fibrosis', 'pleural', 'hernia']


Evaluating classification: 100%|██████████| 702/702 [01:10<00:00,  9.96it/s]


Results for chestmnist:
  micro_auc: 0.5699
  macro_auc: 0.5338

--- Evaluating DERMAMNIST ---
Loaded dermamnist test split:
  Task type: multi-class
  Samples: 2005
  Classes: 2005
  Class names: ['actinic keratoses and intraepithelial carcinoma', 'basal cell carcinoma', 'benign keratosis-like lesions', 'dermatofibroma', 'melanoma', 'melanocytic nevi', 'vascular lesions']


Evaluating classification: 100%|██████████| 63/63 [00:06<00:00, 10.02it/s]


Results for dermamnist:
  accuracy: 0.4494

--- Evaluating OCTMNIST ---
Loaded octmnist test split:
  Task type: multi-class
  Samples: 1000
  Classes: 1000
  Class names: ['choroidal neovascularization', 'diabetic macular edema', 'drusen', 'normal']


Evaluating classification: 100%|██████████| 32/32 [00:03<00:00,  9.93it/s]


Results for octmnist:
  accuracy: 0.2500

--- Evaluating PNEUMONIAMNIST ---
Loaded pneumoniamnist test split:
  Task type: binary-class
  Samples: 624
  Classes: 624
  Class names: ['normal', 'pneumonia']


Evaluating classification: 100%|██████████| 20/20 [00:02<00:00,  9.81it/s]


Results for pneumoniamnist:
  accuracy: 0.3782

--- Evaluating BREASTMNIST ---
Loaded breastmnist test split:
  Task type: binary-class
  Samples: 156
  Classes: 156
  Class names: ['malignant', 'normal, benign']


Evaluating classification: 100%|██████████| 5/5 [00:00<00:00,  8.44it/s]


Results for breastmnist:
  accuracy: 0.7372

--- Evaluating BLOODMNIST ---
Loaded bloodmnist test split:
  Task type: multi-class
  Samples: 3421
  Classes: 3421
  Class names: ['basophil', 'eosinophil', 'erythroblast', 'immature granulocytes(myelocytes, metamyelocytes and promyelocytes)', 'lymphocyte', 'monocyte', 'neutrophil', 'platelet']


Evaluating classification: 100%|██████████| 107/107 [00:10<00:00,  9.96it/s]


Results for bloodmnist:
  accuracy: 0.1432

--- Evaluating TISSUEMNIST ---


KeyboardInterrupt: 