In [3]:
!pip install -q medmnist pillow

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from medmnist import INFO
import numpy as np
from sklearn.metrics import roc_auc_score
import time

print("="*70)
print("üî• RTX 4060 + RESNET-50 + 3 EPOCHS + LIVE MONITORING (FIXED)")
print("="*70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device} | VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")

benchmarks = {
    'pathmnist': 0.989, 'chestmnist': 0.773, 'dermamnist': 0.920, 'octmnist': 0.958,
    'pneumoniamnist': 0.962, 'retinamnist': 0.716, 'breastmnist': 0.866, 'bloodmnist': 0.998,
    'tissuemnist': 0.932, 'organamnist': 0.998, 'organcmnist': 0.993, 'organsmnist': 0.975
}

all_results = {}

def train_single_dataset(ds_name, epochs=3):
    print(f"\n{'='*60}")
    print(f"üî¨ RESNET-50: {ds_name}")
    print(f"{'='*60}")
    
    info = INFO[ds_name]
    module = __import__('medmnist', fromlist=[info['python_class']])
    DataClass = getattr(module, info['python_class'])
    
    # FIX 1: Define transform without ToPILImage (MedMNIST returns PIL)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # FIX 2: Pass transform and as_rgb=True in __init__
    train_ds = DataClass(split='train', transform=transform, download=True, as_rgb=True)
    val_ds = DataClass(split='val', transform=transform, download=True, as_rgb=True) 
    test_ds = DataClass(split='test', transform=transform, download=True, as_rgb=True)
    
    num_classes = len(info['label'])
    print(f"üìä Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,} | Classes: {num_classes}")
    
    # FIX 3: Pass dataset directly to DataLoader
    train_loader = DataLoader(train_ds, batch_size=96, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=192, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=192, shuffle=False, pin_memory=True)
    
    # RESNET-50 SETUP
    model = models.resnet50(weights='IMAGENET1K_V2')
    model.fc = nn.Linear(2048, num_classes)
    model = model.to(device)
    
    # Freeze early layers
    for param in list(model.parameters())[:-50]:
        param.requires_grad = False
    
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), 
                                 lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    best_auc = 0
    start_time = time.time()
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0
        train_batches = 0
        
        for data, target in train_loader:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True).squeeze()
            # Handle single-value targets vs multi-label
            if target.ndim == 0: target = target.unsqueeze(0)
            
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target.long())
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            train_batches += 1
        
        avg_train_loss = train_loss / train_batches
        
        # Validation
        model.eval()
        val_preds, val_targets = [], []
        
        with torch.no_grad():
            for data, target in val_loader:
                data = data.to(device, non_blocking=True)
                target = target.to(device, non_blocking=True).squeeze()
                output = model(data)
                prob = F.softmax(output, dim=1)
                val_preds.extend(prob.cpu().numpy())
                val_targets.extend(target.cpu().numpy())
        
        val_preds = np.array(val_preds)
        val_targets = np.array(val_targets)
        
        try:
            if num_classes > 2:
                val_auc = roc_auc_score(val_targets, val_preds, multi_class='ovr')
            else:
                val_auc = roc_auc_score(val_targets, val_preds[:, 1])
        except:
            val_auc = 0.5
        
        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(model.state_dict(), f'{ds_name}_best.pth')
        
        scheduler.step()
        
        # LIVE MONITORING
        elapsed = time.time() - start_time
        vram = torch.cuda.memory_allocated()/1e9
        print(f"üìà Epoch {epoch+1} | Loss: {avg_train_loss:.3f} | Val AUC: {val_auc:.4f} | Best: {best_auc:.4f} | {elapsed/60:.1f}m | {vram:.1f}GB")
    
    # Test
    model.load_state_dict(torch.load(f'{ds_name}_best.pth'))
    model.eval()
    
    test_preds, test_targets = [], []
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True).squeeze()
            output = model(data)
            prob = F.softmax(output, dim=1)
            test_preds.extend(prob.cpu().numpy())
            test_targets.extend(target.cpu().numpy())
    
    test_preds = np.array(test_preds)
    test_targets = np.array(test_targets)
    
    try:
        if num_classes > 2:
            test_auc = roc_auc_score(test_targets, test_preds, multi_class='ovr')
        else:
            test_auc = roc_auc_score(test_targets, test_preds[:, 1])
    except:
        test_auc = 0.5
    
    total_time = time.time() - start_time
    bench = benchmarks.get(ds_name, 0)
    print(f"‚úÖ {ds_name} DONE | Test AUC: {test_auc:.4f} | Gap: {test_auc-bench:+.4f}")
    
    return test_auc

