In [None]:
%load_ext autoreload
%autoreload 2

#### Data

MedMNIST2D

In [None]:
# Install MedMNIST data
# !pip install medmnist

Libraries

In [None]:
from tqdm import tqdm
import warnings
warnings.simplefilter("ignore")

import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import torch
import torchvision.transforms as transforms
from utils import load_data_and_data_loaders
from transformers import CLIPProcessor, CLIPModel


import medmnist
from medmnist import INFO

Configuration

In [None]:
data_flag = 'octmnist'
download = True

BATCH_SIZE = 128
NUM_EPOCHS = 50

info = INFO[data_flag]
label_dict = info["label"]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

if data_flag == "pathmnist":
    label_choices = [f"a {label} tissue sample" for label in label_dict.values()]
elif data_flag == "chestmnist":
    label_choices = [f"an X-Ray image with {label} disease" for label in label_dict.values()]
elif data_flag == "octmnist":
    label_choices = [f"a Retinal OCT image with {label} disease" for label in label_dict.values()]
else:
    label_choices = list(label_dict.values())

In [None]:
# preprocessing
data_transform = transforms.Compose([transforms.ToTensor()])
training_data, validation_data, test_data, training_loader, validation_loader, test_loader, _ = load_data_and_data_loaders("BLOCK",
                                                                                                                           data_flag,
                                                                                                                           BATCH_SIZE)

##### CLIP - Low Rank Adaptation

LoRA is a popular and lightweight training technique that significantly reduces the number of trainable parameters. It works by inserting a smaller number of new weights into the model and only these are trained.

Libraries

In [None]:
import random
from peft import LoraConfig, get_peft_model
from torchmetrics.functional.classification import multiclass_accuracy, multiclass_auroc

Utility functions

In [None]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}")

Model - LoRA adapter

Vision and Text Encoder
* Q-projection head
* V-projection head
* Fully-connected 1 (Projection head) 
* Fully-connected 2 (Projection head) 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
print_trainable_parameters(model)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
config = LoraConfig(
    r=32,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "fc1", "fc2"],
    lora_dropout=0.1,
    bias="none",
)
lora_model = get_peft_model(model, config)
print_trainable_parameters(lora_model)

In [None]:
def lora_tune(model):
    torch.manual_seed(0)
    random.seed(0)
    np.random.seed(0)
    
    ## Loss function
    loss_img = torch.nn.CrossEntropyLoss()
    loss_txt = torch.nn.CrossEntropyLoss()

    ## Fine-tuning layers
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.Adam(params, lr=1e-5, weight_decay=0.00001)

    num_batches_train = len(training_loader.dataset)/BATCH_SIZE

    ## Training
    for epoch in range(10):
        print(f"Epoch: {epoch}")
        epoch_train_loss = 0
        model.train()
        for batch in tqdm(training_loader, total=num_batches_train):
            optimizer.zero_grad()
            
            ## Format input data
            (x, labels) = batch
            x = x.to(device)
            text_labels = [f"{label_dict[str(label.cpu().numpy()[0])]}" for label in labels]
            try:
                inputs = processor(text=text_labels, images=x, return_tensors="pt",
                                do_rescale=False,
                                do_center_crop=False,
                                padding=True)
                for k, v in inputs.items():
                    inputs[k] = v.to(device) # set each processed data to device

                outputs = model(**inputs)
                logits_per_image = outputs.logits_per_image # this is the image-text similarity score
                logits_per_text = outputs.logits_per_text # this is the image-text similarity score

                ground_truth = torch.arange(logits_per_image.shape[0], dtype=torch.long, device=device)

                total_train_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
                total_train_loss.backward()
                epoch_train_loss += total_train_loss

                torch.nn.utils.clip_grad_norm_(params, 5.0)

                if device == "cpu":
                    optimizer.step()
                else:
                    optimizer.step()
            except Exception as e:
                print("Unable to train due to: ", e)

        epoch_train_loss /= num_batches_train

        # Compute validation performance
        model.eval()
        acc_top1_list = []
        acc_top3_list = []

        num_batches_val = len(validation_loader.dataset)/BATCH_SIZE

        for _, batch in enumerate(tqdm(validation_loader, total=num_batches_val)):
            (val_x, val_labels) = batch
            val_x = val_x.to(device)
            val_labels = val_labels.to(device)
            text_labels = [f"a {label} tissue sample" for label in label_dict.values()]

            inputs = processor(text=text_labels, images=val_x, return_tensors="pt",
                                do_rescale=False,
                                do_center_crop=False,
                                padding=True)
                                
            for k, v in inputs.items():
                inputs[k] = v.to(device) # set each processed data to device

            with torch.no_grad():
                outputs = model(**inputs)

            image_features, text_features = outputs.image_embeds, outputs.text_embeds
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

            acc_top1 = multiclass_accuracy(similarity, val_labels.squeeze(), num_classes=len(label_dict))
            acc_top3 = multiclass_accuracy(similarity, val_labels.squeeze(), top_k=3, num_classes=len(label_dict))
            acc_top1_list.append(acc_top1)
            acc_top3_list.append(acc_top3)

        print(f"Epoch {epoch} train loss: {epoch_train_loss / num_batches_train}")

        # compute mean top3 accuracy and top1 accuracy
        mean_top3_accuracy = torch.stack(acc_top3_list).mean().cpu().numpy()
        print(f"Mean Top 3 Accuracy: {mean_top3_accuracy*100}%.")
        mean_top1_accuracy = torch.stack(acc_top1_list).mean().cpu().numpy()
        print(f"Mean Top 1 Accuracy: {mean_top1_accuracy*100}%.")

