In [1]:
!pip install -q medmnist pillow scipy tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from medmnist import INFO
import numpy as np
from sklearn.metrics import roc_auc_score
import time
import os
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("üöÄ PRODUCTION: ALL 18 MEDMNIST DATASETS (2D+2.5D ResNet-50)")
print("="*80)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

benchmarks = {
    # 2D Datasets
    'pathmnist': 0.989, 'chestmnist': 0.773, 'dermamnist': 0.920, 'octmnist': 0.958,
    'pneumoniamnist': 0.962, 'retinamnist': 0.716, 'breastmnist': 0.866, 'bloodmnist': 0.998,
    'tissuemnist': 0.932, 'organamnist': 0.998, 'organcmnist': 0.993, 'organsmnist': 0.975,
    # 3D Datasets
    'adversarialmnist3d': 0.892, 'nodulemnist3d': 0.913, 'synapsemnist3d': 0.975,
    'fracturemnist3d': 0.871, 'spleenmnist3d': 0.973, 'abasemnist3d': 0.889
}

all_results = {}

def is_3d(ds_name):
    return ds_name.endswith('3d')

def is_multilabel(ds_name):
    """ChestMNIST is multi-label (14 diseases)"""
    return ds_name == 'chestmnist'

def safe_target_processing(target, multilabel=False):
    if multilabel:
        # Multi-label: keep as float, shape (B, num_classes)
        return target.float()
    else:
        # Single-label: squeeze to (B,)
        if target.ndim > 1:
            target = target.squeeze(-1)
        if target.ndim == 0:
            target = target.unsqueeze(0)
        return target.long()

