In [2]:
import os
import numpy as np
import torch
import cv2
import re
from tqdm import tqdm
from collections import Counter

In [3]:
torch.__version__

'1.9.0'

In [2]:
vocab_freq = Counter()
img_captions_dict = {}
with open('Flickr8k_text/Flickr8k.token.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        img_id = line.split('\t')[0].split('#')[0]
        caption = line.split('\t')[1].strip().strip('.').strip()
        clean_caption = re.sub("[^A-Za-z']+", ' ', caption.replace('<br />', ' ')).lower()
        vocab_freq.update(clean_caption.split())
        if img_id not in img_captions_dict:
            img_captions_dict[img_id] = []
        img_captions_dict[img_id].append(clean_caption)

In [3]:
words = [w for w in vocab_freq.keys() if vocab_freq[w] > 5]

In [4]:
word_map = {word:i+3 for i, word in enumerate(words)}

In [5]:
word_map['<start>'] = 0
word_map['<end>'] = 1
word_map['<unk>'] = 2

In [7]:
def fetch_encoded_img_captions(file, img_captions_dict, word_map, max_len = 50):
    img_arrays = []
    encoded_captions = []
    with open(file, 'r') as f:
        img_ids = f.readlines()
        for img_id in img_ids:
            img_id = img_id.strip()
            if img_id not in img_captions_dict:
                continue
            img = cv2.imread('Flicker8k_Dataset/' + img_id)
            img = cv2.resize(img, (256, 256))
            img_arrays.append(img)
            caption_list = img_captions_dict[img_id]
            for caption in caption_list:
                caption = caption.split()
                if len(caption) > max_len:
                    caption = caption[:max_len]
                enc_caption = [word_map['<start>']] + [word_map.get(word, word_map['<unk>']) for word in caption] \
                              + [word_map['<end>']]
                encoded_captions.append(enc_caption)
    return img_arrays, encoded_captions
                
            
    
    

In [8]:
train_imgs, train_captions = fetch_encoded_img_captions('Flickr8k_text/Flickr_8k.trainImages.txt', 
                                                                              img_captions_dict, word_map)

In [9]:
assert len(train_imgs)*5 == len(train_captions)

In [10]:
train_imgs[0].shape

(256, 256, 3)

In [16]:
valid_imgs, valid_captions = fetch_encoded_img_captions('Flickr8k_text/Flickr_8k.devImages.txt', 
                                                                              img_captions_dict, word_map)
test_imgs, test_captions = fetch_encoded_img_captions('Flickr8k_text/Flickr_8k.testImages.txt', 
                                                                              img_captions_dict, word_map)

In [17]:
train_imgs = train_imgs[:500]
train_captions = train_captions[:2500]

valid_imgs = valid_imgs[:100]
valid_captions = valid_captions[:100]

In [18]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torchvision
from torchvision import models, transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
class Encoder(nn.Module):
    
    def __init__(self):
        super(Encoder, self).__init__()
        
        resnet = torchvision.models.resnet101(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)
            
        layers = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*layers)
        
    def forward(self, x):
        
        x = self.resnet(x)
        
        batch_size, feature_length, size_1,  size_2 = x.size()
        x = x.permute(0, 2, 3, 1)
        x = x.view(batch_size, size_1*size_2, feature_length)
        return x
    
class Attention(nn.Module):
    
    def __init__(self, feature_size, hidden_size, output_size = 1):
        super(Attention, self).__init__()
        
        self.feature_size = feature_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.Wa = nn.Linear(self.feature_size, self.hidden_size)
        self.Ua = nn.Linear(self.hidden_size, self.hidden_size)
        self.Va = nn.Linear(self.hidden_size, self.output_size)
        
    
    def forward(self, features, decoder_hidden_state):
        
        decoder_hidden_state = decoder_hidden_state.unsqueeze(1)
        att1 = self.Wa(features)
        att2 = self.Ua(decoder_hidden_state)
        
        atten = torch.tanh(att1 + att2)
        att_score = self.Va(atten)
        att_weight = nn.functional.softmax(att_score)
        
        att_sum = torch.sum(att_weight*features, dim=1)
        att_weight = att_weight.squeeze(dim=2)
        
        return att_sum, att_weight
        
        
    
