Med-Vit-3D

In [None]:
import tensorflow as tf
## reduce GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
!pip install medmnist==3.0.1 \
    torchattacks

In [None]:
pip install torchio

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
from torchsummary import summary

from tqdm import tqdm
import medmnist
from medmnist import INFO, Evaluator

import torchattacks
from torchattacks import PGD, FGSM
from torch.utils.data import Subset, random_split, DataLoader
from torchvision import transforms
import torchvision
import random
import torch
import pandas as pd
from torchvision.transforms.transforms import Resize

from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.metrics import roc_auc_score, accuracy_score
import pandas as pd
import time

import torchio as tio

In [None]:
print("PyTorch", torch.__version__)
print("Torchvision", torchvision.__version__)
print("Torchattacks", torchattacks.__version__)
print("Numpy", np.__version__)
print("Medmnist", medmnist.__version__)

In [None]:
### lower down sample: 1000 sample

import torchio as tio   # for 3D augmentations

def med3dfulldata(DataClass, BATCH_SIZE=15, downsample=1000):
    # Safe base transform (keeps shape as 1×D×H×W)
    def base_transform(x):
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x)
        x = x.float()
        if x.ndim == 3:   # (D,H,W) → add channel
            x = x.unsqueeze(0)
        return x  # (1,D,H,W)

    # TorchIO augmentations for training
    train_transform = tio.Compose([
        tio.RandomFlip(axes=(0, 1, 2), flip_probability=0.5),
        tio.RandomAffine(scales=(0.9, 1.1), degrees=(0, 15), translation=(0, 5)),
        tio.RandomNoise(mean=0, std=0.05),
        tio.RandomBiasField(coefficients=0.3),
        tio.Lambda(base_transform)  # final tensor conversion
    ])

    # Only tensor conversion for val/test
    test_transform = tio.Lambda(base_transform)

    # Load datasets
    train_dataset = DataClass(split='train', transform=train_transform, download=True)
    val_dataset   = DataClass(split='val', transform=test_transform, download=True)
    test_dataset  = DataClass(split='test', transform=test_transform, download=True)

    # Optional downsampling
    if downsample and downsample < len(train_dataset):
        indices = random.sample(range(len(train_dataset)), downsample)
        train_dataset = Subset(train_dataset, indices)

    # Dataloaders
    train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
    test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
    return train_loader, train_loader_at_eval, test_loader




# --- Helper to create a single batch with all classes ---
def get_one_batch_with_all_classes(dataset, num_classes, batch_size=32):
    loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)
    all_inputs, all_targets = [], []
    seen_classes = set()

    for x, y in loader:
        all_inputs.append(x)
        all_targets.append(y)
        seen_classes.update(y.squeeze().tolist())

        if len(seen_classes) == num_classes:
            all_inputs = torch.cat(all_inputs, dim=0)
            all_targets = torch.cat(all_targets, dim=0)
            return [(all_inputs, all_targets)]
    raise ValueError(f"Could not find all {num_classes} classes in the dataset.")