def get_2d_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def get_3d_transform_2_5d():
    def preprocess_3d_to_2_5d(volume):
        if volume.ndim == 4 and volume.shape[0] == 1:
            volume = volume[0]
        if volume.ndim != 3:
            volume = volume.squeeze()
        D, H, W = volume.shape
        idx1, idx2, idx3 = int(D*0.25), int(D*0.50), int(D*0.75)
        slice1 = volume[idx1].astype(np.float32) / 255.0
        slice2 = volume[idx2].astype(np.float32) / 255.0
        slice3 = volume[idx3].astype(np.float32) / 255.0
        from scipy.ndimage import zoom
        zoom_factors = (224/H, 224/W)
        slice1 = zoom(slice1, zoom_factors, order=1)
        slice2 = zoom(slice2, zoom_factors, order=1)
        slice3 = zoom(slice3, zoom_factors, order=1)
        rgb_img = np.stack([slice1, slice2, slice3], axis=0)
        return torch.tensor(rgb_img, dtype=torch.float32)
    return transforms.Compose([
        transforms.Lambda(preprocess_3d_to_2_5d),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def get_resnet50(num_classes):
    model = models.resnet50(weights='IMAGENET1K_V2')
    model.fc = nn.Linear(2048, num_classes)
    return model

def train_single_dataset(ds_name, epochs=5):
    print(f"\n{'='*70}")
    print(f"üî¨ ResNet-50 {'2.5D' if is_3d(ds_name) else '2D'}: {ds_name}")
    print(f"{'='*70}")

    try:
        info = INFO[ds_name]
        num_classes = len(info['label'])
        multilabel = is_multilabel(ds_name)
        if multilabel:
            print(f"‚ö†Ô∏è  Multi-label dataset: {num_classes} classes")
    except:
        print(f"‚ùå NOT FOUND: {ds_name}")
        return 0.5

    transform = get_3d_transform_2_5d() if is_3d(ds_name) else get_2d_transform()
    as_rgb = not is_3d(ds_name)

    try:
        module = __import__('medmnist', fromlist=[info['python_class']])
        DataClass = getattr(module, info['python_class'])
        train_ds = DataClass(split='train', transform=transform, download=True, as_rgb=as_rgb)
        val_ds = DataClass(split='val', transform=transform, download=True, as_rgb=as_rgb)
        test_ds = DataClass(split='test', transform=transform, download=True, as_rgb=as_rgb)
        print(f"‚úÖ Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}")
    except Exception as e:
        print(f"‚ùå Load error: {e}")
        return 0.5

    batch = 64
    train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch, shuffle=False, num_workers=2, pin_memory=True)

    model = get_resnet50(num_classes)
    model = model.to(device)
    print(f"‚úÖ Model: {sum(p.numel() for p in model.parameters()):,} params")

    for param in list(model.parameters())[:-60]:
        param.requires_grad = False

    trainable = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable, lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_auc = 0
    best_state = None
    start_time = time.time()

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            target = safe_target_processing(target, multilabel)

            min_b = min(data.size(0), target.size(0))
            data, target = data[:min_b], target[:min_b]

            optimizer.zero_grad()
            output = model(data)

            # Different loss for multi-label vs single-label
            if multilabel:
                loss = F.binary_cross_entropy_with_logits(output, target)
            else:
                loss = F.cross_entropy(output, target)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()

        scheduler.step()

        # Validation
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                target = safe_target_processing(target, multilabel)

                min_b = min(data.size(0), target.size(0))
                data, target = data[:min_b], target[:min_b]

                # Different prediction for multi-label vs single-label
                if multilabel:
                    prob = torch.sigmoid(model(data))
                else:
                    prob = F.softmax(model(data), dim=1)

                val_preds.extend(prob.cpu().numpy())
                val_targets.extend(target.cpu().numpy())

        val_preds = np.array(val_preds)
        val_targets = np.array(val_targets)

        try:
            if multilabel:
                val_auc = roc_auc_score(val_targets, val_preds, average='macro')
            elif num_classes > 2:
                val_auc = roc_auc_score(val_targets, val_preds, multi_class='ovr')
            else:
                val_auc = roc_auc_score(val_targets, val_preds[:, 1])
        except:
            val_auc = 0.5

        if val_auc > best_auc:
            best_auc = val_auc
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}

        elapsed = time.time() - start_time
        print(f"  Epoch {epoch+1}: Loss={train_loss/len(train_loader):.3f} | AUC={val_auc:.4f} | Best={best_auc:.4f} | {elapsed/60:.1f}m")

    # Test
    if best_state:
        model.load_state_dict(best_state)
    model.eval()

    test_preds, test_targets = [], []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = safe_target_processing(target, multilabel)

            min_b = min(data.size(0), target.size(0))
            data, target = data[:min_b], target[:min_b]

            if multilabel:
                prob = torch.sigmoid(model(data))
            else:
                prob = F.softmax(model(data), dim=1)

            test_preds.extend(prob.cpu().numpy())
            test_targets.extend(target.cpu().numpy())

    test_preds = np.array(test_preds)
    test_targets = np.array(test_targets)

    try:
        if multilabel:
            test_auc = roc_auc_score(test_targets, test_preds, average='macro')
        elif num_classes > 2:
            test_auc = roc_auc_score(test_targets, test_preds, multi_class='ovr')
        else:
            test_auc = roc_auc_score(test_targets, test_preds[:, 1])
    except:
        test_auc = 0.5

    total_time = time.time() - start_time
    bench = benchmarks.get(ds_name, 0)
    gap = test_auc - bench

    print(f"‚úÖ {ds_name} | AUC: {test_auc:.4f} | Bench: {bench:.3f} | Gap: {gap:+.4f}")
    all_results[ds_name] = test_auc
    return test_auc

# ALL 18 DATASETS
all_datasets = [
    # 2D
    'pathmnist','chestmnist','dermamnist','octmnist','pneumoniamnist','retinamnist',
    'breastmnist','bloodmnist','tissuemnist','organamnist','organcmnist','organsmnist',
    # 3D
    'adversarialmnist3d','nodulemnist3d','synapsemnist3d','fracturemnist3d','spleenmnist3d','abasemnist3d'
]

print("\nüöÄ FULL 2D AND 2.5D (MULTI-LABEL FIXED)")

for ds in all_datasets:
    if os.path.exists(f'{ds}_best.pth'):
        os.remove(f'{ds}_best.pth')

for ds_name in all_datasets:
    try:
        train_single_dataset(ds_name, epochs=5)
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"‚ùå {ds_name} FAILED: {str(e)[:100]}")
        all_results[ds_name] = 0.5

