#CS 598 Deep Learning for HealthCare : Project Report

# Mount Notebook to Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
pip install torchcontrib pytorch_lightning timm tensorboard torchmetrics gdown kaggle

# Github URL

https://github.com/ShineyAji/DLH_CASS

# Video Presentation URL

https://drive.google.com/file/d/1DLBOyd5vJFx-TnHBvq-T7f7Sfhgl_erD/view?usp=drive_link

# Introduction

The paper we selected for this project is: Efficient Representation Learning for Healthcare with Cross-Architectural Self-Supervision by Singh and Cirrone (2023).


*   **Background:**\
    In healthcare and biomedical applications, the primary challenges revolve around the substantial computational demands and the limited availability of data.Representation learning has the potential to improve deep learning models by learning valuable priors from limited medical data. Despite their benefits, existing representation and self supervised methods are computationally intensive and requires multiple GPU servers operating for extended period, which restricts accessibility for general practitioner.
    
    Labelling medical data requires specialized domain knowledge and can be expensive and time consuming. Also releasing medical datasets for research is restricted by privacy and regulatory constraints. Moreover, there are limitations in understanding the diseases, because the disease is emerging or there is a lack of systematic data collection mechanisms. All these conditions leads to medical data scarcity, both labelled and unlabeled.Even though reducing pretraining epochs can mitigate the computational problem and small batch size can mitigate the data scarcity problem, the state of art self supervised techniques drop significant performance when used with small batch sizes and reduced epochs.

*   **CASS:**\
    The paper proposes Cross Architectural - Self Supervision (CASS), which is a self supervised learning approach that combines CNN and Transformer in a response-based siamese contrastive method. Siamese cross-architecture techniques combine CNNs and Transformers without any change to their architecture to help both of them learn better representations. In this approach an image is passed through a common set of augmentations, and the augmented image simultaneously pass through a CNN and Transformer to create positive pairs. The output logits from CNN and Transformer are then used to find Cosine Similarity Loss. CASS learns more predictive data representations in limited data scenarios where a Transformer only model cannot find them.

    The authors observed that self-supervised pre-training performs better than transfer learning in all cases by a margin. They also observed that CASS improves upon the classification performance of existing state-of-the-art self-supervised method. They concluded that with CASS, the researchers can begin medical image analysis, even with a small amount of the overall dataset or even if only a small portion is labeled.

*   **Datasets:**\
    We are using 2 datasets in this project that was used by the authors in the paper: Brain Tumor MRI Classification dataset that consists of 7023 samples of black-and- white MRI images, and ISIC 2019 dataset that contains 25,331 images of skin lesion. \

    ********** **In this notebook, we will be using only Brain Tumor MRI Classification dataset.** All the data preprocessing were done in this notebook only for the Brain Tumor MRI dataset. We added another notebook (DLH_CASS_ISIC2009.ipynb) in github with the implementation code for dataset preprocessing and model training for ISIC 2019 dataset.




# Download data from Kaggle

In [None]:
#downloading the data from kaggle into colab
!gdown --id '18WnN-JkA_hv_Q3TWtZcnuk6euc2YwZ3D' --output folder_structure.png #folder structure from kaggle
!gdown --id '12rHOJgPKyvbM-sEL2IbQth1lAIXT6hdW' --output kaggle.json
!rm -rf ~/.kaggle
!mkdir ~/.kaggle
! cp /content/kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d masoudnickparvar/brain-tumor-mri-dataset
!rm -rf BrainMRI
!mkdir BrainMRI
!unzip brain-tumor-mri-dataset.zip -d BrainMRI

# Scope of Reproducibility:


1.   Hypothesis 1: If we evaluate CASS and other self-supervised methods on the given datasets, CASS will show high performance over the existing self-supervised techniques
2.   Hypothesis 2: If trained for reduced epochs or batch size, existing methods will show a performance drop whereas CASS will be robust to these changes.
3.   Hypothesis 3: The training time of CASS will be lesser than the training time of other self-supervised methods. *With the limited computing resource, we won't be able to exactly produce the performance result as original paper. But we can test this hypothesis in relative term*


# CASS Implementation

We will implement the code for CASS and train the model using Brain Tumor MRI dataset in this notebook. \

