In [None]:
#imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms.functional as TF
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import math
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import spacy
from torch.nn.utils.rnn import pad_sequence

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
spacy_eng = spacy.load("en_core_web_sm") #spacy for tokenisation

In [None]:
#vocabulary Class
class Vocabulary:
  def __init__(self, freq_threshold):
    self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
    self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
    self.freq_threshold = freq_threshold

  def __len__(self):
    return len(self.itos)

  @staticmethod #To avoid the tokeniser_eng taking two args
  def tokenizer_eng(text):
    return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

  def build_vocabulary(self, sentence_list):
    frequencies = {}
    idx = 4 # starts at 4 0123 are pseu

    for sentence in sentence_list:
      for word in self.tokenizer_eng(sentence):
        if word not in frequencies:
          frequencies[word] = 1

        else:
          frequencies[word] += 1

        if frequencies[word] == self.freq_threshold:
          self.stoi[word] = idx
          self.itos[idx] = word
          idx += 1

  def numericalize(self, text):
    tokenized_text = self.tokenizer_eng(text)

    return [
        self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
        for token in tokenized_text
        ]

In [None]:
#Dataset Class
class FlickrDataset(Dataset):
  def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
    self.root_dir = root_dir
    self.df = pd.read_csv(captions_file)
    self.transform = transform

    self.imgs = self.df["image"]
    self.captions = self.df["caption"]

    self.vocab = Vocabulary(freq_threshold)
    self.vocab.build_vocabulary(self.captions.tolist())

  def __len__(self):
    return len(self.df)

  def __getitem__(self, index):
    caption = self.captions[index]
    img_id = self.imgs[index]
    img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

    if self.transform is not None:
      img = self.transform(img)

    numericalized_caption = [self.vocab.stoi["<SOS>"]]
    numericalized_caption += self.vocab.numericalize(caption)
    numericalized_caption.append(self.vocab.stoi["<EOS>"])

    return img, torch.tensor(numericalized_caption)

In [None]:
#Masking of words
class MyCollate:
  def __init__(self, pad_idx):
    self.pad_idx = pad_idx

  def __call__(self, batch):
    imgs = [item[0].unsqueeze(0) for item in batch]
    imgs = torch.cat(imgs, dim=0)
    targets = [item[1] for item in batch]
    targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

    return imgs, targets

In [None]:
#Loader
def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=8,
    num_workers=8,
    shuffle=True,
    pin_memory=True,
):
  dataset = FlickrDataset(root_folder, annotation_file, transform=transform)

  pad_idx = dataset.vocab.stoi["<PAD>"]

  loader = DataLoader(
      dataset=dataset,
      batch_size=batch_size,
      num_workers=num_workers,
      shuffle=shuffle,
      pin_memory=pin_memory,
      collate_fn=MyCollate(pad_idx=pad_idx),
    )

  return loader, dataset

In [None]:
#Encoder class. Used pretrained Inception Model
class EncoderCNN(nn.Module):
  def __init__(self, embed_size, train_CNN= False):
    super(EncoderCNN, self).__init__()
    self.train_CNN = train_CNN
    self.inception = models.inception_v3(pretrained = True, aux_logits = True)
    self.inception.aux_logits = False
    self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(0.2)

  def forward(self, images):
    features = self.inception(images)

    for name, param in self.inception.named_parameters():
      if "fc.weight" in name or "fc.bias" in name:
        param.requires_grad = True

      else:
        param.requires_grad = self.train_CNN
    return self.dropout(self.relu(features))


In [None]:
#Decoder model with attention
class DecoderRNNAttn(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNNAttn, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.attention = nn.Linear(hidden_size, hidden_size)
        self.attention_combine = nn.Linear(hidden_size * 2, hidden_size)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, features, captions):
        embeddings = self.embed(captions)
        lstm_out, _ = self.lstm(embeddings)
        attention_weights = self.softmax(self.attention(lstm_out))
        attention_applied = attention_weights * lstm_out
        attention_combined = torch.cat((lstm_out, attention_applied), dim=2)
        output = self.attention_combine(attention_combined)
        output = self.fc(output)
        return output

In [None]:
#CNN to RNN connect
class CNNtoRNN(nn.Module):
  def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
    super(CNNtoRNN, self).__init__()
    self.encoderCNN = EncoderCNN(embed_size)
    self.decoderRNN = DecoderRNNAttn(embed_size, hidden_size, vocab_size, num_layers)

  def forward(self, images, captions):
    features = self.encoderCNN(images)
    output = self.decoderRNN(features, captions)
    return output

def caption_image(self,image, vocabulary, max_len=30): #Word limit set to 30
  result_caption = []

  with torch.no_grad():
    x = self.encoderCNN(image).unsqueeze(0)
    states = None

    for _ in range (max_len):
      hiddens, states = self.decoderRNNAttn.lstm(x, states)
      output = self.decoderRNNAttn.linear(hiddens.squeeze(0))
      predicted = output.argmax(1) #output with max prob

      result_caption.append(predicted.item())
      x = self.decoderRNN.embed(predicted).unsqueeze(0)

      if vocabulary.itos[predicted.item()] == "<EOS>":
        break

    return [vocabulary.itos[idx] for idx in result_caption] #convert to word

In [None]:
#training
transform = transforms.Compose(
      [
          transforms.Resize((356, 356)),
          transforms.RandomCrop((299, 299)),
          transforms.ToTensor(),
          transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
      ]
  )

train_loader, dataset = get_loader(
      root_folder = "/content/drive/MyDrive/Image Captioning/flickr8k/images",
      annotation_file = "/content/drive/MyDrive/Image Captioning/flickr8k/captions.txt",
      transform = transform,
      num_workers = 2,
  )

torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embed_size = 256
hidden_size = 256
vocab_size = len(dataset.vocab)
num_layers = 1
learning_rate = 1e-4
num_epochs = 10
max_len = 30
CUDA_LAUNCH_BLOCKING=1
writer = SummaryWriter("runs/flickr")
step = 0
batch_size = 8

model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index = dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr = learning_rate)

model.train()

for epoch in range(num_epochs):
  for idx, (imgs, captions) in enumerate(train_loader):
    imgs = imgs.to(device)
    captions = captions.to(device)

    outputs = model(imgs, captions[:]) #[8,2994]
    loss = criterion(outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1))

    writer.add_scalar("Training loss", loss.item(), global_step = step)
    step += 1

    optimizer.zero_grad()
    loss.backward(loss)
    optimizer.step()