In [None]:
import os
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
import seaborn as sns
from tqdm import tqdm
import torch
import torch.nn as nn
import torchtext
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader,Dataset
import torchvision.models as models
from torchvision.models import densenet201
from torch.autograd import Variable
from PIL import Image
from torchtext.data.utils import get_tokenizer
import cv2
from textwrap import wrap
from collections import Counter 
import pickle
import gc
import random
import spacy 
from torch.nn.utils.rnn import pad_sequence

np.random.seed(42)

# Reading Data

In [None]:
PATH = '/kaggle/input/flickr8k'
data = pd.read_csv(PATH + f'/captions.txt')
print(data.shape)
data.head()

# Visualization

In [None]:
def load_image(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img/255
    return img

def visualize_data(df):
    df = df.reset_index(drop= True)
    plt.figure(figsize=(20, 20))
    n = 0 
    for i in range(15):
        n += 1
        plt.subplot(5 , 5, n)
        plt.subplots_adjust(hspace = 0.7, wspace = 0.3)
        image = load_image(f"/kaggle/input/flickr8k/Images/{df.image[i]}")
        plt.imshow(image)
        plt.title("\n".join(wrap(df.caption[i], 20)))
        plt.axis("off")

In [None]:
image_path = PATH + f'/Images/1000268201_693b08cb0e.jpg'
image = load_image(image_path)
image

In [None]:
visualize_data(data.sample(15))

# Data Preprocessing

### Model 1

In [None]:
# def text_preprocessing(data):
#     data['caption'] = data['caption'].apply(lambda x: x.lower())
#     data['caption'] = data['caption'].apply(lambda x: x.replace("[^A-Za-z]",""))
#     data['caption'] = data['caption'].apply(lambda x: x.replace("\s+"," "))
#     data['caption'] = data['caption'].apply(lambda x: " ".join([word for word in x.split() if len(word)>1]))
#     data['caption'] = "startseq "+data['caption']+" endseq"
#     return data

# data = text_preprocessing(data)
# captions = data['caption'].tolist()
# captions[:10]

In [None]:
# tokenizer = get_tokenizer("basic_english")
# tokenized_captions = [tokenizer(caption) for caption in captions]

# vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_captions)

# vocab_size = len(vocab) + 1

# max_length = max(len(tokens) for tokens in tokenized_captions)

# print("Vocabulary size:", vocab_size)
# print("Maximum sequence length:", max_length)

In [None]:
# images = data['image'].unique().tolist()
# nimages = len(images)

# split_index = round(0.85 * nimages)

# train_images = images[:split_index]
# val_images = images[split_index:]

# test = data[data['image'].isin(val_images)]
# train = data[data['image'].isin(train_images)]

# train.reset_index(inplace=True, drop=True)
# test.reset_index(inplace=True, drop=True)

# print(train.shape, test.shape)

### Model 2

In [None]:
# def remove_single_char_word(word_list):
#     lst = []
#     for word in word_list:
#         if len(word)>1:
#             lst.append(word)

#     return lst

In [None]:
# data['cleaned_caption'] = data['caption'].apply(lambda caption : ['<start>'] + [word.lower() if word.isalpha() else '' for word in caption.split(" ")] + ['<end>'])
# data['cleaned_caption']  = data['cleaned_caption'].apply(lambda x : remove_single_char_word(x))

In [None]:
# data['seq_len'] = data['cleaned_caption'].apply(lambda x : len(x))
# max_seq_len = data['seq_len'].max()
# print(max_seq_len)

In [None]:
# data.drop(['seq_len'], axis = 1, inplace = True)
# data['cleaned_caption'] = data['cleaned_caption'].apply(lambda caption : caption + ['<pad>']*(max_seq_len-len(caption)))

In [None]:
# display(data.head(2))

In [None]:
# word_list = data['cleaned_caption'].apply(lambda x : " ".join(x)).str.cat(sep = ' ').split(' ')
# word_dict = Counter(word_list)
# word_dict =  sorted(word_dict, key=word_dict.get, reverse=True)

In [None]:
# print(len(word_dict))
# print(word_dict[:5])

In [None]:
# vocab_size = len(word_dict)
# print(vocab_size)

In [None]:
# index_to_word = {index: word for index, word in enumerate(word_dict)}
# word_to_index = {word: index for index, word in enumerate(word_dict)}
# print(len(index_to_word), len(word_to_index))

In [None]:
# data['text_seq']  = data['cleaned_caption'].apply(lambda caption : [word_to_index[word] for word in caption] )

In [None]:
# display(data.head(2))

In [None]:
# data = data.sort_values(by = 'image')
# train = data.iloc[:int(0.9*len(data))]
# valid = data.iloc[int(0.9*len(data)):]

### Model 3

In [None]:
#using spacy for the better text tokenization 
nlp = spacy.load("en_core_web_sm")

#example
text = "This is a good place to find a city"
[token.text.lower() for token in nlp.tokenizer(text)]

In [None]:
class Vocabulary:
    def __init__(self,freq_threshold):
        #setting the pre-reserved tokens int to string tokens
        self.itos = {0:"<PAD>",1:"<SOS>",2:"<EOS>",3:"<UNK>"}
        
        #string to int tokens
        #its reverse dict self.itos
        self.stoi = {v:k for k,v in self.itos.items()}
        
        self.freq_threshold = freq_threshold
        
    def __len__(self): return len(self.itos)
    
    @staticmethod
    def tokenize(text):
        return [token.text.lower() for token in nlp.tokenizer(text)]
    
    def build_vocab(self, sentence_list):
        frequencies = Counter()
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenize(sentence):
                frequencies[word] += 1
                
                #add the word to the vocab if it reaches minum frequecy threshold
                if frequencies[word] == self.freq_threshold:
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
    
    def numericalize(self,text):
        """ For each word in the text corresponding index token for that word form the vocab built as list """
        tokenized_text = self.tokenize(text)
        return [ self.stoi[token] if token in self.stoi else self.stoi["<UNK>"] for token in tokenized_text ] 

In [None]:
#testing the vicab class 
v = Vocabulary(freq_threshold=1)

v.build_vocab(["This is a good place to find a city"])
print(v.stoi)
print(v.numericalize("This is a good place to find a city here!!"))

# Feature Extraction

### Model 1

In [None]:
# train_samples = len(train)
# print(train_samples)

In [None]:
# unq_train_imgs = train[['image']].drop_duplicates()
# unq_valid_imgs = valid[['image']].drop_duplicates()
# print(len(unq_train_imgs), len(unq_valid_imgs))

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)