'''
def data_input_loading(data_flag, BATCH_SIZE=15, lr=0.0005, NUM_EPOCHS=5, total_samples=1500, med3d = False):
    download = True
    info = INFO[data_flag]
    task = info['task']
    n_channels = info['n_channels']
    n_classes = len(info['label'])
    DataClass = getattr(medmnist, info['python_class'])

    if med3d:
      train_loader, train_loader_at_eval, test_loader = med3dfulldata(DataClass)
    else:
      # preprocessing
      train_transform = transforms.Compose([
          transforms.Resize(224),
          transforms.Lambda(lambda image: image.convert('RGB')),
          torchvision.transforms.AugMix(),
          transforms.ToTensor(),
          transforms.Normalize(mean=[.5], std=[.5])
      ])
      test_transform = transforms.Compose([
          transforms.Resize(224),
          transforms.Lambda(lambda image: image.convert('RGB')),
          transforms.ToTensor(),
          transforms.Normalize(mean=[.5], std=[.5])
      ])

      # Load full dataset first (train + test combined)
      full_dataset = DataClass(split='train', transform=train_transform, download=download)
      test_dataset_full = DataClass(split='test', transform=test_transform, download=download)
      combined_dataset = torch.utils.data.ConcatDataset([full_dataset, test_dataset_full])

      # Randomly select only total_samples items
      if total_samples > len(combined_dataset):
        total_samples = len(combined_dataset)
      indices = random.sample(range(len(combined_dataset)), total_samples)
      small_dataset = Subset(combined_dataset, indices)

      # Split into train / val / test (e.g., 70/15/15 split)
      train_size = int(0.7 * total_samples)
      val_size = int(0.15 * total_samples)
      test_size = total_samples - train_size - val_size
      train_dataset, val_dataset, test_dataset = random_split(small_dataset, [train_size, val_size, test_size])

      # Dataloaders
      train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
      train_loader_at_eval = get_one_batch_with_all_classes(train_dataset, n_classes, batch_size=2*BATCH_SIZE)
      test_loader = get_one_batch_with_all_classes(test_dataset, n_classes, batch_size=2*BATCH_SIZE)

    return data_flag, NUM_EPOCHS, BATCH_SIZE, lr, task, train_loader, train_loader_at_eval, test_loader, n_classes
  '''

import numpy as np
import random
import torch
from torch.utils.data import DataLoader, Subset, random_split, WeightedRandomSampler
import torchvision.transforms as transforms
import torchvision
import medmnist

# Balanced DataLoader helper
from torch.utils.data import WeightedRandomSampler

def make_hybrid_loader(dataset, batch_size, alpha=0.5):
    """
    Hybrid sampler:
    - alpha = 1.0 → fully balanced (1/count)
    - alpha = 0.0 → natural distribution
    - e.g., alpha=0.5 → softer balance (1/sqrt(count))
    """
    labels = [dataset[i][1] for i in range(len(dataset))]
    class_counts = np.bincount(labels)

    # hybrid weighting: count^(-alpha)
    class_weights = 1. / (class_counts ** alpha)
    sample_weights = [class_weights[label] for label in labels]

    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(sample_weights),
        replacement=True
    )
    return DataLoader(dataset, batch_size=batch_size, sampler=sampler)


