In [1]:
import pandas as pd
import numpy as np

import os

import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

from tqdm import tqdm

import matplotlib.pyplot as plt
import time

import spacy
from PIL import Image

import warnings
warnings.filterwarnings('ignore')

In [2]:
data_path = '/kaggle/input/flickr8k/captions.txt'
df = pd.read_csv(data_path)
df.head()

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .
3,1000268201_693b08cb0e.jpg,A little girl climbing the stairs to her playh...
4,1000268201_693b08cb0e.jpg,A little girl in a pink dress going into a woo...


In [3]:
max_len = 0

for caption in df['caption'].to_list():
    if len(caption.split(' '))>max_len:
        max_len = len(caption.split(' '))

print(max_len)

38


In [4]:
spacy_eng = spacy.load("en_core_web_sm")

In [5]:
class Vocabulary:
    
    def __init__(self, freq_threshold):
        self.itos = {0:'<PAD>',1:'<SOS>',2:'<EOS>',3:'<UNK>'}
        self.stoi = {'<PAD>':0,'<SOS>':1,'<EOS>':2,'<UNK>':3}
        
        self.freq_threshold = freq_threshold
    
    def __len__(self):
        return len(self.stoi)

    def tokenize_caption(self,text):
        return [token.text.lower() for token in spacy_eng.tokenizer(text)]
    
    def build_vocabulary(self, sentence_list):
        frequencies = {}
        idx = 4
        
        for sentence in sentence_list:
            for word in self.tokenize_caption(sentence):
                if word not in frequencies:
                    frequencies[word] = 1
                else:
                    frequencies[word] += 1
                
                if frequencies[word] == self.freq_threshold and len(word)>1 and word.isalpha():
                    self.stoi[word] = idx
                    self.itos[idx] = word
                    idx += 1
                    
    
    def convert_to_vector(self, text):
        tokenized_text = self.tokenize_caption(text)
        return [self.stoi[token] if token in self.stoi else self.stoi['<UNK>'] for token in tokenized_text]

In [6]:
class FlickrDataset(Dataset):
    
    def __init__(self, root_dir, captions_file, transform = None, freq_threshold = 10):
        self.root_dir = root_dir
        self.transform = transform
        
        self.df = pd.read_csv(captions_file)
        self.df = self.df[:4000]
        
        self.imgs = self.df['image']
        self.captions = self.df['caption']
        
        self.vocab = Vocabulary(freq_threshold)
        self.vocab.build_vocabulary(self.captions)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        caption = self.captions[index]
        img_id = self.imgs[index]
        img = Image.open(os.path.join(self.root_dir,img_id)).convert('RGB')
        
        if self.transform is not None:
            img = self.transform(img)
        
        caption_vector = [self.vocab.stoi['<SOS>']]
        caption_vector += self.vocab.convert_to_vector(caption)
        caption_vector.append(self.vocab.stoi['<EOS>'])
        
        return img, torch.tensor(caption_vector)

class MyCollate:
    
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs,dim = 0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first = False, padding_value = self.pad_idx)
        return imgs,targets

In [7]:
class MyCollate:
    
    def __init__(self, pad_idx):
        self.pad_idx = pad_idx
        
    def __call__(self, batch):
        imgs = [item[0].unsqueeze(0) for item in batch]
        imgs = torch.cat(imgs,dim = 0)
        targets = [item[1] for item in batch]
        targets = pad_sequence(targets, batch_first = False, padding_value = self.pad_idx)
        return imgs,targets

In [8]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [9]:
class EncoderCNN(nn.Module):
    
    def __init__(self, embed_size, train_CNN = False):
        super(EncoderCNN, self).__init__()
        
        self.train_CNN = train_CNN
        
        self.inception = models.inception_v3(pretrained = True, aux_logits = True)
        self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, images):
        features = self.inception(images)[0]
        return self.dropout(self.relu(features))