class DecoderAttention(nn.Module):
    
    def __init__(self, feature_size, emb_size, hidden_size, vocab_size, p = 0.5):
        
        super(DecoderAttention, self).__init__()
        
        self.feature_size = feature_size
        self.vocab_size = vocab_size
        self.embed_size = emb_size
        self.hidden_size = hidden_size
        self.p = p
        
        self.embeddings = nn.Embedding(vocab_size, emb_size)
        
        self.lstm = nn.LSTMCell(emb_size + feature_size, hidden_size)
        self.hidden = nn.Linear(feature_size, hidden_size)
        self.cell = nn.Linear(feature_size, hidden_size)
        
        self.fc = nn.Linear(hidden_size, vocab_size)
        
        self.dropout = nn.Dropout(p=p)
        
        self.attention = Attention(feature_size, hidden_size)
        
    
    def forward(self, features, captions):
        
        emb_captions = self.embeddings(captions)
        
        feature_size = features.size(1)
        batch_size = features.size(0)
        cap_len = captions.size(1)
        
        h, c = self.init_hidden(features)
        
        outputs = torch.zeros(batch_size, cap_len, self.vocab_size)
        att_weights = torch.zeros(batch_size, cap_len, feature_size)
        
        for i in range(cap_len):
            sample_prob = 0.0 if i == 0 else 0.5
            use_sampling = np.random.random() < sample_prob
            if use_sampling == False:
                word_embed = emb_captions[:,i,:]
            
            att_sum, att_weight = self.attention(features, h)
            
            lstm_input = torch.cat([word_embed, att_sum], 1)
            h, c = self.lstm(lstm_input, (h,c))
            h = self.dropout(h)
            output = self.fc(h)
            if use_sampling == True:
                
                scaled_output = output / 0.5
                scoring = nn.functional.log_softmax(scaled_output, dim=1)
                top_idx = scoring.topk(1)[1]
                word_embed = self.embeddings(top_idx).squeeze(1) 
            outputs[:, i, :] = output
            att_weights[:, i, :] = att_weight
        return outputs, att_weights
    
    
    def init_hidden(self, features):
        
        mean_features = torch.mean(features, dim=1)
        h0 = self.hidden(mean_features)
        c0 = self.cell(mean_features)
        return h0, c0
    
    def greedy_search(self, feature, max_cap_len = 50):
        
        weights = []
        sentence = []
        
        input_word = torch.tensor(0).unsqueeze(0)
        
        h, c = self.init_hidden(features)
        
        while True:
            emb_word = self.embeddings(input_word)
            att_sum, att_weight = self.attention(features, h)
            
            input_lstm = torch.cat([emb_word, att_sum], dim=1)
            h, c = self.lstm(input_lstm, (h, c))
            
            h = self.dropout(h)
            
            output = self.fc(h)
            
            weights.append(att_weight)
            scores = nn.functional.softmax(output, dim=1)
            idx = scoring[0].topk(1)[1]
            sentence.append(idx.item())
            
            input_word = idx
            if (len(sentence) >= max_sentence or idx == 1):
                break

        return sentence, weights
            
            
            
            
    