# Full pipeline with augmentation + balancing
def data_input_loading(data_flag, BATCH_SIZE=15, lr=0.0005, NUM_EPOCHS=5, total_samples=1500, med3d=False):
    download = True
    info = INFO[data_flag]
    task = info['task']
    n_channels = info['n_channels']
    n_classes = len(info['label'])
    DataClass = getattr(medmnist, info['python_class'])

    if med3d:
        # keep med3d as is (need TorchIO/MONAI for 3D aug)
        train_loader, train_loader_at_eval, test_loader = med3dfulldata(DataClass)

    else:
        # Augmentation for training
        train_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.Lambda(lambda image: image.convert('RGB')),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(degrees=15),
            transforms.RandomAffine(
                degrees=0,
                translate=(0.1, 0.1),
                scale=(0.9, 1.1)
            ),
            transforms.ColorJitter(
                brightness=0.2,
                contrast=0.2,
                saturation=0.2,
                hue=0.1
            ),
            transforms.AugMix(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
        ])

        # No augmentation for test/val
        test_transform = transforms.Compose([
            transforms.Resize(224),
            transforms.Lambda(lambda image: image.convert('RGB')),
            transforms.ToTensor(),
            transforms.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5])
        ])

        # Load full dataset first
        full_dataset = DataClass(split='train', transform=train_transform, download=download)
        test_dataset_full = DataClass(split='test', transform=test_transform, download=download)
        combined_dataset = torch.utils.data.ConcatDataset([full_dataset, test_dataset_full])

        # Randomly select only total_samples items
        if total_samples > len(combined_dataset):
            total_samples = len(combined_dataset)
        indices = random.sample(range(len(combined_dataset)), total_samples)
        small_dataset = Subset(combined_dataset, indices)

        # Split into train / val / test
        train_size = int(0.7 * total_samples)
        val_size = int(0.15 * total_samples)
        test_size = total_samples - train_size - val_size
        train_dataset, val_dataset, test_dataset = random_split(small_dataset, [train_size, val_size, test_size])

        # Balanced training loader
        train_loader = make_hybrid_loader(train_dataset, BATCH_SIZE)

        # Regular loaders for eval
        train_loader_at_eval = DataLoader(train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
        test_loader = DataLoader(test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

    return data_flag, NUM_EPOCHS, BATCH_SIZE, lr, task, train_loader, train_loader_at_eval, test_loader, n_classes



In [None]:
## loading model
from MedVit3D import MedViT3D_small
from MedViT import MedViT_small

In [None]:
## loading history
from sklearn.metrics import roc_auc_score, accuracy_score
import pandas as pd
import time

def getPrecision(y_true, y_score, task, threshold=0.5):
    y_true = y_true.squeeze()
    y_score = y_score.squeeze()

    if task == "multi-label, binary-class":
        y_pred = (y_score > threshold).astype(int)
        return precision_score(y_true, y_pred, average="macro", zero_division=0)
    elif task == "binary-class":
        if y_score.ndim == 2:
            y_score = y_score[:, -1]
        y_pred = (y_score > threshold).astype(int)
        return precision_score(y_true, y_pred, average="binary", zero_division=0)
    else:
        y_pred = np.argmax(y_score, axis=-1)
        return precision_score(y_true, y_pred, average="macro", zero_division=0)


def getRecall(y_true, y_score, task, threshold=0.5):
    y_true = y_true.squeeze()
    y_score = y_score.squeeze()

    if task == "multi-label, binary-class":
        y_pred = (y_score > threshold).astype(int)
        return recall_score(y_true, y_pred, average="macro", zero_division=0)
    elif task == "binary-class":
        if y_score.ndim == 2:
            y_score = y_score[:, -1]
        y_pred = (y_score > threshold).astype(int)
        return recall_score(y_true, y_pred, average="binary", zero_division=0)
    else:
        y_pred = np.argmax(y_score, axis=-1)
        return recall_score(y_true, y_pred, average="macro", zero_division=0)


def getF1(y_true, y_score, task, threshold=0.5):
    y_true = y_true.squeeze()
    y_score = y_score.squeeze()

    if task == "multi-label, binary-class":
        y_pred = (y_score > threshold).astype(int)
        return f1_score(y_true, y_pred, average="macro", zero_division=0)
    elif task == "binary-class":
        if y_score.ndim == 2:
            y_score = y_score[:, -1]
        y_pred = (y_score > threshold).astype(int)
        return f1_score(y_true, y_pred, average="binary", zero_division=0)
    else:
        y_pred = np.argmax(y_score, axis=-1)
        return f1_score(y_true, y_pred, average="macro", zero_division=0)

# evaluation
def getAUC(y_true, y_score, task):
    """AUC metric.
    :param y_true: the ground truth labels, shape: (n_samples, n_labels) or (n_samples,) if n_labels==1
    :param y_score: the predicted score of each class,
    shape: (n_samples, n_labels) or (n_samples, n_classes) or (n_samples,) if n_labels==1 or n_classes==1
    :param task: the task of current dataset
    """
    y_true = y_true.squeeze()
    y_score = y_score.squeeze()

    if task == "multi-label, binary-class":
        auc = 0
        for i in range(y_score.shape[1]):
            label_auc = roc_auc_score(y_true[:, i], y_score[:, i])
            auc += label_auc
        ret = auc / y_score.shape[1]
    elif task == "binary-class":
        if y_score.ndim == 2:
            y_score = y_score[:, -1]
        else:
            assert y_score.ndim == 1
        ret = roc_auc_score(y_true, y_score)
    else:
        auc = 0
        for i in range(y_score.shape[1]):
            y_true_binary = (y_true == i).astype(float)
            y_score_binary = y_score[:, i]
            auc += roc_auc_score(y_true_binary, y_score_binary)
        ret = auc / y_score.shape[1]

    return ret


def getACC(y_true, y_score, task, threshold=0.5):
    """Accuracy metric.
    :param y_true: the ground truth labels, shape: (n_samples, n_labels) or (n_samples,) if n_labels==1
    :param y_score: the predicted score of each class,
    shape: (n_samples, n_labels) or (n_samples, n_classes) or (n_samples,) if n_labels==1 or n_classes==1
    :param task: the task of current dataset
    :param threshold: the threshold for multilabel and binary-class tasks
    """
    y_true = y_true.squeeze()
    y_score = y_score.squeeze()

    if task == "multi-label, binary-class":
        y_pre = y_score > threshold
        acc = 0
        for label in range(y_true.shape[1]):
            label_acc = accuracy_score(y_true[:, label], y_pre[:, label])
            acc += label_acc
        ret = acc / y_true.shape[1]
    elif task == "binary-class":
        if y_score.ndim == 2:
            y_score = y_score[:, -1]
        else:
            assert y_score.ndim == 1
        ret = accuracy_score(y_true, y_score > threshold)
    else:
        ret = accuracy_score(y_true, np.argmax(y_score, axis=-1))

    return ret

def test(data_loader, model, criterion, task):
    model.eval()
    y_true = torch.tensor([]).to(device)
    y_score = torch.tensor([]).to(device)
    data_loader = data_loader
    total_loss = 0.0
    num_batches = 0
    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                loss = criterion(outputs, targets)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                loss = criterion(outputs, targets)
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)
            total_loss += loss.item()
            num_batches += 1

        y_true = y_true.cpu().numpy()
        y_score = y_score.detach().cpu().numpy()

        auc = getAUC(y_true, y_score, task)
        acc = getACC(y_true, y_score, task)
        avg_loss = total_loss / num_batches

        precision = getPrecision(y_true, y_score, task)
        recall = getRecall(y_true, y_score, task)
        f1 = getF1(y_true, y_score, task)

        return auc, acc ,avg_loss, precision, recall, f1 #, y_true, y_score


def load_or_initialize_model(model_class, model_name, optimizer_class, lr, momentum, n_classes, task):
    model_dir = "./history_record"
    os.makedirs(model_dir, exist_ok=True)

    model_path = os.path.join(model_dir, f"{model_name}.pth")
    history_path = os.path.join(model_dir, f"{model_name}.csv")

    if "3d" in model_name.lower():
      model = MedViT3D_small(num_classes = n_classes).to(device)
    else:
      model = MedViT_small(num_classes = n_classes).to(device)

    optimizer = optimizer_class(model.parameters(), lr=lr, momentum=momentum)

    start_epoch = 0
    best_val_auc = 0
    history = {
        "train_auc": [], "train_acc": [],
        "val_auc": [], "val_acc": [],
        "train_loss": [], "val_loss": [],
        "epoch_time": []
    }

    if os.path.exists(model_path) and os.path.exists(history_path):
        print(f"Loading existing model: {model_name}")
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        history = pd.read_csv(history_path).to_dict(orient='list')
        start_epoch = len(history["train_loss"])
        best_val_auc = max(history["val_auc"]) if history["val_auc"] else 0

    return model, optimizer, history, start_epoch, best_val_auc
def training_and_record(model_class,
                        model_name,
                        NUM_EPOCHS, lr,
                        momentum, train_loader,
                        train_loader_at_eval,
                        test_loader,
                        n_classes,
                        task,
                        steps):
    model, optimizer, history, start_epoch, best_val_auc = load_or_initialize_model(
        model_class, model_name, optimizer_class=torch.optim.SGD, lr=lr, momentum=momentum, n_classes = n_classes,
        task = task
    )
    if task == "multi-label, binary-class":
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()
    total_loss = 0
    step_count = 0  # counter for batch steps
    for epoch in range(start_epoch, start_epoch + NUM_EPOCHS):
        print(f'\nEpoch [{epoch + 1}/{start_epoch + NUM_EPOCHS}]')
        start_time = time.time()
        model.train()

        for inputs, targets in tqdm(train_loader):
            step_count += 1
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            unique_classes = np.unique(targets.cpu().numpy())
            # print(f"Unique class in target data batch is: {unique_classes.tolist()}  | Count: {len(unique_classes)}")
            if task == 'multi-label, binary-class':
                # print("Going to multi-label, bunary-class branch")
                # Ensure targets become [B, n_classes] float
                targets = torch.nn.functional.one_hot(
                    targets.squeeze().long(), num_classes=n_classes
                ).float().to(device)
                # print("target shape of original data before loss ", targets.shape)
                # print("output shape after model, before loss: ", outputs.shape)
                loss = criterion(outputs, targets)
            else:
                # print("going to ther branch")
                targets = targets.squeeze().long()  # labels become long
                # print("target shape of original data before loss ", targets.shape)
                # print("output shape after model, before loss: ", outputs.shape)
                loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            # stop after 20 steps per epoch
            # if step_count >= steps:
            #     print(f"Breaking after {step_count} steps in this epoch.")
            #     break

        torch.cuda.empty_cache()

        # Logging
        train_loss = total_loss / len(train_loader)
        # validation loss

        # train loss
        train_auc, train_acc, train_loss, precision, recall, f1 = test( train_loader_at_eval, model, criterion, task)
        val_auc, val_acc, val_loss, precision, recall, f1 = test(test_loader, model, criterion, task)

        print (f"Val ACC: {val_acc} | Val AUC: {val_auc} | Precision: {precision} | Recall {recall} | F1 Score: {f1}")


        history["train_auc"].append(train_auc)
        history["train_acc"].append(train_acc)
        history["train_loss"].append(train_loss)
        history["val_auc"].append(val_auc)
        history["val_acc"].append(val_acc)
        history["val_loss"].append(val_loss)

        # Log epoch time
        epoch_time = time.time() - start_time
        history["epoch_time"].append(epoch_time)
        print(f"Epoch {epoch+1} finished in {epoch_time:.2f} seconds")

        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            print("📌 New best AUC — saving model")
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, f"./history_record/{model_name}.pth")

        pd.DataFrame(history).to_csv(f"./history_record/{model_name}.csv", index=False)

    print("✅ Training complete.")
    return history