In [None]:
# class extractImageFeatureResNetDataSet():
#     def __init__(self, data):
#         self.data = data 
#         self.scaler = transforms.Resize([224, 224])
#         self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                      std=[0.229, 0.224, 0.225])
#         self.to_tensor = transforms.ToTensor()
#     def __len__(self):  
#         return len(self.data)

#     def __getitem__(self, idx):

#         image_name = self.data.iloc[idx]['image']
#         img_loc = '../input/flickr8k/Images/'+str(image_name)

#         img = Image.open(img_loc)
#         t_img = self.normalize(self.to_tensor(self.scaler(img)))

#         return image_name, t_img

In [None]:
# train_ImageDataset_ResNet = extractImageFeatureResNetDataSet(unq_train_imgs)
# train_ImageDataloader_ResNet = DataLoader(train_ImageDataset_ResNet, batch_size = 1, shuffle=False)

In [None]:
# valid_ImageDataset_ResNet = extractImageFeatureResNetDataSet(unq_valid_imgs)
# valid_ImageDataloader_ResNet = DataLoader(valid_ImageDataset_ResNet, batch_size = 1, shuffle=False)

In [None]:
# resnet18 = torchvision.models.resnet18(pretrained=True).to(device)
# resnet18.eval()
# list(resnet18._modules)

In [None]:
# resNet18Layer4 = resnet18._modules.get('layer4').to(device)

In [None]:
# def get_vector(t_img):
    
#     t_img = Variable(t_img)
#     my_embedding = torch.zeros(1, 512, 7, 7)
#     def copy_data(m, i, o):
#         my_embedding.copy_(o.data)
    
#     h = resNet18Layer4.register_forward_hook(copy_data)
#     resnet18(t_img)
    
#     h.remove()
#     return my_embedding