# Run all 12
datasets_rn50 = [
    'bloodmnist', 'tissuemnist', 'pathmnist', 'organcmnist',
    'organamnist', 'chestmnist', 'pneumoniamnist', 'dermamnist',
    'breastmnist', 'organsmnist', 'octmnist', 'retinamnist'
]

print("\nüöÄ STARTING BATCH TRAINING\n")

for ds_name in datasets_rn50:
    try:
        auc = train_single_dataset(ds_name, epochs=3)
        all_results[ds_name] = auc
    except Exception as e:
        print(f"‚ùå {ds_name} FAILED: {e}")
        all_results[ds_name] = 0.5

# Final Table
print("\n" + "="*70)
print("üéØ FINAL RESULTS (3 EPOCHS)")
print("="*70)
print(f"{'Dataset':<15} {'Test AUC':<10} {'Benchmark':<10} {'Gap'}")
print("-"*60)

gaps = []
for ds in sorted(all_results.keys()):
    if ds in benchmarks:
        auc = all_results[ds]
        bench = benchmarks[ds]
        gap = auc - bench
        gaps.append(gap)
        print(f"{ds:<15} {auc:<10.4f} {bench:<10.3f} {gap:+.4f}")

if gaps:
    print("-"*60)
    print(f"üìä AVG GAP: {np.mean(gaps):+.4f}")
    print("="*70)


üî• RTX 4060 + RESNET-50 + 3 EPOCHS + LIVE MONITORING (FIXED)
Device: cuda | VRAM: 8.6GB

üöÄ STARTING BATCH TRAINING


üî¨ RESNET-50: bloodmnist
Using downloaded and verified file: C:\Users\User\.medmnist\bloodmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\bloodmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\bloodmnist.npz
üìä Train: 11,959 | Val: 1,712 | Test: 3,421 | Classes: 8


Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to C:\Users\User/.cache\torch\hub\checkpoints\resnet50-11ad3fa6.pth
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 97.8M/97.8M [00:16<00:00, 6.25MB/s]


üìà Epoch 1 | Loss: 0.355 | Val AUC: 0.9972 | Best: 0.9972 | 0.8m | 0.4GB
üìà Epoch 2 | Loss: 0.092 | Val AUC: 0.9976 | Best: 0.9976 | 1.5m | 0.4GB
üìà Epoch 3 | Loss: 0.023 | Val AUC: 0.9980 | Best: 0.9980 | 2.2m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ bloodmnist DONE | Test AUC: 0.9980 | Gap: +0.0000

üî¨ RESNET-50: tissuemnist
Using downloaded and verified file: C:\Users\User\.medmnist\tissuemnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\tissuemnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\tissuemnist.npz
üìä Train: 165,466 | Val: 23,640 | Test: 47,280 | Classes: 8
üìà Epoch 1 | Loss: 1.042 | Val AUC: 0.9141 | Best: 0.9141 | 10.0m | 0.3GB
üìà Epoch 2 | Loss: 0.864 | Val AUC: 0.9231 | Best: 0.9231 | 20.0m | 0.3GB
üìà Epoch 3 | Loss: 0.618 | Val AUC: 0.9217 | Best: 0.9231 | 30.0m | 0.3GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ tissuemnist DONE | Test AUC: 0.9238 | Gap: -0.0082

üî¨ RESNET-50: pathmnist
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
üìä Train: 89,996 | Val: 10,004 | Test: 7,180 | Classes: 9
üìà Epoch 1 | Loss: 0.146 | Val AUC: 0.9994 | Best: 0.9994 | 5.4m | 0.3GB
üìà Epoch 2 | Loss: 0.045 | Val AUC: 0.9997 | Best: 0.9997 | 10.8m | 0.3GB
üìà Epoch 3 | Loss: 0.012 | Val AUC: 0.9999 | Best: 0.9999 | 16.2m | 0.3GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ pathmnist DONE | Test AUC: 0.9924 | Gap: +0.0034

