<a href="https://colab.research.google.com/github/Zain506/MedCLIP-SAM/blob/main/notebooks/medclipsam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## MedCLIP-SAM

[Research Paper](https://arxiv.org/pdf/2403.20253)



In [1]:
from datasets import load_dataset

ds = load_dataset("adishourya/MEDPIX-ClinQA") # Import MedPIX dataset


In [2]:
train_valid = ds["train"].train_test_split(test_size=0.15)
training = train_valid["train"] # .select(range(1024))
validation = train_valid["test"]
print(training)

Dataset({
    features: ['image_id', 'mode', 'case_id', 'question', 'answer'],
    num_rows: 17425
})


In [3]:
import torch
device  = "cuda" if torch.cuda.is_available() else "cpu"
print("Processor: ", device)

Processor:  cuda


## There are 3 components of the CLIP-style architecture
1. Text Encoder (PubMedBERT)
2. Image Encoder (ViT)
3. Model mapping encoded images and text into the shared embedding space

**BiomedCLIP from open_clip contains all of them, and that is what MedCLIP-SAM aims to fine-tune**

In [4]:
%%capture cap
%pip install open_clip_torch # run ``cap.show()`` in order to see the output here

In [5]:
import torch
from PIL import Image
import open_clip

model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
# model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
model = model.to(device)
tokenizer = open_clip.get_tokenizer('ViT-B-32')

In [6]:
import torch
from torch.utils.data import DataLoader
bsize: int = 10 # Batch Size
# In this case the data in the dataset isn't in tensor form - so we need to define a custom collate_fn to convert them into tensors
def collate_fn(batch): # Define custom collate_fn to convert relevant data in each batch to tensor
  images = torch.stack([preprocess(x['image_id']) for x in batch])
  texts = tokenizer([f"Prompt: {x['question']} \nAnswer: {x['answer']}" for x in batch])

  return images, texts

train_loader = DataLoader(training, batch_size=bsize, shuffle=True, collate_fn=collate_fn) # Each tensor in the train_loader will have dimension `batch_size` in axis=0

In [37]:
import torch
import torch.nn.functional as F


def _lossComp(sim: torch.Tensor, t: float, b: float) -> torch.Tensor: # Calculate total loss
  # Calculate exponential of scaled similarity
  A = torch.exp(sim/t)
  # Calculate Weight matrix
  I = torch.eye(bsize).to(device)
  W = (bsize -1) * torch.softmax(A ** b - I@(A**b), dim=-1)
  # Hadamaard product
  tmp = A * W
  # Sum of log of sum
  loss = torch.sum(torch.log(torch.sum(tmp, dim=-1)))
  return loss

def Loss(sim: torch.Tensor, t, b1, b2) -> torch.Tensor:
  return _lossComp(sim, t, b1) + _lossComp(sim.T, t, b2)

In [54]:
# Iterate over images, texts in train_loader
import torch.nn.functional as F
import torch
from tqdm.notebook import tqdm
optimizer = torch.optim.AdamW( # Freezing a part of the model requires us to set requires_grad=False
    model.parameters(), # All trainable params
    lr = 5e-6,
    weight_decay = 0.01 # L1 reg
)

t = 0.1
b1 = 0.1
b2 = 0.1

epochs = 1
for epoch in range(epochs):
  for images, texts in tqdm(train_loader, desc=f"Epoch {epoch + 1}"):

    images = images.to(device)
    texts = texts.to(device)
    I = model.encode_image(images)
    I = F.normalize(I, p=2, dim=-1)
    T = model.encode_text(texts)
    T = F.normalize(T, p=2, dim=-1)
    sim = I @ T.T
    sim /= t
    A = torch.exp(sim)
    print("A: ", A)
    print(A.shape)
    # Work on softmax function: Does it work well?
    W = torch.softmax((sim ** b1), dim=-1).fill_diagonal_(-1).to(device)
    print("W: ", W)
    # loss = Loss(sim, t, b1, b2)
    # print(loss)
    break

Epoch 1:   0%|          | 0/1743 [00:00<?, ?it/s]

A:  tensor([[24.4138, 10.4554, 12.2073, 18.1422, 10.7611, 23.0422, 17.0301,  6.3610,
         19.6024, 14.3857],
        [14.1356, 11.2177, 22.4982, 19.9671,  9.5874, 30.5742, 14.9694, 12.7460,
         13.9993, 13.3294],
        [14.2592, 10.6221, 18.8314, 23.7812,  9.5610, 19.8081, 18.5962, 11.2190,
         15.1303, 14.6585],
        [18.5039, 11.7955, 16.9714, 23.2924,  9.0552, 29.3055, 16.0602,  7.3933,
         13.4074, 12.2411],
        [13.7196, 11.8740, 12.8726, 17.7133,  9.2059, 17.8600, 18.2057,  7.2555,
          9.9457, 10.0960],
        [13.2472, 13.4415, 17.4430, 19.4257,  8.4657, 26.9604, 19.7382,  8.9739,
         13.9328, 14.8924],
        [12.2688,  9.6273, 18.1002, 20.7181,  9.2868, 19.4978, 13.3780, 10.2052,
         12.1526, 12.8853],
        [10.1821,  8.8881, 23.8240, 19.3660, 10.4198, 16.6380, 13.4612, 14.9733,
         10.7721, 12.9717],
        [22.6513, 13.1835, 21.2604, 26.5196, 11.2503, 32.2647, 20.7832, 11.5874,
         19.1846, 15.6024],
        [15.104