print("\n" + "="*80)
print("üéØ FINAL RESULTS")
print("="*80)
print(f"{'Dataset':<20} {'AUC':<10} {'Benchmark':<10} {'Gap':<10}")
print("-"*60)
for ds in all_datasets:
    bench = benchmarks.get(ds, 0)
    auc = all_results.get(ds, 0.5)
    gap = auc - bench
    status = "‚úÖ" if gap > -0.1 else "üî¥"
    print(f"{ds:<20} {auc:<10.4f} {bench:<10.3f} {gap:+10.4f} {status}")

avg_auc = np.mean(list(all_results.values()))
print("="*80)
print(f"üìà Average AUC: {avg_auc:.4f}")
print("="*80)


üöÄ PRODUCTION: ALL 18 MEDMNIST DATASETS (2D+2.5D ResNet-50)
Device: cuda

üöÄ FULL 2D AND 2.5D (MULTI-LABEL FIXED)

üî¨ ResNet-50 2D: pathmnist
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
‚úÖ Train: 89,996 | Val: 10,004 | Test: 7,180
‚úÖ Model: 23,526,473 params
  Epoch 1: Loss=0.155 | AUC=0.9995 | Best=0.9995 | 4.9m
  Epoch 2: Loss=0.055 | AUC=0.9988 | Best=0.9995 | 9.8m
  Epoch 3: Loss=0.024 | AUC=0.9980 | Best=0.9995 | 14.6m
  Epoch 4: Loss=0.010 | AUC=0.9998 | Best=0.9998 | 19.5m
  Epoch 5: Loss=0.004 | AUC=0.9999 | Best=0.9999 | 24.3m
‚úÖ pathmnist | AUC: 0.9954 | Bench: 0.989 | Gap: +0.0064

üî¨ ResNet-50 2D: chestmnist
‚ö†Ô∏è  Multi-label dataset: 14 classes
Using downloaded and verified file: C:\Users\User\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\chestmni

In [2]:
!pip install -q medmnist pillow scipy tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from medmnist import INFO
import numpy as np
from sklearn.metrics import roc_auc_score
import time
import os
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("üöÄ PRODUCTION: ALL 18 MEDMNIST DATASETS (2D+2.5D ResNet-50)")
print("   CONFIG: 30 Epochs | Early Stopping (Patience=5)")
print("="*80)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

benchmarks = {
    # 2D Datasets
    'pathmnist': 0.989, 'chestmnist': 0.773, 'dermamnist': 0.920, 'octmnist': 0.958,
    'pneumoniamnist': 0.962, 'retinamnist': 0.716, 'breastmnist': 0.866, 'bloodmnist': 0.998,
    'tissuemnist': 0.932, 'organamnist': 0.998, 'organcmnist': 0.993, 'organsmnist': 0.975,
    # 3D Datasets
    'adversarialmnist3d': 0.892, 'nodulemnist3d': 0.913, 'synapsemnist3d': 0.975,
    'fracturemnist3d': 0.871, 'spleenmnist3d': 0.973, 'abasemnist3d': 0.889
}

all_results = {}

def is_3d(ds_name):
    return ds_name.endswith('3d')

def is_multilabel(ds_name):
    """ChestMNIST is multi-label (14 diseases)"""
    return ds_name == 'chestmnist'

def safe_target_processing(target, multilabel=False):
    if multilabel:
        return target.float()
    else:
        if target.ndim > 1:
            target = target.squeeze(-1)
        if target.ndim == 0:
            target = target.unsqueeze(0)
        return target.long()