In [None]:
# extract_imgFtr_ResNet_train = {}
# for image_name, t_img in tqdm(train_ImageDataloader_ResNet):
#     t_img = t_img.to(device)
#     embdg = get_vector(t_img)
    
#     extract_imgFtr_ResNet_train[image_name[0]] = embdg

In [None]:
# a_file = open("./EncodedImageTrainResNet.pkl", "wb")
# pickle.dump(extract_imgFtr_ResNet_train, a_file)
# a_file.close()

In [None]:
# extract_imgFtr_ResNet_valid = {}
# for image_name, t_img in tqdm(valid_ImageDataloader_ResNet):
#     t_img = t_img.to(device)
#     embdg = get_vector(t_img)
 
#     extract_imgFtr_ResNet_valid[image_name[0]] = embdg

In [None]:
# a_file = open("./EncodedImageValidResNet.pkl", "wb")
# pickle.dump(extract_imgFtr_ResNet_valid, a_file)
# a_file.close()

### Model 2

In [None]:
# import torch
# import torchvision.transforms as transforms
# from torchvision.models import densenet201
# from PIL import Image
# import os
# from tqdm import tqdm

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model = densenet201(pretrained=True)
# fe = torch.nn.Sequential(*list(model.features.children())[:-1])
# fe.to(device)
# fe.eval()

# img_size = 224
# transform = transforms.Compose([
#     transforms.Resize((img_size, img_size)),
#     transforms.ToTensor(),
# ])

# features = {}
# image_path = "/kaggle/input/flickr8k/Images"  
# for image in tqdm(data['image'].unique().tolist()):
#     img = Image.open(os.path.join(image_path, image))
#     img = transform(img)
#     img = img.unsqueeze(0).to(device)
#     feature = fe(img).detach().cpu().numpy()
#     features[image] = feature

# Data Generation

### Model 1

In [None]:
# class CustomDataset(Dataset):
#     def __init__(self, df, X_col, y_col, directory, tokenizer, vocab_size, max_length, features):
#         self.df = df.copy()
#         self.X_col = X_col
#         self.y_col = y_col
#         self.directory = directory
#         self.tokenizer = tokenizer
#         self.vocab_size = vocab_size
#         self.max_length = max_length
#         self.features = features
        
#     def __len__(self):
#         return len(self.df)
    
#     def __getitem__(self, index):
#         image_path = self.df.iloc[index][self.X_col]
#         image_feature = torch.tensor(self.features[image_path][0], dtype=torch.float32)
        
#         caption = self.df.iloc[index][self.y_col]
#         caption_sequence = self.tokenizer.texts_to_sequences([caption])[0]
#         caption_input = []
#         target = []
#         for i in range(1, len(caption_sequence)):
#             in_seq, out_seq = caption_sequence[:i], caption_sequence[i]
#             in_seq = torch.tensor(pad_sequences([in_seq], maxlen=self.max_length)[0], dtype=torch.long)
#             out_seq = torch.tensor(to_categorical([out_seq], num_classes=self.vocab_size)[0], dtype=torch.float32)
#             caption_input.append(in_seq)
#             target.append(out_seq)
        
#         caption_input = torch.stack(caption_input)
#         target = torch.stack(target)
        
#         return image_feature, caption_input, target

In [None]:
# train_dataset = CustomDataset(df=train,X_col='image',y_col='caption',directory=image_path,
#                                       tokenizer=tokenizer,vocab_size=vocab_size,max_length=max_length,features=features)
# test_dataset = CustomDataset(df=test,X_col='image',y_col='caption',directory=image_path,
#                                       tokenizer=tokenizer,vocab_size=vocab_size,max_length=max_length,features=features)

# train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### Model 2

In [None]:
# class FlickerDataSetResnet():
#     def __init__(self, data, pkl_file):
#         self.data = data
#         self.encodedImgs = pd.read_pickle(pkl_file)

#     def __len__(self):
#         return len(self.data)

#     def __getitem__(self, idx):
    
#         caption_seq = self.data.iloc[idx]['text_seq']
#         target_seq = caption_seq[1:]+[0]

#         image_name = self.data.iloc[idx]['image']
#         image_tensor = self.encodedImgs[image_name]
#         image_tensor = image_tensor.permute(0,2,3,1)
#         image_tensor_view = image_tensor.view(image_tensor.size(0), -1, image_tensor.size(3))

