In [15]:
##Environment Setup
!pip install medmnist monai torch torchvision scikit-learn

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torchvision.transforms import ToTensor, Lambda, Compose
from torchvision.models import resnet50
from monai.networks.nets import resnet as monai_resnet
from medmnist import INFO, Evaluator
import medmnist
import random
from collections import defaultdict






In [16]:
# All 12 2D datasets
#Datasets Config

dataset_names_2d = [
    'breastmnist', 'retinamnist', 'pneumoniamnist', 'dermamnist',
    'organamnist', 'organcmnist', 'organsmnist', 'bloodmnist',
    'octmnist', 'chestmnist', 'pathmnist', 'tissuemnist'
]
dataset_names_3d = [
    'organmnist3d', 'nodulemnist3d', 'fracturemnist3d',
    'synapsemnist3d', 'vesselmnist3d', 'adrenalmnist3d'
]
all_datasets = dataset_names_2d + dataset_names_3d




In [17]:
##Unified DataLoader with Temperature-Scaled Sampling
def get_dataloader(name, split, batch_size):
    dsinfo = INFO[name]
    DataClass = getattr(medmnist, dsinfo['python_class'])
    # 2D grayscale becomes 3 channel
    n_channels = dsinfo.get('n_channels', 1)
    if name in dataset_names_2d:
        if n_channels == 1:
            transform = Compose([ToTensor(), Lambda(lambda x: x.repeat(3, 1, 1))])
        else:
            transform = ToTensor()
    else:
        transform = Lambda(lambda x: torch.tensor(x).float())
    dataset = DataClass(split=split, transform=transform, download=True)
    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size,
        shuffle=(split=='train'), num_workers=2, drop_last=False
    )
    n_classes = dsinfo.get('n_classes', dsinfo.get('num_classes', len(dsinfo.get('label', {}))))
    n_samples = dsinfo['n_samples'][split]
    return loader, n_classes, n_samples

dataloaders_train = {}
dataloaders_test = {}
dataset_specs = {}
dataset_sizes = {}

for name in all_datasets:
    train_loader, n_classes, n_train = get_dataloader(name, 'train', 32)
    test_loader, _, _ = get_dataloader(name, 'test', 32)
    dataloaders_train[name] = train_loader
    dataloaders_test[name] = test_loader
    dataset_specs[name] = {'num_classes': n_classes, 'is_3d': name in dataset_names_3d}
    dataset_sizes[name] = n_train




In [18]:
#Model Architecture - Unified Multi-Task Network
class Adapter(nn.Module):
    def __init__(self, dim, hidden_dim=64, is_3d=False):
        super().__init__()
        conv = nn.Conv3d if is_3d else nn.Conv2d
        self.adapter = nn.Sequential(
            conv(dim, hidden_dim, 1),
            nn.ReLU(inplace=True),
            conv(hidden_dim, dim, 1)
        )
    def forward(self, x):
        return self.adapter(x) + x

class MedFusionNet(nn.Module):
    def __init__(self, dataset_specs):
        super().__init__()
        self.encoder_2d = resnet50(weights='IMAGENET1K_V1')
        self.encoder_2d = nn.Sequential(*list(self.encoder_2d.children())[:-2])
        m3d = monai_resnet.resnet50(spatial_dims=3, n_input_channels=1, num_classes=1)
        self.encoder_3d = nn.Sequential(*list(m3d.children())[:-2])
        self.adapters = nn.ModuleDict()
        self.heads = nn.ModuleDict()
        self.dataset_specs = dataset_specs
        for ds, spec in dataset_specs.items():
            is_3d = spec['is_3d']
            self.adapters[ds] = Adapter(2048, 64, is_3d)
            self.heads[ds] = nn.Linear(2048, spec['num_classes'])
    def forward(self, x, ds_name):
        is_3d = self.dataset_specs[ds_name]['is_3d']
        if is_3d:
            feats = self.encoder_3d(x)
            feats = self.adapters[ds_name](feats)
            pooled = F.adaptive_avg_pool3d(feats, 1).flatten(1)
        else:
            feats = self.encoder_2d(x)
            feats = self.adapters[ds_name](feats)
            pooled = F.adaptive_avg_pool2d(feats, 1).flatten(1)
        return self.heads[ds_name](pooled)