class DecoderRNN(nn.Module):
    
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(DecoderRNN, self).__init__()
        
        self.embed = nn.Embedding(vocab_size,embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = torch.cat((features.unsqueeze(0), embeddings),dim = 0)
        hiddens,_ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

class CNNtoRNN(nn.Module):
    
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(CNNtoRNN, self).__init__()
        
        self.encoderCNN = EncoderCNN(embed_size)
        self.decoderRNN = DecoderRNN(embed_size, hidden_size, vocab_size, num_layers)
    
    def forward(self, images ,captions):
        features = self.encoderCNN(images)
        outputs = self.decoderRNN(features,captions)
        return outputs
    
    def caption_image(self, image, vocabulary, max_length = 38):
        result_caption = []
        
        with torch.no_grad():
            
            x = self.encoderCNN(image).unsqueeze(0)
            states = None
            
            for _ in range(max_length):
                
                hidden, states = self.decoderRNN.lstm(x,states)
                output = self.decoderRNN.linear(hidden.squeeze(0))
                predicted = torch.argmax(output)
                result_caption.append(predicted.item())
                
                x = self.decoderRNN.embed(predicted).unsqueeze(0) # Prepare the output to input

                if vocabulary.itos[predicted.item()] == '<EOS>':
                    break
            
        return [vocabulary.itos[idx] for idx in result_caption]

In [10]:
transform = transforms.Compose(
        [
            transforms.Resize((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )

dir_path = '/kaggle/input/flickr8k/Images'

dataset = FlickrDataset(root_dir = dir_path, captions_file = '/kaggle/input/flickr8k/captions.txt', transform=transform)

pad_idx = dataset.vocab.stoi['<PAD>']

loader = DataLoader(
    dataset=dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
    collate_fn=MyCollate(pad_idx=pad_idx)
)

In [17]:
load_model = False
save_model = False
train_CNN = False

embed_size = 256
hidden_size = 256
vocab_size = len(dataset.vocab)
num_layers = 1
learning_rate = 0.001
num_epochs = 100

In [18]:
model = CNNtoRNN(embed_size, hidden_size, vocab_size, num_layers).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=dataset.vocab.stoi['<PAD>'])
optimizer = optim.Adam(model.parameters(), lr = learning_rate)

In [19]:
# Only fine tune the last layer of CNNs
for name, param in model.encoderCNN.inception.named_parameters():
    if 'fc.weight' in name or 'fc.bias' in name:
        param.requires_grad = True
    else:
        param.requires_grad = train_CNN

In [20]:
model.train()

for epoch in range(num_epochs):

    start_time = time.time()

    for idx, (imgs, caption) in tqdm(enumerate(loader),total = len(loader),leave=False):
        imgs = imgs.to(device)
        caption = caption.to(device)

        outputs = model(imgs, caption[:-1])
        loss = criterion(
            outputs.reshape(-1, outputs.shape[2]), caption.reshape(-1)
        )

        optimizer.zero_grad()
        loss.backward(loss)
        optimizer.step()
    
    time_taken = time.time() - start_time
    print(f'Epoch - {epoch+1}/{num_epochs}, Loss - {loss}, Time Taken - {time_taken/60}')

                                               

Epoch - 1/100, Loss - 3.011389970779419, Time Taken - 0.5257211208343506


                                               

Epoch - 2/100, Loss - 2.826110601425171, Time Taken - 0.5111935853958129


                                               

Epoch - 3/100, Loss - 2.7210309505462646, Time Taken - 0.5518576184908549


                                               

Epoch - 4/100, Loss - 2.6435046195983887, Time Taken - 0.5184993982315064


                                               

Epoch - 5/100, Loss - 2.363044023513794, Time Taken - 0.51342613697052


                                               

Epoch - 6/100, Loss - 2.498910903930664, Time Taken - 0.5078112880388895


                                               

Epoch - 7/100, Loss - 2.316697359085083, Time Taken - 0.5184372385342916


                                               

Epoch - 8/100, Loss - 2.127866506576538, Time Taken - 0.5158935228983561


                                               

Epoch - 9/100, Loss - 2.251150369644165, Time Taken - 0.5259422024091085


                                               

Epoch - 10/100, Loss - 2.173001527786255, Time Taken - 0.5096389770507812


                                               

Epoch - 11/100, Loss - 2.207843065261841, Time Taken - 0.5101566195487977


                                               

Epoch - 12/100, Loss - 1.962660551071167, Time Taken - 0.5173877835273742


                                               

Epoch - 13/100, Loss - 2.1058735847473145, Time Taken - 0.5207783102989196


                                               

Epoch - 14/100, Loss - 2.1517724990844727, Time Taken - 0.5213412761688232


                                               

Epoch - 15/100, Loss - 2.0235517024993896, Time Taken - 0.5467546423276265


                                               

Epoch - 16/100, Loss - 1.9980510473251343, Time Taken - 0.5349823395411174


                                               

Epoch - 17/100, Loss - 1.8803588151931763, Time Taken - 0.5244185527165731


                                               

Epoch - 18/100, Loss - 1.9383585453033447, Time Taken - 0.5101554830869038


                                               

Epoch - 19/100, Loss - 1.7664666175842285, Time Taken - 0.5185644189516704


                                               

Epoch - 20/100, Loss - 1.9337255954742432, Time Taken - 0.5122064868609111


                                               

Epoch - 21/100, Loss - 1.9099761247634888, Time Taken - 0.5179528951644897


                                               

Epoch - 22/100, Loss - 1.8798208236694336, Time Taken - 0.5243759751319885


                                               

Epoch - 23/100, Loss - 1.702060341835022, Time Taken - 0.5243890802065532


                                               

Epoch - 24/100, Loss - 1.7923259735107422, Time Taken - 0.5145089070002238


                                               

Epoch - 25/100, Loss - 1.6404690742492676, Time Taken - 0.5082331856091817


                                               

Epoch - 26/100, Loss - 1.6463148593902588, Time Taken - 0.5143419186274211


                                               

Epoch - 27/100, Loss - 1.6857341527938843, Time Taken - 0.5004232168197632


                                               

Epoch - 28/100, Loss - 1.570685625076294, Time Taken - 0.527384614944458


                                               

Epoch - 29/100, Loss - 1.4115209579467773, Time Taken - 0.5172228535016378


                                               

Epoch - 30/100, Loss - 1.4979698657989502, Time Taken - 0.5287404696146647


                                               

Epoch - 31/100, Loss - 1.490797758102417, Time Taken - 0.5079222242037456


                                               

Epoch - 32/100, Loss - 1.4893399477005005, Time Taken - 0.5104196945826213


                                               

Epoch - 33/100, Loss - 1.409144639968872, Time Taken - 0.509884794553121


                                               

Epoch - 34/100, Loss - 1.387158751487732, Time Taken - 0.5346662441889445


                                               

Epoch - 35/100, Loss - 1.3343455791473389, Time Taken - 0.520171320438385


                                               

Epoch - 36/100, Loss - 1.296584129333496, Time Taken - 0.5154239455858867


                                               

Epoch - 37/100, Loss - 1.3303614854812622, Time Taken - 0.5217941125233968


                                               

Epoch - 38/100, Loss - 1.3659212589263916, Time Taken - 0.5153010884920756


                                               

Epoch - 39/100, Loss - 1.2681390047073364, Time Taken - 0.5449888825416564


                                               

Epoch - 40/100, Loss - 1.2420268058776855, Time Taken - 0.5193374713261922


                                               

Epoch - 41/100, Loss - 1.1816495656967163, Time Taken - 0.5261100808779399


                                               

Epoch - 42/100, Loss - 1.1995494365692139, Time Taken - 0.5365758021672566


                                               

Epoch - 43/100, Loss - 1.2003501653671265, Time Taken - 0.5228456815083822


                                               

Epoch - 44/100, Loss - 1.2177228927612305, Time Taken - 0.512045160929362


                                               

Epoch - 45/100, Loss - 1.1640293598175049, Time Taken - 0.5105058471361796


                                               

Epoch - 46/100, Loss - 1.1494355201721191, Time Taken - 0.5063590963681539


                                               

Epoch - 47/100, Loss - 1.0668768882751465, Time Taken - 0.5190868417421977


                                               

Epoch - 48/100, Loss - 1.0815491676330566, Time Taken - 0.5140002290407817


                                               

Epoch - 49/100, Loss - 1.077699065208435, Time Taken - 0.5218424836794535


                                               

Epoch - 50/100, Loss - 1.0743443965911865, Time Taken - 0.5076488018035888


                                               

Epoch - 51/100, Loss - 1.014880657196045, Time Taken - 0.5277623732884725


                                               

Epoch - 52/100, Loss - 0.868918776512146, Time Taken - 0.5165462414423625


                                               

Epoch - 53/100, Loss - 0.9907122850418091, Time Taken - 0.5107636411984762


                                               

Epoch - 54/100, Loss - 0.9634251594543457, Time Taken - 0.5213566303253174


                                               

Epoch - 55/100, Loss - 0.9039758443832397, Time Taken - 0.5210961103439331


                                               

Epoch - 56/100, Loss - 0.9539134502410889, Time Taken - 0.506603475411733


                                               

Epoch - 57/100, Loss - 0.9162119626998901, Time Taken - 0.5290569345156352


                                               

Epoch - 58/100, Loss - 0.9062133431434631, Time Taken - 0.5102260867754619


                                               

Epoch - 59/100, Loss - 0.9059495329856873, Time Taken - 0.5180261850357055


                                               

Epoch - 60/100, Loss - 0.9196451306343079, Time Taken - 0.5305232564608257


                                               

Epoch - 61/100, Loss - 0.9247432947158813, Time Taken - 0.5192461371421814


                                               

Epoch - 62/100, Loss - 0.9060032963752747, Time Taken - 0.5152511477470398


                                               

Epoch - 63/100, Loss - 0.7966647744178772, Time Taken - 0.5573592782020569


                                               

Epoch - 64/100, Loss - 0.7898986339569092, Time Taken - 0.5229081670443217


                                               

Epoch - 65/100, Loss - 0.8242116570472717, Time Taken - 0.4986429214477539


                                               

Epoch - 66/100, Loss - 0.8118194341659546, Time Taken - 0.507461682955424


                                               

Epoch - 67/100, Loss - 0.8341889381408691, Time Taken - 0.525814151763916


                                               

Epoch - 68/100, Loss - 0.7701032161712646, Time Taken - 0.5155320485432943


                                               

Epoch - 69/100, Loss - 0.8133742213249207, Time Taken - 0.5433371623357137


                                               

Epoch - 70/100, Loss - 0.7093269228935242, Time Taken - 0.52491375207901


                                               

Epoch - 71/100, Loss - 0.7740673422813416, Time Taken - 0.5313960512479147


                                               

Epoch - 72/100, Loss - 0.7236576080322266, Time Taken - 0.5087087074915568


                                               

Epoch - 73/100, Loss - 0.7922447919845581, Time Taken - 0.5055837233861288


                                               

Epoch - 74/100, Loss - 0.8055093884468079, Time Taken - 0.510057799021403


                                               

Epoch - 75/100, Loss - 0.7342091798782349, Time Taken - 0.5133309920628866


                                               

Epoch - 76/100, Loss - 0.8308306336402893, Time Taken - 0.5127762794494629


                                               

Epoch - 77/100, Loss - 0.7229975461959839, Time Taken - 0.5208070874214172


                                               

Epoch - 78/100, Loss - 0.7249361276626587, Time Taken - 0.5462879498799642


                                               

Epoch - 79/100, Loss - 0.7406023740768433, Time Taken - 0.5273615002632142


                                               

Epoch - 80/100, Loss - 0.6735209226608276, Time Taken - 0.5068888147672017


                                               

Epoch - 81/100, Loss - 0.7345246076583862, Time Taken - 0.5083995699882508


                                               

Epoch - 82/100, Loss - 0.6994234323501587, Time Taken - 0.5174456397692363


                                               

Epoch - 83/100, Loss - 0.6603811383247375, Time Taken - 0.514319634437561


                                               

Epoch - 84/100, Loss - 0.6760568618774414, Time Taken - 0.5243392546971639


                                               

Epoch - 85/100, Loss - 0.6266647577285767, Time Taken - 0.5149665355682373


                                               

Epoch - 86/100, Loss - 0.6410995721817017, Time Taken - 0.5066098173459371


                                               

Epoch - 87/100, Loss - 0.6532481908798218, Time Taken - 0.5206032872200013


                                               

Epoch - 88/100, Loss - 0.6565370559692383, Time Taken - 0.5120899677276611


                                               

Epoch - 89/100, Loss - 0.6268224716186523, Time Taken - 0.5174934228261312


                                               

Epoch - 90/100, Loss - 0.6322027444839478, Time Taken - 0.5027933319409689


                                               

Epoch - 91/100, Loss - 0.6586633920669556, Time Taken - 0.5103160262107849


                                               

Epoch - 92/100, Loss - 0.6719108819961548, Time Taken - 0.5016773541768392


                                               

Epoch - 93/100, Loss - 0.6519559621810913, Time Taken - 0.5158118565877279


                                               

Epoch - 94/100, Loss - 0.6438457369804382, Time Taken - 0.5233089764912923


                                               

Epoch - 95/100, Loss - 0.6296696662902832, Time Taken - 0.5100956360499064


                                               

Epoch - 96/100, Loss - 0.6131618022918701, Time Taken - 0.5241391698519389


                                               

Epoch - 97/100, Loss - 0.6917980313301086, Time Taken - 0.5312514185905457


                                               

Epoch - 98/100, Loss - 0.6214895844459534, Time Taken - 0.5116109212239583


                                               

Epoch - 99/100, Loss - 0.6185522079467773, Time Taken - 0.506682030359904


                                               

Epoch - 100/100, Loss - 0.6077899932861328, Time Taken - 0.5273542324701945




In [23]:
torch.save(model,'/kaggle/working/entire_model.pth')