#         return torch.tensor(caption_seq), torch.tensor(target_seq), image_tensor_view

In [None]:
# train_dataset_resnet = FlickerDataSetResnet(train, 'EncodedImageTrainResNet.pkl')
# train_dataloader_resnet = DataLoader(train_dataset_resnet, batch_size=32, shuffle=True)

# valid_dataset_resnet = FlickerDataSetResnet(valid, 'EncodedImageValidResNet.pkl')
# valid_dataloader_resnet = DataLoader(valid_dataset_resnet, batch_size=32, shuffle=True)

### Model 3

In [None]:
class FlickrDataset(Dataset):
    """
    FlickrDataset
    """
    def __init__(self,root_dir,caption_file,transform=None,freq_threshold=5):
        self.root_dir = root_dir
        self.df = pd.read_csv(caption_file)
        self.transform = transform
        
        #Get image and caption colum from the dataframe
        self.imgs = self.df["image"]
        self.captions = self.df["caption"]
        
        #Initialize vocabulary and build vocab
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocab(self.captions.tolist())
        
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self,idx):
        caption = self.captions[idx]
        img_name = self.imgs[idx]
        img_location = os.path.join(self.root_dir,img_name)
        img = Image.open(img_location).convert("RGB")
        
        #apply the transfromation to the image
        if self.transform is not None:
            img = self.transform(img)
        
        #numericalize the caption text
        caption_vec = []
        caption_vec += [self.vocab.stoi["<SOS>"]]
        caption_vec += self.vocab.numericalize(caption)
        caption_vec += [self.vocab.stoi["<EOS>"]]
        
        return img, torch.tensor(caption_vec)

In [None]:
#Initiate the Dataset and Dataloader

#setting the constants
data_location =  "../input/flickr8k"
BATCH_SIZE = 256
# BATCH_SIZE = 6
NUM_WORKER = 4

#defining the transform to be applied
transforms = T.Compose([
    T.Resize(226),                     
    T.RandomCrop(224),                 
    T.ToTensor(),                               
    T.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))
])


#testing the dataset class
dataset =  FlickrDataset(
    root_dir = '/kaggle/input/flickr8k/Images',
    caption_file = '/kaggle/input/flickr8k/captions.txt',
    transform=transforms
)

#writing the dataloader
data_loader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKER,
    shuffle=True
)

#vocab_size
vocab_size = len(dataset.vocab)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

# Modeling 

In [None]:
# class EncoderCNN(nn.Module):
#     def __init__(self):
#         super(EncoderCNN, self).__init__()
#         resnet = models.resnet50(pretrained=True)
#         for param in resnet.parameters():
#             param.requires_grad_(False)
        
#         modules = list(resnet.children())[:-2]
#         self.resnet = nn.Sequential(*modules)
        

#     def forward(self, images):
#         features = self.resnet(images)                                    #(batch_size,2048,7,7)
#         features = features.permute(0, 2, 3, 1)                           #(batch_size,7,7,2048)
#         features = features.view(features.size(0), -1, features.size(-1)) #(batch_size,49,2048)
#         return features

In [None]:
# #Bahdanau Attention
# class Attention(nn.Module):
#     def __init__(self, encoder_dim,decoder_dim,attention_dim):
#         super(Attention, self).__init__()
        
#         self.attention_dim = attention_dim
        
#         self.W = nn.Linear(decoder_dim,attention_dim)
#         self.U = nn.Linear(encoder_dim,attention_dim)
        
#         self.A = nn.Linear(attention_dim,1)
        
        
        
        
#     def forward(self, features, hidden_state):
#         u_hs = self.U(features)     #(batch_size,num_layers,attention_dim)
#         w_ah = self.W(hidden_state) #(batch_size,attention_dim)
        
#         combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1)) #(batch_size,num_layers,attemtion_dim)
        
#         attention_scores = self.A(combined_states)         #(batch_size,num_layers,1)
#         attention_scores = attention_scores.squeeze(2)     #(batch_size,num_layers)
        
        
#         alpha = F.softmax(attention_scores,dim=1)          #(batch_size,num_layers)
        
