In [None]:
!nvidia-smi

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torchsummary import summary
from engine import train_one_epoch, evaluate


## Model

In [None]:
from MedViT import MedViT_small, MedViT_base, MedViT_large

In [None]:
model_name = 'large' # small, base, large

In [None]:
if model_name == 'small':
    model = MedViT_small()
    checkpoint = torch.load('./checkpoints/MedViT_small_im1k.pth')
elif model_name == 'base':
    model = MedViT_base()
    checkpoint = torch.load('./checkpoints/MedViT_base_im1k.pth')
elif model_name == 'large':
    model = MedViT_large()
    checkpoint = torch.load('./checkpoints/MedViT_large_im1k.pth')

In [None]:
# load the checkpoint into the model
model.load_state_dict(checkpoint['model'])

In [None]:
model.proj_head[0] = torch.nn.Linear(in_features=1024, out_features=2, bias=True)

In [None]:
model = model.cuda()

## Dataset

In [None]:
from tqdm import tqdm
import numpy as np
import torch
import time
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO, Evaluator

In [None]:
NUM_EPOCHS = 100
BATCH_SIZE = 16
lr = 0.0005
n_classes = 2

In [None]:
from torchvision.transforms.transforms import Resize
# preprocessing
from timm.data import create_transform
transform = create_transform(
            input_size=224,
            is_training=True,
            color_jitter=0.4,
            re_prob=0.25,
            re_mode='pixel',
        )

In [None]:
def load_data(root, transform):
    from torchvision.datasets.folder import ImageFolder
    dataset = ImageFolder(root, transform=transform)
    assert len(dataset.class_to_idx) == 2
    return dataset


In [None]:

# encapsulate data into dataloader form
train_dataset = load_data(root='DDI_data/train', transform=transform)
test_dataset = load_data(root='DDI_data/test', transform=transform)
val_dataset = load_data(root='DDI_data/val', transform=transform)

## Train

In [None]:
# criterion = nn.BCEWithLogitsLoss() # not sure what to use
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
def validate(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    total = 0
    correct = 0
    total_loss = 0
    with torch.no_grad():  # No gradients needed for validation, which saves memory and computations
        for inputs, targets in data_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets.squeeze().long())
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()
    
    avg_loss = total_loss / len(data_loader)
    accuracy = 100 * correct / total
    return avg_loss, accuracy

In [None]:
sampler_train = torch.utils.data.RandomSampler(train_dataset)
sampler_test = torch.utils.data.SequentialSampler(test_dataset)
sampler_val = torch.utils.data.SequentialSampler(val_dataset)

In [None]:
data_loader_train = torch.utils.data.DataLoader(
        train_dataset, sampler=sampler_train,
        batch_size=BATCH_SIZE,
        drop_last=True,
    )

data_loader_val = torch.utils.data.DataLoader(
    val_dataset, sampler=sampler_val,
    batch_size=BATCH_SIZE,
    drop_last=True,
)

data_loader_test = torch.utils.data.DataLoader(
    test_dataset, sampler=sampler_val,
    batch_size=BATCH_SIZE,
    drop_last=True,
)

In [None]:
# print the size of each data set with how many from each class
print(f'Train Size: {len(data_loader_train.dataset)}')
print(f'Val Size: {len(data_loader_val.dataset)}')
print(f'Test Size: {len(data_loader_test.dataset)}')


In [None]:
output_dir = './output'

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm

# Initialize lists to track the losses
train_losses = []
val_losses = []

# Training loop
for epoch in range(NUM_EPOCHS):
    model.train()  # Set the model to training mode
    train_loss = 0
    for inputs, targets in tqdm(data_loader_train, desc=f"Epoch {epoch+1} Training"):
        inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets.squeeze().long())
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(data_loader_train)
    train_losses.append(avg_train_loss)  # Append average train loss for this epoch

    val_loss, val_accuracy = validate(model, data_loader_val)
    val_losses.append(val_loss)  # Append validation loss for this epoch

    print(f'Epoch [{epoch+1}/{NUM_EPOCHS}], Train Loss: {avg_train_loss:.4f}, Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%')

# Plotting the training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.title('Loss vs. Epochs')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()


## Test

In [None]:
from sklearn.metrics import f1_score, roc_auc_score, classification_report, accuracy_score


def test_model(model, data_loader):
    model.eval()
    all_predictions = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets = inputs.cuda(), targets.cuda()
            outputs = model(inputs)
            probabilities = torch.nn.functional.softmax(outputs, dim=1)[:, 1]  # Get the probability of the positive class
            _, predicted = torch.max(outputs.data, 1)
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probs.extend(probabilities.cpu().numpy())  # Collect probabilities for AUC calculation

    # Calculate metrics
    accuracy = accuracy_score(all_targets, all_predictions)
    f1 = f1_score(all_targets, all_predictions, average='binary')
    auc = roc_auc_score(all_targets, all_probs)
    report = classification_report(all_targets, all_predictions, target_names=['Class 0', 'Class 1'])

    print(f'Test Accuracy: {accuracy:.2f}%')
    print(f'F1 Score: {f1:.2f}')
    print(f'AUC ROC Score: {auc:.2f}')
    print('Classification Report:')
    print(report)



In [None]:
test_model(model, data_loader_test)

In [None]:
# lets save the model with an appropriate name

torch.save(model.state_dict(), f'fine_tuned_binary_MedViT_{model_name}_DDI.pth')