Now you can run the previous cell to import `torch` and `make_dot`.

Replicate original script

In [2]:
import tensorflow as tf
## reduce GPU
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [3]:
!pip install medmnist==3.0.1 \
    torchattacks

Collecting medmnist==3.0.1
  Downloading medmnist-3.0.1-py3-none-any.whl.metadata (13 kB)
Collecting torchattacks
  Downloading torchattacks-3.5.1-py3-none-any.whl.metadata (927 bytes)
Collecting fire (from medmnist==3.0.1)
  Downloading fire-0.7.0.tar.gz (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting requests~=2.25.1 (from torchattacks)
  Downloading requests-2.25.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting chardet<5,>=3.0.2 (from requests~=2.25.1->torchattacks)
  Downloading chardet-4.0.0-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting idna<3,>=2.5 (from requests~=2.25.1->torchattacks)
  Downloading idna-2.10-py2.py3-none-any.whl.metadata (9.1 kB)
Collecting urllib3<1.27,>=1.21.1 (from requests~=2.25.1->torchattacks)
  Downloading urllib3-1.26.20-py2.py3-none-any.whl.metadata (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torchsummary import summary

from tqdm import tqdm
import medmnist
from medmnist import INFO, Evaluator

import torchattacks
from torchattacks import PGD, FGSM

In [4]:
print("PyTorch", torch.__version__)
print("Torchvision", torchvision.__version__)
print("Torchattacks", torchattacks.__version__)
print("Numpy", np.__version__)
print("Medmnist", medmnist.__version__)

PyTorch 2.6.0+cu124
Torchvision 0.21.0+cu124
Torchattacks 3.5.1
Numpy 2.0.2
Medmnist 3.0.1


##Dataset

In [None]:
data_flag = 'retinamnist'
# [tissuemnist, pathmnist, chestmnist, dermamnist, octmnist,
# pnemoniamnist, retinamnist, breastmnist, bloodmnist, tissuemnist, organamnist, organcmnist, organsmnist]
download = True

NUM_EPOCHS = 10
BATCH_SIZE = 5
lr = 0.005

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

print("number of channels : ", n_channels)
print("number of classes : ", n_classes)

number of channels :  3
number of classes :  5


In [None]:
from torchvision.transforms.transforms import Resize
# preprocessing
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Lambda(lambda image: image.convert('RGB')),
    torchvision.transforms.AugMix(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])
test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Lambda(lambda image: image.convert('RGB')),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=train_transform, download=download)
test_dataset = DataClass(split='test', transform=test_transform, download=download)

# pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)

100%|██████████| 3.29M/3.29M [00:00<00:00, 4.51MB/s]


In [None]:
print(train_dataset)
print("===================")
print(test_dataset)

Dataset RetinaMNIST of size 28 (retinamnist)
    Number of datapoints: 1080
    Root location: /root/.medmnist
    Split: train
    Task: ordinal-regression
    Number of channels: 3
    Meaning of labels: {'0': '0', '1': '1', '2': '2', '3': '3', '4': '4'}
    Number of samples: {'train': 1080, 'val': 120, 'test': 400}
    Description: The RetinaMNIST is based on the DeepDRiD challenge, which provides a dataset of 1,600 retina fundus images. The task is ordinal regression for 5-level grading of diabetic retinopathy severity. We split the source training set with a ratio of 9:1 into training and validation set, and use the source validation set as the test set. The source images of 3×1,736×1,824 are center-cropped and resized into 3×28×28.
    License: CC BY 4.0