In [None]:
## Define same parameters
NUM_EPOCHS = 50
BATCH_SIZE = 42
lr = 0.0005
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
data_flag = "synapsemnist3d"
med3d = True
model_name = f"MedViT3D_{data_flag}_aug_1"
## loading data
(data_flag,
 NUM_EPOCHS,
 BATCH_SIZE,
 lr,
 task,
 train_loader,
 train_loader_at_eval,
 test_loader,
 n_classes) = data_input_loading(data_flag=data_flag,
                                NUM_EPOCHS = NUM_EPOCHS,
                                BATCH_SIZE=BATCH_SIZE,
                                lr=lr,
                                med3d= med3d)
## train and record
history = training_and_record(
    model_class=MedViT3D_small,
    model_name= model_name,
    NUM_EPOCHS=NUM_EPOCHS,
    lr=lr,
    momentum=0.9,
    train_loader=train_loader,
    train_loader_at_eval=train_loader_at_eval,
    test_loader=test_loader,
    n_classes = n_classes,
    task = task,
    steps = len(train_loader)
)

In [None]:
history.keys()

In [None]:
print (max(history['val_acc']), max(history['val_auc']))

In [None]:
#model.train()
total_loss = 0

for inputs, targets in train_loader:
    inputs = inputs.to(device)
    targets = targets.squeeze().long().to(device)
    outputs = history(inputs)
    loss = criterion(outputs, targets)
    loss.backward()

    total_loss += loss.item()

