# ViT fine tuning with medical image dataset

This demo is going to use the SPMS vs RRMS npy dataset to fine tuning. 


In [40]:
from tqdm import tqdm 
from random import random

import numpy as np  
import torch   
import torch.nn as nn  
from transformers import ViTModel, ViTConfig  
from torchvision.transforms import v2
from torch.optim import Adam  
from torch.utils.data import Dataset, DataLoader

# SPMS vs RRMS
train_img_path = "data/input_data/SPMS_RRMS_CONTROL/SPMSvsRRMS_img_train.npy"
train_lbl_path = "data/input_data/SPMS_RRMS_CONTROL/SPMSvsRRMS_lbl_train.npy"
test_img_path = "data/input_data/SPMS_RRMS_CONTROL/SPMSvsRRMS_img_test.npy"
test_lbl_path = "data/input_data/SPMS_RRMS_CONTROL/SPMSvsRRMS_lbl_test.npy"


In [41]:
# Define the Torch Dataset 
class MedicalImageDataset(Dataset):
    def __init__(self, img_path, lbl_path, transform = None, distribution = None):
        # read the npy file get all the data
        self.images = np.load(img_path).astype("float16")
        self.labels = np.load(lbl_path).flatten()
        self.height = self.images.shape[1]
        self.width = self.images.shape[2]
        
        self.default_transforms = v2.Compose([
            v2.Resize((224,224), antialias=True),
            transform(),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]) if transform != None else v2.Compose([
            v2.Resize((224,224), antialias=True),
            v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        # print(self.default_transforms)
        # print(self.images.shape, self.labels.shape)


        unique, counts = np.unique(self.labels, return_counts=True)
        self.num_of_patients = [9,10,10]
        # base on the distribution filter some part of the data out
        if distribution:
            img0 = self.images[0:counts[0]]
            img1 = self.images[counts[0]:counts[0]+counts[1]+1]
            # print(img0.shape, img1.shape,counts)

            img0 = self.resampled_by_patient(img0, self.num_of_patients[0], distribution[0])
            img1 = self.resampled_by_patient(img1, self.num_of_patients[1], distribution[1])

            self.images = np.concatenate((img0, img1), axis=0)
            self.labels= np.concatenate((np.full((img0.shape[0],), 0, dtype=int),
                                       np.full((img1.shape[0],), 1, dtype=int)), 
                                       axis=0).flatten()
        


    #  (has to be overwrote) define a iterator how to get the image and label by giving index
    def __getitem__(self,i):
        img = self.default_transforms(torch.from_numpy(self.images[i]).reshape((3, self.height, self.width)))
        return img, self.labels[i]
    
    # (has to be overwrote) return the total number of samples from the dataset
    def __len__(self):
        return self.labels.shape[0]

    def resampled_by_patient(self, imgs, total_num, remain_num):
        
        num_slice_per_p = int(imgs.shape[0]/total_num)
        picked_index = random.sample(range(total_num), total_num-remain_num)
        picked_index.sort()
        picked_index.reverse()
        for i in picked_index:
            start = i * num_slice_per_p
            end = start + num_slice_per_p
            imgs = np.delete(imgs, range(start,end), axis=0)
        return imgs 


In [42]:
train_dataset = MedicalImageDataset(train_img_path, train_lbl_path)
# print(torch.max(train_dataset.__getitem__(0)[0]))

Compose(
      Resize(size=[224, 224], interpolation=InterpolationMode.BILINEAR, antialias=True)
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
)
(1916, 256, 256, 3) (1916,)
tensor(-0.6714, dtype=torch.float16)


In [43]:
class ViT(nn.Module):  
  
    def __init__(self, config=ViTConfig(), num_labels=2,  
        model_checkpoint='google/vit-base-patch16-224-in21k'):  
        
        super(ViT, self).__init__()  
        
        self.vit = ViTModel.from_pretrained(model_checkpoint, add_pooling_layer=False)  
        self.classifier = (  
        nn.Linear(config.hidden_size, num_labels)  
        )  
    
    def forward(self, x):  
        
        x = self.vit(x)['last_hidden_state']  
        # Use the embedding of [CLS] token  
        output = self.classifier(x[:, 0, :])  
  
        return output

In [47]:
def model_train(dataset, epochs, learning_rate, bs):
    use_cuda = torch.cuda.is_available()
    use_mps = torch.backends.mps.is_available()
    device = torch.device("cuda" if use_cuda else ("mps" if use_mps else "cpu"))

    # Load nodel, loss function, and optimizer
    model = ViT().to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)

    # Load batch image
    train_dataset = MedicalImageDataset(train_img_path, train_lbl_path)
    train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True)

    # Fine tuning loop
    for i in range(epochs):
        total_acc_train = 0
        total_loss_train = 0.0

        for train_image, train_label in tqdm(train_dataloader):
            output = model(train_image.to(device))
            loss = criterion(output, train_label.to(device))
            acc = (output.argmax(dim=1) == train_label.to(device)).sum().item()
            total_acc_train += acc
            total_loss_train += loss.item()

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print(f'Epochs: {i + 1} | Loss: {total_loss_train / len(train_dataset): .3f} | Accuracy: {total_acc_train / len(train_dataset): .3f}')

    return model