#         attention_weights = features * alpha.unsqueeze(2)  #(batch_size,num_layers,features_dim)
#         attention_weights = attention_weights.sum(dim=1)   #(batch_size,num_layers)
        
#         return alpha,attention_weights
        

In [None]:
# #Attention Decoder
# class DecoderRNN(nn.Module):
#     def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
#         super().__init__()
        
#         #save the model param
#         self.vocab_size = vocab_size
#         self.attention_dim = attention_dim
#         self.decoder_dim = decoder_dim
        
#         self.embedding = nn.Embedding(vocab_size,embed_size)
#         self.attention = Attention(encoder_dim,decoder_dim,attention_dim)
        
        
#         self.init_h = nn.Linear(encoder_dim, decoder_dim)  
#         self.init_c = nn.Linear(encoder_dim, decoder_dim)  
#         self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
#         self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        
        
#         self.fcn = nn.Linear(decoder_dim,vocab_size)
#         self.drop = nn.Dropout(drop_prob)
        
        
    
#     def forward(self, features, captions):
        
#         #vectorize the caption
#         embeds = self.embedding(captions)
        
#         # Initialize LSTM state
#         h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
#         #get the seq length to iterate
#         seq_length = len(captions[0])-1 #Exclude the last one
#         batch_size = captions.size(0)
#         num_features = features.size(1)
        
#         preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
#         alphas = torch.zeros(batch_size, seq_length,num_features).to(device)
                
#         for s in range(seq_length):
#             alpha,context = self.attention(features, h)
#             lstm_input = torch.cat((embeds[:, s], context), dim=1)
#             h, c = self.lstm_cell(lstm_input, (h, c))
                    
#             output = self.fcn(self.drop(h))
            
#             preds[:,s] = output
#             alphas[:,s] = alpha  
        
        
#         return preds, alphas
    
#     def generate_caption(self,features,max_len=20,vocab=None):
#         # Inference part
#         # Given the image features generate the captions
        
#         batch_size = features.size(0)
#         h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
#         alphas = []
        
#         #starting input
#         word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
#         embeds = self.embedding(word)

        
#         captions = []
        
#         for i in range(max_len):
#             alpha,context = self.attention(features, h)
            
            
#             #store the apla score
#             alphas.append(alpha.cpu().detach().numpy())
            
#             lstm_input = torch.cat((embeds[:, 0], context), dim=1)
#             h, c = self.lstm_cell(lstm_input, (h, c))
#             output = self.fcn(self.drop(h))
#             output = output.view(batch_size,-1)
        
            
#             #select the word with most val
#             predicted_word_idx = output.argmax(dim=1)
            
#             #save the generated word
#             captions.append(predicted_word_idx.item())
            
#             #end if <EOS detected>
#             if vocab.itos[predicted_word_idx.item()] == "<EOS>":
#                 break
            
#             #send generated word as the next caption
#             embeds = self.embedding(predicted_word_idx.unsqueeze(0))
        
#         #covert the vocab idx to words and return sentence
#         return [vocab.itos[idx] for idx in captions],alphas
    
    
#     def init_hidden_state(self, encoder_out):
#         mean_encoder_out = encoder_out.mean(dim=1)
#         h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
#         c = self.init_c(mean_encoder_out)
#         return h, c

In [None]:
# class EncoderDecoder(nn.Module):
#     def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
#         super().__init__()
#         self.encoder = EncoderCNN()
#         self.decoder = DecoderRNN(
#             embed_size=embed_size,
#             vocab_size = vocab_size,
#             attention_dim=attention_dim,
#             encoder_dim=encoder_dim,
#             decoder_dim=decoder_dim
#         )
        
#     def forward(self, images, captions):
#         features = self.encoder(images)
#         outputs = self.decoder(features, captions)
#         return outputs

In [None]:
# embed_size=300
# vocab_size = vocab_size
# attention_dim=256
# encoder_dim=2048
# decoder_dim=512
# learning_rate = 3e-4

In [None]:
# model = EncoderDecoder(
#     embed_size=300,
#     vocab_size = vocab_size,
#     attention_dim=256,
#     encoder_dim=2048,
#     decoder_dim=512
# ).to(device)