train_loss = total_loss / len(train_loader)
val_loss, val_acc, val_auc, precision, recall, f1 = test(model, test_loader, device)
_, train_acc, train_aucm, precision, recall, f1 = test(model, train_loader, device)

print (f"Model: {model_name} | Epochs: {NUM_EPOCHS} | Batch Size: {BATCH_SIZE}")
#print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train AUC: {train_auc:.4f}")
print(f"Val Acc:   {val_acc:.4f} | Val AUC:   {val_auc:.4f}")
print(f"Precision:   {precision:.4f} | Recall:   {recall:.4f} | F1 Score:   {f1:.4f}")

In [None]:
data_flag = "organmnist3d"
med3d = True
model_name = f"MedViT3D_{data_flag}"
## loading data
(data_flag,
 NUM_EPOCHS,
 BATCH_SIZE,
 lr,
 task,
 train_loader,
 train_loader_at_eval,
 test_loader,
 n_classes) = data_input_loading(data_flag=data_flag,
                                NUM_EPOCHS = NUM_EPOCHS,
                                BATCH_SIZE=BATCH_SIZE,
                                lr=lr,
                                med3d= med3d)
## train and record
history = training_and_record(
    model_class=MedViT3D_small,
    model_name= model_name,
    NUM_EPOCHS=NUM_EPOCHS,
    lr=lr,
    momentum=0.9,
    train_loader=train_loader,
    train_loader_at_eval=train_loader_at_eval,
    test_loader=test_loader,
    n_classes = n_classes,
    task = task,
    steps = len(train_loader)
)