def get_2d_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def get_3d_transform_2_5d():
    def preprocess_3d_to_2_5d(volume):
        if volume.ndim == 4 and volume.shape[0] == 1:
            volume = volume[0]
        if volume.ndim != 3:
            volume = volume.squeeze()
        D, H, W = volume.shape
        idx1, idx2, idx3 = int(D*0.25), int(D*0.50), int(D*0.75)
        slice1 = volume[idx1].astype(np.float32) / 255.0
        slice2 = volume[idx2].astype(np.float32) / 255.0
        slice3 = volume[idx3].astype(np.float32) / 255.0
        from scipy.ndimage import zoom
        zoom_factors = (224/H, 224/W)
        slice1 = zoom(slice1, zoom_factors, order=1)
        slice2 = zoom(slice2, zoom_factors, order=1)
        slice3 = zoom(slice3, zoom_factors, order=1)
        rgb_img = np.stack([slice1, slice2, slice3], axis=0)
        return torch.tensor(rgb_img, dtype=torch.float32)
    return transforms.Compose([
        transforms.Lambda(preprocess_3d_to_2_5d),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def get_resnet50(num_classes):
    model = models.resnet50(weights='IMAGENET1K_V2')
    model.fc = nn.Linear(2048, num_classes)
    return model

def train_single_dataset(ds_name, epochs=30, patience=5):
    print(f"\n{'='*70}")
    print(f"üî¨ ResNet-50 {'2.5D' if is_3d(ds_name) else '2D'}: {ds_name}")
    print(f"{'='*70}")

    try:
        info = INFO[ds_name]
        num_classes = len(info['label'])
        multilabel = is_multilabel(ds_name)
        if multilabel:
            print(f"‚ö†Ô∏è  Multi-label dataset: {num_classes} classes")
    except:
        print(f"‚ùå NOT FOUND: {ds_name}")
        return 0.5

    transform = get_3d_transform_2_5d() if is_3d(ds_name) else get_2d_transform()
    as_rgb = not is_3d(ds_name)

    try:
        module = __import__('medmnist', fromlist=[info['python_class']])
        DataClass = getattr(module, info['python_class'])
        train_ds = DataClass(split='train', transform=transform, download=True, as_rgb=as_rgb)
        val_ds = DataClass(split='val', transform=transform, download=True, as_rgb=as_rgb)
        test_ds = DataClass(split='test', transform=transform, download=True, as_rgb=as_rgb)
        print(f"‚úÖ Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}")
    except Exception as e:
        print(f"‚ùå Load error: {e}")
        return 0.5

    batch = 64
    train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch, shuffle=False, num_workers=2, pin_memory=True)

    model = get_resnet50(num_classes)
    model = model.to(device)
    
    # Freeze initial layers for faster convergence
    for param in list(model.parameters())[:-60]:
        param.requires_grad = False

    trainable = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable, lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_auc = 0
    best_state = None
    patience_counter = 0  # Initialize patience counter
    start_time = time.time()

    for epoch in range(epochs):
        model.train()
        train_loss = 0

        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            target = safe_target_processing(target, multilabel)

            min_b = min(data.size(0), target.size(0))
            data, target = data[:min_b], target[:min_b]

            optimizer.zero_grad()
            output = model(data)

            if multilabel:
                loss = F.binary_cross_entropy_with_logits(output, target)
            else:
                loss = F.cross_entropy(output, target)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()

        scheduler.step()

        # Validation
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                target = safe_target_processing(target, multilabel)

                min_b = min(data.size(0), target.size(0))
                data, target = data[:min_b], target[:min_b]

                if multilabel:
                    prob = torch.sigmoid(model(data))
                else:
                    prob = F.softmax(model(data), dim=1)

                val_preds.extend(prob.cpu().numpy())
                val_targets.extend(target.cpu().numpy())

        val_preds = np.array(val_preds)
        val_targets = np.array(val_targets)

        try:
            if multilabel:
                val_auc = roc_auc_score(val_targets, val_preds, average='macro')
            elif num_classes > 2:
                val_auc = roc_auc_score(val_targets, val_preds, multi_class='ovr')
            else:
                val_auc = roc_auc_score(val_targets, val_preds[:, 1])
        except:
            val_auc = 0.5

        elapsed = time.time() - start_time
        print(f"  Epoch {epoch+1}/{epochs}: Loss={train_loss/len(train_loader):.3f} | AUC={val_auc:.4f} | Best={best_auc:.4f} | Patience={patience_counter}/{patience}")

        # EARLY STOPPING CHECK
        if val_auc > best_auc:
            best_auc = val_auc
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0  # Reset counter on improvement
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"üõë Early stopping triggered at epoch {epoch+1}")
                break

    # Test
    if best_state:
        model.load_state_dict(best_state)
    model.eval()

    test_preds, test_targets = [], []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = safe_target_processing(target, multilabel)

            min_b = min(data.size(0), target.size(0))
            data, target = data[:min_b], target[:min_b]

            if multilabel:
                prob = torch.sigmoid(model(data))
            else:
                prob = F.softmax(model(data), dim=1)

            test_preds.extend(prob.cpu().numpy())
            test_targets.extend(target.cpu().numpy())

    test_preds = np.array(test_preds)
    test_targets = np.array(test_targets)

    try:
        if multilabel:
            test_auc = roc_auc_score(test_targets, test_preds, average='macro')
        elif num_classes > 2:
            test_auc = roc_auc_score(test_targets, test_preds, multi_class='ovr')
        else:
            test_auc = roc_auc_score(test_targets, test_preds[:, 1])
    except:
        test_auc = 0.5

    bench = benchmarks.get(ds_name, 0)
    gap = test_auc - bench

    print(f"‚úÖ {ds_name} | Test AUC: {test_auc:.4f} | Bench: {bench:.3f} | Gap: {gap:+.4f}")
    all_results[ds_name] = test_auc
    return test_auc

