<a href="https://colab.research.google.com/github/Aiden-Ross-Dsouza/Natural-Language-Processing-IvLabs/blob/master/image_caption_generation/notebooks/Image_Captioning_using_transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Importing Dataset from Kaggle into Colab's Session Storage

In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/

cp: cannot stat 'kaggle.json': No such file or directory


In [None]:
!kaggle datasets download -d adityajn105/flickr8k

Dataset URL: https://www.kaggle.com/datasets/adityajn105/flickr8k
License(s): CC0-1.0
Downloading flickr8k.zip to /content
100% 1.04G/1.04G [00:18<00:00, 55.4MB/s]
100% 1.04G/1.04G [00:18<00:00, 60.8MB/s]


# Extracting dataset from zipped folder

In [None]:
import zipfile
zip_ref = zipfile.ZipFile('/content/flickr8k.zip', 'r')
zip_ref.extractall('/content')
zip_ref.close()

# Importing Required Libraries

In [None]:
import os
import math
import torch
import random
import torchvision
import numpy as np
import pandas as pd
import torch.nn as nn
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from torchvision import models, transforms
from torch.utils.data import Dataset , DataLoader

In [None]:
if torch.cuda.is_available():
  dev = "cuda:0"
else:
  dev = "cpu"
device = torch.device(dev)
print(f"Using {device} device")

Using cpu device


# Preprocess Images

In [None]:
preprocess = transforms.Compose([
                                 transforms.ToTensor(),
                                 transforms.Resize(256),
                                 transforms.CenterCrop(224)
                                 ])

# Dataset

In [None]:
class Flickr(Dataset):

  def __init__(self, img_path, txt_path, transform = None):
    self.sent_len = 0
    self.img_lst = []
    self.cap_lst = []
    self.img_path = img_path
    self.txt_path = txt_path
    self.transform = transform
    self.text = pd.read_csv(txt_path)
    self.imag_lst = self.text['image']
    self.capt_lst = self.text['caption']
    for i in range(0, 40455, 5):
      self.img_lst.append(self.imag_lst[i])
      self.cap_lst.append(self.capt_lst[i])
    self.img_lst = pd.Series(self.img_lst)
    self.cap_lst = pd.Series(self.cap_lst)

    for i in range(len(self.cap_lst)):
      self.cap_lst[i] = ' <SOS> ' + self.cap_lst[i] + ' <EOS> '
      self.cap_lst[i] = self.cap_lst[i].split()
      if len(self.cap_lst[i]) > self.sent_len:
        self.sent_len = len(self.cap_lst[i])

    self.vocab = [w for sent in self.cap_lst for w in sent]
    self.vocab = set(self.vocab)
    self.vocab.add('<PAD>')
    self.vocab_size = len(self.vocab)
    self.word_to_index = {w: idx for (idx, w) in enumerate(self.vocab)}
    self.index_to_word = {idx: w for (idx, w) in enumerate(self.vocab)}

    for i in range(len(self.cap_lst)):
      for j in range(len(self.cap_lst[i])):
        self.cap_lst[i][j] = self.word_to_index[self.cap_lst[i][j]]
      for k in range(len(self.cap_lst[i]), self.sent_len):
        self.cap_lst[i].append(self.word_to_index['<PAD>'])
    self.cap_lst = torch.tensor(self.cap_lst)

  def __len__(self):
    return len(self.img_lst)

  def __getitem__(self, index):
    imag_path = os.path.join(self.img_path, self.img_lst[index])
    img = Image.open(imag_path).convert("RGB")
    cap = self.cap_lst[index]
    if self.transform:
      img = self.transform(img)
    return (img,cap)

In [None]:
params = {'batch_size': 80,
          'shuffle': True,
          'num_workers': 2,
          'drop_last': True}

# Data Loader

In [None]:
dataset = Flickr('/content/Images', '/content/captions.txt', preprocess)
data = DataLoader(dataset, **params)

# Base Model

In [None]:
baseModel = models.resnet18(weights='DEFAULT')

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 107MB/s]


In [None]:
baseModel = torch.nn.Sequential(*(list(baseModel.children())[0:7]))

In [None]:
print(baseModel)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


# Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
  def __init__(self, dim_model, max_len):
    super(PositionalEncoding, self).__init__()

    pos_encoding = torch.zeros((max_len, dim_model))
    pos_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1)
    div_term = torch.pow(10000, torch.arange(0, dim_model, 2, dtype=torch.float) / dim_model)

    pos_encoding[:, 0::2] = torch.sin(pos_list / div_term)

    pos_encoding[:, 1::2] = torch.cos(pos_list / div_term)

    pos_encoding = torch.unsqueeze(pos_encoding, 0)
    self.register_buffer("pos_encoding", pos_encoding)

  def forward(self, tok_embedding: torch.tensor) -> torch.tensor:
    return tok_embedding + self.pos_encoding[:, :tok_embedding.size(1), :]

# Define Encoder

