# Unzip data

In [1]:
from google.colab import drive

drive.mount('/content/drive/', force_remount=True)
data_folder = r'/content/drive/My Drive/data/'

Mounted at /content/drive/


In [2]:
!unzip '/content/drive/My Drive/data/flickr1k.zip' -d '/content/drive/My Drive/data/'

Archive:  /content/drive/My Drive/data/flickr1k.zip
replace /content/drive/My Drive/data/flickr1k/captions.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

# Module

In [2]:
import os
import numpy as np
import pandas as pd
import spacy
from PIL import Image

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

import torch.nn as nn
import torchvision.models as models
import torch.optim as optim
import torchvision.transforms as transforms

In [3]:
spacy_eng = spacy.load('en')


class Vocabulary:
  def __init__(self, freq_threshold):
    self.itos = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}
    self.stoi = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
    self.freq_threshold = freq_threshold

  def __len__(self):
    return len(self.itos)

  @staticmethod
  def tokenizer_eng(text):
    return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]

  def build_vocabulary(self, sentence_list):
    frequencies = {}
    idx = 4

    for sentence in sentence_list:
      for word in self.tokenizer_eng(sentence):
        if word not in frequencies:
          frequencies[word] = 1
        else:
          frequencies[word] += 1

        if frequencies[word] == self.freq_threshold:
          self.stoi[word] = idx
          self.itos[idx] = word
          idx += 1

  def numericalize(self, text):
    tokenized_text = self.tokenizer_eng(text)

    return [
              self.stoi[token] if token in self.stoi else self.stoi['<UNK>'] 
              for token in tokenized_text
    ]


class FlickrDataset(Dataset):
  def __init__(self, root_dir, captions_file, transform, freq_threshold=5):
    self.root_dir = root_dir
    self.df = pd.read_csv(captions_file)
    self.transform = transform

    self.imgs = self.df['image']
    self.captions = self.df['caption']

    self.vocab = Vocabulary(freq_threshold)
    self.vocab.build_vocabulary(self.captions.tolist())

  def __len__(self):
    return len(self.df)

  def __getitem__(self, index):
    caption = self.captions[index]
    img_id = self.imgs[index]
    img = Image.open(os.path.join(self.root_dir, img_id)).convert('RGB')

    img = self.transform(img)
    
    numaricalized_caption = [self.vocab.stoi['<SOS>']]
    numaricalized_caption += self.vocab.numericalize(caption)
    numaricalized_caption.append(self.vocab.stoi['<EOS>'])

    return img, torch.tensor(numaricalized_caption)


class FlickrTest(Dataset):
  def __init__(self, root_dir, captions_file, transform, freq_threshold=5):
    self.root_dir = root_dir
    self.df = pd.read_csv(captions_file)
    unique_ids = self.df['caption'].unique()
    test_ids = unique_ids[np.random.choice(unique_ids.shape[0], 1, replace=False)]
    self.df = self.df[self.df['caption'].isin(test_ids)]
    self.df = self.df.reset_index()

    self.transform = transform

    self.imgs = self.df['image']
    self.captions = self.df['caption']
    print(self.captions)

  def __len__(self):
    return len(self.df)

  def __getitem__(self, index):
    caption = self.captions[index]
    img_id = self.imgs[index]
    img = Image.open(os.path.join(self.root_dir, img_id)).convert('RGB')

    img_transformed = self.transform(img)

    return np.array(img), img_transformed


class MyCollate:
  def __init__(self, pad_idx):
    self.pad_idx = pad_idx

  def __call__(self, batch):
    imgs = [item[0].unsqueeze(0) for item in batch]
    imgs = torch.cat(imgs, dim=0)
    targets = [item[1] for item in batch]
    targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)

    return imgs, targets

In [4]:
class EncoderCNN(nn.Module):
  def __init__(self, embed_size, train=False):
    super(EncoderCNN, self).__init__()
    self.train = train
    self.resnet = models.resnet18(pretrained=True)
    self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embed_size)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(0.5)

  def forward(self, images):
    features = self.resnet(images)

    for name, param in self.resnet.named_parameters():
      if 'fc.weight' in name or 'fc.bias' in name:
        param.requires_grad = True
      else:
        param.requires_grad = self.train

    return self.dropout(self.relu(features))

