<a href="https://colab.research.google.com/github/Pratyushk2003/albef-t2i-retrieval/blob/main/ALBEF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
!git clone https://github.com/salesforce/ALBEF.git

Cloning into 'ALBEF'...
remote: Enumerating objects: 353, done.[K
remote: Counting objects: 100% (145/145), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 353 (delta 87), reused 82 (delta 82), pack-reused 208[K
Receiving objects: 100% (353/353), 71.56 MiB | 35.01 MiB/s, done.
Resolving deltas: 100% (134/134), done.


In [25]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [13]:
import sys
sys.path.append('/content/ALBEF/')

In [None]:
!pip install timm

In [None]:
!pip install transformers==4.8.1

In [None]:
import re
file_path = '/kaggle/working/ALBEF/models/xbert.py'
with open(file_path, 'r') as file:
    file_content = file.read()
modified_content = re.sub(r'\btokenizer_class\b', 'processor_class', file_content)
with open(file_path, 'w') as file:
    file.write(modified_content)

In [53]:
from functools import partial
from models.vit import VisionTransformer
from models.xbert import BertConfig, BertModel
from models.tokenization_bert import BertTokenizer

import torch
from torch import nn
from torchvision import transforms

import json

class VL_Transformer_ITM(nn.Module):
    def __init__(self,
                 text_encoder = None,
                 config_bert = ''
                 ):
        super().__init__()

        bert_config = BertConfig.from_json_file(config_bert)
        self.visual_encoder = VisionTransformer(
            img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
            mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6))

        self.text_encoder = BertModel.from_pretrained(text_encoder, config=bert_config, add_pooling_layer=False)

        self.itm_head = nn.Linear(768, 2)


    def forward(self, image, text):
        image_embeds = self.visual_encoder(image)
        print("image embeddings",image_embeds)
        image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(image.device)

        output = self.text_encoder(text.input_ids,
                                attention_mask = text.attention_mask,
                                encoder_hidden_states = image_embeds,
                                encoder_attention_mask = image_atts,
                                return_dict = True,
                               )

        vl_embeddings = output.last_hidden_state[:,0,:]
        vl_output = self.itm_head(vl_embeddings)
        return output, image_embeds

In [54]:
from transformers import AutoTokenizer, VisionEncoderDecoderModel
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [55]:
model = VL_Transformer_ITM(text_encoder='bert-base-uncased', config_bert='/content/ALBEF/configs/config_bert.json')

In [56]:
import re

def pre_caption(caption,max_words=30):
    caption = re.sub(
        r"([,.'!?\"()*#:;~])",
        '',
        caption.lower(),
    ).replace('-', ' ').replace('/', ' ')

    caption = re.sub(
        r"\s{2,}",
        ' ',
        caption,
    )
    caption = caption.rstrip('\n')
    caption = caption.strip(' ')

    #truncate caption
    caption_words = caption.split(' ')
    if len(caption_words)>max_words:
        caption = ' '.join(caption_words[:max_words])
    return caption

In [57]:
from PIL import Image

import cv2
import numpy as np

from skimage import transform as skimage_transform
from scipy.ndimage import filters
from matplotlib import pyplot as plt

def getAttMap(img, attMap, blur = True, overlap = True):
    attMap -= attMap.min()
    if attMap.max() > 0:
        attMap /= attMap.max()
    attMap = skimage_transform.resize(attMap, (img.shape[:2]), order = 3, mode = 'constant')
    if blur:
        attMap = filters.gaussian_filter(attMap, 0.02*max(img.shape[:2]))
        attMap -= attMap.min()
        attMap /= attMap.max()
    cmap = plt.get_cmap('jet')
    attMapV = cmap(attMap)
    attMapV = np.delete(attMapV, 3, 2)
    if overlap:
        attMap = 1*(1-attMap**0.7).reshape(attMap.shape + (1,))*img + (attMap**0.7).reshape(attMap.shape+(1,)) * attMapV
    return attMap


normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

transform = transforms.Compose([
    transforms.Resize((384,384),interpolation=Image.BICUBIC),
    transforms.ToTensor(),
    normalize,
])

In [85]:
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
images = Image.open(requests.get(url, stream=True).raw)
text_input = "a photo of a cat"


In [87]:
import torchvision.transforms as transforms
image = transform(images).unsqueeze(0)
caption = 'a photo of a cat'
text = pre_caption(caption)
text_input = tokenizer(text, return_tensors="pt")


In [88]:
output, image_embed = model(image, text_input)

image embeddings tensor([[[ 0.4824,  0.9153, -1.9881,  ..., -0.4806,  0.0265, -0.2992],
         [ 1.0946, -0.0651, -2.6840,  ...,  0.1419, -0.5108, -0.9272],
         [ 1.3450,  0.0102, -2.8977,  ..., -0.1485, -0.8198, -1.2747],
         ...,
         [ 1.5173,  0.5864, -3.2647,  ..., -1.3810, -1.1761, -2.1298],
         [ 1.1281,  0.3195, -3.2699,  ..., -1.4792, -1.2343, -1.9161],
         [ 1.5978,  0.6517, -3.0029,  ..., -1.6326, -1.1732, -2.0038]]],
       grad_fn=<NativeLayerNormBackward0>)


In [121]:
text_embeddings  = output['last_hidden_state']

In [122]:
image_embed.shape

torch.Size([1, 577, 768])

In [123]:
text_embeddings.shape

torch.Size([1, 7, 768])

In [137]:
from sklearn.metrics.pairwise import cosine_similarity
similarity_scores = cosine_similarity(image_embed[0,0].detach().numpy().reshape(1, -1), text_embeddings[0,0].detach().numpy().reshape(1, -1))

In [138]:
similarity_scores

array([[0.02134769]], dtype=float32)