üî¨ RESNET-50: organcmnist
Using downloaded and verified file: C:\Users\User\.medmnist\organcmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organcmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organcmnist.npz
üìä Train: 12,975 | Val: 2,392 | Test: 8,216 | Classes: 11
üìà Epoch 1 | Loss: 0.320 | Val AUC: 0.9980 | Best: 0.9980 | 0.8m | 0.4GB
üìà Epoch 2 | Loss: 0.061 | Val AUC: 0.9996 | Best: 0.9996 | 1.6m | 0.4GB
üìà Epoch 3 | Loss: 0.016 | Val AUC: 0.9997 | Best: 0.9997 | 2.4m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ organcmnist DONE | Test AUC: 0.9953 | Gap: +0.0023

üî¨ RESNET-50: organamnist
Using downloaded and verified file: C:\Users\User\.medmnist\organamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organamnist.npz
üìä Train: 34,561 | Val: 6,491 | Test: 17,778 | Classes: 11
üìà Epoch 1 | Loss: 0.180 | Val AUC: 0.9996 | Best: 0.9996 | 2.2m | 0.4GB
üìà Epoch 2 | Loss: 0.033 | Val AUC: 0.9999 | Best: 0.9999 | 4.3m | 0.4GB
üìà Epoch 3 | Loss: 0.016 | Val AUC: 0.9999 | Best: 0.9999 | 6.6m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ organamnist DONE | Test AUC: 0.9983 | Gap: +0.0003

üî¨ RESNET-50: chestmnist
Using downloaded and verified file: C:\Users\User\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\chestmnist.npz
üìä Train: 78,468 | Val: 11,219 | Test: 22,433 | Classes: 14
‚ùå chestmnist FAILED: Expected floating point type for target with class probabilities, got Long

üî¨ RESNET-50: pneumoniamnist
Using downloaded and verified file: C:\Users\User\.medmnist\pneumoniamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pneumoniamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pneumoniamnist.npz
üìä Train: 4,708 | Val: 524 | Test: 624 | Classes: 2
üìà Epoch 1 | Loss: 0.164 | Val AUC: 0.9715 | Best: 0.9715 | 0.3m | 0.4GB
üìà Epoch 2 | Loss: 0.041 | Val AUC: 0.9903 | Best: 0.9903 | 0.6m | 0.4GB
üìà Epoch 3 | Loss: 0.029 | Val AUC: 0.9936 | Best: 0.9936

  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ pneumoniamnist DONE | Test AUC: 0.9667 | Gap: +0.0047

üî¨ RESNET-50: dermamnist
Using downloaded and verified file: C:\Users\User\.medmnist\dermamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\dermamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\dermamnist.npz
üìä Train: 7,007 | Val: 1,003 | Test: 2,005 | Classes: 7
üìà Epoch 1 | Loss: 0.796 | Val AUC: 0.9167 | Best: 0.9167 | 0.4m | 0.3GB
üìà Epoch 2 | Loss: 0.519 | Val AUC: 0.9251 | Best: 0.9251 | 0.9m | 0.3GB
üìà Epoch 3 | Loss: 0.228 | Val AUC: 0.9359 | Best: 0.9359 | 1.3m | 0.3GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ dermamnist DONE | Test AUC: 0.9281 | Gap: +0.0081

üî¨ RESNET-50: breastmnist
Using downloaded and verified file: C:\Users\User\.medmnist\breastmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\breastmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\breastmnist.npz
üìä Train: 546 | Val: 78 | Test: 156 | Classes: 2
üìà Epoch 1 | Loss: 0.546 | Val AUC: 0.8053 | Best: 0.8053 | 0.0m | 0.4GB
üìà Epoch 2 | Loss: 0.262 | Val AUC: 0.7251 | Best: 0.8053 | 0.1m | 0.4GB
üìà Epoch 3 | Loss: 0.102 | Val AUC: 0.8521 | Best: 0.8521 | 0.1m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ breastmnist DONE | Test AUC: 0.7149 | Gap: -0.1511