# criterion = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# def save_model(model,num_epochs):
#     model_state = {
#         'num_epochs':num_epochs,
#         'embed_size':embed_size,
#         'vocab_size':len(dataset.vocab),
#         'attention_dim':attention_dim,
#         'encoder_dim':encoder_dim,
#         'decoder_dim':decoder_dim,
#         'state_dict':model.state_dict()
#     }

#     torch.save(model_state,'attention_model_state.pth')

In [None]:
# num_epochs = 25
# print_every = 100

# for epoch in range(1,num_epochs+1):   
#     for idx, (image, captions) in enumerate(iter(train_dataloader)):
#         image,captions = image.to(device),captions.to(device)

#         # Zero the gradients.
#         optimizer.zero_grad()

#         # Feed forward
#         outputs,attentions = model(image, captions)

#         # Calculate the batch loss.
#         targets = captions[:,1:]
#         loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        
#         # Backward pass.
#         loss.backward()

#         # Update the parameters in the optimizer.
#         optimizer.step()

#         if (idx+1)%print_every == 0:
#             print("Epoch: {} loss: {:.5f}".format(epoch,loss.item()))
            
            
#             #generate the caption
#             model.eval()
#             with torch.no_grad():
#                 dataiter = iter(train_dataloader)
#                 img,_ = next(dataiter)
#                 features = model.encoder(img[0:1].to(device))
#                 caps,alphas = model.decoder.generate_caption(features,vocab=vocab)
#                 caption = ' '.join(caps)
#                 show_image(img[0],title=caption)
                
#             model.train()
        
#     #save the latest model
#     save_model(model,epoch)

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision.models as models
from torch.utils.data import DataLoader,Dataset
import torchvision.transforms as T

In [None]:
class EncoderCNN(nn.Module):
    def __init__(self):
        super(EncoderCNN, self).__init__()
        resnet = models.resnet50(pretrained=True)
        for param in resnet.parameters():
            param.requires_grad_(False)
        
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        

    def forward(self, images):
        features = self.resnet(images)                                    #(batch_size,2048,7,7)
        features = features.permute(0, 2, 3, 1)                           #(batch_size,7,7,2048)
        features = features.view(features.size(0), -1, features.size(-1)) #(batch_size,49,2048)
        return features

In [None]:
#Bahdanau Attention
class Attention(nn.Module):
    def __init__(self, encoder_dim,decoder_dim,attention_dim):
        super(Attention, self).__init__()
        
        self.attention_dim = attention_dim
        
        self.W = nn.Linear(decoder_dim,attention_dim)
        self.U = nn.Linear(encoder_dim,attention_dim)
        
        self.A = nn.Linear(attention_dim,1)
        
        
        
        
    def forward(self, features, hidden_state):
        u_hs = self.U(features)     #(batch_size,num_layers,attention_dim)
        w_ah = self.W(hidden_state) #(batch_size,attention_dim)
        
        combined_states = torch.tanh(u_hs + w_ah.unsqueeze(1)) #(batch_size,num_layers,attemtion_dim)
        
        attention_scores = self.A(combined_states)         #(batch_size,num_layers,1)
        attention_scores = attention_scores.squeeze(2)     #(batch_size,num_layers)
        
        
        alpha = F.softmax(attention_scores,dim=1)          #(batch_size,num_layers)
        
        attention_weights = features * alpha.unsqueeze(2)  #(batch_size,num_layers,features_dim)
        attention_weights = attention_weights.sum(dim=1)   #(batch_size,num_layers)
        
        return alpha,attention_weights
        