Dataset RetinaMNIST of size 28 (retinamnist)
    Number of datapoints: 400
    Root location: /root/.medmnist
    Split: test
    Task: ordinal-regression
    Number of channels: 3
    Meaning of labels: {'0': '0', '1': '1', '2'

## Model loader

5

In [None]:
from MedViT import MedViT_small, MedViT_base, MedViT_large

model = MedViT_small(num_classes = n_classes)
# model = MedViT_small(num_classes = n_classes).cuda()
#model = MedViT_base(num_classes = n_classes).cuda()
#model = MedViT_large(num_classes = n_classes).cuda()



initialize_weights...


## Train

In [None]:
# define loss function and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
# train

for epoch in range(NUM_EPOCHS):
    train_correct = 0
    train_total = 0
    test_correct = 0
    test_total = 0
    print('Epoch [%d/%d]'% (epoch+1, NUM_EPOCHS))
    model.train()
    for inputs, targets in tqdm(train_loader):
        # inputs, targets = inputs.cuda(), targets.cuda()
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = model(inputs)

        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
        else:
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

Epoch [1/10]


100%|██████████| 216/216 [23:26<00:00,  6.51s/it]


Epoch [2/10]


 56%|█████▋    | 122/216 [12:59<10:00,  6.39s/it]


KeyboardInterrupt: 

In [None]:
from medmnist import INFO, Evaluator

In [None]:
# evaluation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def test(split, train_loader_at_eval, test_loader):
    model.eval()
    y_true = torch.tensor([]).cuda()
    y_score = torch.tensor([]).cuda()

    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.cpu().numpy()
        y_score = y_score.detach().cpu().numpy()

        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)

        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))
        return metrics #, y_true, y_score


# print('==> Evaluating ...')
# # test('train')
# metrics, y_true, y_score = test('test')

In [None]:
def training_and_record(NUM_EPOCHS, model = model):
  history = {
      "train_auc": [],
      "train_acc": [],
      "val_auc": [],
      "val_acc": [],
      "train_loss": [],
  }
  for epoch in range(NUM_EPOCHS):
    train_correct = 0
    train_total = 0
    test_correct = 0
    test_total = 0
    print('Epoch [%d/%d]'% (epoch+1, NUM_EPOCHS))
    model = model.to(device)
    model.train()
    for inputs, targets in tqdm(train_loader):
        inputs, targets = inputs.cuda(), targets.cuda()
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = model(inputs)

        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
        else:
            targets = targets.squeeze().long()
            loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()
    torch.cuda.empty_cache()
    ## logging accuracy and AUC
    train_metrics = test('train')
    val_metrics = test('test')
    print(train_metrics)
    history["train_auc"].append(train_metrics.AUC)
    history["train_acc"].append(train_metrics.ACC)
    history["val_auc"].append(val_metrics.AUC)
    history["val_acc"].append(val_metrics.ACC)
    history["train_loss"].append(loss.item())
  return history

In [None]:
history = training_and_record(NUM_EPOCHS= 1, model = model)

Epoch [1/1]


100%|██████████| 216/216 [14:23<00:00,  4.00s/it]


train  auc: 0.714  acc:0.481
test  auc: 0.710  acc:0.512
Metrics(AUC=np.float64(0.7144169878694029), ACC=0.48055555555555557)


In [None]:
print("Is built with CUDA:", tf.test.is_built_with_cuda())
print("Is GPU available:", tf.config.list_logical_devices('GPU'))
print("GPU device name:", tf.test.gpu_device_name())

Is built with CUDA: True
Is GPU available: [LogicalDevice(name='/device:GPU:0', device_type='GPU')]
GPU device name: /device:GPU:0


In [None]:
history

{'train_auc': [np.float64(0.7144169878694029)],
 'train_acc': [0.48055555555555557],
 'val_auc': [np.float64(0.7099973013334173)],
 'val_acc': [0.5125],
 'train_loss': [2.034088373184204]}

### Team7: Modify script to continue train by loading history

In [None]:
# evaluation
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def test(split, train_loader_at_eval, test_loader):
    model.eval()
    y_true = torch.tensor([]).to(device)
    y_score = torch.tensor([]).to(device)

    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.cpu().numpy()
        y_score = y_score.detach().cpu().numpy()

        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)

        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))
        return metrics #, y_true, y_score