In [None]:
class EncoderCNN(nn.Module):
  def __init__(self, embedding_dim, enc_cnn):
    super(EncoderCNN, self).__init__()
    self.encoder_cnn = enc_cnn
    self.fc = nn.Linear(256, embedding_dim)

  def forward(self, images):
    features = self.encoder_cnn(images)
    features = torch.flatten(features, start_dim=2)
    features = features.permute(0, 2, 1)
    features = F.relu(self.fc(features))
    return features

# Define Decoder

In [None]:
class DecoderTrn(nn.Module):
  def __init__(self,  embedding_dim, vocab_size, n_head, numlayers):
    super(DecoderTrn, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embedding_dim)
    self.positional_encoder = PositionalEncoding(dim_model=256, max_len=50)
    self.decoder_layer = nn.TransformerDecoderLayer(d_model=embedding_dim, nhead=n_head,
                                                    dim_feedforward=1024, batch_first=True)
    self.transformer_decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=numlayers)
    self.linear = nn.Linear(embedding_dim, vocab_size)

  def forward(self, trg, mem, pad_idx):
    tgt = self.embedding(trg)
    tgt = self.positional_encoder(tgt)
    out = self.transformer_decoder(tgt, mem,
    tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1), device=tgt.device),
    tgt_key_padding_mask = (trg == pad_idx))
    out = self.linear(out)
    return F.log_softmax(out, dim=-1)

# Initialize encoder and decoder

In [None]:
encoder = EncoderCNN(256, baseModel)
decoder = DecoderTrn(256, dataset.vocab_size, 8, 4)

In [None]:
encoder = encoder.to(device)
decoder = decoder.to(device)

In [None]:
for name, param in encoder.encoder_cnn.named_parameters():
  param.requires_grad = False

In [None]:
optimizer = torch.optim.Adam([
    {'params': encoder.fc.parameters(), 'lr': 1e-5},
    {'params': decoder.parameters(), 'lr': 5e-5}
    ])
criterion = nn.NLLLoss(ignore_index=dataset.word_to_index["<PAD>"])

# Train

In [None]:
def train_loop(enc, dec, opt, device, loss_fn, dataloader):
  enc.train(), dec.train()
  global curr_batch, total_loss
  curr_batch, total_loss = 0, 0
  pad_ind = loss_fn.ignore_index

  for batch in dataloader:
    x, y = batch[0], batch[1]
    x, y = x.to(device), y.to(device)

    X = enc(x)

    y_input = y[:,:-1]
    y_expected = y[:,1:]

    pred = dec(y_input, X, pad_ind)
    pred = pred.permute(0, 2, 1)

    loss = loss_fn(pred, y_expected)

    opt.zero_grad()
    loss.backward()
    opt.step()

    total_loss += loss.detach().item()
    curr_batch += 1

  epoch_loss = total_loss / len(dataloader)
  epoch_perplexity = math.exp(epoch_loss)

  return epoch_loss, epoch_perplexity

In [None]:
def fit(enc, dec, opt, epochs, device, loss_fn, dataloader):
  train_loss_list = []
  train_perplexity_list = []
  global curr_epoch, train_loss
  curr_epoch, train_loss = 0, 0

  print("Training Transformer Model")
  for epoch in range(epochs):

    train_loss, train_perplexity = train_loop(enc, dec, opt, device, loss_fn, dataloader)
    train_loss_list += [train_loss]
    train_perplexity_list  += [train_perplexity]
    curr_epoch += 1
    print(f"Epoch: {epoch+1}, Training loss: {train_loss:.4f}, Training perplexity: {train_perplexity:.4f}")

  return train_loss_list, train_perplexity_list

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
loss_func, perplexity = fit(encoder, decoder, optimizer, 25, device, criterion, data)

Training Transformer Model


RuntimeError: mat1 and mat2 shapes cannot be multiplied (15680x1024 and 2048x256)

# Train Loss and Perplexity

In [None]:
plt.plot(loss_func, 'o:r')
plt.title("Loss")
plt.show()

In [None]:
plt.plot(perplexity, 'o:r')
plt.title("Perplexity")
plt.show()

# Test

In [None]:
def predict(enc, dec, img, max_len=dataset.sent_len, SOS_token=dataset.word_to_index["<SOS>"], EOS_token=dataset.word_to_index["<EOS>"]):
  enc.eval(), dec.eval()
  y_input = torch.tensor([[SOS_token]], dtype=torch.long, device=device)

  for _ in range(max_len):
    with torch.no_grad():
      X = enc(img)
      pred = dec(y_input, X, dataset.word_to_index["<PAD>"])

    next_item = pred.topk(1)[1].view(-1)[-1].item()
    next_item = torch.tensor([[next_item]], device=device)

    y_input = torch.cat((y_input, next_item), dim=1)

    if next_item.view(-1).item() == EOS_token:
      break

  y_input = y_input.view(-1).tolist()
  sentence = [dataset.index_to_word[id] for id in y_input]
  return sentence

In [None]:
n = random.randint(0,8090)
path = os.path.join('/content/Images', dataset.img_lst[n])
imag = Image.open(path).convert("RGB")
img = preprocess(imag)
img = img.to(device).unsqueeze(0)
caption = predict(encoder, decoder, img)
plt.imshow(imag)
plt.title("_".join(caption))
plt.axis('off')