<a href="https://colab.research.google.com/github/ShruthiVidya-git/MultimodalContrastiveLearning/blob/main/Sample/Sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

>Executable file for sample data

This code is adopted from Gloria Repository

In [1]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.21.2-py3-none-any.whl (4.7 MB)
[K     |████████████████████████████████| 4.7 MB 4.3 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 41.6 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.9.0-py3-none-any.whl (120 kB)
[K     |████████████████████████████████| 120 kB 58.3 MB/s 
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.9.0 tokenizers-0.12.1 transformers-4.21.2


In [98]:
import sys, torch.nn as nn, torch, torchvision, pickle, os, pathlib, re, numpy as np, pandas as pd, glob, gc,numpy as np, pandas as pd, random, os, warnings, cv2
from torchvision import models as models_2d
from sklearn import metrics
from tqdm import tqdm
from torch.autograd import Variable
from transformers import AutoTokenizer, BertModel, AutoModel

random.seed(500)

In [100]:
class VisionTransformer( nn.Module):
    def __init__(self):
        super(VisionTransformer,self).__init__()

        # specificatoins for vit
         
        self.patch_size = 32
        self.num_channels = 3
        self.num_heads = 8
        self.embed_dim = 768
        self.hidden_dim = 512
        self.num_patches = (256 // self.patch_size) ** 2
        self.dropout= 0.1
        self.num_layers = 6
        
        # Layers/Networks
        self.input_layer = nn.Linear(self.num_channels*(self.patch_size**2), self.embed_dim)

        self.layer_norm_1 = nn.LayerNorm(self.embed_dim)
        self.attn = nn.MultiheadAttention(self.embed_dim, self.num_heads,
                                          dropout=self.dropout)
        self.layer_norm_2 = nn.LayerNorm(self.embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(self.embed_dim, self.hidden_dim),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.hidden_dim, self.embed_dim),
            nn.Dropout(self.dropout))
       
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(self.embed_dim)
        )
        self.dropout = nn.Dropout(self.dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1,1,self.embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+self.num_patches,self.embed_dim))

 
    #image to patch
    def img_to_patch(self, x, patch_size = 32, flatten_channels=True):
        B, C, H, W = x.shape
        x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
        x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
        if flatten_channels:
                x = x.flatten(2,4)          
        return x



    def encode(self, x):
        # Preprocess input
        x = self.img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]
        x1 = x
        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x)) 
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        cls = x[0]
        return cls
        
    
    #similarity global function 
    def get_global_similarities( self, img_emb_g, text_emb_g):
        img_emb_g = img_emb_g.detach().cpu().numpy()
        text_emb_g = text_emb_g.detach().cpu().numpy()
        global_similarities = metrics.pairwise.cosine_similarity(img_emb_g, text_emb_g)
        global_similarities = torch.Tensor(global_similarities*10)
        return global_similarities

In [101]:
if __name__=="__main__":   
    checkpoints_path = '/content/checkpoint_state_dict.pt'
    img_path = '/content/Sample Image.jpg'
    prompts_path = '/content/class_prompts_embeddings.pickle'

    # import and load checkpoints
    vit = VisionTransformer()
    checkpoint = torch.load(checkpoints_path)
    vit.load_state_dict(checkpoint)

    # read, resize & normalize image
    sample = cv2.imread(img_path)
    sample_img = cv2.resize(sample, (256,256),interpolation = cv2.INTER_CUBIC)
    sample_img = (sample_img - np.min(sample_img)) / (np.max(sample_img) - np.min(sample_img)) 
    sample_img = torch.reshape(torch.tensor(sample_img), (3,256,256)).unsqueeze(0).type(torch.FloatTensor)

    # get encoded image 
    img_g = vit.encode(sample_img)

    # read class prompt embeddings from pickle file
    with open(prompts_path, "rb") as f:
        promp_embeddings = pickle.load(f)

    text_g = promp_embeddings['global_embed']['Cardiomegaly'].unsqueeze(0)
    similarity = vit.get_global_similarities(img_g, text_g)
    threshold = 0
    if similarity > threshold:
        print('Predicted class is Cardiomegaly ! ')
    else:
        print('No Findings')

Predicted class is Cardiomegaly ! 