In [None]:
data_flag = "adrenalmnist3d"
med3d = True
model_name = f"MedViT3D_{data_flag}_v3"
## loading data
(data_flag,
 NUM_EPOCHS,
 BATCH_SIZE,
 lr,
 task,
 train_loader,
 train_loader_at_eval,
 test_loader,
 n_classes) = data_input_loading(data_flag=data_flag,
                                NUM_EPOCHS = NUM_EPOCHS,
                                BATCH_SIZE=BATCH_SIZE,
                                lr=lr,
                                med3d= med3d)
## train and record
history = training_and_record(
    model_class=MedViT3D_small,
    model_name= model_name,
    NUM_EPOCHS=NUM_EPOCHS,
    lr=lr,
    momentum=0.9,
    train_loader=train_loader,
    train_loader_at_eval=train_loader_at_eval,
    test_loader=test_loader,
    n_classes = n_classes,
    task = task,
    steps = len(train_loader)
)

In [None]:
history.keys()

In [None]:
print (max(history['val_acc']), max(history['val_auc']))

In [None]:
history['val_auc']

In [None]:
data_flag = "fracturemnist3d"
med3d = True
model_name = f"MedViT3D_{data_flag}_100_1"
## loading data
(data_flag,
 NUM_EPOCHS,
 BATCH_SIZE,
 lr,
 task,
 train_loader,
 train_loader_at_eval,
 test_loader,
 n_classes) = data_input_loading(data_flag=data_flag,
                                NUM_EPOCHS = NUM_EPOCHS,
                                BATCH_SIZE=BATCH_SIZE,
                                lr=lr,
                                med3d= med3d)
## train and record
history = training_and_record(
    model_class=MedViT3D_small,
    model_name= model_name,
    NUM_EPOCHS=NUM_EPOCHS,
    lr=lr,
    momentum=0.9,
    train_loader=train_loader,
    train_loader_at_eval=train_loader_at_eval,
    test_loader=test_loader,
    n_classes = n_classes,
    task = task,
    steps = len(train_loader)
)

In [None]:
history.keys()

In [None]:
print (max(history['val_auc']), max(history['val_acc']))

In [None]:
history['val_auc']

In [None]:
data_flag = "vesselmnist3d"
med3d = True
model_name = f"MedViT3D_{data_flag}"
## loading data
(data_flag,
 NUM_EPOCHS,
 BATCH_SIZE,
 lr,
 task,
 train_loader,
 train_loader_at_eval,
 test_loader,
 n_classes) = data_input_loading(data_flag=data_flag,
                                NUM_EPOCHS = NUM_EPOCHS,
                                BATCH_SIZE=BATCH_SIZE,
                                lr=lr,
                                med3d= med3d)
