# Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys

from tqdm.auto import tqdm

import numpy as np
import h5py

import matplotlib.pyplot as plt

from sklearn.metrics import classification_report, accuracy_score, roc_auc_score
from sklearn.utils import class_weight

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

from torch.optim.lr_scheduler import MultiStepLR, StepLR

In [3]:
!pip install torchinfo
from torchinfo import summary



In [4]:
!pip install einops



In [5]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [6]:
os.chdir("/content/drive/MyDrive/MRViT")

In [7]:
from models.mrvit import MultiResViT
from dataset.dataset import get_datasets

In [8]:
random_seed = 42
torch.manual_seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [9]:
USE_GPU = True

if USE_GPU and torch.cuda.is_available():
    device = "cuda:0"
    print('using device: cuda')
else:
    device = "cpu"
    print('using device: cpu')

using device: cuda


# Dataset

In [10]:
train_dataset, valid_dataset, test_dataset = get_datasets()

<KeysViewHDF5 ['label_test', 'label_train', 'test', 'test_small', 'train', 'train_small']>
torch.Size([400, 112, 112, 3]) torch.Size([400, 28, 28, 3]) torch.Size([400, 1])
1079 1080


In [11]:
BATCH_SIZE = 64
num_classes = 5

In [12]:
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = data.DataLoader(dataset=valid_dataset, batch_size=len(valid_dataset), shuffle=True)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=len(test_dataset) // 4)

In [13]:
for batch_data in train_loader:
    print(batch_data[0].shape)
    print(batch_data[1].shape)
    print(batch_data[2].shape)
    break

torch.Size([64, 3, 28, 28])
torch.Size([64, 3, 112, 112])
torch.Size([64, 1])


# Merged Model

In [14]:
mrv = MultiResViT(4)

img_s = torch.randn(2, 3, 112, 112)
img_b = torch.randn(2, 3, 112, 112)

pred = mrv(img_s, img_b)

print(pred)

AttributeError: ignored

# Training


In [None]:
weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(train_dataset.get_labels().numpy()),
                                                 train_dataset.get_labels().numpy())
weights = torch.FloatTensor(weights).to(device)

In [None]:
weights

In [None]:
epochs = 50
lr = 7.5e-3

model = MultiResViT(nclasses = num_classes) # Put model to be trained and tested
model.to(device)

# loss function
criterion = nn.CrossEntropyLoss(weight = weights)
# optimizer
optimizer = optim.Adam(model.parameters(), lr = lr)
# scheduler
scheduler = MultiStepLR(optimizer, [20, 40])

In [None]:
summary(model)

In [None]:
keys = ['epochs', 'loss', 'acc', 'val_loss', 'val_acc']
history = {key: [] for key in keys}

for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for batch_data in tqdm(train_loader):
        images_s = batch_data[0].to(device)
        images_l = batch_data[1].to(device)
        label = batch_data[2].to(device)

        output = model(images_s, images_l)

        loss = criterion(output.float(), label.flatten())

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        acc = accuracy_score(output.argmax(dim=1).detach().cpu().float(), label.detach().cpu().float())
        epoch_accuracy += acc 
        epoch_loss += loss 

    with torch.no_grad(): 
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for batch_data in tqdm(valid_loader):
            images_s = batch_data[0].to(device)
            images_l = batch_data[1].to(device)
            label = batch_data[2].to(device)

            val_output = model(images_s, images_l)
            val_loss = criterion(val_output.float(), label.flatten())

            acc = accuracy_score(val_output.argmax(dim=1).detach().cpu().float(), label.detach().cpu().float())
            epoch_val_accuracy += acc 
            epoch_val_loss += val_loss
    
    scheduler.step()

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss/ len(train_loader):.4f} - acc: {epoch_accuracy/ len(train_loader):.4f} \
        - val_loss : {epoch_val_loss / len(valid_loader):.4f} - val_acc: {epoch_val_accuracy / len(valid_loader):.4f}\n"
    )
    history['epochs'].append(epoch+1)
    history['loss'].append(epoch_loss)
    history['acc'].append(epoch_accuracy)
    history['val_loss'].append(epoch_val_loss)
    history['val_acc'].append(epoch_val_accuracy)

In [None]:
plt.figure(figsize = (12,4))
plt.subplot(1,2,1)
plt.plot(history['epochs'], history['loss'], label='loss') 
plt.plot(history['epochs'], history['acc'], label='accuracy')
plt.xlabel('Epochs')
plt.title("Training Visualisation")
plt.legend()

plt.subplot(1,2,2)
plt.plot(history['epochs'], history['val_loss'], label='Val loss')
plt.plot(history['epochs'], history['val_acc'], label='Val accuracy')
plt.xlabel('Epochs')
plt.title("Training Visualisation")
plt.legend()

In [None]:
test_acc = 0
test_auc = 0
preds = []

with torch.no_grad():
    model.eval()
    for batch_data in tqdm(test_loader):
        images_s = batch_data[0].to(device)
        images_l = batch_data[1].to(device)
        label = batch_data[2].to(device)

        test_output = model(images_s, images_l)
        
        acc = accuracy_score(label.detach().cpu().float(), test_output.argmax(dim=1).detach().cpu().float()) 
        
        if num_classes == 2:
            auc = roc_auc_score(label.detach().cpu().float(), test_output.argmax(dim=1).detach().cpu().float(), average = "weighted")
        else:
            auc = roc_auc_score(label.detach().cpu().float().flatten(), F.softmax(test_output, dim = 1).detach().cpu().float(), multi_class = "ovr", average = "weighted")
       
        preds.append(test_output.argmax(dim=1).detach().cpu().numpy())

        test_acc += acc
        test_auc += auc

preds = [a.squeeze().tolist() for a in preds]

In [None]:
print(f"ACC Model: {test_acc / len(test_loader)}, AUC Model: {test_auc / len(test_loader)}")

In [None]:
preds = [item for sublist in preds for item in sublist]
preds = np.array([np.array(y) for y in preds], dtype=float)

In [None]:
print(classification_report(test_dataset.get_labels(), preds.flatten()))