üî¨ RESNET-50: organsmnist
Using downloaded and verified file: C:\Users\User\.medmnist\organsmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organsmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organsmnist.npz
üìä Train: 13,932 | Val: 2,452 | Test: 8,827 | Classes: 11
üìà Epoch 1 | Loss: 0.512 | Val AUC: 0.9927 | Best: 0.9927 | 0.9m | 0.4GB
üìà Epoch 2 | Loss: 0.173 | Val AUC: 0.9950 | Best: 0.9950 | 1.8m | 0.4GB
üìà Epoch 3 | Loss: 0.069 | Val AUC: 0.9949 | Best: 0.9950 | 2.6m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ organsmnist DONE | Test AUC: 0.9801 | Gap: +0.0051

üî¨ RESNET-50: octmnist
Using downloaded and verified file: C:\Users\User\.medmnist\octmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\octmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\octmnist.npz
üìä Train: 97,477 | Val: 10,832 | Test: 1,000 | Classes: 4
üìà Epoch 1 | Loss: 0.353 | Val AUC: 0.9668 | Best: 0.9668 | 6.0m | 0.4GB
üìà Epoch 2 | Loss: 0.234 | Val AUC: 0.9708 | Best: 0.9708 | 11.8m | 0.4GB
üìà Epoch 3 | Loss: 0.116 | Val AUC: 0.9738 | Best: 0.9738 | 17.5m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ octmnist DONE | Test AUC: 0.9560 | Gap: -0.0020

üî¨ RESNET-50: retinamnist
Using downloaded and verified file: C:\Users\User\.medmnist\retinamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\retinamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\retinamnist.npz
üìä Train: 1,080 | Val: 120 | Test: 400 | Classes: 5
üìà Epoch 1 | Loss: 1.276 | Val AUC: 0.7587 | Best: 0.7587 | 0.1m | 0.4GB
üìà Epoch 2 | Loss: 0.935 | Val AUC: 0.7667 | Best: 0.7667 | 0.1m | 0.4GB
üìà Epoch 3 | Loss: 0.663 | Val AUC: 0.7283 | Best: 0.7667 | 0.2m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ retinamnist DONE | Test AUC: 0.7411 | Gap: +0.0251

üéØ FINAL RESULTS (3 EPOCHS)
Dataset         Test AUC   Benchmark  Gap
------------------------------------------------------------
bloodmnist      0.9980     0.998      +0.0000
breastmnist     0.7149     0.866      -0.1511
chestmnist      0.5000     0.773      -0.2730
dermamnist      0.9281     0.920      +0.0081
octmnist        0.9560     0.958      -0.0020
organamnist     0.9983     0.998      +0.0003
organcmnist     0.9953     0.993      +0.0023
organsmnist     0.9801     0.975      +0.0051
pathmnist       0.9924     0.989      +0.0034
pneumoniamnist  0.9667     0.962      +0.0047
retinamnist     0.7411     0.716      +0.0251
tissuemnist     0.9238     0.932      -0.0082
------------------------------------------------------------
üìä AVG GAP: -0.0321


In [1]:
!pip install -q medmnist pillow

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, models
from medmnist import INFO
import numpy as np
from sklearn.metrics import roc_auc_score
import time
from collections import defaultdict

print("="*70)
print("üî• RTX 4060 + RESNET-50 + BULLETPROOF FIX (ALL DATASETS)")
print("="*70)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

benchmarks = {
    'pathmnist': 0.989, 'chestmnist': 0.773, 'dermamnist': 0.920, 'octmnist': 0.958,
    'pneumoniamnist': 0.962, 'retinamnist': 0.716, 'breastmnist': 0.866, 'bloodmnist': 0.998,
    'tissuemnist': 0.932, 'organamnist': 0.998, 'organcmnist': 0.993, 'organsmnist': 0.975
}

all_results = {}
best_aucs = {}

def is_multilabel(ds_name):
    return ds_name == 'chestmnist'

def safe_target_processing(target, multilabel):
    """Handle all target shapes safely"""
    if multilabel:
        return target.float()
    else:
        if target.ndim > 1:
            target = target.squeeze(-1)
        if target.ndim == 0:
            target = target.unsqueeze(0)
        return target.long()

