In [None]:
!pip install img2dataset

In [None]:
path_json = '../input/guie-laion5b-dataset/GUIE_laion5b_dataset_en.json'
!ls $path_json

In [None]:
from typing import Dict, List, Callable
from pathlib import Path
from PIL import Image

import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [None]:
df = pd.read_json(path_json)
df['text'] = df['caption']
df['text_en'] = df['caption_en']

df.drop(columns=['caption', 'caption_en'], inplace=True)

df_head = df.head(500)
df_head.to_json('./fixed_df.json', orient='records')

In [None]:
out_dir = Path('./dataset')
out_dir.mkdir(exist_ok=True)

!img2dataset "fixed_df.json" --input_format="json" --caption_col="text_en" --output_folder=$out_dir --processes_count=1 --output_format="files" --resize_mode="no"

In [None]:
def get_image_files_dict(base_path: Path) -> Dict:
    image_files = [
        *base_path.glob("**/*.png"),
        *base_path.glob("**/*.jpg"),
        *base_path.glob("**/*.jpeg"),
        *base_path.glob("**/*.bmp"),
    ]
    return {image_file.stem: image_file for image_file in image_files}


def get_text_files_dict(base_path: Path) -> Dict:
    text_files = [*base_path.glob("**/*.txt")]
    return {text_file.stem: text_file for text_file in text_files}


def get_shared_stems(image_files_dict: Dict, text_files_dict: Dict) -> List:
    image_files_stems = set(image_files_dict.keys())
    text_files_stems = set(text_files_dict.keys())
    return list(image_files_stems & text_files_stems)


class TextImageDataset(Dataset):
    """Dataset for text-image pairs"""

    def __init__(
            self,
            root: str,
            preprocess: Callable = None,
            tokenizer: Callable = None,
    ):
        super().__init__()
        self.root = Path(root)
        self.preprocess = preprocess
        self.tokenizer = tokenizer

        self.image_files_dict = get_image_files_dict(self.root)
        self.text_files_dict = get_text_files_dict(self.root)
        self.shared_stems = get_shared_stems(self.image_files_dict, self.text_files_dict)

    def __len__(self):
        return len(self.shared_stems)

    def get_caption(self, text_file: Path):
        with open(text_file, 'r') as f:
            return f.read().strip()

    def __getitem__(self, i: int):
        stem = self.shared_stems[i]
        # read image
        image_file = self.image_files_dict[stem]
        image = Image.open(image_file).convert("RGB")

        # read text
        text_file = self.text_files_dict[stem]
        text = self.get_caption(text_file)

        # preprocess image and text
        if self.preprocess:
            image = self.preprocess(image)
        
        if self.tokenizer:
            text = self.tokenizer(text)
        return image, text

In [None]:
trans = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor()
])

dataset = TextImageDataset(out_dir, trans)
for i,_ in enumerate(dataset):
    print(i)
    break

In [None]:
img, caption = dataset[16]

plt.figure()
plt.title(caption)
plt.imshow(img.numpy().transpose(1, 2, 0))
plt.axis('off')

In [None]:
dataloader = DataLoader(dataset, batch_size=16)

for idx,(img, caption) in enumerate(dataloader):
    print(f'For {idx}', img.shape, caption, end='\n\n')
    break
# #   img = img.squeeze().numpy().transpose(1, 2, 0)
#     print(img.shape)
#     caption = caption[0]
    
#     plt.figure()
#     plt.title(caption)
#     plt.imshow(img)
#     plt.show()
#     break

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import requests
import io
import time
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from torchvision import transforms
from PIL import Image

In [None]:
df = pd.read_json('/kaggle/input/guie-laion5b-dataset/GUIE_laion5b_dataset_en.json')
df = df.loc[:, ['url', 'caption_en']]
df = df.iloc[:20_000]

# Building vocabulary

PAD_token = 0   # Used for padding short sentences
SOS_token = 1   # Start-of-sentence token
EOS_token = 2   # End-of-sentence token
CLS_token = 3

class Vocabulary():
    def __init__(self, name):
      self.name = name
      self.word2index = {}
      self.word2count = {}
      self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS", CLS_token: "CLS"}
      self.num_words = 3
      self.num_sentences = 0
      self.longest_sentence = 0
    
    def add_word(self, word):
      if word not in self.word2index:
        # First entry of word into vocabulary
        self.word2index[word] = self.num_words
        self.word2count[word] = 1
        self.index2word[self.num_words] = word
        self.num_words += 1
      else:
        # Word exists; increase word count
        self.word2count[word] += 1
    
    def add_sentence(self, sentence):
      sentence_len = 0
      for word in sentence.split(' '):
        sentence_len += 1
        self.add_word(word)
        if sentence_len > self.longest_sentence:
        # This is the longest sentence
            self.longest_sentence = sentence_len
      # Count the number of sentences
        self.num_sentences += 1
    
    def to_word(self, index):
      return self.index2word[index]

    def to_index(self, word):
      return self.word2index[word]