# ALL 18 DATASETS
all_datasets = [
    # 2D
    'pathmnist','chestmnist','dermamnist','octmnist','pneumoniamnist','retinamnist',
    'breastmnist','bloodmnist','tissuemnist','organamnist','organcmnist','organsmnist',
    # 3D
    'adversarialmnist3d','nodulemnist3d','synapsemnist3d','fracturemnist3d','spleenmnist3d','abasemnist3d'
]

print("\nüöÄ STARTING TRAINING LOOP")

for ds in all_datasets:
    if os.path.exists(f'{ds}_best.pth'):
        os.remove(f'{ds}_best.pth')

for ds_name in all_datasets:
    try:
        train_single_dataset(ds_name, epochs=30, patience=5)
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"‚ùå {ds_name} FAILED: {str(e)[:100]}")
        all_results[ds_name] = 0.5

print("\n" + "="*80)
print("üéØ FINAL RESULTS (30 Epochs + Early Stopping)")
print("="*80)
print(f"{'Dataset':<20} {'AUC':<10} {'Benchmark':<10} {'Gap':<10}")
print("-"*60)
for ds in all_datasets:
    bench = benchmarks.get(ds, 0)
    auc = all_results.get(ds, 0.5)
    gap = auc - bench
    status = "‚úÖ" if gap > -0.1 else "üî¥"
    print(f"{ds:<20} {auc:<10.4f} {bench:<10.3f} {gap:+10.4f} {status}")

avg_auc = np.mean(list(all_results.values()))
print("="*80)
print(f"üìà Average AUC: {avg_auc:.4f}")
print("="*80)


üöÄ PRODUCTION: ALL 18 MEDMNIST DATASETS (2D+2.5D ResNet-50)
   CONFIG: 30 Epochs | Early Stopping (Patience=5)
Device: cuda

üöÄ STARTING TRAINING LOOP

üî¨ ResNet-50 2D: pathmnist
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
‚úÖ Train: 89,996 | Val: 10,004 | Test: 7,180
  Epoch 1/30: Loss=0.155 | AUC=0.9981 | Best=0.0000 | Patience=0/5
  Epoch 2/30: Loss=0.057 | AUC=0.9997 | Best=0.9981 | Patience=0/5
  Epoch 3/30: Loss=0.034 | AUC=0.9997 | Best=0.9997 | Patience=0/5
  Epoch 4/30: Loss=0.026 | AUC=0.9996 | Best=0.9997 | Patience=1/5
  Epoch 5/30: Loss=0.017 | AUC=0.9998 | Best=0.9997 | Patience=2/5
  Epoch 6/30: Loss=0.015 | AUC=0.9976 | Best=0.9998 | Patience=0/5
  Epoch 7/30: Loss=0.013 | AUC=0.9998 | Best=0.9998 | Patience=1/5
  Epoch 8/30: Loss=0.010 | AUC=0.9997 | Best=0.9998 | Patience=0/5
  Epoch 9/3

KeyboardInterrupt: 