In [5]:
class DecoderRNN(nn.Module):
  def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
    super(DecoderRNN, self).__init__()
    self.embedding = nn.Embedding(vocab_size, embed_size)
    self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
    self.linear = nn.Linear(hidden_size, vocab_size)
    self.dropout = nn.Dropout(0.5)

  def forward(self, features, captions):
    embeddings = self.dropout(self.embedding(captions))
    embeddings = torch.cat((features.unsqueeze(0), embeddings), dim=0)
    hiddens, _ = self.lstm(embeddings)
    outputs = self.linear(hiddens)

    return outputs

In [6]:
class CNNtoRNN(nn.Module):
  def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
    super(CNNtoRNN, self).__init__()
    self.encoder = EncoderCNN(embed_size)
    self.decoder = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)

  def forward(self, images, captions):
    features = self.encoder(images)
    outputs = self.decoder(features, captions)
    return outputs

# Training

In [7]:
transform = transforms.Compose(
  [
   transforms.Resize((256,256)),
   transforms.CenterCrop(224),
   transforms.ToTensor(),
   transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
  ]
)

In [8]:
trainset = FlickrDataset(root_dir='/content/drive/My Drive/data/flickr1k/images', 
                         captions_file='/content/drive/My Drive/data/flickr1k/captions.csv', 
                         transform=transform)
pad_idx = trainset.vocab.stoi['<PAD>']
train_loader = DataLoader(
    dataset=trainset,
    batch_size=32,
    shuffle=True,
    collate_fn=MyCollate(pad_idx=pad_idx)
  )

testset = FlickrTest(root_dir='/content/drive/My Drive/data/flickr1k/images', 
                     captions_file='/content/drive/My Drive/data/flickr1k/captions.csv', 
                     transform=transform)
test_loader = DataLoader(
    dataset=testset,
    batch_size=32,
    shuffle=False
  )

0    a dog jumping over a small wall at a beach nea...
Name: caption, dtype: object


In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  
embed_size = 256
hidden_size = 256
vocab_size = len(trainset.vocab)
num_layers = 1
learning_rate = 3e-4
num_epochs = 10

In [10]:
losses = list()
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=trainset.vocab.stoi['<PAD>'])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, num_epochs+1):
    
    for i_step, (img, caption) in enumerate(train_loader):

        model.zero_grad()

        img = img.to(device)
        captions_target = caption[1:].to(device)
        captions_train = caption[:-1].to(device)

        outputs = model(img, captions_train)

        loss = criterion(outputs[1:].view(-1, vocab_size), captions_target.contiguous().view(-1))

        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        stats = 'Epoch [%d/%d], Step [%d], Loss: %.4f' % (epoch, num_epochs, i_step, loss.item())

        print('\r' + stats, end='')

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch [10/10], Step [156], Loss: 2.7136

In [None]:
# testset = FlickrTest(root_dir='/content/drive/My Drive/data/flickr1k/images', 
#                      captions_file='/content/drive/My Drive/data/flickr1k/captions.csv', 
#                      transform=transform)
# test_loader = DataLoader(
#     dataset=testset,
#     batch_size=32,
#     shuffle=False
#   )

In [11]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
with torch.no_grad():
  image_caption = []
  for (img, img_transformed) in test_loader:
    img_transformed = img_transformed.to(device)

    x = model.encoder(img_transformed).unsqueeze(0)
    states = None

    for _ in range(20):
      hiddens, states = model.decoder.lstm(x, states)
      output = model.decoder.linear(hiddens.squeeze(0))
      predicted = output.argmax(1)
      
      image_caption.append([x.item() for x in predicted])
      x = model.decoder.embedding(predicted).unsqueeze(0)

In [15]:
[trainset.vocab.itos[idx] for idx in np.array(image_caption)[:,0].tolist()]

['a',
 'man',
 'in',
 'a',
 'red',
 'jacket',
 'is',
 'standing',
 'on',
 'a',
 '<UNK>',
 '.',
 '<EOS>',
 'a',
 '<UNK>',
 '.',
 '<EOS>',
 '<EOS>',
 '.',
 '<EOS>']