In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/Omid-Nejati/MedVit.git

In [None]:
cd /content/MedViT

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
from torchsummary import summary

In [None]:
!pip install timm
!pip install einops

In [None]:
from MedViTWithAdapters import MedViTWithAdapters_small as tiny

In [None]:
model = tiny()

In [None]:
model.proj_head[0]

In [None]:
model.proj_head[0] = torch.nn.Linear(in_features=1024, out_features=2, bais = True)

In [None]:
model = model.cuda()

### Dataset

In [None]:
 !pip install medmnist

In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms

import medmnist
from medmnist import INFO , Evaluator

In [None]:
data_flag = 'breastmnist'
# [tissuemnist , pathmnist, chestmnist, dermamnist, octmnisr
# ,pnemonismnist , retinamnist, breastmnist, bloodmnist, tissuemnist , organcmnist, organs ]
download = True

NUM_EPOCHS = 10
BATCH_SIZE = 10
lr = 0.005

info  = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist,info['python_class'])

In [None]:
from torchvision.transforms.transforms import Resize 
# preprocessing
train_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Lambda(lambda image: image.convert('RGB')),
    torchvision.transforms.AugMix(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])
test_transform =  transforms.Compose([
    transforms.Resize(224),
    transforms.Lambda(lambda image: image.convert('RGB')),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5], std=[.5])
])

# load the data
train_dataset = DataClass(split='train', transform=train_transform, download=download)
test_dataset = DataClass(split='test', transform=test_transform, download=download)


# encapsulate data into dataloader form
train_loader = data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
train_loader_at_eval = data.DataLoader(dataset=train_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
test_loader = data.DataLoader(dataset=test_dataset, batch_size=2*BATCH_SIZE, shuffle=False)
     

In [None]:
print(train_dataset)
print("==================")
print(test_dataset)

In [None]:
# defune loss function and optimizer
# define loss function and optimizer
if task == "multi-label, binary-class":
    criterion = nn.BCEWithLogitsLoss()
else:
    criterion = nn.CrossEntropyLoss()
    
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)

In [None]:
# train 
for epoch in range(NUM_EPOCHS):
    train_correct = 0 
    train_total = 0
    test_correct = 0
    test_total = 0
    print('Epoch [%d/%d]'%(epoch+1, NUM_EPOCHS))
    model.train()
    for inputs, targets in tqdm(train_loader):
        inputs, targets = inputs.cuda() , targets.cuda()
        # forward + backward + optimize
        optimizer.zero_grad()
        outputs = model(inputs)

        if task == 'multi-label, binary-class':
            targets = targets.to(torch.float32)
            loss = criterion(outputs, targets)
        else:
            targets = targets.to(torch.float32)
            loss= criterion(outputs, targets)
        loss.backward()
        optimizer.step()

In [None]:
# evaluation

def test(split):
    model.eval()
    y_true = torch.tensor([]).cuda()
    y_score = torch.tensor([]).cuda()

    data_loader = train_loader_at_eval if split =='train' else test_loader

    with torch.no_grad():
        for inputs, targets in data_loader:
            inputs, targets - inputs.cuda() , targets.cuda()
            outputs = outputs.softmax(dim=1)

            if task == 'mutli-lablel, binary-class':
                targets = targets.to(torch.float32)
                outputs = outputs.softmax(dim=-1)
            else:
                targets = targets.squeeze().long()
                outputs = outputs.softmax(dim=-1)
                targets = targets.float().resize_(len(targets), 1)

            y_true = torch.cat((y_true,targets), 0)
            y_score = torch.cat((y_score,outputs), 0)

        y_true = y_true.cpu().numpy()
        y_score = y_score.detach().cpu().numpy()

        evaluator = Evaluator(data_flag, split)
        metrics = evaluator.evaluate(y_score)

        print('%s auc: %.3f acc: %.3f' % (split,*metrics))

print('==> Evaluating...')
test('train')
test('test')