In [45]:
EPOCHS = 10 
LEARNING_RATE = 1e-4 
BATCH_SIZE = 8

In [48]:
trained_model = model_train(train_dataset, EPOCHS, LEARNING_RATE, BATCH_SIZE)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTModel: ['pooler.dense.bias', 'pooler.dense.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Compose(
      Resize(size=[224, 224], interpolation=InterpolationMode.BILINEAR, antialias=True)
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False)
)
(1916, 256, 256, 3) (1916,)


100%|██████████| 240/240 [01:36<00:00,  2.49it/s]


Epochs: 1 | Loss:  0.082 | Accuracy:  0.597


100%|██████████| 240/240 [01:31<00:00,  2.62it/s]


Epochs: 2 | Loss:  0.068 | Accuracy:  0.680


100%|██████████| 240/240 [01:31<00:00,  2.62it/s]


Epochs: 3 | Loss:  0.052 | Accuracy:  0.773


100%|██████████| 240/240 [01:30<00:00,  2.65it/s]


Epochs: 4 | Loss:  0.044 | Accuracy:  0.824


100%|██████████| 240/240 [01:30<00:00,  2.65it/s]


Epochs: 5 | Loss:  0.038 | Accuracy:  0.844


100%|██████████| 240/240 [01:30<00:00,  2.65it/s]


Epochs: 6 | Loss:  0.036 | Accuracy:  0.852


100%|██████████| 240/240 [01:31<00:00,  2.62it/s]


Epochs: 7 | Loss:  0.035 | Accuracy:  0.863


100%|██████████| 240/240 [01:32<00:00,  2.61it/s]


Epochs: 8 | Loss:  0.030 | Accuracy:  0.872


100%|██████████| 240/240 [01:32<00:00,  2.60it/s]


Epochs: 9 | Loss:  0.027 | Accuracy:  0.888


100%|██████████| 240/240 [01:31<00:00,  2.63it/s]

Epochs: 10 | Loss:  0.037 | Accuracy:  0.847





In [None]:
def predict(img):  
  
    use_cuda = torch.cuda.is_available()  
    device = torch.device("cuda" if use_cuda else "cpu")  
    transform = v2.Compose([  
    v2.ToTensor(),  
    v2.Resize((224, 224)),  
    v2.Normalize(mean=[0.5, 0.5, 0.5],  
    std=[0.5, 0.5, 0.5])  
    ])  
    
    img = transform(img)  
    output = trained_model(img.unsqueeze(0).to(device))  
    prediction = output.argmax(dim=1).item()  
  
    return id2label[prediction]

