<a href="https://colab.research.google.com/github/G0nkly/pytorch_sandbox/blob/main/vits/vlms/nanoVLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
## Imports

In [None]:
import math, random
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt

In [None]:
## Variables

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 32
EMBED_DIM = 32
ATTENTION_HEADS = 4
BATCH_SIZE = 12
EPOCHS = 10
LR = 3e-4
TEMPERATURE = 0.07

In [None]:
## Synthetic Dataset

In [None]:
colors = ["red", "green", "blue", "yellow", "purple", "orange", "pink", "brown", "gray"]
shapes = ["square", "circle", "triangle"]
positions = ["left", "center", "right", "top", "bottom", "top-left", "top-right", "bottom-left", "bottom-right"]

In [None]:
### Drawing image shapes

In [None]:
def draw_sample(color, shape, position, img_size=IMG_SIZE):
  img = Image.new("RGB", (img_size, img_size), "white")
  draw = ImageDraw.Draw(img)
  margin = 6
  w = h = img_size - 2 * margin

  # Calculate the coordinates
  if "left" in position:
    x0 = margin
    x1 = margin + w // 2
  elif "top-left" in position:
    x0 = margin
    x1 = margin + w // 2
  elif "bottom-left" in position:
    x0 = margin
    x1 = margin + w // 2
  elif "right" in position:
    x0 = margin + w // 2
    x1 = img_size - margin
  elif "top-right" in position:
    x0 = margin + w // 2
    x1 = img_size - margin
  elif "bottom-right" in position:
    x0 = margin + w // 2
    x1 = img_size - margin
  else:
    x0 = margin + w // 4
    x1 = margin + h // 2


  # Calculate y coordinates
  if "top" in position:
    y0 = margin
    y1 = margin + h // 2
  elif "top-left" in position:
    y0 = margin
    y1 = margin + h // 2
  elif "top-right" in position:
    y0 = margin
    y1 = margin + h // 2
  elif "bottom" in position:
    y0 = margin + h // 2
    y1 = img_size - margin
  elif "bottom-left" in position:
    y0 = margin + h // 2
    y1 = img_size - margin
  elif "bottom-right" in position:
    y0 = margin + h // 2
    y1 = img_size - margin
  else:
    y0 = margin + h // 4
    y1 = margin + 3 * h // 4

  if shape == "square":
    draw.rectangle([x0, y0, x1, y1], fill=color, outline="black")
  elif shape == "circle":
    draw.ellipse([x0, y0, x1, y1], fill=color, outline="black")
  else:
    draw.polygon([((x1+x0)//2, y0), (x0, y1), (x1, y1)], fill=color, outline="black")

  return img

In [None]:
## Class for building the dataset

In [None]:
class ShapesDataset():
  def __init__(self):
    self.images = []
    self.captions = []

    for c in colors:
      for s in shapes:
        for p in positions:
          img = draw_sample(c, s, p)
          cap = f"{c} {s} {p}"
          self.images.append(torch.from_numpy(np.asarray(img)).permute(2,0,1).float()/255.0)
          self.captions.append(cap)

    self.vocab, self.word2idx = self.build_vocab(self.captions)

  def build_vocab(self, texts):
    words = sorted({w for t in texts for w in t.split()})
    vocab = ["[CLS]"] + words
    w2i = {w:i for i,w in enumerate(vocab)}
    return vocab, w2i

  def __len__(self):
    return len(self.images)

  def encode_text(self, text):
    toks = [self.word2idx["[CLS]"]] + [self.word2idx[w] for w in text.split()]
    return torch.tensor(toks, dtype=torch.long)

  def __getitem__(self,  idx):
    return self.images[idx], self.encode_text(self.captions[idx]), self.captions[idx]

In [None]:
## Create Dataset

In [None]:
full_ds = ShapesDataset()
VOCAB_SIZE = len(full_ds.vocab)
print(VOCAB_SIZE)
print(full_ds.vocab)

In [None]:
## Train - Val data creation

In [None]:
train_size = int(0.8 * len(full_ds))
val_size = len(full_ds) - train_size
train_ds, val_ds = torch.utils.data.random_split(full_ds, [train_size, val_size])

In [None]:
## Dataloader

In [None]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
## Display a simple data point

In [None]:
imgs, encoded_caps, _ = next(iter(train_loader))
idx = random.randint(0, len(imgs) - 1)
img = (imgs[idx].permute(1,2,0).numpy() * 255).astype(np.uint8) # Convert to displayable image

# Decode the caption
caption_tokens = encoded_caps[idx].tolist()
caption = "".join([full_ds.vocab[token] for token in caption_tokens if token in range(len(full_ds.vocab))])
# Remove the [CLS] token from the displayed caption
caption = caption.replace("[CLS]", "")

plt.figure(figsize=(2.5, 2.5))
plt.imshow(img)
plt.title(caption, fontsize=8)
plt.axis("off")
plt.show()

In [None]:
## Image Encoder

In [None]:
class ImageEncoder(nn.Module):

  def __init__(self, embed_dim = EMBED_DIM):
    super().__init__()
    self.convolutions = nn.Sequential(
      nn.Conv2d(3, 32, 3, 2, 1),
      nn.ReLU(),
      nn.Conv2d(32, 64, 3, 2, 1),
      nn.ReLU(),
      nn.Conv2d(64, 128, 3, 2, 1),
      nn.ReLU(),
      nn.Conv2d(128, 256, 3, 2, 1)
    )

    self.projection = nn.Linear(256, embed_dim)
    self.norm = nn.LayerNorm(embed_dim)

  def forward(self, x):
    x = self.convolutions(x)
    x = x.mean(dim=[2,3])
    x = self.projection(x)
    x = F.normalize(self.norm(x), dim=-1)
    return x

In [None]:
## Text Encoder

In [None]:
class TextEncoder(nn.Module):

  def __init__(self, embed_dim = EMBED_DIM, num_heads = ATTENTION_HEADS, vocab_size = VOCAB_SIZE, context_window = 4):
    super().__init__()
    self.token_embedding = nn.Embedding(vocab_size, embed_dim)
    self.position_embedding = nn.Embedding(context_window, embed_dim)
    self.mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
    self.projection = nn.Linear(embed_dim, embed_dim)
    self.norm = nn.LayerNorm(embed_dim)

  def forward(self, toks):
   N, L = toks.shape
   position_embedding = torch.arange(L, device=toks.device).unsqueeze(0).expand(N, L)
   final_embedding = self.token_embedding(toks) + self.position_embedding(position_embedding)
   context_vectors = self.mha(final_embedding, final_embedding, final_embedding)[0]
   final_token = context_vectors[:,0]
   projection = self.projection(final_token)
   output = F.normalize(self.norm(projection), dim=-1)
   return output

In [None]:
## CLIP loss

In [None]:
def clip_loss(img_emb, txt_emb, temperature = TEMPERATURE):
  logits = img_emb @ txt_emb.t()
  targets = torch.arange(img_emb.size(0), device = img_emb.device)
  loss_i = F.cross_entropy(logits, targets)
  loss_t = F.cross_entropy(logits.t(), targets)
  return ((loss_i + loss_t) / 2.0)

In [None]:
## Model, data, optimizer

In [None]:
VOCAB_SIZE = len(full_ds.vocab)
img_enc = ImageEncoder().to(device)
txt_enc = TextEncoder().to(device)
params = list(img_enc.parameters()) + list(txt_enc.parameters())
optimizer = torch.optim.Adam(params, lr=LR)

In [None]:
## Before training embeddings

In [None]:
def show_image(t, title=None):
  img = (t.permute(1,2,0).numpy()*255).astype(np.uint8)
  plt.figure(figsize=(2.2,2.2))
  plt.axis("off")
  if title: plt.title(title, fontsize=8)
  plt.imshow(img); plt.show()

In [None]:
img_enc.eval(); txt_enc.eval()

with torch.no_grad():
  # Select a random index
  random_idx = random.randrange(len(full_ds))
  sample_img, sample_toks, sample_cap = full_ds[random_idx]
  sample_img = sample_img.unsqueeze(0).to(device)
  sample_toks = sample_toks.unsqueeze(0).to(device)
  pre_train_img_emb = img_enc(sample_img).squeeze(0).cpu().numpy()
  pre_train_txt_emb = txt_enc(sample_toks).squeeze(0).cpu().numpy()

# Display the same image and caption
print(f"Sample image and caption for embeddings visualization: '{sample_cap}'")
show_image(sample_img.squeeze(0).cpu())

def plot_embedding(embedding, title):
  plt.figure(figsize=(8,1))
  plt.imshow(embedding.reshape(1,-1), aspect="auto", cmap="viridis")
  plt.title(title)
  plt.axis("off")
  plt.show()

plot_embedding(pre_train_img_emb, "Pre-Training Image Embedding")
plot_embedding(pre_train_txt_emb, "Pre-Training Text Embedding")

In [None]:
  ## Training loop

In [None]:
best_val = float("inf")

for epoch in range(1, EPOCHS + 1):
  img_enc.train(); txt_enc.train()
  total = 0.0

  for imgs, toks, _ in train_loader:
    imgs = imgs.to(device); toks = toks.to(device)
    optimizer.zero_grad(set_to_none=True)
    ie = img_enc(imgs); te = txt_enc(toks)
    loss = clip_loss(ie, te)
    loss.backward()
    optimizer.step()
    total += loss.item() * imgs.size(0)
  train_loss = total / (len(train_loader) * BATCH_SIZE)

  # quick val
  img_enc.eval(); txt_enc.eval()
  with torch.no_grad():
    vtotal, n = 0.0, 0
    for imgs, toks, _ in val_loader:
      imgs = imgs.to(device); toks = toks.to(device)
      vtotal += clip_loss(img_enc(imgs), txt_enc(toks)).item()*imgs.size(0)
      n += imgs.size(0)
    val_loss = vtotal / n

  print(f"Epoch {epoch:02d} | train {train_loss:.4f} | val {val_loss:.4f}")
  best_val = min(best_val, val_loss)

In [None]:
## Embeddings after training

In [None]:
img_enc.eval(); txt_enc.eval()
with torch.no_grad():
  # use the same random index as before training
  sample_img, sample_toks, sample_cap = full_ds[random_idx]
  sample_img = sample_img.unsqueeze(0).to(device)
  sample_toks = sample_toks.unsqueeze(0).to(device)

  post_train_img_emb = img_enc(sample_img).squeeze(0).cpu().numpy()
  post_train_txt_emb = txt_enc(sample_toks).squeeze(0).cpu().numpy()

  # Display the sample image and caption
  # Display the same image and caption
print(f"Sample image and caption for embeddings visualization: '{sample_cap}'")
show_image(sample_img.squeeze(0).cpu())

plot_embedding(post_train_img_emb, "Post-Training Image Embedding")
plot_embedding(post_train_txt_emb, "Post-Training Text Embedding")

In [None]:
print("\nDot product between image and text embeddings:")
print(f" Before training: {np.dot(pre_train_img_emb, pre_train_txt_emb):.4f}")
print(f" After training: {np.dot(post_train_img_emb, post_train_txt_emb):.4f}")

In [None]:
## Build text bank for retrieval on val set

In [None]:
img_enc.eval(); txt_enc.eval()
with torch.no_grad():
  val_imgs, val_toks, val_caps = [], [], []
  for imgs, toks, caps in val_loader:
    val_imgs.append(imgs); val_toks.append(toks); val_caps += list(caps)
  val_imgs = torch.cat(val_imgs).to(device)
  val_toks = torch.cat(val_toks).to(device)
  img_emb = img_enc(val_imgs)
  txt_emb = txt_enc(val_toks)

In [None]:
## Retriever helper functions

In [None]:
def topk_text_for_images(k=3, idxs=None):
  if idxs is None: idxs = np.random.choice(len(val_caps), size=1, replace=False)
  sims = (img_emb @ txt_emb.t()).softmax(dim=1) # similarity as softmax
  for i in idxs:
    best = sims[i].topk(k).indices.tolist()
    print(f"\nImage {i} best captions:")
    for j in best:
      print("  -", val_caps[j])
    show_image(val_imgs[i].cpu())

def topk_images_for_text(k=3, idxs=None):
  if idxs is None: idxs = np.random.choice(len(val_caps), size=1, replace=False)
  sims = (txt_emb @ img_emb.t()).softmax(dim=1)
  for i in idxs:
    best = sims[i].topk(k).indices.tolist()
    print(f"\nText '{val_caps[i]}' best images:")
    for j in best:
      show_image(val_imgs[j].cpu(), title=f"match {val_caps[j]}")

In [None]:
topk_text_for_images(k=1)
topk_images_for_text(k=1)