**Code Source:** We reproduced the source code for the model from the original paper (https://github.com/pranavsinghps1/CASS) with changes needed for our dataset. \

**Changes Incorporated:** Source repository from the paper authors has the code for MedMNIST and ISIS 2019 dataset. So the data preprocessing for BrainMRI dataset to fit into the designed model was completely done by us. Also we got error during loss calculation because of code issue from source, which we fixed.


**Changes planned to implement :** Planning to implement necessary changes for ablation studies which will be captured in final report.

##  Dependencies

In [None]:
# import  packages you need
from google.colab import drive
import os
import numpy as np
import pytorch_lightning as pl
import torch
import pandas as pd
import timm
import math
import gdown
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
from sklearn.model_selection import KFold
from torchvision import transforms as tsfm
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchcontrib.optim import SWA
from torchmetrics import Metric
from torch.utils.tensorboard import SummaryWriter
from sklearn.model_selection import train_test_split
from google.colab.patches import cv2_imshow
import cv2

In [None]:
class CFG:
    label_num2str = {0: 'glioma',
                     1: 'meningioma',
                     2: 'notumor',
                     3: 'pituitary'
                     }
    label_str2num = {'glioma': 0,
                     'meningioma':1,
                     'notumor':2,
                     'pituitary':3
                     }
    fl_alpha = 1.0  # alpha of focal_loss
    fl_gamma = 2.0  # gamma of focal_loss
    cls_weight = [0.2, 0.5970802919708029, 1.0, 0.25255474452554744]
    cnn_name='resnet50'
    vit_name='vit_base_patch16_384'
    seed = 77
    num_classes = 4
    batch_size = 16
    t_max = 16
    lr = 1e-3
    min_lr = 1e-6
    n_fold = 6
    num_workers = 2 # In GCP terminal, we ran this with 8 workers
    accum_grad_batch = 1
    early_stop_delta = 1e-7
    gpu_idx = 0
    device = torch.device(f'cuda:{gpu_idx}' if torch.cuda.is_available() else 'cpu')
    gpu_list = [gpu_idx]

##  Data
Data includes raw data (MIMIC III tables), descriptive statistics (our homework questions), and data processing (feature engineering).
  * This dataset contains 7k+ samples of human brain MRI images which can be accessed from https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset.
  * While training the model using GCP, we got the data from kaggle by running the API command in terminal: \
       **API Command:** *kaggle datasets download -d masoudnickparvar/brain-tumor-mri-dataset*
  * This datset has 7023 images of human brain MRI images classified into 4 classes: glioma, meningioma, no tumor and pituitary. The curator created predefined training and testing splits. 5712 images for training and 1311 images for testing.This collection of MRI images are combination of multiple datasets, the size of images varies throughout the dataset.
  * Using 'label_str2num', we performed the class to integer mapping, by mapping each labels into one hot vector.

This collection of MRI images are combination of multiple datasets, the size of images varies throughout the dataset. So we resized the images to uniform size.


In [None]:
# This image shows the structure of training and testing data. There are 2 folders, one each for training and testing and the subfolders has the class name for the images inside those folders.
img = cv2.imread('/content/folder_structure.png')
cv2_imshow(img)

In [None]:
# Fetch the training data from the path where the data downloaded from kaggle is saved.
# The folder structure was shown above. Since each of these folders has more images and gdown allows only maximum 50 files per folder,
#  we couldn't use gdown to get the files into colab notebook.

train_data_dir = 'BrainMRI/Training'
train_filepaths = []
train_labels = []
all_train_img_labels_ts = []

folds = os.listdir(train_data_dir)
for fold in folds:
     if not fold.startswith('.'):
        foldpath = os.path.join(train_data_dir, fold)
        filelist = os.listdir(foldpath)
        for file in filelist:
            fpath = os.path.join(foldpath, file)
            # perform the class to integer mapping, by mapping each labels into one hot vector
            train_filepaths.append(fpath)
            train_labels.append(fold)
            tmp_label = torch.zeros([CFG.num_classes], dtype=torch.float)
            label_num=CFG.label_str2num.get(fold)
            k=int(label_num)
            tmp_label[k] = 1.0
            all_train_img_labels_ts.append(tmp_label)


ts_label_list = [tensor.tolist() for tensor in all_train_img_labels_ts]
# Concatenate data paths with labels into one dataframe
Fseries = pd.Series(train_filepaths, name= 'filepaths')
Lseries = pd.DataFrame({'labels': ts_label_list})
train_df = pd.concat([Fseries, Lseries], axis= 1)

In [None]:
# Fetch the testing data from the path where the data downloaded from kaggle is saved.
# The folder structure was shown above. Since each of these folders has more images and gdown allows only maximum 50 files per folder,
#  we couldn't use gdown to get the files into colab notebook.

test_data_dir = 'BrainMRI/Testing'
test_filepaths = []
test_labels = []
all_test_img_labels_ts = []

folds = os.listdir(test_data_dir)
for fold in folds:
      if not fold.startswith('.'):
        foldpath = os.path.join(test_data_dir, fold)
        filelist = os.listdir(foldpath)
        for file in filelist:
            fpath = os.path.join(foldpath, file)
            # perform the class to integer mapping, by mapping each labels into one hot vector
            test_filepaths.append(fpath)
            test_labels.append(fold)
            tmp_label = torch.zeros([CFG.num_classes], dtype=torch.float)
            label_num=CFG.label_str2num.get(fold)
            k=int(label_num)
            tmp_label[k] = 1.0
            all_test_img_labels_ts.append(tmp_label)

ts_label_list = [tensor.tolist() for tensor in all_test_img_labels_ts]
Fseries = pd.Series(test_filepaths, name='filepaths')
Lseries = pd.DataFrame({'labels': ts_label_list})
ts_df = pd.concat([Fseries, Lseries], axis=1)

In [None]:
print("*************Training Dataset Statistics**************")
print("Num of images in train dataset:",len(train_df))
train_glioma_count = sum(1 for label in train_df['labels'] if label[0] == 1.0)
train_meningioma_count = sum(1 for label in train_df['labels'] if label[1] == 1.0)
train_notumor_count = sum(1 for label in train_df['labels'] if label[2] == 1.0)
train_pituitary_count = sum(1 for label in train_df['labels'] if label[3] == 1.0)
print("Total Number of images with glioma in training dataset:", train_glioma_count)
print("Total Number of images with meningioma in training dataset:", train_meningioma_count)
print("Total Number of images with notumor in training dataset:", train_notumor_count)
print("Total Number of images with pituitary in training dataset:", train_pituitary_count)


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
seed_everything(77)
cfg=CFG()

Here we split the data from testing into two sets, one for validation and one for testing.

In [None]:
valid_df, test_df = train_test_split(ts_df,  train_size= 0.5, shuffle= True, random_state= 123)

In [None]:
print("*************Validation Dataset Statistics**************")
print("Num of images in test dataset:",len(valid_df))
valid_glioma_count = sum(1 for label in valid_df['labels'] if label[0] == 1.0)
valid_meningioma_count = sum(1 for label in valid_df['labels'] if label[1] == 1.0)
valid_notumor_count = sum(1 for label in valid_df['labels'] if label[2] == 1.0)
valid_pituitary_count = sum(1 for label in valid_df['labels'] if label[3] == 1.0)
print("Total Number of images with glioma in testing dataset:", valid_glioma_count)
print("Total Number of images with meningioma in testing dataset:", valid_meningioma_count)
print("Total Number of images with notumor in testing dataset:", valid_notumor_count)
print("Total Number of images with pituitary in testing dataset:", valid_pituitary_count)

print("*************Testing Dataset Statistics**************")
print("Num of images in test dataset:",len(test_df))
test_glioma_count = sum(1 for label in test_df['labels'] if label[0] == 1.0)
test_meningioma_count = sum(1 for label in test_df['labels'] if label[1] == 1.0)
test_notumor_count = sum(1 for label in test_df['labels'] if label[2] == 1.0)
test_pituitary_count = sum(1 for label in test_df['labels'] if label[3] == 1.0)
print("Total Number of images with glioma in testing dataset:", test_glioma_count)
print("Total Number of images with meningioma in testing dataset:", test_meningioma_count)
print("Total Number of images with notumor in testing dataset:", test_notumor_count)
print("Total Number of images with pituitary in testing dataset:", test_pituitary_count)

In [None]:
train_image_list = train_filepaths
train_label_list = all_train_img_labels_ts
valid_image_list = valid_df['filepaths'].tolist()
valid_label_list = valid_df['labels'].tolist()
all_valid_label_list = [torch.tensor(sublist) for sublist in valid_label_list]
test_image_list = test_df['filepaths'].tolist()
test_label_list = test_df['labels'].tolist()
all_test_label_list = [torch.tensor(sublist) for sublist in test_label_list]

In [None]:
"""
Define image transformation
"""
DATASET_IMAGE_MEAN = (0.485, 0.456, 0.406)
DATASET_IMAGE_STD = (0.229, 0.224, 0.225)

train_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.RandomApply([tsfm.ColorJitter(0.2, 0.2, 0.2),tsfm.RandomPerspective(distortion_scale=0.2),], p=0.3),
                                tsfm.RandomApply([tsfm.ColorJitter(0.2, 0.2, 0.2),tsfm.RandomAffine(degrees=10),], p=0.3),
                                tsfm.RandomVerticalFlip(p=0.3),
                                tsfm.RandomHorizontalFlip(p=0.3),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

valid_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

test_transform = tsfm.Compose([tsfm.Resize((384,384)),
                                tsfm.ToTensor(),
                                tsfm.Normalize(DATASET_IMAGE_MEAN, DATASET_IMAGE_STD), ])

In [None]:
"""
Define dataset class
"""
class Dataset(Dataset):
    def __init__(self, cfg, image: list, labels: list, transform=None):

        self.transform = transform
        self.image = image
        self.labels = labels

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):

        img_path = self.image[idx]
        img = Image.open(img_path).convert('RGB')
        img_ts = self.transform(img)
        label_ts = self.labels[idx]

        return img_ts, label_ts