## train and record
history = training_and_record(
    model_class=MedViT3D_small,
    model_name= model_name,
    NUM_EPOCHS=NUM_EPOCHS,
    lr=lr,
    momentum=0.9,
    train_loader=train_loader,
    train_loader_at_eval=train_loader_at_eval,
    test_loader=test_loader,
    n_classes = n_classes,
    task = task,
    steps = len(train_loader)
)

In [None]:
data_flag = "synapsemnist3d"
med3d = True
model_name = f"MedViT3D_{data_flag}"
## loading data
(data_flag,
 NUM_EPOCHS,
 BATCH_SIZE,
 lr,
 task,
 train_loader,
 train_loader_at_eval,
 test_loader,
 n_classes) = data_input_loading(data_flag=data_flag,
                                NUM_EPOCHS = NUM_EPOCHS,
                                BATCH_SIZE=BATCH_SIZE,
                                lr=lr,
                                med3d= med3d)
## train and record
history = training_and_record(
    model_class=MedViT3D_small,
    model_name= model_name,
    NUM_EPOCHS=NUM_EPOCHS,
    lr=lr,
    momentum=0.9,
    train_loader=train_loader,
    train_loader_at_eval=train_loader_at_eval,
    test_loader=test_loader,
    n_classes = n_classes,
    task = task,
    steps = len(train_loader)
)

 MedVit2D

In [None]:
## Grad CAM for model explanation

In [None]:
!pip install torchcam
# OR
!pip install captum

In [None]:
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask
from torchvision.transforms.functional import to_pil_image

cam_extractor = GradCAM(model, target_layer="ltb1")

In [None]:
model.eval()

# Select a sample from test loader
inputs, targets = next(iter(test_loader))
inputs = inputs.to(device)
targets = targets.squeeze().long().to(device)

# Forward pass and CAM extraction
with torch.no_grad():
    output = model(inputs)
    pred_class = output.argmax(dim=1)

# Extract CAM for the first image
cam = cam_extractor(pred_class[0].item(), output)  # Automatically registers and removes hooks

# Process image and heatmap
input_image = inputs[0].cpu().squeeze().numpy()  # [D, H, W]
slice_idx = input_image.shape[0] // 2
slice_img = input_image[slice_idx]
slice_cam = cam[0][slice_idx]

# Normalize and overlay
from torchvision.transforms.functional import to_pil_image
from torchcam.utils import overlay_mask
import matplotlib.pyplot as plt

norm_slice = (slice_img - slice_img.min()) / (slice_img.max() - slice_img.min())
norm_cam = (slice_cam - slice_cam.min()) / (slice_cam.max() - slice_cam.min())

overlay = overlay_mask(to_pil_image(norm_slice), to_pil_image(norm_cam), alpha=0.6)

plt.imshow(overlay)
plt.title(f"Grad-CAM (Pred: {pred_class[0].item()}, True: {targets[0].item()})")
plt.axis('off')
plt.show()


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask

# === Setup ===
device = "cuda" if torch.cuda.is_available() else "cpu"
model.eval().to(device)

# === Choose target layer for GradCAM ===
target_layer = 'ltb3.conv.0'  # Change if needed
cam_extractor = GradCAM(model, target_layer=target_layer)

# === Load 1 test sample ===
inputs, targets = next(iter(test_loader))
input_tensor = inputs[0].unsqueeze(0).to(device)  # [1, 1, D, H, W]
true_label = targets[0].item()

# === Forward pass & CAM extraction ===
with torch.set_grad_enabled(True):
    scores = model(input_tensor)
    pred_class = scores.argmax(dim=1).item()
    cams = cam_extractor(pred_class, scores)  # list of CAMs