In [None]:
def load_or_initialize_model(model_class, model_name, optimizer_class, lr, momentum):
    model_dir = "./history_record"
    os.makedirs(model_dir, exist_ok=True)

    model_path = os.path.join(model_dir, f"{model_name}.pth")
    history_path = os.path.join(model_dir, f"{model_name}.csv")

    model = model_class().to(device)
    optimizer = optimizer_class(model.parameters(), lr=lr, momentum=momentum)
    start_epoch = 0
    best_val_auc = 0
    history = {
        "train_auc": [], "train_acc": [],
        "val_auc": [], "val_acc": [],
        "train_loss": []
    }

    if os.path.exists(model_path) and os.path.exists(history_path):
        print(f"Loading existing model: {model_name}")
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        history = pd.read_csv(history_path).to_dict(orient='list')
        start_epoch = len(history["train_loss"])
        best_val_auc = max(history["val_auc"]) if history["val_auc"] else 0

    return model, optimizer, history, start_epoch, best_val_auc

In [None]:
def training_and_record(model_class, model_name, NUM_EPOCHS, lr, momentum, train_loader, train_loader_at_eval, test_loader):
    model, optimizer, history, start_epoch, best_val_auc = load_or_initialize_model(
        model_class, model_name, optimizer_class=torch.optim.SGD, lr=lr, momentum=momentum
    )

    for epoch in range(start_epoch, start_epoch + NUM_EPOCHS):
        print(f'\nEpoch [{epoch + 1}/{start_epoch + NUM_EPOCHS}]')
        model.train()

        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.float()
                loss = criterion(outputs, targets)
            else:
                targets = targets.squeeze().long()
                loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

        torch.cuda.empty_cache()

        # Logging
        train_metrics = test('train', train_loader_at_eval, test_loader)
        val_metrics = test('test', train_loader_at_eval, test_loader)

        history["train_auc"].append(train_metrics.AUC)
        history["train_acc"].append(train_metrics.ACC)
        history["val_auc"].append(val_metrics.AUC)
        history["val_acc"].append(val_metrics.ACC)
        history["train_loss"].append(loss.item())

        # Save best model
        if val_metrics.AUC > best_val_auc:
            best_val_auc = val_metrics.AUC
            print("📌 New best AUC — saving model")
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, f"./history_record/{model_name}.pth")

        pd.DataFrame(history).to_csv(f"./history_record/{model_name}.csv", index=False)

    print("✅ Training complete.")
    return history


In [None]:
n_classes = len(info['label'])
n_classes

5

In [None]:
history = training_and_record(
    model_class=MedViT_small,
    model_name="MedViT_retinamnist",
    NUM_EPOCHS=2,
    lr=0.001,
    momentum=0.9,
    train_loader=train_loader,
    train_loader_at_eval=train_loader_at_eval,
    test_loader=test_loader
)

initialize_weights...

Epoch [1/2]


100%|██████████| 216/216 [19:32<00:00,  5.43s/it]


train  auc: 0.456  acc:0.180
test  auc: 0.481  acc:0.170
📌 New best AUC — saving model

Epoch [2/2]


100%|██████████| 216/216 [19:21<00:00,  5.38s/it]


train  auc: 0.454  acc:0.180
test  auc: 0.481  acc:0.170
✅ Training complete.


## Team7: Now fit into our dataset

In [None]:
from step1Preprocessing import DATALOAD
xu = DATALOAD("mri", load_mode= "pytorch", train_control_sample=100,val_control_sample=20 )
xu.train_ds, xu.val_ds, xu.train_loader_at_eval

(<torch.utils.data.dataloader.DataLoader at 0x7b513964e450>,
 <torch.utils.data.dataloader.DataLoader at 0x7b5139677050>,
 <torch.utils.data.dataloader.DataLoader at 0x7b5139677110>)

In [None]:
# define loss function and optimizer
data_flag = 'retinamnist'
# [tissuemnist, pathmnist, chestmnist, dermamnist, octmnist,
# pnemoniamnist, retinamnist, breastmnist, bloodmnist, tissuemnist, organamnist, organcmnist, organsmnist]
download = True

NUM_EPOCHS = 10
BATCH_SIZE = 5
lr = 0.005

info = INFO[data_flag]
task = info['task']

if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()

optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
len(xu.train_ds), len(xu.val_ds), len(xu.train_loader_at_eval)

(10, 10, 1)

In [None]:
n_classes = len(xu.train_ds.dataset.dataset.classes)
n_classes

4

In [None]:
from MedViT import MedViT_small, MedViT_base, MedViT_large
model = MedViT_small(num_classes = n_classes)

initialize_weights...