##   Model

CASS combines CNNs and Transformers without any change to their architecture to help both of them learn better representation. The architecture of CASS is below:
\
![picture](https://drive.google.com/uc?/export=view&id=1NSE7agnZrrTYpu-aNAzrhJ8SYIyM-oaN)

In CASS, R represents ResNet50, a CNN and T in the other box represents the Transformer used (ViT - Vision Transformer); X is the input image, which becomes X’ after applying augmentations. The output logits from the CNN and Transformer are then used to find cosine similarity
loss. CASS applies only one set of augmentations to create X’. X’ is passed through both arms to compute loss.
![picture](https://drive.google.com/uc?/export=view&id=16U9D1v0kO0bixk3y07RXAWluq7oppI-A)

R and T represent embeddings from CNN and Transformer, respectively. Same set of parameters for both architectures’ optimizer and learning schedule. This experiment used stochastic weigh averaging with Adam optimizer and a learning rate of 1e-3. For the learning rate, cosine schedule with a maximum of 16 iterations and a minimum value of 1e-6 was used.

**CNN and Transformer:** For CNN and Transformer implementation, Timm's library was being used for the pretrained models. Details in https://github.com/huggingface/pytorch-image-models

In [None]:
model_cnn = timm.create_model(cfg.cnn_name, pretrained=True)
model_vit = timm.create_model(cfg.vit_name, pretrained=True)
model_cnn.to(device)
model_vit.to(device)

In [None]:
#Define the Self-Supervised model with both CNN & Transformer(ViT)
def ssl_train_model(train_loader,model_vit,criterion_vit,optimizer_vit,scheduler_vit,model_cnn,criterion_cnn,optimizer_cnn,scheduler_cnn,num_epochs):
    writer = SummaryWriter()
    phase = 'train'
    model_cnn.train()
    model_vit.train()
    f1_score_cnn=0
    f1_score_vit=0
    for i in tqdm(range(num_epochs)):
        with torch.set_grad_enabled(phase == 'train'):
            for img,_ in train_loader:
                f1_score_cnn=0
                f1_score_vit=0
                img = img.to(device)
                pred_vit = model_vit(img)
                pred_cnn = model_cnn(img)
                model_sim_loss=loss_fn(pred_vit,pred_cnn)
                loss = model_sim_loss.mean()
                loss.backward()
                optimizer_cnn.step()
                optimizer_vit.step()
                scheduler_cnn.step()
                scheduler_vit.step()
            print('For -',i,'Loss:',loss)
            writer.add_scalar("Self-Supervised Loss/train", loss, i)
    writer.flush()

In [None]:
#implementing the loss function for the calculation mentioned above
def loss_fn(x, y):
    x =  torch.nn.functional.normalize(x, dim=-1, p=2)
    y =  torch.nn.functional.normalize(y, dim=-1, p=2)
    return 2 - 2 * (x * y).sum(dim=-1)

**Focal Loss:** Since all medical datasets have a class imbalance, the authors addressed it by applying class distribution normalized Focal Loss described by  [Lin et al.(2017)](https://arxiv.org/pdf/1708.02002)

In [None]:
"""
Define Focal-Loss
"""

class FocalLoss(nn.Module):
    """
    The focal loss for fighting against class-imbalance
    """
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = 1e-12  # prevent training from Nan-loss error
        self.cls_weights = torch.tensor([CFG.cls_weight],dtype=torch.float, requires_grad=False, device=CFG.device)

    def forward(self, logits, target):
        """
        logits & target should be tensors with shape [batch_size, num_classes]
        """
        probs = torch.sigmoid(logits)
        one_subtract_probs = 1.0 - probs
        # add epsilon
        probs_new = probs + self.epsilon
        one_subtract_probs_new = one_subtract_probs + self.epsilon
        # calculate focal loss

        log_pt = target * torch.log(probs_new) + (1.0 - target) * torch.log(one_subtract_probs_new)
        pt = torch.exp(log_pt)
        focal_loss = -1.0 * (self.alpha * (1 - pt) ** self.gamma) * log_pt
        focal_loss = focal_loss * self.cls_weights
        return torch.mean(focal_loss)

In [None]:
"""
Define F1 score metric
"""
class MyF1Score(Metric):
    def __init__(self, cfg, threshold: float = 0.5, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.cfg = cfg
        self.threshold = threshold
        self.add_state("tp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fp", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("fn", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        assert preds.shape == target.shape
        preds_str_batch = self.num_to_str(torch.sigmoid(preds))
        target_str_batch = self.num_to_str(target)
        tp, fp, fn = 0, 0, 0
        for pred_str_list, target_str_list in zip(preds_str_batch, target_str_batch):
            for pred_str in pred_str_list:
                if pred_str in target_str_list:
                    tp += 1
                if pred_str not in target_str_list:
                    fp += 1

            for target_str in target_str_list:
                if target_str not in pred_str_list:
                    fn += 1
        self.tp += tp
        self.fp += fp
        self.fn += fn

    def compute(self):
        #To switch between F1 score and recall.
        f1 = 2.0 * self.tp / (2.0 * self.tp + self.fn + self.fp)
        #rec = self.tp/(self.tp + self.fn)
        return f1

    def num_to_str(self, ts: torch.Tensor) -> list:
        batch_bool_list = (ts > self.threshold).detach().cpu().numpy().tolist()
        batch_str_list = []
        for one_sample_bool in batch_bool_list:
            lb_str_list = [self.cfg.label_num2str[lb_idx] for lb_idx, bool_val in enumerate(one_sample_bool) if bool_val]
            batch_str_list.append(lb_str_list)
        return batch_str_list

**Optimizer:** Stochastic weigh averaging (SWA) with Adam optimizer and a learning rate of 1e-3 was used. More details about SWA in [Stochastic Weight Averaging](https://arxiv.org/abs/1803.05407)

In [None]:
optimizer_cnn = SWA(torch.optim.Adam(model_cnn.parameters(), lr= 1e-3))
optimizer_vit = SWA(torch.optim.Adam(model_vit.parameters(), lr= 1e-3))
scheduler_cnn = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_cnn,
                                                                    T_max=16,
                                                                    eta_min=1e-6)
scheduler_vit = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vit,
                                                                    T_max=16,
                                                                    eta_min=1e-6)

criterion_vit = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
criterion_cnn = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)

##   Training

**Computational Requirement:**

Original Paper: The authors of original paper ran all the experiments on single NVIDIA RTX8000 GPU with 48GB video memory, 2 CPU cores, and 64 GB system RAM. They used the internal clusters for their need.

What we used: We used Google Cloud Platform single T4 GPU (3.75 GB), 16 CPU Core and 60 GB RAM. We couldn't access to multiple GPU units or high performance GPU virtual machines in Google Cloud Platform with the free credit. So there will be performance and calculation difference from the original paper to what we reproduced.


********  **Reduced Label Learning:**  With large set of data, there was computational limitations. So in the draft notebook here, we have used reduced set of labels.

******** We performed training with 100% of labels as well and the trained model file was downloaded here as well. Change the x value in the below code to change the learning label %.


**Runtime:**

|Reduced/Full|Epoch|Runtime|
| --- | --- | --- |
|100%|100|22 H 12 M|
|10%|100|2 H 17 M|
|10%|20|23 M 24 S|



In [None]:
import random
random.seed(77)
x=0.1 #currently set to use 10% of the labels for reduced label training
onep=random.sample(range(0, len(train_image_list)), int(len(train_image_list)*x))
all_train_image_list = [train_image_list[idx] for idx in onep]
all_train_label_list = [train_label_list[idx] for idx in onep]

In [None]:
train_dataset = Dataset(CFG, all_train_image_list, all_train_label_list, train_transform)
valid_dataset = Dataset(CFG, valid_image_list, all_valid_label_list, valid_transform)
train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers)
valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)

In [None]:
#Train SSL
print('Training CASS-1Epoch for Draft')

#enabled the model here in draft with just 1 epoch to show that the code is running, but we trained the model in GCP using 20 epoch & 100 epoch
ssl_train_model(train_loader,model_vit,criterion_vit,optimizer_vit,scheduler_vit,model_cnn,criterion_cnn,optimizer_cnn,scheduler_cnn,num_epochs=1)

# Ran this in GCP with 20 epochs and 100 epochs
#ssl_train_model(train_loader,model_vit,criterion_vit,optimizer_vit,scheduler_vit,model_cnn,criterion_cnn,optimizer_cnn,scheduler_cnn,num_epochs=100)

#Saving SSL Models
print('Saving CASS-1Epoch for Draft')
torch.save(model_cnn,'/content/drive/My Drive/DLH_CASS/BrainMRI/cass-mri-draft.pt') # save the model in your work directory
torch.save(model_vit,'/content/drive/My Drive/DLH_CASS/BrainMRI/cass-mri-vit-draft.pt') # save the model in your work directory

In [None]:
# obtain the model from your work directory
model_cnn=torch.load('/content/drive/My Drive/DLH_CASS/BrainMRI/cass-mri-draft.pt')
model_vit=torch.load('/content/drive/My Drive/DLH_CASS/BrainMRI/cass-mri-vit-draft.pt')

**Supervised fine-tuning:**

**Hyperparams:** For supervised fine-tuning, the  Adam optimizer with a cosine annealing learning rate starting at 3e-04 was used. Since almost all medical datasets have some class imbalance, class distribution normalized Focal Loss was applied to navigate class imbalance. We trained with 50 epochs.



In [None]:
print('Fine tuning CASS-CNN-T')
model_cnn.fc=nn.Linear(in_features=2048, out_features=4, bias=True)
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model_cnn.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model_cnn.train()
from torch.autograd import Variable
best=0
best_val=0
last_loss=math.inf
writer = SummaryWriter()
#for epoch in range(50):
for epoch in range(2): #setting the epoch to 2 for draft running
    for images,label in train_loader:
        model_cnn.train()
        images = images.to(device)
        label = label.to(device)
        model_cnn.to(device)
        optimizer.zero_grad()
        pred_ts=model_cnn(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
        loss.backward()
        optimizer.step()
        scheduler.step()
    train_score=metric.compute()
    logs = {'train_loss': loss, 'f1': train_score, 'lr': optimizer.param_groups[0]['lr']}
    writer.add_scalar("Supervised-CNN Loss/train", loss, epoch)
    writer.add_scalar("Supervised-CNN Recall/train", train_score, epoch)
    for name, weight in model_cnn.named_parameters():
        writer.add_histogram(name,weight, epoch)
        writer.add_histogram(f'{name}.grad',weight.grad, epoch)
    print(logs)
    if best < train_score:
        with torch.no_grad():
            best=train_score
            model_cnn.eval()
            total_loss = 0
            for images,label in valid_loader:
                images = images.to(device)
                label = label.to(device)
                model_cnn.to(device)
                pred_ts=model_cnn(images)
                score_val = val_metric(pred_ts,label)
                val_loss = criterion(pred_ts, label)
                total_loss += val_loss.detach()
            avg_loss=total_loss/ len(train_loader)
            print('Val Loss:',avg_loss)
            val_score=val_metric.compute()
            print('CNN Validation Score:',val_score)
            writer.add_scalar("CNN Supervised F1/Validation", val_score, epoch)
            if avg_loss > last_loss:
                counter+=1
            else:
                counter=0

            last_loss = avg_loss
            if counter > 5:
                print('Early Stopping!')
                break
            else:
                if val_score > best_val:
                    best_val=val_score
                    print('Saving')
                    torch.save(model_cnn,'/content/drive/My Drive/DLH_CASS/BrainMRI/CASS-Draft-CNN-part-ft.pt') #save in your work directory
writer.flush()

In [None]:
print('Fine tunning CASS-ViT')
model_vit.head=nn.Linear(in_features=768, out_features=4, bias=True)
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
optimizer = torch.optim.Adam(model_vit.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model_vit.train()
val_metric=MyF1Score(cfg)
writer = SummaryWriter()
from torch.autograd import Variable
best=0
best_val=0
last_loss=math.inf
#for epoch in range(50):
for epoch in range(2): #setting the epoch to 2 for draft running
    for images,label in train_loader:
        model_vit.train()
        images = images.to(device)
        label = label.to(device)
        model_vit.to(device)
        optimizer.zero_grad()
        pred_ts=model_vit(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts,label)
        loss.backward()
        optimizer.step()
        scheduler.step()
    train_score=metric.compute()
    logs = {'train_loss': loss, 'f1': train_score, 'lr': optimizer.param_groups[0]['lr']}
    writer.add_scalar("Supervised-ViT Loss/train", loss, epoch)
    writer.add_scalar("Supervised-ViT Recall/train", train_score, epoch)
    for name, weight in model_vit.named_parameters():
        writer.add_histogram(name,weight, epoch)
        writer.add_histogram(f'{name}.grad',weight.grad, epoch)
    print(logs)
    if best < train_score:
        with torch.no_grad():
            best=train_score
            model_vit.eval()
            total_loss = 0
            for images,label in valid_loader:
                images = images.to(device)
                label = label.to(device)
                model_vit.to(device)
                pred_ts=model_vit(images)
                score_val = val_metric(pred_ts,label)
                val_loss = criterion(pred_ts, label)
                total_loss += val_loss.detach()
            avg_loss=total_loss/ len(train_loader)
            print('Val Loss:',avg_loss)
            val_score=val_metric.compute()
            print('ViT Validation Score:',val_score)
            writer.add_scalar("ViT Supervised F1/Validation", val_score, epoch)
            if avg_loss > last_loss:
                counter+=1
            else:
                counter=0

            last_loss = avg_loss
            if counter > 5:
                print('Early Stopping!')
                break
            else:
                if val_score > best_val:
                    best_val=val_score
                    print('Saving')
                    torch.save(model_cnn,
                        '/content/drive/My Drive/DLH_CASS/BrainMRI/CASS-Draft-ViT-part-ft.pt') #save in your work directory
writer.flush()

## Evaluation

**Metrics Description:**

we used the F1 score as our metric, which is defined as


F1 = (2* Precision* Recall) / Precision+Recall = (2* TP) / (2* TP+FP+FN)

In [None]:
test_dataset = Dataset(CFG, test_image_list,all_test_label_list, test_transform)
test_loader = DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, drop_last=True)

In [None]:
model=torch.load('/content/drive/My Drive/DLH_CASS/BrainMRI/CASS-Draft-CNN-part-ft.pt')

In [None]:
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model.eval()

with torch.no_grad():
    for images,label in test_loader:
        images = images.to(device)
        label = label.to(device)
        model.to(device)
        pred_ts=model(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
test_score=metric.compute()
#logs = {'train_loss': loss, 'Recall': test_score, 'lr': optimizer.param_groups[0]['lr']}
logs = {'train_loss': loss, 'f1': test_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)

In [None]:
model=torch.load('/content/drive/My Drive/DLH_CASS/BrainMRI/CASS-Draft-ViT-part-ft.pt')

In [None]:
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model.eval()

with torch.no_grad():
    for images,label in test_loader:
        images = images.to(device)
        label = label.to(device)
        model.to(device)
        pred_ts=model(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
test_score=metric.compute()
#logs = {'train_loss': loss, 'Recall': test_score, 'lr': optimizer.param_groups[0]['lr']}
logs = {'train_loss': loss, 'f1': test_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)

# Results

CASS model is more robust to changes in batch size and pretraining epochs. Authors from original paper captured the results by training the data with 1%, 10% and 100% labelled data and 100 epochs. They compared the result with other state of art models to show the efficiency of CASS.

We performed training with 10% labelled data with 100 epochs, 20 epochs for self supervised training and 50 epoch & 10 epoch for fine tuning.  We loaded those models below and calculated the f1 score for testing.


**Result Metrics:**

Metrics for batch size 16 for 10% reduced label learning with 100 epoch for BrainMRI classification dataset.
![picture](https://drive.google.com/uc?/export=view&id=1wGB1VrTJOk1yjJ-WvNMmVbCA4t6IxVLV)

In [None]:
# Pre-trained model with 10% reduced label, trained for 20 epoch and fine tuned with 10 epoch - f1 score as 0.5440
# loaded that pretrained model here
!gdown --id '1sQMmCZr5wzbxG0QrtiIhJPoTW_GnrdtY' --output CASSCNN10R20E10ft.pt

model = torch.load('/content/CASSCNN10R20E10ft.pt')
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model.eval()

with torch.no_grad():
    for images,label in test_loader:
        images = images.to(device)
        label = label.to(device)
        model.to(device)
        pred_ts=model(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
test_score=metric.compute()
#logs = {'train_loss': loss, 'Recall': test_score, 'lr': optimizer.param_groups[0]['lr']}
logs = {'train_loss': loss, 'f1': test_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)

In [None]:
# Pre-trained model with 10% reduced label and pretrained for 100 epoch and fine tuned with 50 epoch and got the f1 score as 0.6920
# loaded that pretrained model here
!gdown --id '1ywhZ2FhKW3IGsgUdFuIvOOHHOsOIrCm2' --output CASSCNN10R100E50ft.pt

model = torch.load('/content/CASSCNN10R100E50ft.pt')
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model.eval()

with torch.no_grad():
    for images,label in test_loader:
        images = images.to(device)
        label = label.to(device)
        model.to(device)
        pred_ts=model(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
test_score=metric.compute()
#logs = {'train_loss': loss, 'Recall': test_score, 'lr': optimizer.param_groups[0]['lr']}
logs = {'train_loss': loss, 'f1': test_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)



---


 *******We also trained the model with 100% data and 100 epoch and noticed vanishing gradient problem.

 100% data with 100 epoch runs for 22 hours in single core T4 GPU(3.5GB Memory) 16 Core CPU and 60 GB RAM

---



## Model comparison

### Runtime Comparison:

This is the Self-supervised pretraining time comparison between DINO ( [Caron et al. (2021)](https://arxiv.org/abs/2104.14294)), the Original CASS implementation by the authors for 100 epochs on a single RTX8000 GPU and the reproduced time with GCP NVIDIA T4 - 1 GPU

| DataSet | DINO | CASS - Original | CASS - Reproduced |
| --- | --- | --- | --- |
| BrainMRI | 26 H 21 M | 7 H 11 M | 22 H 12 M

Even with limited computational resources, we could see that the training time is lesser than the state of art DINO model that the original authors compared with.

### Metrics Comparison:

Original paper has compared the CASS result with other state of art model techniques such as DINO and transfer. To achieve this result they pretrained and fine tuned the model with 100 epoch.

![picture](https://drive.google.com/uc?/export=view&id=1mDix1rSuNpTFtRo_2ekmM0MAQEp1Sfps)


### Our Test Result:

|Techniques|Backbone|F1 Score 1%|F1Score 10%|
| --- | --- | --- | --- |
|CASS|Resnet-50|0.3968|0.6920|
|CASS|ViT B/16|0.2318|0.6164|

We couldn't achieve the exact score from the original authors because of the computational limits, but with our limited resources we could get a good score that is comparable with the other models.

## Ablations


###Change in epochs:


Study the effect of longer pre training on the selected dataset by changing the epoch and observe the performance variations.

Performance Comparison over a varied number of epochs for 10% label training:

|Epoch| F1-Score|
|---|---|
|20|0.5440|
|50|0.6426|
|100|0.6920|



In [None]:
# We trained a model with 10% reduced label and pretrained for 100 epoch and fine tuned with 50 epoch and got the f1 score as 0.6920
# loaded that pretrained model here
!gdown --id '1ywhZ2FhKW3IGsgUdFuIvOOHHOsOIrCm2' --output CASSCNN10R100E50ft.pt

model = torch.load('/content/CASSCNN10R100E50ft.pt')
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model.eval()

with torch.no_grad():
    for images,label in test_loader:
        images = images.to(device)
        label = label.to(device)
        model.to(device)
        pred_ts=model(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
test_score=metric.compute()
#logs = {'train_loss': loss, 'Recall': test_score, 'lr': optimizer.param_groups[0]['lr']}
logs = {'train_loss': loss, 'f1': test_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)

In [None]:
# Pre-trained model with 10% reduced label, trained for 20 epoch and fine tuned with 10 epoch - f1 score as 0.5440
# loaded that pretrained model here
!gdown --id '1sQMmCZr5wzbxG0QrtiIhJPoTW_GnrdtY' --output CASSCNN10R20E10ft.pt

model = torch.load('/content/CASSCNN10R20E10ft.pt')
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model.eval()

with torch.no_grad():
    for images,label in test_loader:
        images = images.to(device)
        label = label.to(device)
        model.to(device)
        pred_ts=model(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
test_score=metric.compute()
#logs = {'train_loss': loss, 'Recall': test_score, 'lr': optimizer.param_groups[0]['lr']}
logs = {'train_loss': loss, 'f1': test_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)

###Change in Batch Size:

Trained the model with different batch sizes and the model is robust to changes.

![picture](https://drive.google.com/uc?/export=view&id=1hdikyfvPxAahWEPgwJaeugw6ntWV4DBD)
![picture](https://drive.google.com/uc?/export=view&id=1v4XENTuyg8rbvAQBKvK1nFaiMLeLmrFW)
![picture](https://drive.google.com/uc?/export=view&id=1PNSbofP-jy_UMEE90oEzUlrNauuWfGxU)




In [None]:
# Model with 10% reduced label and pretrained for 100 epoch and fine tuned with 50 epoch for batch size 16
# loaded that pretrained model here
!gdown --id '1ywhZ2FhKW3IGsgUdFuIvOOHHOsOIrCm2' --output CASSCNN10R100E50ft.pt

model = torch.load('/content/CASSCNN10R100E50ft.pt')
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model.eval()

with torch.no_grad():
    for images,label in test_loader:
        images = images.to(device)
        label = label.to(device)
        model.to(device)
        pred_ts=model(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
test_score=metric.compute()
#logs = {'train_loss': loss, 'Recall': test_score, 'lr': optimizer.param_groups[0]['lr']}
logs = {'train_loss': loss, 'f1': test_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)

In [None]:
# Model with 10% reduced label and pretrained for 100 epoch and fine tuned with 50 epoch for batch size 8
# loaded that pretrained model here
!gdown --id '17e_fF0nQc2YT0sJh-UFjHHI-ehENreLN' --output CASSCNN8BAT10R100E50ft.pt

model = torch.load('/content/CASSCNN8BAT10R100E50ft.pt')
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model.eval()

with torch.no_grad():
    for images,label in test_loader:
        images = images.to(device)
        label = label.to(device)
        model.to(device)
        pred_ts=model(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
test_score=metric.compute()
#logs = {'train_loss': loss, 'Recall': test_score, 'lr': optimizer.param_groups[0]['lr']}
logs = {'train_loss': loss, 'f1': test_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)

In [None]:
# Model with 10% reduced label and pretrained for 100 epoch and fine tuned with 50 epoch for batch size 4
# loaded that pretrained model here
!gdown --id '1K2hgb4bXkI7sppGru4hs3w5OQ8-2JBLy' --output CASSCNN4BAT10R100E50ft.pt

model = torch.load('/content/CASSCNN4BAT10R100E50ft.pt')
criterion = FocalLoss(cfg.fl_alpha, cfg.fl_gamma)
metric = MyF1Score(cfg)
val_metric=MyF1Score(cfg)
optimizer = torch.optim.Adam(model.parameters(), lr = 3e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.t_max,eta_min=cfg.min_lr,verbose=True)
model.eval()

with torch.no_grad():
    for images,label in test_loader:
        images = images.to(device)
        label = label.to(device)
        model.to(device)
        pred_ts=model(images)
        loss = criterion(pred_ts, label)
        score = metric(pred_ts, label)
test_score=metric.compute()
#logs = {'train_loss': loss, 'Recall': test_score, 'lr': optimizer.param_groups[0]['lr']}
logs = {'train_loss': loss, 'f1': test_score, 'lr': optimizer.param_groups[0]['lr']}
print(logs)

# Discussion

The code given in github repository is reproducible with few changes from our side for data preprocessing and some code fixes for the error we encountered during execution. Due to hardware limitations, we couldn't acheive the Self-supervised pretraining time mentioned by the authors in the paper. Also while training with 100 epoch, we noticed the vanishing gradient issue.

![picture](https://drive.google.com/uc?/export=view&id=1o5JrUp7w-n8TxBWa8q28OU6hieuYY7u-)

For Ablations, we trained the model with different epochs and different batch sizes. With these studies, we can see that CASS is more robust to changes as mentioned by the authors.

For others who are trying to reproduce this, if you need to achieve the exact result as the authors, you need to have the computational resources available. But we can execute the code with reduced epoch and small subset of data using GCP free trial version or Google Colab Pro. Also the code has the data processing steps only for MedMNIST and ISIC 2019. If you are using other dataset, you need to preprocess the code like how we did for BrainMRI Classification. There are few errors with optimizer and other logic in e2e fine tuning code, which needs to be fixed. We have the defect fixed running code in this notebook.





# References

1.   Pranav Singh & Jacopo Cirrone, [Efficient Representation Learning for Healthcare with
Cross-Architectural Self-Supervision], [Proceedings of Machine Learning Research 219:1–36, 2023],[https://proceedings.mlr.press/v219/singh23a/singh23a.pdf]

2. Original Source Code: https://github.com/pranavsinghps1/CASS

3. Kaggle Source Data Link for BrainMRI: https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset

4. Kaggle Source Data Link for ISIC2019: https://www.kaggle.com/datasets/andrewmvd/isic-2019

