<a href="https://www.kaggle.com/code/prasannakasar/image-captioning?scriptVersionId=217505520" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import pandas as pd
pd.set_option('display.max_colwidth', None)

In [2]:
df = pd.read_csv("/kaggle/input/flickr8k/captions.txt", sep=",")

In [3]:
df.head(3)

Unnamed: 0,image,caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .


In [4]:
df['cleaned_caption'] = df['caption'].apply(lambda sentence : [word.lower() for word in sentence.split(" ") if word.isalpha()])
df['cleaned_caption'] = df['cleaned_caption'].apply(lambda lis : [word for word in lis if len(word) > 1])
df['cleaned_caption'] = df['cleaned_caption'].apply(lambda lis : ['<start>'] + lis + ['<end>'])

In [5]:
df.head(3)

Unnamed: 0,image,caption,cleaned_caption
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <end>]"
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<start>, girl, going, into, wooden, building, <end>]"
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,"[<start>, little, girl, climbing, into, wooden, playhouse, <end>]"


In [6]:
df['seq_len'] = df['cleaned_caption'].apply(lambda x : len(x))
max_len = df['seq_len'].max()

In [7]:
df['cleaned_caption'].apply(len).idxmax()

8049

In [8]:
df['cleaned_caption'] = df['cleaned_caption'].apply(lambda lis : lis + ['<pad>'] * (max_len - len(lis)))

In [9]:
df.head(3)

Unnamed: 0,image,caption,cleaned_caption,seq_len
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",16
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<start>, girl, going, into, wooden, building, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",7
2,1000268201_693b08cb0e.jpg,A little girl climbing into a wooden playhouse .,"[<start>, little, girl, climbing, into, wooden, playhouse, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",8


In [10]:
word_list = []
df['cleaned_caption'].apply(lambda lis: [word_list.append(word) for word in lis])