def train_single_dataset(ds_name, epochs=3):
    print(f"\n{'='*60}")
    print(f"üî¨ RESNET-50: {ds_name}")
    print(f"{'='*60}")
    
    info = INFO[ds_name]
    module = __import__('medmnist', fromlist=[info['python_class']])
    DataClass = getattr(module, info['python_class'])
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    train_ds = DataClass(split='train', transform=transform, download=True, as_rgb=True)
    val_ds = DataClass(split='val', transform=transform, download=True, as_rgb=True)
    test_ds = DataClass(split='test', transform=transform, download=True, as_rgb=True)
    
    num_classes = len(info['label'])
    multilabel = is_multilabel(ds_name)
    
    # DYNAMIC BATCH SIZE - FIX FOR SMALL DATASETS
    train_batch = min(96, len(train_ds) // 16)
    val_batch = min(192, len(val_ds) // 8)
    test_batch = min(192, len(test_ds) // 8)
    
    print(f"üìä Train: {len(train_ds):,} | Val: {len(val_ds):,} | Test: {len(test_ds):,} | Classes: {num_classes}")
    print(f"üì¶ Batches: {train_batch}/{val_batch}/{test_batch} | Multi-label: {multilabel}")
    
    train_loader = DataLoader(train_ds, batch_size=train_batch, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=val_batch, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=test_batch, shuffle=False, pin_memory=True)
    
    model = models.resnet50(weights='IMAGENET1K_V2')
    model.fc = nn.Linear(2048, num_classes)
    model = model.to(device)
    
    for param in list(model.parameters())[:-50]:
        param.requires_grad = False
    
    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), 
                                 lr=1e-3, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    
    best_auc = 0
    start_time = time.time()
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        train_batches = 0
        
        for data, target in train_loader:
            data = data.to(device, non_blocking=True)
            target = safe_target_processing(target.to(device), multilabel)
            
            # ENSURE SHAPES MATCH
            if data.size(0) != target.size(0):
                min_batch = min(data.size(0), target.size(0))
                data = data[:min_batch]
                target = target[:min_batch]
            
            optimizer.zero_grad()
            output = model(data)
            
            if multilabel:
                loss = F.binary_cross_entropy_with_logits(output, target)
            else:
                loss = F.cross_entropy(output, target)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item()
            train_batches += 1
        
        avg_train_loss = train_loss / train_batches
        
        # Validation
        model.eval()
        val_preds, val_targets = [], []
        
        with torch.no_grad():
            for data, target in val_loader:
                data = data.to(device, non_blocking=True)
                target = safe_target_processing(target.to(device), multilabel)
                
                if data.size(0) != target.size(0):
                    min_batch = min(data.size(0), target.size(0))
                    data = data[:min_batch]
                    target = target[:min_batch]
                
                if multilabel:
                    output = model(data)
                    prob = torch.sigmoid(output)
                else:
                    output = model(data)
                    prob = F.softmax(output, dim=1)
                
                val_preds.extend(prob.cpu().numpy())
                val_targets.extend(target.cpu().numpy())
        
        val_preds = np.array(val_preds)
        val_targets = np.array(val_targets)
        
        try:
            if multilabel:
                val_auc = roc_auc_score(val_targets, val_preds, average='macro')
            elif num_classes > 2:
                val_auc = roc_auc_score(val_targets, val_preds, multi_class='ovr')
            else:
                val_auc = roc_auc_score(val_targets, val_preds[:, 1])
        except:
            val_auc = 0.5
        
        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(model.state_dict(), f'{ds_name}_best.pth')
            best_aucs[ds_name] = val_auc
        
        scheduler.step()
        
        elapsed = time.time() - start_time
        vram = torch.cuda.memory_allocated()/1e9
        print(f"üìà Epoch {epoch+1} | Loss: {avg_train_loss:.3f} | Val AUC: {val_auc:.4f} | Best: {best_auc:.4f} | {elapsed/60:.1f}m | {vram:.1f}GB")
    
    # Test
    model.load_state_dict(torch.load(f'{ds_name}_best.pth'))
    model.eval()
    
    test_preds, test_targets = [], []
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device, non_blocking=True)
            target = safe_target_processing(target.to(device), multilabel)
            
            if data.size(0) != target.size(0):
                min_batch = min(data.size(0), target.size(0))
                data = data[:min_batch]
                target = target[:min_batch]
            
            if multilabel:
                output = model(data)
                prob = torch.sigmoid(output)
            else:
                output = model(data)
                prob = F.softmax(output, dim=1)
            
            test_preds.extend(prob.cpu().numpy())
            test_targets.extend(target.cpu().numpy())
    
    test_preds = np.array(test_preds)
    test_targets = np.array(test_targets)
    
    try:
        if multilabel:
            test_auc = roc_auc_score(test_targets, test_preds, average='macro')
        elif num_classes > 2:
            test_auc = roc_auc_score(test_targets, test_preds, multi_class='ovr')
        else:
            test_auc = roc_auc_score(test_targets, test_preds[:, 1])
    except:
        test_auc = 0.5
    
    total_time = time.time() - start_time
    bench = benchmarks.get(ds_name, 0)
    print(f"‚úÖ {ds_name} | Test AUC: {test_auc:.4f} | Bench: {bench:.3f} | Gap: {test_auc-bench:+.4f}")
    
    all_results[ds_name] = test_auc
    return test_auc

# Run ALL datasets
datasets = [
    'bloodmnist', 'tissuemnist', 'pathmnist', 'organcmnist',
    'organamnist', 'chestmnist', 'pneumoniamnist', 'dermamnist',
    'breastmnist', 'organsmnist', 'octmnist', 'retinamnist'
]

print("\nüöÄ BULLETPROOF TRAINING - NO MORE ERRORS!\n")

for ds_name in datasets:
    try:
        train_single_dataset(ds_name, epochs=3)
    except Exception as e:
        print(f"‚ùå {ds_name} CRASHED: {e}")
        all_results[ds_name] = 0.5

print("\n" + "="*70)
print("üéØ ALL RESULTS")
print("="*70)
for ds in sorted(all_results):
    bench = benchmarks.get(ds, 0)
    gap = all_results[ds] - bench
    print(f"{ds:<15} Test: {all_results[ds]:.4f} | Gap: {gap:+.4f}")


üî• RTX 4060 + RESNET-50 + BULLETPROOF FIX (ALL DATASETS)
Device: cuda

üöÄ BULLETPROOF TRAINING - NO MORE ERRORS!


üî¨ RESNET-50: bloodmnist
Using downloaded and verified file: C:\Users\User\.medmnist\bloodmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\bloodmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\bloodmnist.npz
üìä Train: 11,959 | Val: 1,712 | Test: 3,421 | Classes: 8
üì¶ Batches: 96/192/192 | Multi-label: False
üìà Epoch 1 | Loss: 0.362 | Val AUC: 0.9967 | Best: 0.9967 | 0.7m | 0.4GB
üìà Epoch 2 | Loss: 0.100 | Val AUC: 0.9967 | Best: 0.9967 | 1.5m | 0.4GB
üìà Epoch 3 | Loss: 0.020 | Val AUC: 0.9980 | Best: 0.9980 | 2.2m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ bloodmnist | Test AUC: 0.9978 | Bench: 0.998 | Gap: -0.0002

üî¨ RESNET-50: tissuemnist
Using downloaded and verified file: C:\Users\User\.medmnist\tissuemnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\tissuemnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\tissuemnist.npz
üìä Train: 165,466 | Val: 23,640 | Test: 47,280 | Classes: 8
üì¶ Batches: 96/192/192 | Multi-label: False
üìà Epoch 1 | Loss: 1.042 | Val AUC: 0.9155 | Best: 0.9155 | 10.3m | 0.3GB
üìà Epoch 2 | Loss: 0.862 | Val AUC: 0.9233 | Best: 0.9233 | 20.4m | 0.3GB
üìà Epoch 3 | Loss: 0.618 | Val AUC: 0.9220 | Best: 0.9233 | 30.5m | 0.3GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ tissuemnist | Test AUC: 0.9241 | Bench: 0.932 | Gap: -0.0079

üî¨ RESNET-50: pathmnist
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pathmnist.npz
üìä Train: 89,996 | Val: 10,004 | Test: 7,180 | Classes: 9
üì¶ Batches: 96/192/192 | Multi-label: False
üìà Epoch 1 | Loss: 0.148 | Val AUC: 0.9996 | Best: 0.9996 | 5.4m | 0.3GB
üìà Epoch 2 | Loss: 0.042 | Val AUC: 0.9997 | Best: 0.9997 | 10.7m | 0.3GB
üìà Epoch 3 | Loss: 0.012 | Val AUC: 0.9999 | Best: 0.9999 | 16.1m | 0.3GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ pathmnist | Test AUC: 0.9922 | Bench: 0.989 | Gap: +0.0032

üî¨ RESNET-50: organcmnist
Using downloaded and verified file: C:\Users\User\.medmnist\organcmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organcmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organcmnist.npz
üìä Train: 12,975 | Val: 2,392 | Test: 8,216 | Classes: 11
üì¶ Batches: 96/192/192 | Multi-label: False
üìà Epoch 1 | Loss: 0.344 | Val AUC: 0.9994 | Best: 0.9994 | 0.8m | 0.4GB
üìà Epoch 2 | Loss: 0.066 | Val AUC: 0.9992 | Best: 0.9994 | 1.6m | 0.4GB
üìà Epoch 3 | Loss: 0.013 | Val AUC: 0.9997 | Best: 0.9997 | 2.4m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ organcmnist | Test AUC: 0.9955 | Bench: 0.993 | Gap: +0.0025

üî¨ RESNET-50: organamnist
Using downloaded and verified file: C:\Users\User\.medmnist\organamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organamnist.npz
üìä Train: 34,561 | Val: 6,491 | Test: 17,778 | Classes: 11
üì¶ Batches: 96/192/192 | Multi-label: False
üìà Epoch 1 | Loss: 0.186 | Val AUC: 0.9999 | Best: 0.9999 | 2.2m | 0.4GB
üìà Epoch 2 | Loss: 0.027 | Val AUC: 0.9999 | Best: 0.9999 | 4.3m | 0.4GB
üìà Epoch 3 | Loss: 0.010 | Val AUC: 1.0000 | Best: 1.0000 | 6.5m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ organamnist | Test AUC: 0.9986 | Bench: 0.998 | Gap: +0.0006

üî¨ RESNET-50: chestmnist
Using downloaded and verified file: C:\Users\User\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\chestmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\chestmnist.npz
üìä Train: 78,468 | Val: 11,219 | Test: 22,433 | Classes: 14
üì¶ Batches: 96/192/192 | Multi-label: True
üìà Epoch 1 | Loss: 0.168 | Val AUC: 0.7603 | Best: 0.7603 | 4.8m | 0.4GB
üìà Epoch 2 | Loss: 0.156 | Val AUC: 0.7630 | Best: 0.7630 | 9.5m | 0.4GB
üìà Epoch 3 | Loss: 0.134 | Val AUC: 0.7522 | Best: 0.7630 | 14.3m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ chestmnist | Test AUC: 0.7612 | Bench: 0.773 | Gap: -0.0118

üî¨ RESNET-50: pneumoniamnist
Using downloaded and verified file: C:\Users\User\.medmnist\pneumoniamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pneumoniamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\pneumoniamnist.npz
üìä Train: 4,708 | Val: 524 | Test: 624 | Classes: 2
üì¶ Batches: 96/65/78 | Multi-label: False
üìà Epoch 1 | Loss: 0.164 | Val AUC: 0.9847 | Best: 0.9847 | 0.3m | 0.3GB
üìà Epoch 2 | Loss: 0.031 | Val AUC: 0.9913 | Best: 0.9913 | 0.6m | 0.3GB
üìà Epoch 3 | Loss: 0.005 | Val AUC: 0.9939 | Best: 0.9939 | 0.9m | 0.3GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ pneumoniamnist | Test AUC: 0.9632 | Bench: 0.962 | Gap: +0.0012

üî¨ RESNET-50: dermamnist
Using downloaded and verified file: C:\Users\User\.medmnist\dermamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\dermamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\dermamnist.npz
üìä Train: 7,007 | Val: 1,003 | Test: 2,005 | Classes: 7
üì¶ Batches: 96/125/192 | Multi-label: False
üìà Epoch 1 | Loss: 0.794 | Val AUC: 0.9098 | Best: 0.9098 | 0.4m | 0.3GB
üìà Epoch 2 | Loss: 0.507 | Val AUC: 0.9297 | Best: 0.9297 | 0.9m | 0.3GB
üìà Epoch 3 | Loss: 0.182 | Val AUC: 0.9335 | Best: 0.9335 | 1.3m | 0.3GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ dermamnist | Test AUC: 0.9258 | Bench: 0.920 | Gap: +0.0058

üî¨ RESNET-50: breastmnist
Using downloaded and verified file: C:\Users\User\.medmnist\breastmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\breastmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\breastmnist.npz
üìä Train: 546 | Val: 78 | Test: 156 | Classes: 2
üì¶ Batches: 34/9/19 | Multi-label: False
üìà Epoch 1 | Loss: 0.494 | Val AUC: 0.6023 | Best: 0.6023 | 0.0m | 0.3GB
üìà Epoch 2 | Loss: 0.156 | Val AUC: 0.8179 | Best: 0.8179 | 0.1m | 0.3GB
üìà Epoch 3 | Loss: 0.028 | Val AUC: 0.8964 | Best: 0.8964 | 0.1m | 0.3GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ breastmnist | Test AUC: 0.8409 | Bench: 0.866 | Gap: -0.0251

üî¨ RESNET-50: organsmnist
Using downloaded and verified file: C:\Users\User\.medmnist\organsmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organsmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\organsmnist.npz
üìä Train: 13,932 | Val: 2,452 | Test: 8,827 | Classes: 11
üì¶ Batches: 96/192/192 | Multi-label: False
üìà Epoch 1 | Loss: 0.540 | Val AUC: 0.9942 | Best: 0.9942 | 0.9m | 0.4GB
üìà Epoch 2 | Loss: 0.195 | Val AUC: 0.9940 | Best: 0.9942 | 1.7m | 0.4GB
üìà Epoch 3 | Loss: 0.078 | Val AUC: 0.9948 | Best: 0.9948 | 2.6m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ organsmnist | Test AUC: 0.9809 | Bench: 0.975 | Gap: +0.0059

üî¨ RESNET-50: octmnist
Using downloaded and verified file: C:\Users\User\.medmnist\octmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\octmnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\octmnist.npz
üìä Train: 97,477 | Val: 10,832 | Test: 1,000 | Classes: 4
üì¶ Batches: 96/192/125 | Multi-label: False
üìà Epoch 1 | Loss: 0.354 | Val AUC: 0.9624 | Best: 0.9624 | 5.8m | 0.4GB
üìà Epoch 2 | Loss: 0.233 | Val AUC: 0.9719 | Best: 0.9719 | 11.6m | 0.4GB
üìà Epoch 3 | Loss: 0.115 | Val AUC: 0.9729 | Best: 0.9729 | 17.4m | 0.4GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ octmnist | Test AUC: 0.9523 | Bench: 0.958 | Gap: -0.0057

üî¨ RESNET-50: retinamnist
Using downloaded and verified file: C:\Users\User\.medmnist\retinamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\retinamnist.npz
Using downloaded and verified file: C:\Users\User\.medmnist\retinamnist.npz
üìä Train: 1,080 | Val: 120 | Test: 400 | Classes: 5
üì¶ Batches: 67/15/50 | Multi-label: False
üìà Epoch 1 | Loss: 1.257 | Val AUC: 0.7344 | Best: 0.7344 | 0.1m | 0.3GB
üìà Epoch 2 | Loss: 0.893 | Val AUC: 0.7718 | Best: 0.7718 | 0.1m | 0.3GB
üìà Epoch 3 | Loss: 0.445 | Val AUC: 0.7890 | Best: 0.7890 | 0.2m | 0.3GB


  model.load_state_dict(torch.load(f'{ds_name}_best.pth'))


‚úÖ retinamnist | Test AUC: 0.7123 | Bench: 0.716 | Gap: -0.0037

üéØ ALL RESULTS
bloodmnist      Test: 0.9978 | Gap: -0.0002
breastmnist     Test: 0.8409 | Gap: -0.0251
chestmnist      Test: 0.7612 | Gap: -0.0118
dermamnist      Test: 0.9258 | Gap: +0.0058
octmnist        Test: 0.9523 | Gap: -0.0057
organamnist     Test: 0.9986 | Gap: +0.0006
organcmnist     Test: 0.9955 | Gap: +0.0025
organsmnist     Test: 0.9809 | Gap: +0.0059
pathmnist       Test: 0.9922 | Gap: +0.0032
pneumoniamnist  Test: 0.9632 | Gap: +0.0012
retinamnist     Test: 0.7123 | Gap: -0.0037
tissuemnist     Test: 0.9241 | Gap: -0.0079