In [19]:
#Uncertainty-Weighted Multi-Task Loss
class UncertaintyWeightedLoss(nn.Module):
    def __init__(self, dataset_names):
        super().__init__()
        self.log_vars = nn.ParameterDict({
            name: nn.Parameter(torch.zeros(1)) for name in dataset_names
        })
    def forward(self, loss, ds_name):
        precision = torch.exp(-self.log_vars[ds_name])
        weighted_loss = precision * loss + self.log_vars[ds_name]
        return weighted_loss


In [20]:
from torch.cuda.amp import GradScaler, autocast

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MedFusionNet(dataset_specs).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
uncertainty_loss = UncertaintyWeightedLoss(all_datasets).to(device)
scaler = GradScaler()

config = {
    '2d': {'epochs': 10, 'batch_size': 64, 'accumulation_steps': 2},
    '3d': {'epochs': 15, 'batch_size': 32, 'accumulation_steps': 4}
}
multilabel_datasets = ['chestmnist']

def compute_sampling_probs(dataset_sizes, temperature=0.5):
    sizes = np.array([dataset_sizes[ds] for ds in all_datasets])
    probs = sizes ** temperature
    probs = probs / probs.sum()
    return {ds: p for ds, p in zip(all_datasets, probs)}

sampling_probs = compute_sampling_probs(dataset_sizes, temperature=0.5)
print("Sampling probabilities computed!")

num_epochs = 15
steps_per_epoch = 100

for epoch in range(1, num_epochs + 1):
    model.train()
    epoch_losses = defaultdict(list)
    optimizer.zero_grad()
    step_count = 0
    for step in range(steps_per_epoch):
        ds_name = np.random.choice(all_datasets, p=list(sampling_probs.values()))
        is_3d = dataset_specs[ds_name]['is_3d']
        cfg = config['3d'] if is_3d else config['2d']
        # Skip if dataset type has finished its epochs
        if is_3d and epoch > cfg['epochs']: continue
        if not is_3d and epoch > cfg['epochs']: continue
        # Use prebuilt dataloader
        loader = dataloaders_train[ds_name]
        try: x, y = next(iter(loader))
        except: continue
        if x.shape[0] < 2: continue
        x = x.to(device)
        if ds_name in multilabel_datasets:
            y = y.to(device).float()
            with autocast():
                logits = model(x, ds_name)
                loss = F.binary_cross_entropy_with_logits(logits, y)
                weighted_loss = uncertainty_loss(loss, ds_name) / cfg['accumulation_steps']
        else:
            y = y.to(device).view(-1)
            if x.shape[0] != y.shape[0]: continue
            with autocast():
                logits = model(x, ds_name)
                loss = F.cross_entropy(logits, y)
                weighted_loss = uncertainty_loss(loss, ds_name) / cfg['accumulation_steps']
        scaler.scale(weighted_loss).backward()
        if (step + 1) % cfg['accumulation_steps'] == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
        epoch_losses[ds_name].append(loss.item())
        step_count += 1
    print(f"Epoch {epoch}/{num_epochs} (steps: {step_count})")
    for ds in sorted(all_datasets):
        is_3d = dataset_specs[ds]['is_3d']
        max_ep = config['3d']['epochs'] if is_3d else config['2d']['epochs']
        if epoch <= max_ep and ds in epoch_losses and len(epoch_losses[ds]) > 0:
            print(f"  {ds}: {np.mean(epoch_losses[ds]):.4f}")
print("\n✓ Training complete!")
torch.save(model.state_dict(), 'medfusionnet_model.pth')
print("Model saved as 'medfusionnet_model.pth'")


Sampling probabilities computed!


  scaler = GradScaler()
  with autocast():
  with autocast():


Epoch 1/15 (steps: 100)
  adrenalmnist3d: 0.7921
  bloodmnist: 2.0414
  breastmnist: 0.8185
  chestmnist: 0.4834
  dermamnist: 1.5693
  fracturemnist3d: 1.1032
  nodulemnist3d: 1.2979
  octmnist: 1.1666
  organamnist: 2.3932
  organcmnist: 2.1573
  organmnist3d: 2.5559
  organsmnist: 2.1440
  pathmnist: 2.1407
  pneumoniamnist: 0.6589
  retinamnist: 1.8071
  synapsemnist3d: 0.6375
  tissuemnist: 1.9374
  vesselmnist3d: 0.6156
