# Download the dataset (flickr8k)

In [1]:
!python -m spacy download en_core_web_sm


Collecting en-core-web-sm==3.8.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl (12.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.8/12.8 MB[0m [31m124.6 MB/s[0m eta [36m0:00:00[0m
[?25h[38;5;2m✔ Download and installation successful[0m
You can now load the package via spacy.load('en_core_web_sm')
[38;5;3m⚠ Restart to reload dependencies[0m
If you are in a Jupyter or Colab notebook, you may need to restart Python in
order to load all the package's dependencies. You can do this by selecting the
'Restart kernel' or 'Restart runtime' option.


In [2]:
!pip install -q kaggle


In [3]:
# Upload your kaggle.json API key
from google.colab import files
files.upload()

Saving kaggle.json to kaggle.json


{'kaggle.json': b'{"username":"mailastro","key":"70bb5f56ab340cac467e9c97d300731c"}'}

In [4]:
# Move kaggle.json and set permissions
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [5]:
# Download dataset directly to /content
!kaggle datasets download -d adityajn105/flickr8k -p /content --unzip

Dataset URL: https://www.kaggle.com/datasets/adityajn105/flickr8k
License(s): CC0-1.0
Downloading flickr8k.zip to /content
 96% 0.99G/1.04G [00:05<00:00, 143MB/s]
100% 1.04G/1.04G [00:05<00:00, 202MB/s]


## We want to convert text --> numerical values

1.   We need a vocabulary mapping each word to an index
2.   We need to setup a Pytorch dataset to load the data
3.   Setup padding of every batch (all examples should be of the same seq_len and setup dataloader)



In [68]:
import os
import pandas as pd
import spacy # for tokenizer
import torch
from torch.nn.utils.rnn import pad_sequence # pad batch
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as transforms

In [69]:
spacy_eng = spacy.load("en_core_web_sm")


class Vocabulary:
  def __init__(self, freq_threshold):
    self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
    self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
    self.freq_threshold = freq_threshold

  def __len__(self):
    return len(self.itos)

  @staticmethod
  def tokenizer_eng(text):
    return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

  def build_vocabulary(self, sentence_list):
    frequencies = {}
    idx = 4

    for sentence in sentence_list:
      for word in self.tokenizer_eng(sentence):
        if word not in frequencies:
          frequencies[word] = 1
        else:
          frequencies[word] += 1

        if frequencies[word] == self.freq_threshold:
          self.stoi[word] = idx
          self.itos[idx] = word
          idx += 1

  def numericalize(self, text):
    tokenized_text = self.tokenizer_eng(text)

    return [
        self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
        for token in tokenized_text
    ]


class FlickrDataset(Dataset):
  def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
    self.root_dir = root_dir
    self.df = pd.read_csv(captions_file)
    self.transform = transform

    # Get img, caption columns
    self.imgs = self.df["image"]
    self.captions = self.df["caption"]

    # Initialize vocabulary and build vocab
    self.vocab = Vocabulary(freq_threshold)
    self.vocab.build_vocabulary(self.captions.tolist())

  def __len__(self):
    return len(self.df)

  def __getitem__(self, index):
    caption = self.captions[index]
    img_id = self.imgs[index]
    img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")

    if self.transform is not None:
      img = self.transform(img)

    numericalized_caption = [self.vocab.stoi["<SOS>"]]
    numericalized_caption += self.vocab.numericalize(caption)
    numericalized_caption.append(self.vocab.stoi["<EOS>"])

    return img, torch.tensor(numericalized_caption)


class MyCollate:
  def __init__(self, pad_idx):
    self.pad_idx = pad_idx

  def __call__(self, batch):
    imgs = [item[0].unsqueeze(0) for item in batch]
    imgs = torch.cat(imgs, dim=0)
    targets = [item[1] for item in batch]
    targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

    return imgs, targets

In [70]:
def get_loader(
    root_folder,
    annotation_file,
    transform,
    batch_size=32,
    num_workers=8,
    shuffle=True,
    pin_memory=True
):
  dataset = FlickrDataset(root_folder, annotation_file, transform=transform)

  pad_idx = dataset.vocab.stoi["<PAD>"]

  loader = DataLoader(
      dataset=dataset,
      batch_size=batch_size,
      num_workers=num_workers,
      shuffle=shuffle,
      pin_memory=pin_memory,
      collate_fn=MyCollate(pad_idx=pad_idx),
      )

  return loader, dataset

# Construct the model

In [71]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

In [78]:
class EncoderCNN(nn.Module):
    def __init__(self, embed_size, train_CNN=False):
        super(EncoderCNN, self).__init__()
        self.train_CNN = train_CNN

        # Load pretrained EfficientNet-B0
        self.efficientnet = models.efficientnet_b0(pretrained=True)

        # Replace the classifier layer with a new Linear layer
        in_features = self.efficientnet.classifier[1].in_features
        self.efficientnet.classifier[1] = nn.Linear(in_features, embed_size)

        # Freeze layers if needed
        for name, param in self.efficientnet.named_parameters():
            if "classifier.1.weight" in name or "classifier.1.bias" in name:
                param.requires_grad = True
            else:
                param.requires_grad = train_CNN

        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)

    def forward(self, images):
        features = self.efficientnet(images)  # returns tensor
        return self.dropout(self.relu(features))


In [79]:
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)

    def forward(self, features, captions):
        embeddings = self.dropout(self.embed(captions))
        embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs


In [80]:
class CNNtoRNN(nn.Module):
  def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
    super(CNNtoRNN, self).__init__()
    self.encoderCNN = EncoderCNN(embed_size)
    self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

  def forward(self, images, captions):
    features = self.encoderCNN(images)
    outputs = self.decoderRNN(features, captions)
    return outputs

  def caption_image(self, image, vocabulary, max_length=50):
    result_caption = []

    with torch.no_grad():
      x = self.encoderCNN(image).unsqueeze(0)
      states = None

      for _ in range(max_length):
        hiddens, states = self.decoderRNN.lstm(x, states)
        output = self.decoderRNN.linear(hiddens.squeeze(0))
        predicted = output.argmax(1)
        result_caption.append(predicted.item())
        x = self.decoderRNN.embed(predicted).unsqueeze(0)

        if vocabulary.itos[predicted.item()] == "<EOS>":
          break

    return [vocabulary.itos[idx] for idx in result_caption]


# Utils

In [81]:
def print_examples(model, device, dataset):
    transform = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ]
    )

    model.eval()
    test_img1 = transform(Image.open("test_examples/dog.jpg").convert("RGB")).unsqueeze(
        0
    )
    print("Example 1 CORRECT: Dog on a beach by the ocean")
    print(
        "Example 1 OUTPUT: "
        + " ".join(model.caption_image(test_img1.to(device), dataset.vocab))
    )
    test_img2 = transform(
        Image.open("test_examples/child.jpg").convert("RGB")
    ).unsqueeze(0)
    print("Example 2 CORRECT: Child holding red frisbee outdoors")
    print(
        "Example 2 OUTPUT: "
        + " ".join(model.caption_image(test_img2.to(device), dataset.vocab))
    )
    test_img3 = transform(Image.open("test_examples/bus.png").convert("RGB")).unsqueeze(
        0
    )
    print("Example 3 CORRECT: Bus driving by parked cars")
    print(
        "Example 3 OUTPUT: "
        + " ".join(model.caption_image(test_img3.to(device), dataset.vocab))
    )
    test_img4 = transform(
        Image.open("test_examples/boat.png").convert("RGB")
    ).unsqueeze(0)
    print("Example 4 CORRECT: A small boat in the ocean")
    print(
        "Example 4 OUTPUT: "
        + " ".join(model.caption_image(test_img4.to(device), dataset.vocab))
    )
    test_img5 = transform(
        Image.open("test_examples/horse.png").convert("RGB")
    ).unsqueeze(0)
    print("Example 5 CORRECT: A cowboy riding a horse in the desert")
    print(
        "Example 5 OUTPUT: "
        + " ".join(model.caption_image(test_img5.to(device), dataset.vocab))
    )
    model.train()


def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    step = checkpoint["step"]
    return step

# Training

In [82]:
def train():

  transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),  # Crop the center to 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
    ])

  train_loader, dataset = get_loader(
      root_folder="/content/Images",
      annotation_file="/content/captions.txt",
      transform=transform,
      num_workers=2,
  )

  torch.backends.cudnn.benchmark = True
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  load_model = False
  save_model = True

  # Hyperparameters
  embed_size = 256
  hidden_size = 256
  vocab_size = len(dataset.vocab)
  num_layers = 1
  learning_rate = 3e-4
  num_epochs = 100

  # tensorboard
  writer = SummaryWriter("runs/flickr")
  step = 0

  # initialize model, loss etc
  model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
  criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
  optimizer = optim.Adam(model.parameters(), lr=learning_rate)

  if load_model:
    step = load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

  model.train()
  for epoch in range(num_epochs):

    print_examples(model, device, dataset)

    if save_model:
      checkpoint = {
          "state_dict": model.state_dict(),
          "optimizer": optimizer.state_dict(),
          "step": step,
      }
      save_checkpoint(checkpoint)

    for idx, (imgs, captions) in tqdm(enumerate(train_loader), total=len(train_loader), leave=False):
      imgs = imgs.to(device)
      captions = captions.to(device)

      outputs = model(imgs, captions[:-1])
      # (seq_len, N, vocabulary_size), (seq_len, N)
      loss = criterion(
          outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
      )
      writer.add_scalar("Training loss", loss.item(), global_step=step)
      step += 1

      optimizer.zero_grad()
      loss.backward(loss)
      optimizer.step()

In [None]:
train()



Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: mouths quad tags offering climbing horizon beverage attacking snowbank dirty mobile workers club punk lifts alley mats squats strange chess hats clean video bricks portrait fireworks musical arena diner surfers surfers called viz landing sells paintball bag bag bag show set monument machines bread carried give navy clothing lockers topless
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: shows sheer leaves station handle meadow wilderness above seaweed little move jet genocide catching bent clothing lockers topless drops paintball posts canvas close mountain lifting girl blows monitor big 6 teenager jackson tournament incoming mouths runner fetches kicks person blocks side carpeted helping lap park tv move side pointy sprinkler
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: croquet bagpipe posts ramps balancing amid it uphill logo paintings guarding opens using grey bicycler enclosu



Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a man in a red shirt is standing on a <UNK> . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a man in a red shirt is standing on a <UNK> . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a red shirt is standing on a <UNK> . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a red shirt is standing on a <UNK> . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man in a red shirt is standing on a <UNK> . <EOS>
=> Saving checkpoint




Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a dog is running through the grass . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a man in a blue shirt is standing on a bench . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a blue shirt is standing on a bench . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a blue shirt is standing on a bench . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man in a blue shirt is standing on a bench . <EOS>
=> Saving checkpoint




Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a dog is running through the grass . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a boy in a red shirt is standing on a bench . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a blue shirt and a black shirt and a black shirt and a woman in a white shirt and a black hat and a woman in a white shirt and a black hat and a woman in a white shirt and a black hat and a
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a blue shirt is standing on a bench . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man in a black shirt and a black shirt and a black shirt and a woman in a blue shirt and a white shirt and a black hat and a woman in a white shirt and a black hat and a woman in a white shirt and a
=> Saving checkpoint




Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a dog is running through the grass . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a little girl in a blue shirt is playing with a ball in the grass . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a red shirt and a black hat and a woman in a black shirt and a black hat and a woman in a white shirt and a black hat and a woman in a black shirt and a black hat and a black hat and a
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a blue shirt is riding a bike . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man in a black shirt and a black hat and a black hat and a woman in a black jacket and a black hat and a black hat and a woman in a black jacket and a black hat and a black hat and a woman in a
=> Saving checkpoint




Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a dog is running through the grass . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a young boy in a blue shirt is playing with a ball in a pool . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a black shirt and a hat and a woman in a black jacket and a black hat and a black hat and a black hat and a woman in a black jacket and black pants is standing on a bench . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a blue shirt is riding a bicycle . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man and a woman are sitting on a bench . <EOS>
=> Saving checkpoint




Example 1 CORRECT: Dog on a beach by the ocean
Example 1 OUTPUT: <SOS> a dog is running through the grass . <EOS>
Example 2 CORRECT: Child holding red frisbee outdoors
Example 2 OUTPUT: <SOS> a young boy in a red shirt and a white shirt and a blue shirt and a red shirt is playing with a ball in a field . <EOS>
Example 3 CORRECT: Bus driving by parked cars
Example 3 OUTPUT: <SOS> a man in a white shirt and a hat is standing on a bench . <EOS>
Example 4 CORRECT: A small boat in the ocean
Example 4 OUTPUT: <SOS> a man in a blue shirt is riding a bike on a dirt path . <EOS>
Example 5 CORRECT: A cowboy riding a horse in the desert
Example 5 OUTPUT: <SOS> a man and a woman are sitting on a bench . <EOS>
=> Saving checkpoint


 17%|█▋        | 221/1265 [00:43<03:03,  5.68it/s]