# CNN+LSTM

## Setting up Kaggle API and downloading the dataset -Flickr30k

In [0]:
# Run this cell and select the kaggle.json file downloaded
# from the Kaggle account settings page.
from google.colab import files
files.upload()

In [0]:
# Next, install the Kaggle API client.
!pip install -q kaggle
# The Kaggle API client expects this file to be in ~/.kaggle,
# so move it there.
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/

# This permissions change avoids a warning on Kaggle tool startup.
!chmod 600 ~/.kaggle/kaggle.json

### Downloading the Dataset

In [0]:
# Download from Kaggle
!kaggle datasets download -d hsankesara/flickr-image-dataset

In [0]:
# Unzip dataset
!unzip flickr-image-dataset.zip

In [0]:
!rm flickr-image-dataset.zip
!rm -r flickr30k_images/flickr30k_images/flickr30k_images
!rm flickr30k_images/flickr30k_images/results.csv

## Importing Modules

In [0]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, Subset
!pip install torchtext==0.5.0
import torchtext
from torchtext.data import get_tokenizer, Field
from torchtext.data.metrics import bleu_score
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer
import os
import gc
import math
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import models

### Loading Flickr30k and Preprocessing

In [0]:
data_path = 'flickr30k_images/flickr30k_images'
csv_path = 'flickr30k_images/results.csv'
!mkdir pretrained_models
save_path = 'pretrained_models'

In [0]:
#Fixing a data issue
results = pd.read_csv(csv_path, sep='|')
fix_19999 = results.loc[19999][' comment_number']
results.loc[19999][' comment_number'] = ' 4'
results.loc[19999][' comment'] = fix_19999[4:]
results = results.sort_values(by=[' comment_number', 'image_name', ])
results

In [0]:
class Flickr30kDataset(Dataset):
    def __init__(self, data_path, transforms, comment_num):
        self.data_path = data_path
        self.data_files = os.listdir(self.data_path)
        self.data_files = sorted(self.data_files)
        self.transforms = transforms
        assert 0 <= comment_num <= 4
        self.comment_num = ' ' + str(comment_num)
        self._tokenizer()
        
    def __getitem__(self, idx):
        img_path = os.path.join(self.data_path, self.data_files[idx])
        img = Image.open(img_path)
        inputs = self.transforms(img)
        captions = self.token_list[idx]
        cap_lens = self.len_list[idx]
        return inputs, captions, cap_lens
    
    def __len__(self):
        return len(self.data_files)
    
    def _tokenizer(self):
        self.caption_list = results.loc[results[' comment_number'] == self.comment_num][' comment'].tolist()
        tokenizer = get_tokenizer("basic_english")
        self.token_list = [tokenizer(caption) for caption in self.caption_list]
        self.len_list = torch.tensor([len(token) for token in self.token_list])
        self.len_list += 1 # allow for <sos> or <eos>
        self.seq_len = self.len_list.max() + 1
        self.field = Field(tokenize='spacy', tokenizer_language='en', 
                           init_token='<sos>', eos_token='<eos>', lower=True, fix_length=self.seq_len)
        self.field.build_vocab(self.token_list)
        self.token_list = self.field.process(self.token_list)
        self.token_list = self.token_list.transpose(1, 0)