clipVocab = Vocabulary('CLIP')

In [None]:
for ind in df.index:
    clipVocab.add_sentence(df['caption_en'][ind])
    if ind%5000 == 0:
        print(f'Reached {ind}')
del df
print(f'{clipVocab.num_words} in Vocabulary')

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.nn.functional as F


class SelfAttention(nn.Module):
  def __init__(self, embed_size, heads):
    super().__init__()
    self.embed_size = embed_size
    self.heads = heads
    self.head_dim = embed_size // heads

    assert (self.head_dim * heads == embed_size), 'Embed size needs to be divisible by number of heads'

    # Each head will get keys, values and queries
    self.values = nn.Linear(self.head_dim, self.head_dim, bias = False)
    self.keys = nn.Linear(self.head_dim, self.head_dim, bias = False)
    self.queries = nn.Linear(self.head_dim, self.head_dim, bias = False)
    self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

  def forward(self, values, keys, query, mask):
    N = query.shape[0] # num of training samples

    # these vectors will be of same length as source/target sentence
    value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

    values = values.reshape(N, value_len, self.heads, self.head_dim)
    queries = query.reshape(N, query_len, self.heads, self.head_dim)
    keys = keys.reshape(N, key_len, self.heads, self.head_dim)

    values = self.values(values)
    keys = self.keys(keys)
    queries = self.queries(queries)

    energy = torch.einsum('nqhd,nkhd->nhqk', [queries, keys])
    # query shape: [num_of_samples, query_len, heads, heads_dim]
    # keys shape: [num_of_samples, key_len, heads, heads_dim]
    # energy shape: [N, heads, query_len, key_len]

    if mask is not None:
        energy = energy.masked_fill(mask, float("-1e28"))

    attention = torch.softmax(energy / (self.embed_size ** (0.5)), dim = 3)

    out = torch.einsum('nhql,nlhd->nqhd', [attention, values]).reshape(N, query_len, self.heads*self.head_dim)
    # attention shape: (N, heads, query_len, key_len)
    # values shape: (N, value_len, heads, heads_dim)
    # (N, query_len, heads, head_dim)

    out = self.fc_out(out)
    return out


class TransformerBlock(nn.Module):
  def __init__(self, embed_size, heads, dropout, forward_expansion):
    super().__init__()
    self.attention = SelfAttention(embed_size, heads)
    self.norm1 = nn.LayerNorm(embed_size)
    self.norm2 = nn.LayerNorm(embed_size)

    self.feed_forward = nn.Sequential(
        nn.Linear(embed_size, forward_expansion * embed_size),
        nn.ReLU(),
        nn.Linear(forward_expansion*embed_size, embed_size)
    )

    self.dropout = nn.Dropout(dropout)

  def forward(self, value, key, query, mask):

    attention = self.attention(value, key, query, mask)

    x = self.dropout(self.norm1(attention + query))
    forward = self.feed_forward(x)
    out = self.dropout(self.norm2(forward + x))
    return out

class Encoder(nn.Module):
  def __init__(
      self,
      src_vocab_size,
      embed_size,
      num_layers,
      heads,
      device,
      forward_expansion,
      dropout,
      max_length
  ):

    super().__init__()
    self.embed_size = embed_size
    self.device = device
    
   
    self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
    self.positional_embedding = nn.Embedding(max_length, embed_size)

    self.layers = nn.ModuleList(
        [
            TransformerBlock(
                embed_size,
                heads,
                dropout=dropout,
                forward_expansion=forward_expansion
            )
          for _ in range(num_layers)
        ]
    )
    self.final_text_embed = 256
    self.dropout = nn.Dropout(dropout)
    self.mlp = nn.Sequential(
        nn.Linear(embed_size, self.final_text_embed)    # Only using linear because we handle softmax outside the transformer
    )

  def forward(self, x, mask):
    N, seq_length = x.shape     # (num_samples, num_seq_length)
    # print(f'The shape of forward in Encoder input is {x.shape}')
    positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
    word_embed = self.word_embedding(x)     # (7, 21, 256)
    pos_embed = self.positional_embedding(positions)     # (7, 21, 256)
    out = word_embed + pos_embed

    out = self.dropout(out)

    for layer in self.layers:
      out = layer(out, out, out, mask) # Passing out into the key, query and value of the transformer block

    out = self.mlp(out[:, 0])   # Taking the 256 dimension CLS token from all the sentences and putting it into an MLP
    # print(f'Shape of out is {out.shape}')
    return out      # (7, 256)