0        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
1        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
2        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
3        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
4        [None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, No

In [11]:
from collections import Counter
word_dict = Counter(word_list)

In [12]:
word_dict = sorted(word_dict, key=word_dict.get, reverse=True)

In [13]:
word_to_index = {word : index for index, word in enumerate(word_dict)}
index_to_word = {index : word for index, word in enumerate(word_dict)}

In [14]:
df['word_token'] = df['cleaned_caption'].apply(lambda lis : [word_to_index[word] for word in lis])

In [15]:
df.head(2)

Unnamed: 0,image,caption,cleaned_caption,seq_len,word_token
0,1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set of stairs in an entry way .,"[<start>, child, in, pink, dress, is, climbing, up, set, of, stairs, in, an, entry, way, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",16,"[1, 41, 3, 89, 168, 6, 118, 52, 392, 11, 389, 3, 27, 5075, 690, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,1000268201_693b08cb0e.jpg,A girl going into a wooden building .,"[<start>, girl, going, into, wooden, building, <end>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>, <pad>]",7,"[1, 18, 311, 63, 192, 116, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [16]:
df.drop(columns=['caption', 'cleaned_caption', 'seq_len'], inplace=True)
df.head(3)

Unnamed: 0,image,word_token
0,1000268201_693b08cb0e.jpg,"[1, 41, 3, 89, 168, 6, 118, 52, 392, 11, 389, 3, 27, 5075, 690, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
1,1000268201_693b08cb0e.jpg,"[1, 18, 311, 63, 192, 116, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"
2,1000268201_693b08cb0e.jpg,"[1, 39, 18, 118, 63, 192, 2402, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]"


In [17]:
train_size = int(0.8*len(df))
test_size = len(df) - train_size
#/kaggle/input/flickr8k/Images

In [18]:
import torch
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

class ImageDataset(Dataset):
    def __init__(self, img_dir, dataframe):
        self.img_dir = img_dir
        self.dataframe = dataframe
        self.scaler = transforms.Resize([224, 224])
        self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_name = self.dataframe.iloc[idx, 0]
        label = self.dataframe.iloc[idx, 1]
        
        img_path = os.path.join(self.img_dir, img_name)
        image = Image.open(img_path)
        t_img = self.normalize(self.to_tensor(self.scaler(image)))
        # t_img = torch.tensor(t_img)
        # label = torch.tensor(label)

        return t_img, torch.tensor(label)

In [19]:
img_dir = '/kaggle/input/flickr8k/Images'
dataset = ImageDataset(img_dir, df)

In [20]:
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [21]:
import torch.nn as nn
import torch.nn.functional as F

class CNNFeatureExtractor(nn.Module):
    def __init__(self):
        super(CNNFeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding = 1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding = 1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 56 * 56, 256)


    def forward(self, x, inference=False):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        # print(x.size())
        x = x.view(x.size(0), -1) 
        if inference:
            x = x.view(1, -1) 
        # x = x.unsqueeze(0)
        # print(x.size())
        x = F.relu(self.fc1(x))
        return x

In [22]:
class LSTMCaptionGenerator(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers):
        super(LSTMCaptionGenerator, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers)
        self.linear = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, features, captions):
        embeddings = self.embed(captions)
        embeddings = self.dropout(embeddings)
        # print(f"dim of features after unsqueezzing = ", features.unsqueeze(1).size())
        # print(f"dim of embeddings = ", embeddings.size())
        # repeated_features = features.unsqueeze(1).repeat(1, embeddings.size(1), 1)
        embeddings = torch.cat((features.unsqueeze(1), embeddings[:, :-1, :]), dim=1)
        # embeddings = torch.cat((repeated_features, embeddings), dim=1) 
        # print(f"embedding dim={embeddings.size()}")
        hiddens, _ = self.lstm(embeddings)
        outputs = self.linear(hiddens)
        return outputs

In [23]:
# class CNN_LSTM_model(nn.Module):  # Must inherit from nn.Module
#     def __init__(self, vocab_size, embed_size, hidden_size, num_layers, max_seq_length):
#         super(CNN_LSTM_model, self).__init__()
#         self.CNN_model = CNNFeatureExtractor()  # Initialize CNN feature extractor
#         self.LSTM_model = LSTMCaptionGenerator(vocab_size, embed_size, hidden_size, num_layers, max_seq_length)  # Initialize LSTM model

#     def forward(self, image, captions):
#         features = self.CNN_model(image)  # Get image features from CNN
#         outputs = self.LSTM_model(features, captions)  # Use image features and captions in LSTM
#         return outputs

In [24]:
import torch

vocab_size = len(word_dict)
embed_size = 256
hidden_size = 256
num_layers = 4
learning_rate = 0.001
max_seq_length = 33

device = ""
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

encoder = CNNFeatureExtractor().to(device)
decoder = LSTMCaptionGenerator(embed_size=embed_size, 
                       hidden_size=hidden_size,
                       vocab_size=vocab_size,
                       num_layers=num_layers).to(device)

In [25]:
index_to_word[0]

'<pad>'

In [26]:
from tqdm import tqdm

def train_and_test(num_epoch):
    #train
    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    for epoch in range(num_epoch):
        encoder.train()
        decoder.train()
        total_loss_train = 0
        pbar = tqdm(train_loader)
        for image, caption in pbar:
            image = image.to(device)
            caption = caption.to(device)
            optimizer.zero_grad()
            features = encoder(image)
            predicted_caption = decoder(features, caption)
            # predicted_caption_idx = torch.argmax(predicted_caption, dim=-1)
            # predicted_caption = predicted_caption[: , 1: , :]
            # torch.set_printoptions(threshold=torch.inf)
            # print(predicted_caption_idx[:, :])
            # print(f"predicted cpation dim=", predicted_caption.size())
            # print(f"predicted_caption type", predicted_caption.type())
            # print(f"caption dim=", caption.size())
            predicted_caption = predicted_caption.reshape(-1, predicted_caption.size(-1))
            caption = caption.reshape(-1)

            loss = criterion(predicted_caption, caption)
            loss.backward()
            optimizer.step()
            total_loss_train += loss.item()

        print(f"Train Loss at {epoch+1} = {total_loss_train / len(train_loader)}")

        encoder.eval()
        decoder.eval()
        total_loss_test = 0
        # pbar = tqdm(test_loader)
        for image, caption in tqdm(test_loader):
            image = image.to(device)
            caption = caption.to(device)
            features = encoder(image)
            predicted_caption = decoder(features, caption)
            # predicted_caption = predicted_caption[: , 1: , :]
            # predicted_caption_idx = torch.argmax(predicted_caption, dim=-1)
            # print(f"predicted cpation dim=", predicted_caption.size())
            # print(f"predicted_caption type", predicted_caption.type())
            # print(f"caption dim=", caption.size())
            predicted_caption = predicted_caption.reshape(-1, predicted_caption.size(-1))
            caption = caption.reshape(-1)

            loss = criterion(predicted_caption, caption)
            total_loss_test += loss.item()

        print(f"Test Loss at {epoch+1} = {total_loss_test / len(test_loader)}")
            

In [27]:
train_and_test(50)

100%|██████████| 506/506 [04:55<00:00,  1.71it/s]


Train Loss at 1 = 5.3459152402613945


100%|██████████| 127/127 [00:54<00:00,  2.31it/s]


Test Loss at 1 = 5.07773768432497


100%|██████████| 506/506 [04:38<00:00,  1.81it/s]


Train Loss at 2 = 4.792318034077821


100%|██████████| 127/127 [00:52<00:00,  2.42it/s]


Test Loss at 2 = 4.487927143967997


100%|██████████| 506/506 [04:03<00:00,  2.08it/s]


Train Loss at 3 = 4.310532506275554


100%|██████████| 127/127 [00:52<00:00,  2.40it/s]


Test Loss at 3 = 4.190126674381767


100%|██████████| 506/506 [04:22<00:00,  1.93it/s]


Train Loss at 4 = 4.087788904608474


100%|██████████| 127/127 [00:53<00:00,  2.40it/s]


Test Loss at 4 = 4.057096004486084


100%|██████████| 506/506 [04:01<00:00,  2.09it/s]


Train Loss at 5 = 3.971968069378095


100%|██████████| 127/127 [00:55<00:00,  2.28it/s]


Test Loss at 5 = 3.985963151210875


100%|██████████| 506/506 [04:06<00:00,  2.05it/s]


Train Loss at 6 = 3.8937674850343242


100%|██████████| 127/127 [00:52<00:00,  2.42it/s]


Test Loss at 6 = 3.9461279966699796


100%|██████████| 506/506 [05:09<00:00,  1.63it/s]


Train Loss at 7 = 3.833530293151795


100%|██████████| 127/127 [01:14<00:00,  1.71it/s]


Test Loss at 7 = 3.9137140503079872


100%|██████████| 506/506 [04:25<00:00,  1.90it/s]


Train Loss at 8 = 3.782816171646118


100%|██████████| 127/127 [00:51<00:00,  2.46it/s]


Test Loss at 8 = 3.8844170401415488


100%|██████████| 506/506 [04:00<00:00,  2.11it/s]


Train Loss at 9 = 3.7427800559243667


100%|██████████| 127/127 [00:51<00:00,  2.46it/s]


Test Loss at 9 = 3.865437631531963


100%|██████████| 506/506 [04:00<00:00,  2.11it/s]


Train Loss at 10 = 3.7057978989107334


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 10 = 3.8541412822843535


100%|██████████| 506/506 [05:40<00:00,  1.49it/s]


Train Loss at 11 = 3.674525800429785


100%|██████████| 127/127 [00:57<00:00,  2.21it/s]


Test Loss at 11 = 3.838547695340134


100%|██████████| 506/506 [04:38<00:00,  1.82it/s]


Train Loss at 12 = 3.647269994373849


100%|██████████| 127/127 [01:08<00:00,  1.85it/s]


Test Loss at 12 = 3.8338638534696083


100%|██████████| 506/506 [04:27<00:00,  1.89it/s]


Train Loss at 13 = 3.6211145951342676


100%|██████████| 127/127 [00:57<00:00,  2.20it/s]


Test Loss at 13 = 3.8255153340617505


100%|██████████| 506/506 [04:19<00:00,  1.95it/s]


Train Loss at 14 = 3.600320716149251


100%|██████████| 127/127 [00:55<00:00,  2.31it/s]


Test Loss at 14 = 3.8226386768611396


100%|██████████| 506/506 [04:14<00:00,  1.99it/s]


Train Loss at 15 = 3.5791043043136597


100%|██████████| 127/127 [00:54<00:00,  2.32it/s]


Test Loss at 15 = 3.820892491678553


100%|██████████| 506/506 [04:04<00:00,  2.07it/s]


Train Loss at 16 = 3.5611710967753716


100%|██████████| 127/127 [00:55<00:00,  2.28it/s]


Test Loss at 16 = 3.8194726808803288


100%|██████████| 506/506 [04:05<00:00,  2.06it/s]


Train Loss at 17 = 3.5420904315036275


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 17 = 3.8231660812858523


100%|██████████| 506/506 [04:32<00:00,  1.86it/s]


Train Loss at 18 = 3.526468230801609


100%|██████████| 127/127 [00:54<00:00,  2.35it/s]


Test Loss at 18 = 3.8265029802097112


100%|██████████| 506/506 [04:14<00:00,  1.99it/s]


Train Loss at 19 = 3.5094908249708032


100%|██████████| 127/127 [00:55<00:00,  2.29it/s]


Test Loss at 19 = 3.823519438270509


100%|██████████| 506/506 [04:38<00:00,  1.82it/s]


Train Loss at 20 = 3.4956508333032783


100%|██████████| 127/127 [00:54<00:00,  2.32it/s]


Test Loss at 20 = 3.8284668622054454


100%|██████████| 506/506 [04:09<00:00,  2.03it/s]


Train Loss at 21 = 3.4817769527435303


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 21 = 3.8298593637511487


100%|██████████| 506/506 [04:02<00:00,  2.08it/s]


Train Loss at 22 = 3.4702470962237935


100%|██████████| 127/127 [00:53<00:00,  2.38it/s]


Test Loss at 22 = 3.836007354766365


100%|██████████| 506/506 [04:00<00:00,  2.10it/s]


Train Loss at 23 = 3.4573966679365737


100%|██████████| 127/127 [00:52<00:00,  2.41it/s]


Test Loss at 23 = 3.845181022103377


100%|██████████| 506/506 [04:20<00:00,  1.95it/s]


Train Loss at 24 = 3.4463153314213506


100%|██████████| 127/127 [00:54<00:00,  2.34it/s]


Test Loss at 24 = 3.8461742419896163


100%|██████████| 506/506 [04:14<00:00,  1.99it/s]


Train Loss at 25 = 3.4368580867179297


100%|██████████| 127/127 [00:53<00:00,  2.37it/s]


Test Loss at 25 = 3.852085442054929


100%|██████████| 506/506 [04:05<00:00,  2.06it/s]


Train Loss at 26 = 3.426118088333974


100%|██████████| 127/127 [00:53<00:00,  2.38it/s]


Test Loss at 26 = 3.8588587261560394


100%|██████████| 506/506 [04:03<00:00,  2.08it/s]


Train Loss at 27 = 3.4160123347293716


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 27 = 3.8673124257035143


100%|██████████| 506/506 [04:05<00:00,  2.06it/s]


Train Loss at 28 = 3.406587752900105


100%|██████████| 127/127 [00:52<00:00,  2.40it/s]


Test Loss at 28 = 3.878535235021997


100%|██████████| 506/506 [04:00<00:00,  2.10it/s]


Train Loss at 29 = 3.3983711824115557


100%|██████████| 127/127 [00:53<00:00,  2.39it/s]


Test Loss at 29 = 3.876858853918361


100%|██████████| 506/506 [03:59<00:00,  2.11it/s]


Train Loss at 30 = 3.3903236422142964


100%|██████████| 127/127 [00:51<00:00,  2.45it/s]


Test Loss at 30 = 3.892255340035506


100%|██████████| 506/506 [03:59<00:00,  2.12it/s]


Train Loss at 31 = 3.3814202683716426


100%|██████████| 127/127 [00:52<00:00,  2.41it/s]


Test Loss at 31 = 3.8951041210354784


100%|██████████| 506/506 [03:57<00:00,  2.13it/s]


Train Loss at 32 = 3.374358931077799


100%|██████████| 127/127 [00:52<00:00,  2.40it/s]


Test Loss at 32 = 3.9033796505665217


100%|██████████| 506/506 [04:00<00:00,  2.11it/s]


Train Loss at 33 = 3.367836571022456


100%|██████████| 127/127 [00:52<00:00,  2.43it/s]


Test Loss at 33 = 3.9097246635617235


100%|██████████| 506/506 [04:01<00:00,  2.09it/s]


Train Loss at 34 = 3.359071752770616


100%|██████████| 127/127 [00:53<00:00,  2.38it/s]


Test Loss at 34 = 3.9197261990524654


100%|██████████| 506/506 [04:02<00:00,  2.09it/s]


Train Loss at 35 = 3.35284007773569


100%|██████████| 127/127 [00:57<00:00,  2.22it/s]


Test Loss at 35 = 3.9262761337550605


100%|██████████| 506/506 [04:08<00:00,  2.04it/s]


Train Loss at 36 = 3.345324566712964


100%|██████████| 127/127 [00:54<00:00,  2.32it/s]


Test Loss at 36 = 3.9321714855554535


100%|██████████| 506/506 [04:05<00:00,  2.06it/s]


Train Loss at 37 = 3.3376206126608867


100%|██████████| 127/127 [00:52<00:00,  2.41it/s]


Test Loss at 37 = 3.9429080185927745


100%|██████████| 506/506 [04:02<00:00,  2.08it/s]


Train Loss at 38 = 3.3334750813457807


100%|██████████| 127/127 [00:53<00:00,  2.40it/s]


Test Loss at 38 = 3.9494286484605685


100%|██████████| 506/506 [04:00<00:00,  2.10it/s]


Train Loss at 39 = 3.3262265755725


100%|██████████| 127/127 [00:53<00:00,  2.39it/s]


Test Loss at 39 = 3.9571772034712662


100%|██████████| 506/506 [04:04<00:00,  2.07it/s]


Train Loss at 40 = 3.321773088967847


100%|██████████| 127/127 [00:54<00:00,  2.34it/s]


Test Loss at 40 = 3.9648964911933957


100%|██████████| 506/506 [04:08<00:00,  2.04it/s]


Train Loss at 41 = 3.315984040380938


100%|██████████| 127/127 [00:54<00:00,  2.33it/s]


Test Loss at 41 = 3.9748095245811883


100%|██████████| 506/506 [04:10<00:00,  2.02it/s]


Train Loss at 42 = 3.3096957512995


100%|██████████| 127/127 [00:53<00:00,  2.38it/s]


Test Loss at 42 = 3.9757681700188345


100%|██████████| 506/506 [04:02<00:00,  2.08it/s]


Train Loss at 43 = 3.304362772952898


100%|██████████| 127/127 [00:52<00:00,  2.41it/s]


Test Loss at 43 = 3.9881104296586645


100%|██████████| 506/506 [04:02<00:00,  2.08it/s]


Train Loss at 44 = 3.300120659496473


100%|██████████| 127/127 [00:52<00:00,  2.41it/s]


Test Loss at 44 = 3.9946260827732836


100%|██████████| 506/506 [04:02<00:00,  2.09it/s]


Train Loss at 45 = 3.293940256235628


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 45 = 4.002445498789389


100%|██████████| 506/506 [04:09<00:00,  2.02it/s]


Train Loss at 46 = 3.290160084430408


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 46 = 4.002375996957614


100%|██████████| 506/506 [04:08<00:00,  2.04it/s]


Train Loss at 47 = 3.285427805463316


100%|██████████| 127/127 [00:55<00:00,  2.31it/s]


Test Loss at 47 = 4.015650683500636


100%|██████████| 506/506 [04:05<00:00,  2.06it/s]


Train Loss at 48 = 3.2808797547939736


100%|██████████| 127/127 [00:53<00:00,  2.36it/s]


Test Loss at 48 = 4.025014727134404


100%|██████████| 506/506 [04:03<00:00,  2.08it/s]


Train Loss at 49 = 3.2770369180106362


100%|██████████| 127/127 [00:53<00:00,  2.37it/s]


Test Loss at 49 = 4.03145112015131


100%|██████████| 506/506 [04:03<00:00,  2.08it/s]


Train Loss at 50 = 3.2722767648018394


100%|██████████| 127/127 [00:53<00:00,  2.37it/s]

Test Loss at 50 = 4.035590235642561





In [28]:
torch.save(encoder, "encoder_20_epochs.pth")
torch.save(decoder, "decoder_20_epochs.pth")

In [29]:
# encoder = torch.load("/kaggle/input/image_cpationing/pytorch/default/1/encoder_20_epochs.pth")
# decoder = torch.load("/kaggle/input/image_cpationing/pytorch/default/1/decoder_20_epochs.pth")

In [30]:
def load_img(idx):
    
    scaler = transforms.Resize([224, 224])
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    to_tensor = transforms.ToTensor()
    
    img_name = df.iloc[idx, 0]
    label = df.iloc[idx, 1]
        
    img_path = os.path.join(img_dir, img_name)
    image = Image.open(img_path)
    t_img = normalize(to_tensor(scaler(image)))
    # t_img = torch.tensor(t_img)
    # label = torch.tensor(label)
    return t_img, label

In [31]:
def inference(image, max_seq_len=50):
    result_caption = []
    image = image.to(device)
    # print(img.size())
    x = encoder(image, inference=True).unsqueeze(1)
    # print(f"expected dim of input for lstm={x.size()}")
    states = None
            
    for _ in range(max_seq_len):
        hidden, states = decoder.lstm(x, states)
        output = decoder.linear(hidden)
        # print(f"output dim={output.size()}")
        predicted = output.argmax(2)
        # print(f"predicted token={predicted.item()}")
        result_caption.append(predicted.item())
        x = decoder.embed(predicted)
        # print(f"next intput size={x.size()}")
    
        if predicted.item() == 2:
            break
    
    res = [index_to_word[idx] for idx in result_caption]
    print(res)

In [32]:
img, caption = load_img(0)
res = inference(img)
expected = [index_to_word[idx] for idx in caption]
print(expected)
print(res)

['<start>', '<start>', '<start>', 'two', 'two', 'two', 'young', 'boy', 'in', 'black', 'dog', 'jumps', 'to', 'the', 'street', '<end>']
['<start>', 'child', 'in', 'pink', 'dress', 'is', 'climbing', 'up', 'set', 'of', 'stairs', 'in', 'an', 'entry', 'way', '<end>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
None