In [None]:
#Attention Decoder
class DecoderRNN(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        
        #save the model param
        self.vocab_size = vocab_size
        self.attention_dim = attention_dim
        self.decoder_dim = decoder_dim
        
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.attention = Attention(encoder_dim,decoder_dim,attention_dim)
        
        
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  
        self.lstm_cell = nn.LSTMCell(embed_size+encoder_dim,decoder_dim,bias=True)
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)
        
        
        self.fcn = nn.Linear(decoder_dim,vocab_size)
        self.drop = nn.Dropout(drop_prob)
        
        
    
    def forward(self, features, captions):
        
        #vectorize the caption
        embeds = self.embedding(captions)
        
        # Initialize LSTM state
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        #get the seq length to iterate
        seq_length = len(captions[0])-1 #Exclude the last one
        batch_size = captions.size(0)
        num_features = features.size(1)
        
        preds = torch.zeros(batch_size, seq_length, self.vocab_size).to(device)
        alphas = torch.zeros(batch_size, seq_length,num_features).to(device)
                
        for s in range(seq_length):
            alpha,context = self.attention(features, h)
            lstm_input = torch.cat((embeds[:, s], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
                    
            output = self.fcn(self.drop(h))
            
            preds[:,s] = output
            alphas[:,s] = alpha  
        
        
        return preds, alphas
    
    def generate_caption(self,features,max_len=20,vocab=None):
        # Inference part
        # Given the image features generate the captions
        
        batch_size = features.size(0)
        h, c = self.init_hidden_state(features)  # (batch_size, decoder_dim)
        
        alphas = []
        
        #starting input
        word = torch.tensor(vocab.stoi['<SOS>']).view(1,-1).to(device)
        embeds = self.embedding(word)

        
        captions = []
        
        for i in range(max_len):
            alpha,context = self.attention(features, h)
            
            
            #store the apla score
            alphas.append(alpha.cpu().detach().numpy())
            
            lstm_input = torch.cat((embeds[:, 0], context), dim=1)
            h, c = self.lstm_cell(lstm_input, (h, c))
            output = self.fcn(self.drop(h))
            output = output.view(batch_size,-1)
        
            
            #select the word with most val
            predicted_word_idx = output.argmax(dim=1)
            
            #save the generated word
            captions.append(predicted_word_idx.item())
            
            #end if <EOS detected>
            if vocab.itos[predicted_word_idx.item()] == "<EOS>":
                break
            
            #send generated word as the next caption
            embeds = self.embedding(predicted_word_idx.unsqueeze(0))
        
        #covert the vocab idx to words and return sentence
        return [vocab.itos[idx] for idx in captions],alphas
    
    
    def init_hidden_state(self, encoder_out):
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self,embed_size, vocab_size, attention_dim,encoder_dim,decoder_dim,drop_prob=0.3):
        super().__init__()
        self.encoder = EncoderCNN()
        self.decoder = DecoderRNN(
            embed_size=embed_size,
            vocab_size = len(dataset.vocab),
            attention_dim=attention_dim,
            encoder_dim=encoder_dim,
            decoder_dim=decoder_dim
        )
        
    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

In [None]:
embed_size=300
vocab_size = len(dataset.vocab)
attention_dim=256
encoder_dim=2048
decoder_dim=512
learning_rate = 3e-4

In [None]:
model = EncoderDecoder(
    embed_size=300,
    vocab_size = len(dataset.vocab),
    attention_dim=256,
    encoder_dim=2048,
    decoder_dim=512
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def save_model(model,num_epochs):
    model_state = {
        'num_epochs':num_epochs,
        'embed_size':embed_size,
        'vocab_size':len(dataset.vocab),
        'attention_dim':attention_dim,
        'encoder_dim':encoder_dim,
        'decoder_dim':decoder_dim,
        'state_dict':model.state_dict()
    }

    torch.save(model_state,'attention_model_state.pth')

In [None]:
num_epochs = 25
print_every = 100

for epoch in range(1,num_epochs+1):   
    for idx, (image, captions) in enumerate(iter(data_loader)):
        image,captions = image.to(device),captions.to(device)

        # Zero the gradients.
        optimizer.zero_grad()

        # Feed forward
        outputs,attentions = model(image, captions)

        # Calculate the batch loss.
        targets = captions[:,1:]
        loss = criterion(outputs.view(-1, vocab_size), targets.reshape(-1))
        
        # Backward pass.
        loss.backward()

        # Update the parameters in the optimizer.
        optimizer.step()

        if (idx+1)%print_every == 0:
            print("Epoch: {} loss: {:.5f}".format(epoch,loss.item()))
            
            
            #generate the caption
            model.eval()
            with torch.no_grad():
                dataiter = iter(data_loader)
                img,_ = next(dataiter)
                features = model.encoder(img[0:1].to(device))
                caps,alphas = model.decoder.generate_caption(features,vocab=dataset.vocab)
                caption = ' '.join(caps)
                show_image(img[0],title=caption)
                
            model.train()
        
    #save the latest model
    save_model(model,epoch)