In [None]:
lora_tune(model=lora_model)

Save model

In [None]:

weights_path = Path(f"./lora_model/{data_flag}")
os.makedirs(weights_path, exist_ok=True)
torch.save({"model_state_dict": lora_model}, weights_path / f"model.pt")

##### CLIP - Attention to Detail

* #TODO: Attention visualization maps

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
lora_model.eval()

Evaluate LoRA model

In [None]:
def evaluate_clip(model, data_loader, save_dir):
    auc_list = []
    acc_top1_list = []
    acc_top3_list = []
    similarities = []
    for _, batch in enumerate(tqdm(data_loader)):
        (val_x, val_labels) = batch
        val_x = val_x.to(device)
        val_labels = val_labels.to(device)
        text_labels = label_choices

        inputs = processor(text=text_labels, images=val_x, return_tensors="pt",
                            do_rescale=False,
                            do_center_crop=False,
                            padding=True)
                            
        for k, v in inputs.items():
            inputs[k] = v.to(device) # set each processed data to device

        with torch.no_grad():
            outputs = model(**inputs)

        image_features, text_features = outputs.image_embeds, outputs.text_embeds
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

        acc_top1 = multiclass_accuracy(similarity, val_labels.squeeze(), num_classes=len(label_choices))
        acc_top3 = multiclass_accuracy(similarity, val_labels.squeeze(), top_k=3, num_classes=len(label_choices))
        auc = multiclass_auroc(similarity, val_labels.squeeze(), num_classes=len(label_choices))
        acc_top1_list.append(acc_top1)
        acc_top3_list.append(acc_top3)
        auc_list.append(auc)
        similarities.append(similarity)

    # compute mean top3 accuracy and top1 accuracy
    mean_top3_accuracy = torch.stack(acc_top3_list).mean().cpu().numpy()
    print(f"Mean Top 3 Accuracy: {mean_top3_accuracy*100}%.")
    mean_top1_accuracy = torch.stack(acc_top1_list).mean().cpu().numpy()
    print(f"Mean Top 1 Accuracy: {mean_top1_accuracy*100}%.")
    mean_auc = torch.stack(auc_list).mean().cpu().numpy()
    print(f"Mean Current AUC: {mean_auc*100}%")

    predicted_similarities = torch.cat(similarities).cpu().numpy()
    os.makedirs(save_dir, exist_ok=True)
    np.save(f"{save_dir}/val_preds.npy", predicted_similarities)


In [None]:
evaluate_clip(model=lora_model, data_loader=test_loader, save_dir=f"./lora_model/{data_flag}/preds")

Evaluate Zero-shot CLIP model

##### CLIP - Zero-shot classification

* Prompt:
    * pathmnist: "a {label} tissue sample"
    * octmnist:  "a Retinal OCT image with {label} disease"

* class_label := INFO[data_flag]["label"].values()

In [None]:
# Predictions saved through zero shot classification using CLIP
# ref: clip_zero_shot_classification.py

train_preds = np.load(f"./results/zero-shot/train_{data_flag}.npy")
val_preds = np.load(f"./results/zero-shot/val_{data_flag}.npy")
test_preds = np.load(f"./results/zero-shot/test_{data_flag}.npy")

Metrics

In [None]:
from medmnist.evaluator import getACC, getAUC

In [None]:
def evaluation(true, preds, info, split: str="Train"):
    # Accuracy
    topk_acc = getACC(true, preds, task=info["task"])
    print(f"Accuracy: {topk_acc}")
    
    # AUC
    micro_roc_auc_ovr = getAUC(true, preds, task=info["task"])

    print(f"ROC AUC score: {micro_roc_auc_ovr:.4f}")

evaluation(test_data.labels.squeeze(), test_preds, info)