Epoch 2/15 (steps: 100)
  adrenalmnist3d: 0.9176
  bloodmnist: 1.6221
  chestmnist: 0.2248
  dermamnist: 1.5016
  fracturemnist3d: 1.5298
  nodulemnist3d: 0.8582
  octmnist: 1.0554
  organamnist: 1.6930
  organcmnist: 1.7181
  organmnist3d: 2.4992
  organsmnist: 1.7649
  pathmnist: 1.5982
  pneumoniamnist: 0.4012
  retinamnist: 1.5550
  tissuemnist: 1.5383
Epoch 3/15 (steps: 100)
  adrenalmnist3d: 0.6965
  bloodmnist: 1.3964
  chestmnist: 0.1965
  dermamnist: 1.2729
  octmnist: 0.9024
  organamnist: 1.1869
  organcmnist: 1.3889
  organmnist3d: 2.2756
  organsmnist

In [22]:
results_dict = {}
model.eval()
with torch.no_grad():
    for name in all_datasets:
        test_loader = dataloaders_test[name]
        all_probs, all_labels = [], []
        for x, y in test_loader:
            x = x.to(device)
            logits = model(x, name)
            if name == 'chestmnist':
                probs = torch.sigmoid(logits).cpu().numpy()
            else:
                probs = F.softmax(logits, dim=1).cpu().numpy()
            all_probs.append(probs)
            all_labels.append(y.numpy())
        all_probs = np.concatenate(all_probs)
        all_labels = np.concatenate(all_labels).flatten()
        evaluator = Evaluator(name, split='test')
        metrics = evaluator.evaluate(all_probs)

        # Robust extraction of auc regardless of return type
        if isinstance(metrics, tuple):
            metrics_dict, metric_names = metrics
            if isinstance(metrics_dict, dict):
                auc = metrics_dict.get('auc', 0)
            else:
                auc = metrics_dict
        elif isinstance(metrics, dict):
            auc = metrics.get('auc', 0)
        elif isinstance(metrics, (list, np.ndarray)):
            auc = metrics[1] if len(metrics) > 1 else metrics[0]
        else:
            auc = float(metrics)

        print(f"{name} Test AUC: {auc:.4f}")
        results_dict[name] = auc

print("\n=== Final Results ===")
for k, v in results_dict.items():
    print(f"{k}: {v:.4f}")
print(f"Average AUC: {np.mean(list(results_dict.values())):.4f}")


breastmnist Test AUC: 0.6784
retinamnist Test AUC: 0.5811
pneumoniamnist Test AUC: 0.8628
dermamnist Test AUC: 0.4221
organamnist Test AUC: 0.9084
organcmnist Test AUC: 0.9240
organsmnist Test AUC: 0.8847
bloodmnist Test AUC: 0.7551
octmnist Test AUC: 0.7086
chestmnist Test AUC: 0.5308
pathmnist Test AUC: 0.5470
tissuemnist Test AUC: 0.6444
organmnist3d Test AUC: 0.7614
nodulemnist3d Test AUC: 0.6731
fracturemnist3d Test AUC: 0.5640
synapsemnist3d Test AUC: 0.6029
vesselmnist3d Test AUC: 0.7111
adrenalmnist3d Test AUC: 0.7762

=== Final Results ===
breastmnist: 0.6784
retinamnist: 0.5811
pneumoniamnist: 0.8628
dermamnist: 0.4221
organamnist: 0.9084
organcmnist: 0.9240
organsmnist: 0.8847
bloodmnist: 0.7551
octmnist: 0.7086
chestmnist: 0.5308
pathmnist: 0.5470
tissuemnist: 0.6444
organmnist3d: 0.7614
nodulemnist3d: 0.6731
fracturemnist3d: 0.5640
synapsemnist3d: 0.6029
vesselmnist3d: 0.7111
adrenalmnist3d: 0.7762
Average AUC: 0.6964