In [10]:
def load_or_initialize_model(model_class, model_name, optimizer_class, lr, momentum):
    model_dir = "./history_record"
    os.makedirs(model_dir, exist_ok=True)

    model_path = os.path.join(model_dir, f"{model_name}.pth")
    history_path = os.path.join(model_dir, f"{model_name}.csv")

    model = model_class().to(device)
    optimizer = optimizer_class(model.parameters(), lr=lr, momentum=momentum)
    start_epoch = 0
    best_val_auc = 0
    history = {
        "train_auc": [], "train_acc": [],
        "val_auc": [], "val_acc": [],
        "train_loss": []
    }

    if os.path.exists(model_path) and os.path.exists(history_path):
        print(f"Loading existing model: {model_name}")
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        history = pd.read_csv(history_path).to_dict(orient='list')
        start_epoch = len(history["train_loss"])
        best_val_auc = max(history["val_auc"]) if history["val_auc"] else 0

    return model, optimizer, history, start_epoch, best_val_auc

In [None]:
## define new test function: Have to
from sklearn.metrics import roc_auc_score, accuracy_score
from collections import namedtuple
from sklearn.metrics import roc_auc_score, accuracy_score
from collections import namedtuple
import numpy as np
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Metrics = namedtuple("Metrics", ["AUC", "ACC"])

def evaluate_custom(y_true, y_score):
    Metrics = namedtuple("Metrics", ["AUC", "ACC"])
    y_true = np.array(y_true).reshape(-1)
    if y_score.shape[0] != y_true.shape[0]:
        raise ValueError("Mismatch between number of predictions and true labels")
    y_pred = y_score.argmax(axis=1)
    acc = accuracy_score(y_true, y_pred)
    unique_classes = np.unique(y_true)
    if len(unique_classes) < 2:
        print("Skipping AUC: Only one class in ground truth.")
        auc = -1  # or float('nan')
    else:
        auc = roc_auc_score(y_true, y_score, multi_class='ovr')

    return Metrics(AUC=auc, ACC=acc)

def test(split, train_loader_at_eval, test_loader):
    model.eval()
    y_true = torch.tensor([]).to(device)
    y_score = torch.tensor([]).to(device)

    data_loader = train_loader_at_eval if split == 'train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true, targets), 0)
            y_score = torch.cat((y_score, outputs), 0)

        y_true = y_true.cpu().numpy()
        y_score = y_score.detach().cpu().numpy()

        metrics = evaluate_custom(y_true, y_score)
        print('%s  auc: %.3f  acc:%.3f' % (split, *metrics))
        return metrics #, y_true, y_score


In [None]:
def training_and_record(model_class, model_name, NUM_EPOCHS, lr, momentum, train_loader, train_loader_at_eval, test_loader):
    model, optimizer, history, start_epoch, best_val_auc = load_or_initialize_model(
        model_class, model_name, optimizer_class=torch.optim.SGD, lr=lr, momentum=momentum
    )

    for epoch in range(start_epoch, start_epoch + NUM_EPOCHS):
        print(f'\nEpoch [{epoch + 1}/{start_epoch + NUM_EPOCHS}]')
        model.train()

        for inputs, targets in tqdm(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)

            if task == 'multi-label, binary-class':
                targets = targets.float()
                loss = criterion(outputs, targets)
            else:
                targets = targets.squeeze().long()
                loss = criterion(outputs, targets)

            loss.backward()
            optimizer.step()

        torch.cuda.empty_cache()

        # Logging
        train_metrics = test('train', train_loader_at_eval, test_loader)
        val_metrics = test('test', train_loader_at_eval, test_loader)

        history["train_auc"].append(train_metrics.AUC)
        history["train_acc"].append(train_metrics.ACC)
        history["val_auc"].append(val_metrics.AUC)
        history["val_acc"].append(val_metrics.ACC)
        history["train_loss"].append(loss.item())

        # Save best model
        if val_metrics.AUC > best_val_auc:
            best_val_auc = val_metrics.AUC
            print("📌 New best AUC — saving model")
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, f"./history_record/{model_name}.pth")

        pd.DataFrame(history).to_csv(f"./history_record/{model_name}.csv", index=False)

    print("✅ Training complete.")
    return history