# === Get the CAM tensor ===
cam = cams[0]  # Could be [1, D, H, W] or [D, H, W]
if cam.dim() == 3:
    cam = cam.unsqueeze(0).unsqueeze(0)  # → [1, 1, D, H, W]
elif cam.dim() == 4:
    cam = cam.unsqueeze(0)               # → [1, 1, D, H, W]
# Else: already fine

# === Interpolate CAM to input shape ===
target_shape = input_tensor.shape[2:]  # [D, H, W]
cam_upsampled = F.interpolate(cam, size=target_shape, mode="trilinear", align_corners=False)

cam_volume = cam_upsampled.squeeze().cpu()  # [D, H, W]

# === Get input volume ===
input_volume = input_tensor.squeeze().cpu()  # [D, H, W]

# === Pick a middle slice ===
slice_idx = input_volume.shape[0] // 2
input_slice = input_volume[slice_idx]  # [H, W]
cam_slice = cam_volume[slice_idx]      # [H, W]

# === Normalize both slices ===
input_norm = (input_slice - input_slice.min()) / (input_slice.max() - input_slice.min() + 1e-6)
cam_norm = (cam_slice - cam_slice.min()) / (cam_slice.max() - cam_slice.min() + 1e-6)

# === Convert to PIL images ===
input_pil = to_pil_image(input_slice)
cam_pil = to_pil_image(cam_slice)

# Convert grayscale input to RGB for overlay
input_pil = input_pil.convert("RGB")

# === Overlay CAM ===
overlay = overlay_mask(input_pil, cam_pil, alpha=0.5)

# === Plot ===
plt.figure(figsize=(6, 6))
plt.imshow(overlay)
plt.title(f"Grad-CAM Overlay (slice {slice_idx}) | Pred: {pred_class} | GT: {target_class}")
plt.axis("off")
plt.show()


In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image
from torchcam.methods import GradCAM
from torchcam.utils import overlay_mask

# === Prepare your model ===
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Choose a target layer name from your model — ltb3 is good
target_layer = "ltb3"

# Initialize GradCAM
cam_extractor = GradCAM(model, target_layer=target_layer)

# === Pick one test sample ===
inputs, targets = next(iter(test_loader))
inputs = inputs.to(device)
input_tensor = inputs[0].unsqueeze(0)  # [1, C, D, H, W]
target_class = targets[0].item()

# === Forward pass and CAM generation ===
output = model(input_tensor)
pred_class = output.argmax(dim=1).item()

# Generate CAM
cams = cam_extractor(pred_class, output)
cam = cams[0]

if cam.dim() == 3:
    cam = cam.unsqueeze(0).unsqueeze(0)  # → [1, 1, D, H, W]
elif cam.dim() == 4:
    cam = cam.unsqueeze(0)               # → [1, 1, D, H, W]
# === Interpolate CAM to match input size ===
target_shape = input_tensor.shape[2:]  # [D, H, W]
cam_upsampled = F.interpolate(cam, size=target_shape, mode="trilinear", align_corners=False)

cam_volume = cam_upsampled.squeeze().cpu()  # [D, H, W]


# === Normalize input volume ===
input_volume = input_tensor.squeeze().cpu()  # [D, H, W]
input_volume = (input_volume - input_volume.min()) / (input_volume.max() - input_volume.min())

# === Choose center slice ===
slice_idx = input_volume.shape[0] // 2
input_slice = input_volume[slice_idx]  # [H, W]
cam_slice = cam_volume[slice_idx]      # [H, W]

# Normalize CAM slice
cam_slice = (cam_slice - cam_slice.min()) / (cam_slice.max() - cam_slice.min())

# Convert grayscale input to RGB for overlay
input_pil = input_pil.convert("RGB")

# === Overlay CAM ===
overlay = overlay_mask(input_pil, cam_pil, alpha=0.5)

# === Plot result ===
plt.figure(figsize=(6, 6))
plt.imshow(overlay)
plt.title(f"Grad-CAM (target: {target_class}, pred: {pred_class})")
plt.axis('off')
plt.show()