In [3]:
!pip install -q medmnist pillow scipy tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from medmnist import INFO
import numpy as np
from sklearn.metrics import roc_auc_score
import time
import os
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("üöÄ OPTIMAL: ACS (All Slices) for 3D + ResNet-50 for 2D")
print("   CONFIG: 30 Epochs | Early Stopping (Patience=5)")
print("="*80)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

benchmarks = {
    'pathmnist': 0.989, 'chestmnist': 0.773, 'dermamnist': 0.920, 'octmnist': 0.958,
    'pneumoniamnist': 0.962, 'retinamnist': 0.716, 'breastmnist': 0.866, 'bloodmnist': 0.998,
    'tissuemnist': 0.932, 'organamnist': 0.998, 'organcmnist': 0.993, 'organsmnist': 0.975,
    'adrenalmnist3d': 0.889, 'fracturemnist3d': 0.871, 'nodulemnist3d': 0.913, 
    'organmnist3d': 0.995, 'synapsemnist3d': 0.975, 'vesselmnist3d': 0.899
}

all_results = {}

def is_3d(ds_name):
    return ds_name.endswith('3d')

def is_multilabel(ds_name):
    return ds_name == 'chestmnist'

def safe_target_processing(target, multilabel=False):
    if multilabel:
        return target.float()
    else:
        if target.ndim > 1:
            target = target.squeeze(-1)
        if target.ndim == 0:
            target = target.unsqueeze(0)
        return target.long()

def get_2d_transform():
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

def preprocess_3d_acs(volume):
    """ACS: Process ALL slices, not just 3"""
    if volume.ndim == 4 and volume.shape[0] == 1:
        volume = volume[0]
    if volume.ndim != 3:
        volume = volume.squeeze()
    
    D, H, W = volume.shape
    volume = volume.astype(np.float32) / 255.0
    
    from scipy.ndimage import zoom
    zoom_factors = (1.0, 224/H, 224/W)  # Keep all slices, resize H,W
    resized = zoom(volume, zoom_factors, order=1)
    
    # Convert to RGB by repeating: (D,224,224) -> (D,3,224,224)
    slices = []
    for i in range(resized.shape[0]):
        slice_rgb = np.stack([resized[i]]*3, axis=0)  # (3,224,224)
        slices.append(slice_rgb)
    
    return torch.tensor(np.array(slices), dtype=torch.float32)  # (D,3,224,224)