In [0]:
transforms_train = transforms.Compose([
        transforms.Resize(size=(224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

In [0]:
data_set = Flickr30kDataset(data_path, transforms_train, 3)

In [0]:
reference_corpus = []
for num in range(5):
    caption_candidates = results.loc[results[' comment_number'] == ' '+str(num)][' comment'].tolist()
    print(caption_candidates[0])
    reference_corpus.append([x.split() for x in caption_candidates])
reference_corpus = np.array(reference_corpus)
reference_corpus = reference_corpus.transpose(1, 0)

### Splitting the Dataset

In [0]:
num_data = len(data_set)
idx = list(range(num_data))
train_set = Subset(data_set, idx[:-2000])
vali_set = Subset(data_set, idx[-2000:-1000])
test_set = Subset(data_set, idx[-1000:])
print(len(train_set), len(vali_set), len(test_set))

## Defining the Model

In [0]:
class Encoder(nn.Module):
    def __init__(self, cnn):
        super(Encoder, self).__init__()
        self.cnn = cnn
        
    def forward(self, x):
        enc_output = self.cnn(x)
        enc_output = F.relu(enc_output)
        return enc_output

    def freeze_bottom(self):
        for p in self.cnn.parameters():
            p.requires_grad = False
        for c in list(self.cnn.children())[-3:]:
            for p in c.parameters():
                p.requires_grad = True

    def freeze_all(self):
        for p in self.cnn.parameters():
            p.requires_grad = False

In [0]:
class Decoder(nn.Module):
    def __init__(self, enc_size, emb_size, vocab_size, num_layers=1, embedding_matrix=None):
        super(Decoder, self).__init__()
        self.enc_size = enc_size
        self.hidden_size = enc_size
        self.emb_size = emb_size
        self.vocab_size = vocab_size
        self.num_layers = num_layers
        
        # self.embed = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
        self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.emb_size)
        self.lstm = nn.LSTM(input_size=self.emb_size+self.enc_size, hidden_size=self.hidden_size, num_layers=self.num_layers)
        self.fc = nn.Linear(self.hidden_size, self.vocab_size)
        self._init_weights()

    def _init_weights(self):
        self.embed.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)
        
    def forward(self, enc_out, captions, caplens):
        enc_out = enc_out.unsqueeze(0) # 1 * batch_size * 1000
        h0 = enc_out.cuda()
        c0 = enc_out.clone().cuda()
        captions = self.embed(captions) # seq_len * batch_size * emb_size
        enc_out = enc_out.repeat(captions.size(0), 1, 1) # seq_len * batch_size * 1000
        packed_captions = pack_padded_sequence(torch.cat((captions, enc_out), dim=2), caplens) # concatenate tokens and features
        outputs, _ = self.lstm(packed_captions, (h0, c0))
        outputs = self.fc(outputs[0])
        return outputs

    def greedy_pred(self, enc_out, pred_len):
        init = torch.ones((1, enc_out.size(0)), dtype=int) *  2
        init = init.cuda()
        enc_out = enc_out.unsqueeze(0) 
        h0 = enc_out.cuda()
        c0 = enc_out.clone().cuda()
        init = self.embed(init) # 1 * batch_size * emb_size
        next_out, (h, c) = self.lstm(torch.cat((init, enc_out), dim=2), (h0, c0))
        next_out = self.fc(next_out)
        next_in = next_out.argmax(dim=2)
        outputs = next_in.clone()
        for i in range(pred_len - 1):
            next_in = self.embed(next_in)
            next_out, (h, c) = self.lstm(torch.cat((next_in, enc_out), dim=2), (h, c))
            next_out = self.fc(next_out)
            next_in = next_out.argmax(dim=2)
            outputs = torch.cat((outputs, next_in), dim=0)
        return outputs

    def beam_search_pred(self, enc_outs, pred_len, beam_size=3):
        batch_size = enc_outs.size(0)
        outputs = torch.ones((pred_len, batch_size), dtype=int) # place_holder for outputs
        for idx in range(batch_size):
            enc_out = enc_outs[idx]
            enc_out = enc_out.unsqueeze(0).unsqueeze(0) # 1 * 1 * enc_size
            enc_out = enc_out.repeat(1, beam_size, 1) # 1 * beam_size * enc_size, (view beam_size as batch_size for convenience)
            k_words = torch.ones((1, beam_size), dtype=int).cuda() * 2 # 1 * beam_size
            seqs = k_words # 1 * beam_size
            k_scores = torch.zeros(1, beam_size, 1).cuda() # 1 * beam_size * 1
            h = enc_out.cuda() # 1 * beam_size * enc_size
            c = enc_out.clone().cuda() # 1 * beam_size * enc_size

            for step in range(pred_len):
                embedding = self.embed(k_words) # 1 * beam_size * emb_size
                lstm_out, (h, c) = self.lstm(torch.cat((embedding, enc_out), dim=2), (h, c))
                scores = self.fc(lstm_out) # 1 * beam_size * vocab_size
                scores =  F.log_softmax(scores, dim=2) # 1 * beam_size * vocab_size
                scores = k_scores.expand_as(scores) + scores # 1 * beam_size * vocab_size
                if step == 0: # first step
                    scores = scores.squeeze() # beam_size * vocab_size
                    k_scores, k_words = scores[0].topk(beam_size) # beam_size
                else:
                    scores = scores.squeeze() # beam_size * vocab_size
                    k_scores, k_words = scores.view(-1).topk(beam_size) # beam_size
                prev_idx = k_words / self.vocab_size # beam_size (between 0 and beam_size)
                next_idx = k_words % self.vocab_size # beam_size
                seqs = seqs[:, prev_idx] # L * beam_size
                seqs = torch.cat((seqs, next_idx.unsqueeze(0)), dim=0) # L * beam_size
                k_scores = k_scores.unsqueeze(0).unsqueeze(-1)
                k_words = next_idx.unsqueeze(0)
            output = seqs[:, k_scores.squeeze().argmax()]
            outputs[:, idx] = output[1:] # Don't include <sos>
        return outputs

## Training

In [0]:
epochs = 20
batch_size = 32
enc_lr = 1e-4
dec_lr = 1e-4
patience = 10
enc_save_path, dec_save_path = os.path.join(save_path, 'best_enc_lstm_f30'+'_demo'), os.path.join(save_path, 'best_dec_lstm_f30'+'_demo')
best_acc = 0
best_epoch = 0

enc_size = 1000
emb_size = 300
vocab_size = len(data_set.field.vocab.itos)

In [0]:
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=2)
vali_loader = DataLoader(vali_set, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=2)

In [0]:
resnet = models.resnet50(pretrained=True)
encoder = Encoder(resnet)
encoder.freeze_all() # Freeze all or bottom
decoder = Decoder(enc_size=enc_size, emb_size=emb_size, vocab_size=vocab_size)
encoder, decoder = encoder.cuda(), decoder.cuda()
del resnet
gc.collect()
torch.cuda.empty_cache()

In [0]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [0]:
def mask_accuracy(pred, targets, ignore_index=data_set.field.vocab.stoi['<pad>']):
    """
    pred: logit output
    target: labels
    ignore_index: exclude <pad> when calculating accuracy
    """
    mask = ~targets.eq(ignore_index).cuda()
    pred = pred[mask]
    targets = targets[mask]
    num_correct = pred.argmax(dim=1).eq(targets).sum()
    acc = num_correct.float() / targets.size(0)
    return acc

In [0]:
criterion = nn.CrossEntropyLoss(ignore_index=data_set.field.vocab.stoi['<pad>'])
#enc_optimizer = torch.optim.AdamW([p for p in encoder.parameters() if p.requires_grad], lr=enc_lr)
dec_optimizer = torch.optim.AdamW(decoder.parameters(), lr=dec_lr)

In [0]:
for epoch in range(epochs):
    encoder.train()
    decoder.train()
    train_loss = AverageMeter()
    vali_loss = AverageMeter()
    vali_acc = AverageMeter()
    for batch_index, (inputs, captions, caplens) in enumerate(train_loader):
        inputs, captions = inputs.cuda(), captions.cuda()
        dec_optimizer.zero_grad()
        enc_out = encoder(inputs)
        captions = captions.transpose(0, 1)
        caplens_sorted, sort_id = caplens.sort(descending=True)
        captions_sorted = captions[:, sort_id]
        captions_input = captions_sorted[:-1, :]
        captions_target = captions_sorted[1:, :]
        enc_out_sorted = enc_out[sort_id]
        outputs = decoder(enc_out_sorted, captions_input, caplens_sorted)
        captions_target_sorted = pack_padded_sequence(captions_target, caplens_sorted)
        loss = criterion(outputs, captions_target_sorted[0])
        loss.backward()
        # enc_optimizer.step()
        dec_optimizer.step()
        train_loss.update(loss.item(), inputs.size(0))
        if batch_index % 40 == 0:
            print("Batch: {}, loss {:.4f}.".format(batch_index, loss.item()))
    # Evaluation
    encoder.eval()
    decoder.eval()
    with torch.no_grad():
        for batch_index, (inputs, captions, caplens) in enumerate(vali_loader):
            inputs, captions = inputs.cuda(), captions.cuda()
            enc_out = encoder(inputs)
            captions = captions.transpose(0, 1)
            caplens_sorted, sort_id = caplens.sort(descending=True)
            captions_sorted = captions[:, sort_id]
            captions_input = captions_sorted[:-1, :]
            captions_target = captions_sorted[1:, :]
            enc_out_sorted = enc_out[sort_id]
            outputs = decoder(enc_out_sorted, captions_input, caplens_sorted)
            captions_target_sorted = pack_padded_sequence(captions_target, caplens_sorted)
            loss = criterion(outputs, captions_target_sorted[0])
            acc = mask_accuracy(outputs, captions_target_sorted[0])
            vali_loss.update(loss.item(), inputs.size(0))
            vali_acc.update(acc, inputs.size(0))
    print("Epoch: {}/{}, training loss: {:.4f}, vali loss: {:.4f}, vali acc: {:.4f}.".format(epoch, epochs, train_loss.avg, vali_loss.avg, vali_acc.avg))
    # Save best
    if vali_acc.avg > best_acc:
        best_acc = vali_acc.avg
        best_epoch = epoch
        torch.save(encoder.state_dict(), enc_save_path)
        torch.save(decoder.state_dict(), dec_save_path)
    # Early stopping
    if epoch - best_epoch >= patience:
        print("Early stopping")
        break

## Results

In [0]:
def token_sentence(decoder_out, itos):
    tokens = decoder_out
    # tokens = decoder_out.argmax(dim=2)
    tokens = tokens.transpose(1, 0)
    tokens = tokens.cpu().numpy()
    results = []
    for instance in tokens:
        result = ' '.join([itos[x] for x in instance])
        results.append(''.join(result.partition('<eos>')[0])) # Cut before '<eos>'
    return results

In [0]:
encoder.load_state_dict(torch.load(enc_save_path))
decoder.load_state_dict(torch.load(dec_save_path))
encoder.eval()
decoder.eval()

In [0]:
itos = data_set.field.vocab.itos
pred_len = data_set.seq_len
result_collection = []

# Predictions with greedy
with torch.no_grad():
    for batch_index, (inputs, captions, caplens) in enumerate(test_loader):
        inputs, captions = inputs.cuda(), captions.cuda()
        enc_outs = encoder(inputs)
        outputs = decoder.greedy_pred(enc_outs, pred_len)
        result_caption = token_sentence(outputs, itos)
        result_collection.extend(result_caption)

In [0]:
itos = data_set.field.vocab.itos
pred_len = data_set.seq_len
result_collection_bs = []

# Predictions with beam search
with torch.no_grad():
    for batch_index, (inputs, captions, caplens) in enumerate(test_loader):
        inputs, captions = inputs.cuda(), captions.cuda()
        enc_outs = encoder(inputs)
        outputs = decoder.beam_search_pred(enc_outs, pred_len, beam_size=3)
        result_caption_bs = token_sentence(outputs, itos)
        result_collection_bs.extend(result_caption_bs)

In [0]:
def plotax(ax, i):
    ax.imshow(Image.open(os.path.join(data_path, data_set.data_files[-1000+i])))
    ax.axis('off')
    ax.set_title('\n'.join(wrap(result_collection[i], 32)), fontsize=10)

In [0]:
from textwrap import wrap
fig, axs = plt.subplots(4, figsize=(3, 3))

fig.subplots_adjust(top=3)

plotax(axs[0], 10)
plotax(axs[1], 20)
plotax(axs[2], 50)
plotax(axs[3], 14)

In [0]:
#Visualize
i = 10
plt.imshow(Image.open(os.path.join(data_path, data_set.data_files[-1000+i])))
plt.axis('off')
plt.show()
print("Ground truth:", data_set.caption_list[-1000+i])
print("Prediction-greedy:", result_collection[i])
print("Prediction-beam search:", result_collection_bs[i])

## BLEU Scores

In [0]:
# Bleu scores - Greedy search w.r.t. all candidates
uni_bleu = bleu_score([x.split(' ') for x in result_collection], reference_corpus[-1000:], max_n=1, weights=[1])
bi_bleu = bleu_score([x.split(' ') for x in result_collection], reference_corpus[-1000:], max_n=2, weights=[1/2]*2)
tri_bleu = bleu_score([x.split(' ') for x in result_collection], reference_corpus[-1000:], max_n=3, weights=[1/3]*3)
qua_bleu = bleu_score([x.split(' ') for x in result_collection], reference_corpus[-1000:], max_n=4, weights=[1/4]*4)
uni_bleu, bi_bleu, tri_bleu, qua_bleu

In [0]:
# Bleu scores - Beam search w.r.t. all candidates
uni_bleu = bleu_score([x.split(' ') for x in result_collection_bs], reference_corpus[-1000:], max_n=1, weights=[1])
bi_bleu = bleu_score([x.split(' ') for x in result_collection_bs], reference_corpus[-1000:], max_n=2, weights=[1/2]*2)
tri_bleu = bleu_score([x.split(' ') for x in result_collection_bs], reference_corpus[-1000:], max_n=3, weights=[1/3]*3)
qua_bleu = bleu_score([x.split(' ') for x in result_collection_bs], reference_corpus[-1000:], max_n=4, weights=[1/4]*4)
uni_bleu, bi_bleu, tri_bleu, qua_bleu