class ViT(nn.Module):
  def __init__(
    self,
    chw,
    num_of_patches,
    forward_expansion=4,
    embed_size=64,
    num_layers=6,
    heads=2,
    dropout=0,
    device='cuda'
  ):

    super().__init__()

    self.num_of_patches = num_of_patches
    self.chw = chw
    self.patch_size_h = self.chw[1]//num_of_patches
    self.patch_size_w = self.chw[2]//num_of_patches
    self.hidden_size = embed_size   # 8

    self.fc1 = nn.Linear(self.chw[0] * self.patch_size_h * self.patch_size_w, self.hidden_size)

    self.class_token = nn.Parameter(torch.rand(1, self.hidden_size))    # Random tensor of shape (1, hidden-size -> 8)
    self.pos_embed = nn.Parameter(torch.tensor(get_positional_embeddings(self.num_of_patches ** 2 + 1, self.hidden_size)))  # Function is along with make_patches()
    self.pos_embed.requires_grad = False

    self.layers = nn.ModuleList(
        [
            TransformerBlock(
                embed_size,
                heads,
                dropout=dropout,
                forward_expansion=forward_expansion
            )
          for _ in range(num_layers)
        ]
    )
    self.final_out = 256 # Could be number of samples of image-text pairs------
    self.mlp = nn.Sequential(
            nn.Linear(self.hidden_size, self.final_out),
            # nn.Softmax(dim=-1) -- Removing the softmax
        )
  def forward(
    self,
    image_inp
  ):

    img_width = image_inp.shape[-1]
    img_height = image_inp.shape[-2]
    assert img_width%self.num_of_patches == 0, 'Image width not divisible by number of patches'
    assert img_height%self.num_of_patches == 0, 'Image height not divisible by number of patches'

    patches = make_patches(image_inp, self.num_of_patches).to(device=device) # (num_of_imgs, num_of_patches, patch_h * patch_w * num_channels)
    tokens = self.fc1(patches)  # patch_h * patch_w * num_channels -> hidden_size

    # [num_of_images, num_of_patches, hidden_size(size we reduced to)]
    tokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])
    # A class token is added to each patch, to all 8 dims simultaenously
    # [num_of_images, num_of_patches + 1 (class token), hidden_size]

    positions = self.pos_embed.repeat(image_inp.shape[0], 1, 1)
    tokens += positions
    for layer in self.layers:
       out = layer(tokens, tokens, tokens, None)  # -> Putting tokens in query, key and value of transformer block ------- # What to do with mask -> We don't use masking in CLIP

    out = self.mlp(out[:, 0])  # torch.Size([7, 101, 64]) -> Taking the 64 dimension CLS token of all the images and passing it through MLP
    return out  # (7, 256)     # We could have either taken the CLS token or performed global average pooling

In [None]:
class ClipTransformer(nn.Module):
  def __init__(
      self,
      src_vocab_size,
      src_pad_idx,
      forward_expansion,
      embed_size=768,
      num_layers=6,
      heads=8,
      dropout=0,
      device= 'cuda',
      max_length=76
  ):
    super().__init__()

    self.encoder = Encoder(
        src_vocab_size,
        embed_size,
        num_layers,
        heads,
        device,
        forward_expansion,
        dropout,
        max_length
    )

    self.src_pad_idx = src_pad_idx
    self.device = device

  def make_src_mask(self, src):

    src_mask = (src == self.src_pad_idx).unsqueeze(1).unsqueeze(2)  # True wherever there's a masked index, false everywhere else
    return src_mask.to(self.device)

  def forward(self, src):
    src = src.to(device=device) # # (num_of_samples, max_seq_length (before CLS token))
    n = src.shape[0]
    tok = torch.full((n,1), 3).to(device=device)
    
    src = torch.cat((tok, src), 1) # Adding a CLS token to the sentences
    src_mask = self.make_src_mask(src)  # Boolean mask that has True wherever there's a value and False wherever there's the pad-index value, i.e 0 in our case
    enc_src = self.encoder(src, src_mask)       # (num_of_samples, hidden_dim)
    return enc_src