def get_3d_transform_acs():
    return transforms.Compose([
        transforms.Lambda(preprocess_3d_acs),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

class ResNet50_ACS(nn.Module):
    """ResNet-50 with ACS: processes all slices and aggregates"""
    def __init__(self, num_classes):
        super().__init__()
        self.resnet = models.resnet50(weights='IMAGENET1K_V2')
        self.resnet.fc = nn.Identity()  # Remove final FC
        self.fc = nn.Linear(2048, num_classes)
        
    def forward(self, x):
        # x shape: (B, D, 3, 224, 224) for 3D or (B, 3, 224, 224) for 2D
        if x.ndim == 5:  # 3D data
            B, D, C, H, W = x.shape
            x = x.view(B*D, C, H, W)  # (B*D, 3, 224, 224)
            features = self.resnet(x)  # (B*D, 2048)
            features = features.view(B, D, -1)  # (B, D, 2048)
            features = features.mean(dim=1)  # (B, 2048) - average pooling
        else:  # 2D data
            features = self.resnet(x)  # (B, 2048)
        
        return self.fc(features)

def get_resnet50(num_classes, is_3d_data):
    if is_3d_data:
        return ResNet50_ACS(num_classes)
    else:
        model = models.resnet50(weights='IMAGENET1K_V2')
        model.fc = nn.Linear(2048, num_classes)
        return model

def train_single_dataset(ds_name, epochs=30, patience=5):
    print(f"\n{'='*70}")
    print(f"üî¨ ResNet-50 {'ACS' if is_3d(ds_name) else '2D'}: {ds_name}")
    print(f"{'='*70}")
    
    try:
        info = INFO[ds_name]
        num_classes = len(info['label'])
        multilabel = is_multilabel(ds_name)
        if multilabel:
            print(f"‚ö†Ô∏è  Multi-label: {num_classes} classes")
    except:
        print(f"‚ùå NOT FOUND: {ds_name}")
        return 0.5
    
    transform = get_3d_transform_acs() if is_3d(ds_name) else get_2d_transform()
    as_rgb = not is_3d(ds_name)
    
    try:
        module = __import__('medmnist', fromlist=[info['python_class']])
        DataClass = getattr(module, info['python_class'])
        train_ds = DataClass(split='train', transform=transform, download=True, as_rgb=as_rgb)
        val_ds = DataClass(split='val', transform=transform, download=True, as_rgb=as_rgb)
        test_ds = DataClass(split='test', transform=transform, download=True, as_rgb=as_rgb)
        print(f"‚úÖ Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,}")
    except Exception as e:
        print(f"‚ùå Load error: {e}")
        return 0.5
    
    # Smaller batch for 3D (more memory)
    batch = 16 if is_3d(ds_name) else 64
    train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=0, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=batch, shuffle=False, num_workers=0, pin_memory=True)
    
    model = get_resnet50(num_classes, is_3d(ds_name))
    model = model.to(device)
    print(f"‚úÖ Model: {sum(p.numel() for p in model.parameters()):,} params")
    
    # Freeze early layers
    for param in list(model.parameters())[:-60]:
        param.requires_grad = False
    
    trainable = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable, lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    best_auc = 0
    best_state = None
    patience_counter = 0
    start_time = time.time()
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            target = safe_target_processing(target, multilabel)
            
            min_b = min(data.size(0), target.size(0))
            data, target = data[:min_b], target[:min_b]
            
            optimizer.zero_grad()
            output = model(data)
            
            if multilabel:
                loss = F.binary_cross_entropy_with_logits(output, target)
            else:
                loss = F.cross_entropy(output, target)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            train_loss += loss.item()
        
        scheduler.step()
        
        model.eval()
        val_preds, val_targets = [], []
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                target = safe_target_processing(target, multilabel)
                
                min_b = min(data.size(0), target.size(0))
                data, target = data[:min_b], target[:min_b]
                
                if multilabel:
                    prob = torch.sigmoid(model(data))
                else:
                    prob = F.softmax(model(data), dim=1)
                
                val_preds.extend(prob.cpu().numpy())
                val_targets.extend(target.cpu().numpy())
        
        val_preds = np.array(val_preds)
        val_targets = np.array(val_targets)
        
        try:
            if multilabel:
                val_auc = roc_auc_score(val_targets, val_preds, average='macro')
            elif num_classes > 2:
                val_auc = roc_auc_score(val_targets, val_preds, multi_class='ovr')
            else:
                val_auc = roc_auc_score(val_targets, val_preds[:, 1])
        except:
            val_auc = 0.5
        
        elapsed = time.time() - start_time
        print(f"  Epoch {epoch+1}/{epochs}: Loss={train_loss/len(train_loader):.3f} | AUC={val_auc:.4f} | Best={best_auc:.4f}")
        
        # EARLY STOPPING
        if val_auc > best_auc:
            best_auc = val_auc
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"  üõë Early stopping triggered at epoch {epoch+1}")
                break
    
    if best_state:
        model.load_state_dict(best_state)
    model.eval()
    
    test_preds, test_targets = [], []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target = safe_target_processing(target, multilabel)
            
            min_b = min(data.size(0), target.size(0))
            data, target = data[:min_b], target[:min_b]
            
            if multilabel:
                prob = torch.sigmoid(model(data))
            else:
                prob = F.softmax(model(data), dim=1)
            
            test_preds.extend(prob.cpu().numpy())
            test_targets.extend(target.cpu().numpy())
    
    test_preds = np.array(test_preds)
    test_targets = np.array(test_targets)
    
    try:
        if multilabel:
            test_auc = roc_auc_score(test_targets, test_preds, average='macro')
        elif num_classes > 2:
            test_auc = roc_auc_score(test_targets, test_preds, multi_class='ovr')
        else:
            test_auc = roc_auc_score(test_targets, test_preds[:, 1])
    except:
        test_auc = 0.5
    
    total_time = time.time() - start_time
    bench = benchmarks.get(ds_name, 0)
    gap = test_auc - bench
    
    print(f"‚úÖ {ds_name} | AUC: {test_auc:.4f} | Bench: {bench:.3f} | Gap: {gap:+.4f}")
    all_results[ds_name] = test_auc
    return test_auc

