# ViT fine tuning with medical image dataset

This demo is going to use the SPMS vs RRMS npy dataset to fine tuning. 


In [11]:
import numpy as np  
import torch   
import torch.nn as nn  
from transformers import ViTModel, ViTConfig  
from torchvision import transforms  
from torch.optim import Adam  
from torch.utils.data import DataLoader  
from tqdm import tqdm 


In [16]:
print(torch.backends.mps.is_available())
use_cuda = torch.cuda.is_available()
use_mps = torch.backends.mps.is_available()
device = torch.device("cuda" if use_cuda else ("mps" if use_mps else "cpu"))
print(device)

True
mps


In [None]:
class ImageDataset(torch.utils.data.Dataset):  
  
    def __init__(self, input_data):  
    
        self.input_data = input_data  
        # Transform input data  
        self.transform = transforms.Compose([  
        transforms.ToTensor(),  
        transforms.Resize((224, 224), antialias=True),  
        transforms.Normalize(mean=[0.5, 0.5, 0.5],  
        std=[0.5, 0.5, 0.5])  
        ])  
    
    def __len__(self):  
        return len(self.input_data)  
    
    def get_images(self, idx):  
        return self.transform(self.input_data[idx]['image'])  
    
    def get_labels(self, idx):  
        return self.input_data[idx]['label']  
    
    def __getitem__(self, idx):  
        # Get input data in a batch  
        train_images = self.get_images(idx)  
        train_labels = self.get_labels(idx)  
        return train_images, train_labels


In [None]:
class ViT(nn.Module):  
  
    def __init__(self, config=ViTConfig(), num_labels=20,  
        model_checkpoint='google/vit-base-patch16-224-in21k'):  
        
        super(ViT, self).__init__()  
        
        self.vit = ViTModel.from_pretrained(model_checkpoint, add_pooling_layer=False)  
        self.classifier = (  
        nn.Linear(config.hidden_size, num_labels)  
        )  
    
    def forward(self, x):  
        
        x = self.vit(x)['last_hidden_state']  
        # Use the embedding of [CLS] token  
        output = self.classifier(x[:, 0, :])  
  
        return output

In [None]:
def model_train(dataset, epochs, learning_rate, bs):
    use_cuda = torch.cuda.is_available()
    use_mps = torch.backends.mps.is_available()
    device = torch.device("cuda" if use_cuda else ("mps" if use_mps else "cpu"))

    # Load nodel, loss function, and optimizer
    model = ViT().to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)

    # Load batch image
    train_dataset = ImageDataset(dataset)
    train_dataloader = DataLoader(train_dataset, num_workers=1, batch_size=bs, shuffle=True)

    # Fine tuning loop
    for i in range(epochs):
        total_acc_train = 0
        total_loss_train = 0.0

        for train_image, train_label in tqdm(train_dataloader):
            output = model(train_image.to(device))
            loss = criterion(output, train_label.to(device))
            acc = (output.argmax(dim=1) == train_label.to(device)).sum().item()
            total_acc_train += acc
            total_loss_train += loss.item()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print(f'Epochs: {i + 1} | Loss: {total_loss_train / len(train_dataset): .3f} | Accuracy: {total_acc_train / len(train_dataset): .3f}')

    return model


In [None]:
EPOCHS = 10 
LEARNING_RATE = 1e-4 
BATCH_SIZE = 8

In [None]:
trained_model = model_train(dataset['train'], EPOCHS, LEARNING_RATE, BATCH_SIZE)