In [None]:
def make_patches(image_inp, num_of_patches):
    n, c, h, w = image_inp.shape
    # 2, 1, 50, 100
    patch_size_h = h//num_of_patches    # 50/10 = 5
    patch_size_w = w//num_of_patches    # 100/10 = 10

    patches = torch.zeros(n, num_of_patches**2, h * w * c // num_of_patches ** 2) # (N, patches, patch dimensionality)      # (2, 100, 50)
    for idx, each_img in enumerate(image_inp):
      for i in range(num_of_patches):
        for j in range(num_of_patches):
          patch = each_img[:, i*patch_size_h: (i+1)*patch_size_h, j*patch_size_w:(j+1)*patch_size_w]  # (channels, width, height)   # patch of (1, 0:5, 10:20)
          patches[idx, i*num_of_patches + j] = patch.flatten()
    return patches

def get_positional_embeddings(sequence_length, depth):
    out = torch.ones(sequence_length, depth)

    for i in range(sequence_length):
        for j in range(depth):
          if j % 2 == 0:
            out[i][j] = np.sin(i / (10000 ** (j / depth)))
          else:
            out[i][j] = np.cos(i / (10000 ** ((j - 1) / depth)))
    return out

In [None]:
import os
import gc
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(clipVocab.num_words)
model1 = ViT((3, 224, 224), 8).to(device=device)   # (chw, num_of_patches)
model2 = ClipTransformer(clipVocab.num_words, 0, 4).to(device=device)   # (Vocab_size, pad_idx, forward_expansion)
device

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = Adam([
    {'params': model1.parameters()},
    {'params': model2.parameters()}
])

In [None]:
!pip install GPUtil
from GPUtil import showUtilization as gpu_usage
from numba import cuda

def free_gpu_cache():                   
    torch.cuda.empty_cache()

In [None]:
num_epochs = 32
max_seq_length = 75
temperature = 0.07
T_max = 32

for epoch in range(num_epochs):
    for batch_idx, (img, captions) in enumerate(dataloader):
        sent = [i.split() for i in captions]
        
        try:
            tokens = [[1] + [clipVocab.to_index(word) for word in each_sent] + [2] for each_sent in sent]
        except:
            continue
        
        for idx, token in enumerate(tokens):
            if len(token) >= max_seq_length:
                token = token[:max_seq_length]
                token[-1] = 2
                tokens[idx] = torch.Tensor(token)
            if len(token) < max_seq_length:
                diff = max_seq_length - len(token)
                token = token + [0]*diff
                tokens[idx] = torch.Tensor(token)
        
        tokens = torch.vstack(tokens).to(torch.int64).to(device=device)
        img = img.to(device=device)
        
        I_f = model1(img)
        T_f = model2(tokens)
        
        logits = ((T_f @ I_f.T) * np.exp(temperature)).to(device=device)
        
        img_t = I_f @ I_f.T
        text_t = T_f @ T_f.T
        
        targets = F.softmax(
            (img_t + text_t) / 2 * np.exp(temperature), dim=-1
        )
        print('This is what we have \n', softs, 'This is what we\'re going for', targets)
#         targets = torch.eye(logits.shape[0]).to(device=device)
        softs = F.softmax(logits, dim=1)
        
        img_loss = criterion(logits, targets)
        text_loss = criterion(logits.T, targets.T)
        
        loss = (img_loss + text_loss)/2.0
        loss = loss.mean()
        
        optimizer.zero_grad()
        print(f'Loss: {loss}')
        loss.backward()
        optimizer.step()
        
        free_gpu_cache()
        break
    print(f'Reached epoch: {epoch+1}/{num_epochs}')

In [None]:
def find_img_matches(inp_sent):
    sent = inp_sent.split()
    max_seq_length = 75
    
    dataloader = DataLoader(dataset, batch_size=16)
    try:
        tokens = [1] + [clipVocab.to_index(word) for word in sent] + [2]
    except:
        print('Key error!')
        return
    
    if len(tokens) >= max_seq_length:
        tokens = tokens[:max_seq_length]
        tokens[-1] = 2
    if len(tokens) < max_seq_length:
        diff = max_seq_length - len(tokens)
        tokens = tokens + [0]*diff
    
    tokens = torch.Tensor(tokens).unsqueeze(0).to(torch.int64).to(device=device)
    
    best_img = -1
    highest_score = 0
    T_f = model2(tokens)
    
    for idx, (img, caption) in enumerate(dataloader):
        with torch.no_grad():
            I_f = model1(img)
            logits = ((I_f @ T_f.T) * np.exp(temperature)).to(device=device)
            softs = F.softmax(logits, dim=0)
            if torch.max(logits[0]) > highest_score:
                print(f'Got match at {idx*16 + torch.argmax(logits[0])}')
                highest_score = torch.max(logits[0])
                best_img = torch.argmax(logits[0])
                img_prev = img[best_img].squeeze().numpy().transpose(1, 2, 0)
                plt.figure()
                plt.title(caption)
                plt.imshow(img_prev)
                plt.show()

In [None]:
find_img_matches('Gladiator')