all_datasets = [
    'adrenalmnist3d', 'fracturemnist3d', 'nodulemnist3d', 
    'organmnist3d', 'synapsemnist3d', 'vesselmnist3d',
    'pathmnist', 'chestmnist', 'dermamnist', 'octmnist', 'pneumoniamnist', 'retinamnist',
    'breastmnist', 'bloodmnist', 'tissuemnist', 'organamnist', 'organcmnist', 'organsmnist'
]

print("\nüöÄ ACS (All Slices) for 3D | 30 EPOCHS (Patience=5) | 3D FIRST ‚Üí 2D AFTER\n")

for ds in all_datasets:
    if os.path.exists(f'{ds}_best.pth'):
        os.remove(f'{ds}_best.pth')

for ds_name in all_datasets:
    try:
        train_single_dataset(ds_name, epochs=30, patience=5)
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"‚ùå {ds_name} FAILED: {str(e)[:100]}")
        all_results[ds_name] = 0.5

print("\n" + "="*80)
print("üéØ FINAL RESULTS (ACS for 3D)")
print("="*80)
print(f"{'Dataset':<20} {'AUC':<10} {'Benchmark':<10} {'Gap':<10}")
print("-"*60)
print("\nüìä 3D RESULTS (ACS - All Slices)")
print("-"*60)
for ds in all_datasets[:6]:
    bench = benchmarks.get(ds, 0)
    auc = all_results.get(ds, 0.5)
    gap = auc - bench
    status = "‚úÖ" if gap > -0.1 else "üî¥"
    print(f"{ds:<20} {auc:<10.4f} {bench:<10.3f} {gap:+10.4f} {status}")
print("\nüìä 2D RESULTS")
print("-"*60)
for ds in all_datasets[6:]:
    bench = benchmarks.get(ds, 0)
    auc = all_results.get(ds, 0.5)
    gap = auc - bench
    status = "‚úÖ" if gap > -0.1 else "üî¥"
    print(f"{ds:<20} {auc:<10.4f} {bench:<10.3f} {gap:+10.4f} {status}")
avg_auc = np.mean(list(all_results.values()))
print("="*80)
print(f"üìà Average AUC: {avg_auc:.4f}")
print("="*80)


üöÄ OPTIMAL: ACS (All Slices) for 3D + ResNet-50 for 2D
   CONFIG: 30 Epochs | Early Stopping (Patience=5)
Device: cuda

üöÄ ACS (All Slices) for 3D | 30 EPOCHS (Patience=5) | 3D FIRST ‚Üí 2D AFTER


üî¨ ResNet-50 ACS: adrenalmnist3d
Using downloaded and verified file: C:\Users\User\.medmnist\adrenalmnist3d.npz
Using downloaded and verified file: C:\Users\User\.medmnist\adrenalmnist3d.npz
Using downloaded and verified file: C:\Users\User\.medmnist\adrenalmnist3d.npz
‚úÖ Train: 1,188 | Val: 98 | Test: 298
‚úÖ Model: 23,512,130 params
  Epoch 1/30: Loss=0.564 | AUC=0.6896 | Best=0.0000
  Epoch 2/30: Loss=0.470 | AUC=0.8678 | Best=0.6896
  Epoch 3/30: Loss=0.438 | AUC=0.8947 | Best=0.8678
  Epoch 4/30: Loss=0.402 | AUC=0.8858 | Best=0.8947
  Epoch 5/30: Loss=0.390 | AUC=0.8618 | Best=0.8947
  Epoch 6/30: Loss=0.380 | AUC=0.8708 | Best=0.8947
  Epoch 7/30: Loss=0.340 | AUC=0.8319 | Best=0.8947
  Epoch 8/30: Loss=0.284 | AUC=0.8600 | Best=0.8947
  üõë Early stopping triggered at epoch 8