In [None]:
history = training_and_record(
    model_class=MedViT_small,
    model_name="MedViT_mri",
    NUM_EPOCHS=2,
    lr=0.001,
    momentum=0.9,
    train_loader=xu.train_ds,
    train_loader_at_eval=xu.train_loader_at_eval,
    test_loader=xu.val_ds
)

initialize_weights...
Loading existing model: MedViT_mri

Epoch [3/4]


100%|██████████| 10/10 [01:52<00:00, 11.25s/it]


train  auc: 0.430  acc:0.150
test  auc: 0.463  acc:0.140

Epoch [4/4]


100%|██████████| 10/10 [01:51<00:00, 11.12s/it]


train  auc: 0.430  acc:0.150
test  auc: 0.463  acc:0.140
✅ Training complete.


**bold text**## MedVit3D
```
# This is formatted as code
```



In [5]:
data_flag = 'organmnist3d'
# [tissuemnist, pathmnist, chestmnist, dermamnist, octmnist,
# pnemoniamnist, retinamnist, breastmnist, bloodmnist, tissuemnist, organamnist, organcmnist, organsmnist]
download = True

NUM_EPOCHS = 10
BATCH_SIZE = 15
lr = 0.005

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

print("number of channels : ", n_channels)
print("number of classes : ", n_classes)

number of channels :  1
number of classes :  11


In [6]:
from torchvision.transforms.transforms import Resize
# preprocessing
transform = lambda x: torch.from_numpy(x).squeeze(1).float()
train_dataset = DataClass(split='train', transform=transform, download=True)
val_dataset = DataClass(split='val', transform=transform, download=True)
test_dataset = DataClass(split='test', transform=transform, download=True)


# pil_dataset = DataClass(split='train', download=download)

# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)


100%|██████████| 32.7M/32.7M [00:46<00:00, 699kB/s]


In [7]:
len(train_loader)


65

In [14]:
n_classes

11

In [8]:
from MedVit3D import MedViT3D_small
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MedViT3D_small(num_classes = n_classes).to(device)



In [6]:
import torch.nn as nn
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
model.train()
for inputs, targets in train_loader:
    inputs = inputs.to(device)
    targets = targets.squeeze()  # Remove extra dimension
    if targets.ndim != 1:
        targets = targets.view(-1)  # Ensure shape is [B]
    targets = targets.long().to(device)

    optimizer.zero_grad()
    outputs = model(inputs)  # Shape: [B, num_classes]
    loss = criterion(outputs, targets)  # targets: [B]
    loss.backward()
    optimizer.step()

    print("Loss:", loss.item())
    break  # Only run 1 batch for testing


Loss: 2.3220160007476807


In [7]:
model.eval()
all_preds = []
all_labels = []
from sklearn.metrics import accuracy_score

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(device)
        targets = targets.squeeze()
        if targets.ndim != 1:
            targets = targets.view(-1)
        targets = targets.long().to(device)

        outputs = model(inputs)  # logits
        preds = torch.argmax(outputs, dim=1)

        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(targets.cpu().numpy())

# ✅ Calculate Accuracy
acc = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {acc:.4f}")

Test Accuracy: 0.1131


In [22]:
## train full batch
from sklearn.metrics import accuracy_score, roc_auc_score
import numpy as np
import torch

def evaluate(model, loader, device):
    model.eval()
    total_loss = 0
    all_labels = []
    all_outputs = []

    with torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            targets = targets.squeeze().long().to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()

            all_outputs.append(outputs.softmax(dim=1).cpu().numpy())
            all_labels.append(targets.cpu().numpy())

    # Flatten
    all_preds = np.concatenate(all_outputs, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Metrics
    acc = accuracy_score(all_labels, all_preds.argmax(axis=1))
    try:
        auc = roc_auc_score(all_labels, all_preds, multi_class='ovr')
    except:
        auc = -1  # fallback if AUC fails (e.g., single class present)

    avg_loss = total_loss / len(loader)
    return avg_loss, acc, auc
model.train()
total_loss = 0

for inputs, targets in train_loader:
    inputs = inputs.to(device)
    targets = targets.squeeze().long().to(device)

    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

train_loss = total_loss / len(train_loader)
val_loss, val_acc, val_auc = evaluate(model, test_loader, device)
_, train_acc, train_auc = evaluate(model, train_loader, device)

print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Train AUC: {train_auc:.4f}")
print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f} | Val AUC:   {val_auc:.4f}")