In [42]:
class CaptionDataset(Dataset):
    
    def __init__(self, imgs, captions, batch_size, train = True):
        self.imgs = imgs
        self.captions = captions
        self.size = len(self.captions)
        self.is_train = train
        self.batch_size = batch_size
        self.caption_lengths = [len(token) for token in self.captions]
        
        if train:
            self.transforms = A.Compose([ 
                A.RandomCrop(224, 224),                      # get 224x224 crop from random location
                A.HorizontalFlip(p=0.5),               # horizontally flip image with probability=0.5
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2()
                
            ])
        else:
            self.transforms = A.Compose([ 
                A.CenterCrop(224, 224),                   # smaller edge of image resized to 256
                A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ToTensorV2()
            ])
        
    def __len__(self):
        return self.size
    
    def __getitem__(self, idx):
        
        img = self.imgs[idx // 5]
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        transformed = self.transforms(image=img.astype(np.uint8))
        img = transformed['image']

        caption = torch.LongTensor(self.captions[idx])
        all_caption_per_image = []
        if self.is_train:
            return img, caption
        for k in range(idx // 5 , ((idx // 5) + 5)):
            all_caption_per_image.append(self.captions[k])
        print(all_caption_per_image)
        return img, caption, all_caption_per_image
    
    def get_indices(self):
        # randomly select the caption length from the list of lengths
        sel_length = np.random.choice(self.caption_lengths)
        all_indices = np.where([self.caption_lengths[i] == sel_length for i in np.arange(len(self.caption_lengths))])[0]
        # select m = batch_size captions from list above
        indices = list(np.random.choice(all_indices, size=self.batch_size))
        # return the caption indices of specified batch
        return indices
    

def get_loader(batch_size=1, train = True):
    
    if train:
        # Randomly sample a caption length and indices of that length
        dataset = CaptionDataset(train_imgs, train_captions,batch_size=batch_size, train = True)
        indices = dataset.get_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices
        # functionality from torch.utils
        initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader = data.DataLoader(dataset=dataset,
                                      batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
                                                                              batch_size=dataset.batch_size,
                                                                              drop_last=False))
    else:
        dataset = CaptionDataset(valid_imgs, valid_captions, batch_size=batch_size, train = False)
        indices = dataset.get_indices()
        # Create and assign a batch sampler to retrieve a batch with the sampled indices
        # functionality from torch.utils
        initial_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader = data.DataLoader(dataset=dataset,
                                      batch_sampler=data.sampler.BatchSampler(sampler=initial_sampler,
                                                                              batch_size=dataset.batch_size,
                                                                              drop_last=False))
        


    return data_loader
        
        

In [43]:
valid_data_loader = get_loader(batch_size=batch_size, train = False)

total_step_valid = math.ceil(len(valid_data_loader.dataset.caption_lengths) / valid_data_loader.batch_sampler.batch_size)

In [44]:
next(iter(valid_data_loader))

99
19
12
99
20
9
99
21
12
99
22
14
99
23
13
[[0, 46, 47, 34, 35, 3, 490, 127, 5, 23, 474, 1], [0, 46, 219, 9, 572, 226, 23, 404, 1], [0, 46, 772, 30, 9, 157, 572, 122, 24, 23, 404, 1], [0, 46, 772, 971, 38, 351, 12, 572, 5, 56, 12, 23, 404, 1], [0, 46, 772, 5, 654, 655, 38, 62, 572, 44, 23, 404, 1]]
99
19
12
99
20
9
99
21
12
99
22
14
99
23
13
[[0, 46, 47, 34, 35, 3, 490, 127, 5, 23, 474, 1], [0, 46, 219, 9, 572, 226, 23, 404, 1], [0, 46, 772, 30, 9, 157, 572, 122, 24, 23, 404, 1], [0, 46, 772, 971, 38, 351, 12, 572, 5, 56, 12, 23, 404, 1], [0, 46, 772, 5, 654, 655, 38, 62, 572, 44, 23, 404, 1]]
9
1
11
9
2
9
9
3
12
9
4
12
9
5
16
[[0, 46, 403, 258, 38, 3, 386, 5, 3, 287, 1], [0, 46, 183, 258, 38, 3, 910, 386, 1], [0, 46, 63, 406, 5, 105, 1261, 34, 38, 3, 386, 1], [0, 46, 71, 406, 38, 3, 386, 17, 441, 3, 586, 1], [0, 3, 139, 5, 3, 153, 351, 8, 164, 133, 367, 572, 5, 23, 602, 1]]
90
18
12
90
19
12
90
20
9
90
21
12
90
22
14
[[0, 3, 40, 27, 227, 3, 26, 27, 5, 23, 212, 1], [0, 46, 47, 34, 35,

RuntimeError: each element in list of batch should be of equal size

In [21]:
import torch
import torch.nn as nn
from torchvision import transforms
import numpy as np
import sys
import os
import math

import torch.utils.data as data

import nltk
from nltk.translate.bleu_score import corpus_bleu
nltk.download('punkt')

[nltk_data] Downloading package punkt to
[nltk_data]     /Users/rahulbethavalli/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [22]:
#batch_size = 64          # batch size, change to 64
batch_size = 10
embed_size = 256           # dimensionality of image and word embeddings
hidden_size = 512          # number of features in hidden state of the RNN decoder
feature_size = 2048        # number of feature maps, produced by Encoder
num_epochs = 3             # number of training epochs
print_every = 100          # determines window for printing average loss

In [23]:
def train(encoder, decoder, optimizer, lossFn, data_loader, total_step):
    
    total_loss = 0.0
    
    for i in tqdm(range(1, total_step+1)):
        encoder.eval()
        decoder.train()
        
        indices = data_loader.dataset.get_indices()
        new_sampler = data.sampler.SubsetRandomSampler(indices=indices)
        data_loader.batch_sampler.sampler = new_sampler
        
        img, cap = next(iter(data_loader))
        cap_target = cap[:, 1:]
        cap_train = cap[:, :-1]
        
        encoder.zero_grad()
        decoder.zero_grad()
        
        img_features = encoder(img)
        out, att_weights = decoder(img_features, cap_train)
        
        loss = lossFn(out.view(-1, vocab_size), cap_target.reshape(-1))
        
        loss.backward()
        
        optimizer.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss / total_step
    print('Avg. Loss train: ', avg_loss)
    return


def validate(encoder, decoder, optimizer, lossFn, data_loader, total_step):
    
        epoch_loss = 0.0
        references = []
        hypothesis = []
        
        for i in range(1, total_step+1):
            # evaluation of encoder and decoder
            encoder.eval()
            decoder.eval()
            
            img, cap, caps_all = next(iter(data_loader))
            
            captions_target = cap[:, 1:]
            val_captions = cap[:, :-1]
            
            features_val = encoder(img)
            outputs_val, atten_weights_val = decoder(captions= val_captions,
                                             features = features_val)
            loss_val = lossFn(outputs_val.view(-1, vocab_size), 
                                 captions_target.reshape(-1))

            # preprocess captions and add them to the list
            
            references.append(caps_all)
            # get corresponding indicies from predictions
            # and form hypothesis from output
            terms_idx = torch.max(outputs_val, dim=2)[1]
            hyp_list = get_hypothesis(terms_idx, data_loader=data_loader)
            hypothesis.append(hyp_list)

            epoch_loss += loss_val.item()

            
        epoch_loss_avg = epoch_loss / total_step

        # prepare the proper shape for computing BLEU scores
        references = np.array(references).reshape(total_step*batch_size, -1)
        #hyps = np.array(hypothesis).reshape(total_step*batch_size, -1)
        hyps = np.concatenate(np.array(hypothesis))

        bleu_1 = corpus_bleu(references, hyps, weights = (1.0, 0, 0, 0))
        bleu_2 = corpus_bleu(references, hyps, weights = (0.5, 0.5, 0, 0))
        bleu_3 = corpus_bleu(references, hyps, weights = (1.0/3.0, 1.0/3.0, 1.0/3.0, 0))
        bleu_4 = corpus_bleu(references, hyps, weights = (0.25, 0.25, 0.25, 0.25))
        # append individual n_gram scores
        #bleu_score_list.append((bleu_1, bleu_2, bleu_3, bleu_4))

        epoch_stat = 'Avg. Loss valid: %.4f, \
        BLEU-1: %.2f, BLEU-2: %.2f, BLEU-3: %.2f, BLEU-4: %.2f' % (epoch_loss_avg, bleu_1, bleu_2, bleu_3, bleu_4)

        print('\r' + epoch_stat, end="")
        print('\r')
        
        return
        

In [24]:
index2vocab = {}

for key, value in word_map.items():
    index2vocab[value] = key

In [25]:
def get_hypothesis(terms_idx, data_loader):
    
    hypothesis_list = []

    for i in range(terms_idx.size(0)):
        words = [index2vocab.get(idx.item()) for idx in terms_idx[i]]
        words = [word for word in words if word not in (',', '.', '<end>')]
        hypothesis_list.append(words)
    return hypothesis_list

In [26]:
vocab_size = len(vocab_freq) + 2

encoder = Encoder()
decoder = DecoderAttention(feature_size = feature_size, 
                     emb_size = embed_size, 
                     hidden_size = hidden_size, 
                     vocab_size = vocab_size)

lossFn = nn.CrossEntropyLoss()

params = list(decoder.parameters())

optimizer = torch.optim.Adam(params, lr = 1e-4)

data_loader = get_loader(
                         batch_size=batch_size,
                         train = True)

total_step = math.ceil(len(data_loader.dataset.caption_lengths) / data_loader.batch_sampler.batch_size)

valid_data_loader = get_loader(batch_size=batch_size, train = False)

total_step_valid = math.ceil(len(valid_data_loader.dataset.caption_lengths) / valid_data_loader.batch_sampler.batch_size)

In [29]:
a,b,c = next(iter(valid_data_loader))

RuntimeError: each element in list of batch should be of equal size

In [27]:

for epoch in tqdm(range(num_epochs)):  
    print('Epoch: ', epoch)
    
#     train(encoder, decoder, optimizer, lossFn,
#           data_loader=data_loader, total_step=total_step)
    
    validate(encoder, decoder, optimizer, lossFn, 
             data_loader= valid_data_loader, total_step= total_step_valid)


  0%|                                                                                                                                                                                 | 0/3 [00:00<?, ?it/s]

Epoch:  0





RuntimeError: each element in list of batch should be of equal size