<a href="https://colab.research.google.com/github/Zain506/Similarity/blob/main/notebooks/medclipsam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## MedCLIP-SAM

[Research Paper](https://arxiv.org/pdf/2403.20253)



In [25]:
from datasets import load_dataset

ds = load_dataset("adishourya/MEDPIX-ClinQA") # Import MedPIX dataset
print(ds)

DatasetDict({
    train: Dataset({
        features: ['image_id', 'mode', 'case_id', 'question', 'answer'],
        num_rows: 20500
    })
})


In [33]:
import torch
device  = "cuda" if torch.cuda.is_available() else "cpu"
print("Processor: ", device)

Processor:  cuda


## There are 3 components of the CLIP-style architecture
1. Text Encoder (PubMedBERT)
2. Image Encoder (ViT)
3. Model mapping encoded images and text into the shared embedding space

**BiomedCLIP from open_clip contains all of them, and that is what MedCLIP-SAM aims to fine-tune**

In [37]:
%%capture cap
%pip install open_clip_torch # run ``cap.show()`` in order to see the output here

In [40]:
import torch
from PIL import Image
import open_clip

model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer('ViT-B-32')


# This section should handle a batch and construct the similarity matrix bit by bit
image = preprocess(ds["train"][0]["image_id"]).unsqueeze(0)
text = tokenizer(["Subcranial Hematoma", "Intracranial Hematoma", "Perfectly healthy brain"])

with torch.no_grad(), torch.autocast("cuda"):
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)

print("Label probs:", text_probs)  # prints: [[1., 0., 0.]]

Label probs: tensor([[3.7976e-02, 9.6202e-01, 5.6641e-06]])