NameError: name 'optimizer' is not defined

In [27]:
## train multiple epoch

In [23]:
def load_or_initialize_model(model_class, model_name, optimizer_class, lr, momentum, n_classes):
    model_dir = "./history_record"
    os.makedirs(model_dir, exist_ok=True)

    model_path = os.path.join(model_dir, f"{model_name}.pth")
    history_path = os.path.join(model_dir, f"{model_name}.csv")

    model = MedViT3D_small(num_classes = n_classes).to(device)
    optimizer = optimizer_class(model.parameters(), lr=lr, momentum=momentum)
    start_epoch = 0
    best_val_auc = 0
    history = {
        "train_auc": [], "train_acc": [],
        "val_auc": [], "val_acc": [],
        "train_loss": []
    }

    if os.path.exists(model_path) and os.path.exists(history_path):
        print(f"Loading existing model: {model_name}")
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        history = pd.read_csv(history_path).to_dict(orient='list')
        start_epoch = len(history["train_loss"])
        best_val_auc = max(history["val_auc"]) if history["val_auc"] else 0

    return model, optimizer, history, start_epoch, best_val_auc

In [28]:
criterion = nn.CrossEntropyLoss()
def training_and_record(model_class, model_name, NUM_EPOCHS, lr, momentum, train_loader, train_loader_at_eval, test_loader, n_classes):
    model, optimizer, history, start_epoch, best_val_auc = load_or_initialize_model(
        model_class, model_name, optimizer_class=torch.optim.SGD, lr=lr, momentum=momentum, n_classes= n_classes
    )

    for epoch in range(start_epoch, start_epoch + NUM_EPOCHS):
        print(f'\nEpoch [{epoch + 1}/{start_epoch + NUM_EPOCHS}]')
        model.train()
        total_loss = 0
        for inputs, targets in tqdm(train_loader):
            inputs = inputs.to(device)
            targets = targets.squeeze().long().to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            # print("Unique labels:", targets.unique())
            # print("Targets shape:", targets.shape)
            # print("Output shape:", outputs.shape)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        torch.cuda.empty_cache()

        # Logging
        train_loss = total_loss / len(train_loader)
        val_loss, val_acc, val_auc = evaluate(model, test_loader, device)
        _, train_acc, train_auc = evaluate(model, train_loader, device)

        history["train_auc"].append(train_auc)
        history["train_acc"].append(train_acc)
        history["val_auc"].append(val_auc)
        history["val_acc"].append(val_acc)
        history["train_loss"].append(train_loss)

        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            print("📌 New best AUC — saving model")
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict()
            }, f"./history_record/{model_name}.pth")

        pd.DataFrame(history).to_csv(f"./history_record/{model_name}.csv", index=False)

    print("✅ Training complete.")
    return history


In [30]:
## Train môre epoch and recording.
import os
import pandas as pd
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
history = training_and_record(
    model_class=MedViT3D_small,
    model_name="MedViT3D_organmnist3d",
    NUM_EPOCHS=10,
    lr=0.001,
    momentum=0.9,
    train_loader=train_loader,
    train_loader_at_eval=train_loader_at_eval,
    test_loader=test_loader,
    n_classes = n_classes
)

Loading existing model: MedViT3D_organmnist3d

Epoch [3/12]


100%|██████████| 65/65 [02:13<00:00,  2.05s/it]


📌 New best AUC — saving model

Epoch [4/12]


100%|██████████| 65/65 [02:10<00:00,  2.00s/it]


📌 New best AUC — saving model

Epoch [5/12]


100%|██████████| 65/65 [02:07<00:00,  1.97s/it]


KeyboardInterrupt: 


Epoch [1/2]


100%|██████████| 195/195 [01:44<00:00,  1.86it/s]


Done first eopoch
going to test function
Done y_true,m and y_score
