#### Image captioning Model
In this we would train an image captioning model on top of our multilabel resenet. An try to train to predict the captions of the image.

#### Changes that need to be made to implement the model
1. We would need to change the dataloader to output the caption vector along with the images.
2. We would also need to train a maybe a recurrent neural network to output the captions.

1. First we will take the resnet output feature vector and we will pass it  as the starting hidden state of the RNN.
2. we will use word level embeddings for the inputs. 
3. We will also use start and end tokens.
4. Then we will output the probabilities of the next word.
5. Then we will jointly train the full network to generate the captions for the images.
6. For training the RNN we can keep a block size length of 5 words.

##### So lets start by modiying the dataloder to output the captions.

In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
import torch.optim as optim
import torch.nn as nn
from torchinfo import summary
from tqdm import tqdm, trange
torch.random.manual_seed(42)
torch.cuda.manual_seed(42)
from torchvision.transforms import Compose, Resize, Normalize, CenterCrop
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.tensorboard import SummaryWriter
import datetime as dt
import json
import torchtext
from collections import OrderedDict
import nltk


In [56]:
image_captions_df = pd.read_parquet('F://coco/captions/image_captions_df_new.parquet')
image_captions_df

Unnamed: 0,image_id,file_name,captions,caption_token_set,idx_captions,tokens
0,57870,COCO_train2014_000000057870.jpg,"[[a, restaurant, has, modern, wooden, tables, ...","[arrangement, is, flower, on, table, blue, top...","[[27, 13, 25, 23, 9, 30, 12, 19], [27, 22, 13,...","[[a, restaurant, has, modern, wooden, tables, ..."
1,384029,COCO_train2014_000000384029.jpg,"[[a, man, preparing, desserts, in, a, kitchen,...","[types, desserts, person, is, kitchen, many, u...","[[27, 53, 56, 33, 29, 27, 35, 42, 29, 50], [27...","[[a, man, preparing, desserts, in, a, kitchen,..."
2,222016,COCO_train2014_000000222016.jpg,"[[a, big, red, telephone, booth, that, a, man,...","[inside, person, image, standing, is, an, that...","[[27, 66, 62, 67, 65, 61, 27, 53, 1, 59, 29], ...","[[a, big, red, telephone, booth, that, a, man,..."
3,520950,COCO_train2014_000000520950.jpg,"[[the, kitchen, is, full, of, spices, on, the,...","[is, on, that, kitchen, all, utilizes, with, i...","[[20, 35, 1, 73, 28, 80, 3, 20, 75], [27, 35, ...","[[the, kitchen, is, full, of, spices, on, the,..."
4,69675,COCO_train2014_000000069675.jpg,"[[a, child, and, woman, are, cooking, in, the,...","[child, woman, person, on, women, an, together...","[[27, 88, 12, 89, 94, 98, 29, 20, 99], [27, 89...","[[a, child, and, woman, are, cooking, in, the,..."
...,...,...,...,...,...,...
66658,360271,COCO_train2014_000000360271.jpg,"[[women, sitting, at, a, dinner, table, with, ...","[woman, another, dinner., sitting, is, watchin...","[[90, 215, 102, 27, 311, 4, 7, 140, 145, 102, ...","[[women, sitting, at, a, dinner, table, with, ..."
66659,471345,COCO_train2014_000000471345.jpg,"[[a, table, topped, with, two, white, plates, ...","[lobster,, rice., sitting, is, on, table, sit,...","[[27, 4, 283, 7, 103, 118, 427, 28, 97, 163, 1...","[[a, table, topped, with, two, white, plates, ..."
66660,444010,COCO_train2014_000000444010.jpg,"[[a, group, of, friends, sitting, down, at, a,...","[woman, jovial, sitting, table, group, togethe...","[[27, 362, 28, 2829, 215, 144, 102, 27, 4, 330...","[[a, group, of, friends, sitting, down, at, a,..."
66661,565004,COCO_train2014_000000565004.jpg,"[[wine, being, poured, into, a, glass, over, a...","[another, is, on, table, wine, with, red, glas...","[[418, 319, 6776, 161, 27, 686, 139, 27, 4], [...","[[wine, being, poured, into, a, glass, over, a..."


In [3]:
vocab_dict = {}
def make_vocab(token_list):
    for tokens in  token_list:
        for token in tokens:
            if token in vocab_dict:
                vocab_dict[token] +=1
            else:
                vocab_dict[token] = 1

image_captions_df['tokens'].apply(lambda x : make_vocab(x))
vocab_dict

{'a': 551261,
 'restaurant': 1446,
 'has': 8483,
 'modern': 610,
 'wooden': 5196,
 'tables': 746,
 'and': 79581,
 'chairs': 1711,
 '.': 250427,
 'long': 1976,
 'table': 17829,
 'with': 86972,
 'rattan': 3,
 'rounded': 20,
 'back': 3201,
 'plant': 607,
 'on': 121090,
 'top': 12320,
 'of': 113888,
 'it': 14840,
 'surrounded': 1058,
 'flower': 729,
 'arrangement': 162,
 'in': 103832,
 'the': 110813,
 'middle': 2291,
 'for': 6175,
 'meetings': 1,
 'is': 55372,
 'adorned': 73,
 'blue': 7708,
 'accents': 45,
 'man': 41987,
 'preparing': 1260,
 'desserts': 165,
 'kitchen': 7968,
 'covered': 4000,
 'frosting': 242,
 'chef': 225,
 'decorating': 39,
 'many': 3804,
 'small': 11447,
 'pastries': 328,
 'baker': 20,
 'prepares': 434,
 'various': 1516,
 'types': 589,
 'baked': 290,
 'goods': 97,
 'close': 3863,
 'up': 11683,
 'person': 13988,
 'grabbing': 95,
 'pastry': 288,
 'container': 462,
 'hand': 2687,
 'touching': 228,
 'big': 3151,
 'red': 9863,
 'telephone': 180,
 'booth': 135,
 'that': 1572

In [4]:
chartoidx = {}
idxtochar = []
for i,word in enumerate(vocab_dict.keys()):
    chartoidx[word] = i
    idxtochar.append(word) 

In [5]:
chartoidx

{'a': 0,
 'restaurant': 1,
 'has': 2,
 'modern': 3,
 'wooden': 4,
 'tables': 5,
 'and': 6,
 'chairs': 7,
 '.': 8,
 'long': 9,
 'table': 10,
 'with': 11,
 'rattan': 12,
 'rounded': 13,
 'back': 14,
 'plant': 15,
 'on': 16,
 'top': 17,
 'of': 18,
 'it': 19,
 'surrounded': 20,
 'flower': 21,
 'arrangement': 22,
 'in': 23,
 'the': 24,
 'middle': 25,
 'for': 26,
 'meetings': 27,
 'is': 28,
 'adorned': 29,
 'blue': 30,
 'accents': 31,
 'man': 32,
 'preparing': 33,
 'desserts': 34,
 'kitchen': 35,
 'covered': 36,
 'frosting': 37,
 'chef': 38,
 'decorating': 39,
 'many': 40,
 'small': 41,
 'pastries': 42,
 'baker': 43,
 'prepares': 44,
 'various': 45,
 'types': 46,
 'baked': 47,
 'goods': 48,
 'close': 49,
 'up': 50,
 'person': 51,
 'grabbing': 52,
 'pastry': 53,
 'container': 54,
 'hand': 55,
 'touching': 56,
 'big': 57,
 'red': 58,
 'telephone': 59,
 'booth': 60,
 'that': 61,
 'standing': 62,
 'inside': 63,
 'phone': 64,
 'this': 65,
 'an': 66,
 'image': 67,
 'using': 68,
 'full': 69,
 'spic

In [6]:
len(chartoidx)

22329

In [7]:
idxtochar

['a',
 'restaurant',
 'has',
 'modern',
 'wooden',
 'tables',
 'and',
 'chairs',
 '.',
 'long',
 'table',
 'with',
 'rattan',
 'rounded',
 'back',
 'plant',
 'on',
 'top',
 'of',
 'it',
 'surrounded',
 'flower',
 'arrangement',
 'in',
 'the',
 'middle',
 'for',
 'meetings',
 'is',
 'adorned',
 'blue',
 'accents',
 'man',
 'preparing',
 'desserts',
 'kitchen',
 'covered',
 'frosting',
 'chef',
 'decorating',
 'many',
 'small',
 'pastries',
 'baker',
 'prepares',
 'various',
 'types',
 'baked',
 'goods',
 'close',
 'up',
 'person',
 'grabbing',
 'pastry',
 'container',
 'hand',
 'touching',
 'big',
 'red',
 'telephone',
 'booth',
 'that',
 'standing',
 'inside',
 'phone',
 'this',
 'an',
 'image',
 'using',
 'full',
 'spices',
 'rack',
 'counter',
 ',',
 'oven',
 'other',
 'accessories',
 'utilizes',
 'all',
 'its',
 'space',
 'pots',
 'pans',
 'display',
 'very',
 'stove',
 'shelf',
 'child',
 'woman',
 'are',
 'cooking',
 'glances',
 'at',
 'young',
 'girl',
 "'s",
 'stovetop',
 'food'

In [8]:
# glove_embedding = torchtext.vocab.GloVe(name='6B', dim='300')
idxtochar = idxtochar + ['<unk>','<start>','<end>']
chartoidx.update({'<unk>':22329,
                         '<start>':22330,
                         '<end>':22331})
print(len(chartoidx))

22332


In [9]:
def stoi(captions_list):
    caption_token_list = []
    for caption in captions_list:
        tokens_list = [chartoidx['<start>']] + []
        for token in caption:
            try:
                tokens_list.append(chartoidx[token])
            except:
                tokens_list.append(chartoidx['<unk>'])
        tokens_list = tokens_list + [chartoidx['<end>']]
        caption_token_list.append(tokens_list)
    return caption_token_list

image_captions_df['idx_tokens'] = image_captions_df['tokens'].apply(lambda x : stoi(x))
image_captions_df

Unnamed: 0,image_id,file_name,captions,caption_token_set,idx_captions,tokens,idx_tokens
0,57870,COCO_train2014_000000057870.jpg,"[[a, restaurant, has, modern, wooden, tables, ...","[arrangement, is, flower, on, table, blue, top...","[[27, 13, 25, 23, 9, 30, 12, 19], [27, 22, 13,...","[[a, restaurant, has, modern, wooden, tables, ...","[[22330, 0, 1, 2, 3, 4, 5, 6, 7, 8, 22331], [2..."
1,384029,COCO_train2014_000000384029.jpg,"[[a, man, preparing, desserts, in, a, kitchen,...","[types, desserts, person, is, kitchen, many, u...","[[27, 53, 56, 33, 29, 27, 35, 42, 29, 50], [27...","[[a, man, preparing, desserts, in, a, kitchen,...","[[22330, 0, 32, 33, 34, 23, 0, 35, 36, 23, 37,..."
2,222016,COCO_train2014_000000222016.jpg,"[[a, big, red, telephone, booth, that, a, man,...","[inside, person, image, standing, is, an, that...","[[27, 66, 62, 67, 65, 61, 27, 53, 1, 59, 29], ...","[[a, big, red, telephone, booth, that, a, man,...","[[22330, 0, 57, 58, 59, 60, 61, 0, 32, 28, 62,..."
3,520950,COCO_train2014_000000520950.jpg,"[[the, kitchen, is, full, of, spices, on, the,...","[is, on, that, kitchen, all, utilizes, with, i...","[[20, 35, 1, 73, 28, 80, 3, 20, 75], [27, 35, ...","[[the, kitchen, is, full, of, spices, on, the,...","[[22330, 24, 35, 28, 69, 18, 70, 16, 24, 71, 2..."
4,69675,COCO_train2014_000000069675.jpg,"[[a, child, and, woman, are, cooking, in, the,...","[child, woman, person, on, women, an, together...","[[27, 88, 12, 89, 94, 98, 29, 20, 99], [27, 89...","[[a, child, and, woman, are, cooking, in, the,...","[[22330, 0, 87, 6, 88, 89, 90, 23, 24, 35, 8, ..."
...,...,...,...,...,...,...,...
66658,360271,COCO_train2014_000000360271.jpg,"[[women, sitting, at, a, dinner, table, with, ...","[woman, another, dinner., sitting, is, watchin...","[[90, 215, 102, 27, 311, 4, 7, 140, 145, 102, ...","[[women, sitting, at, a, dinner, table, with, ...","[[22330, 100, 198, 92, 0, 278, 10, 11, 125, 12..."
66659,471345,COCO_train2014_000000471345.jpg,"[[a, table, topped, with, two, white, plates, ...","[lobster,, rice., sitting, is, on, table, sit,...","[[27, 4, 283, 7, 103, 118, 427, 28, 97, 163, 1...","[[a, table, topped, with, two, white, plates, ...","[[22330, 0, 10, 244, 11, 99, 118, 381, 18, 97,..."
66660,444010,COCO_train2014_000000444010.jpg,"[[a, group, of, friends, sitting, down, at, a,...","[woman, jovial, sitting, table, group, togethe...","[[27, 362, 28, 2829, 215, 144, 102, 27, 4, 330...","[[a, group, of, friends, sitting, down, at, a,...","[[22330, 0, 326, 18, 2253, 198, 133, 92, 0, 10..."
66661,565004,COCO_train2014_000000565004.jpg,"[[wine, being, poured, into, a, glass, over, a...","[another, is, on, table, wine, with, red, glas...","[[418, 319, 6776, 161, 27, 686, 139, 27, 4], [...","[[wine, being, poured, into, a, glass, over, a...","[[22330, 371, 285, 5020, 144, 0, 574, 131, 0, ..."


In [10]:
# def stoi(captions_list):
#     caption_token_list = []
#     for caption in captions_list:
#         tokens_list = []
#         for token in caption:
#             try:
#                 tokens_list.append(chartoidx[token])
#             except:
#                 tokens_list.append(chartoidx['<unk>'])
#         # tokens_list = tokens_list + [glove_embedding.stoi['<end>']]
#         caption_token_list.append(tokens_list)
#     return caption_token_list

# image_captions_df['idx_tokens'] = image_captions_df['tokens'].apply(lambda x : stoi(x))
# image_captions_df

In [11]:
image_captions_df['idx_tokens'][0]

[[22330, 0, 1, 2, 3, 4, 5, 6, 7, 8, 22331],
 [22330, 0, 9, 1, 10, 11, 12, 13, 14, 7, 8, 22331],
 [22330, 0, 9, 10, 11, 0, 15, 16, 17, 18, 19, 20, 11, 4, 7, 22331],
 [22330, 0, 9, 10, 11, 0, 21, 22, 23, 24, 25, 26, 27, 22331],
 [22330, 0, 10, 28, 29, 11, 4, 7, 11, 30, 31, 8, 22331]]

In [12]:
# # we would also need to define the <unk>, <strart> and <end> token embeddings as a vector of zeros.
# unk_vector = torch.zeros(1,300)
# start_vector = torch.zeros(1,300)
# end_vector = torch.zeros(1,300)

# glove_embedding.vectors = torch.cat((glove_embedding.vectors, unk_vector, start_vector, end_vector), dim = 0)
# print(glove_embedding.vectors.shape)


In [13]:
# glove_embedding.vectors.shape

In [14]:
manual_transforms = Compose([
                            Resize(size = (256,256)),
                            CenterCrop(size=(224,224)),
                            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

##### Now lets define the dataloader which would take in 

#### We would add the start token seperately and also the end token at the end

In [15]:
class CocoDataset(Dataset):
    def __init__(self,image_label_df, val_stride = 10, is_val_set_bool = False, test_data_set =False, test_stride = 5, transforms = None):
        
        self.transfrom = transforms
        self.is_val_set_bool = is_val_set_bool
        self.val_stride = val_stride
        self.image_label_df = image_label_df.copy()
        self.test_data_set = test_data_set
        self.test_stride = test_stride

        if self.test_data_set: #If we need only a small subset of the data to work with
            self.image_label_df = self.image_label_df[::test_stride].reset_index(drop = True)

        elif self.is_val_set_bool: # If we need only the validation data then return the validation data which is a subset of total data
            assert self.val_stride > 0
            self.image_label_df = self.image_label_df[::val_stride]

        elif self.val_stride > 0:  # Else if val_stride is greater than zero then return the remaining dataframe after removing 10% of data
            self.image_label_df = self.image_label_df.drop(index = list(range(0,len(self.image_label_df),self.val_stride)))
            
        else: # else train on the full dataset
            self.image_label_df = self.image_label_df
        
    def __len__(self):
        """ This method calculates the length of your data."""
        return len(self.image_label_df)

    # Now we will introduce function which gives us an image and its corresponding labels by id
    def get_image_by_id(self, img_id = None):
        """ This function returns an image by its id and the corresponding multiple labesl in a present/no_present binary format"""
        

        folder_path = 'F://coco/train2014/train2014/'

        if img_id == None:
            raise ValueError('Must provide IMAGE ID')

        else:
            row = self.image_label_df[self.image_label_df['image_id']==img_id]
            file_name = row['file_name'].values[0]

            captions_array = torch.tensor(row['idx_tokens'][0][0])

            # Get the image-data from file_name
            image_array = torchvision.io.read_image(folder_path + file_name)
            image_array = (image_array/255.0).to(torch.float32)
            image_array = self.transfrom(image_array).to(torch.float32)
        
            return image_array, captions_array


    # Now we would write the code for returning the image from index and its corresponding labels. Which would be used by the train/val dataloaders
    def __getitem__(self, ndx):

        """ This function takes in an index and returns the image and the labels of the image at that index"""
        folder_path = 'F://coco/train2014/train2014/'
        
        row = self.image_label_df.iloc[ndx]
        # print(row)

        # Now get the image_id and the file_name of the image
        image_id = row['image_id']
        file_name = row['file_name']
        
        # Now get the captions array from the dataframe
        captions_array = torch.tensor(row['idx_tokens'][0])
        # Now get the image_data from storage
        image_array = torchvision.io.read_image(folder_path + file_name)
        image_array = (image_array/255.0).to(torch.float32)
        image_array = self.transfrom(image_array).to(torch.float32)
        # print(captions_array)
        return (image_array, captions_array, image_id)

In [16]:
# Create the datasets
# train_coco = CocoDataset(image_label_df=image_label_df, val_stride=10, transforms=manual_transforms)
val_coco = CocoDataset(image_label_df=image_captions_df,is_val_set_bool=True, val_stride=200, transforms=manual_transforms)

# Create the dataloaders
# train_dataloader = DataLoader(dataset=train_coco, batch_size=128, pin_memory=True, drop_last=True, num_workers=4)
val_dataloader = DataLoader(dataset=val_coco, batch_size=1,pin_memory=True, drop_last = True)

# create a test dataset and a dataloader
test_coco = CocoDataset(image_label_df=image_captions_df, test_data_set=True, test_stride=50, transforms=manual_transforms)

test_coco_dataloader = DataLoader(dataset=test_coco, batch_size=1, pin_memory =  True,drop_last=True)

In [17]:
print(len(test_coco), len(val_coco))

1334 334


In [18]:
val_coco[0]

(tensor([[[ 1.9845,  2.0282,  2.1066,  ..., -1.5453, -1.4803, -1.4329],
          [ 2.0787,  2.1186,  2.1630,  ..., -1.5129, -1.5295, -1.3770],
          [ 2.1288,  2.1801,  2.2007,  ..., -1.5440, -1.5506, -1.3885],
          ...,
          [-0.8186, -0.9293, -1.2151,  ...,  1.1449,  1.2222,  1.1112],
          [-0.6762, -0.9331, -1.2721,  ...,  1.2070,  1.0772,  1.0296],
          [-0.9186, -0.8472, -1.1707,  ...,  1.1230,  0.9924,  1.0074]],
 
         [[ 1.4887,  1.5653,  1.6484,  ..., -1.5204, -1.5866, -1.5113],
          [ 1.5893,  1.6470,  1.7540,  ..., -1.5135, -1.5127, -1.5059],
          [ 1.6629,  1.7494,  1.8205,  ..., -1.5403, -1.5721, -1.5089],
          ...,
          [-1.5387, -1.5428, -1.6320,  ...,  1.4164,  1.5767,  1.4394],
          [-1.5362, -1.5023, -1.5565,  ...,  1.5152,  1.3464,  1.2955],
          [-1.3907, -1.5480, -1.5731,  ...,  1.5051,  1.3453,  1.2162]],
 
         [[ 0.9374,  0.8916,  1.0398,  ..., -1.2865, -1.2206, -1.2715],
          [ 0.9505,  0.9995,

##### Now we have created our test and val dataloaders. Now is the time to create the full convnet+RNN model

In [49]:
class CustomResNet(nn.Module):
    def __init__(self, pretrained_model = None):
        super().__init__()



        for parameter in pretrained_model.parameters():
            parameter.requires_grad = False

        
        pretrained_model.fc = nn.Linear(2048,256)

        
        self.backbone = nn.Sequential(pretrained_model)
        # Till the previous step we have a pretrained Resnet whose weights we have frozen. Now we will define the RNN

        # self.embedding_matrix = torch.randn(len(chartoidx), 50, requires_grad=True) / (len(chartoidx)**0.5)
        self.embedding_matrix = nn.Embedding(len(chartoidx), 50)
        self.rnn = nn.RNNCell(input_size=50, hidden_size = 256, nonlinearity='tanh')
        self.output_layer = nn.Linear(256,len(chartoidx))

        self.cross_entopy_loss = nn.CrossEntropyLoss()

    def forward(self, image, caption_list):
        out_backbone = self.backbone(image)

        hidden = out_backbone.squeeze(0)
        # print(f"the hidden shape from image is {hidden.shape}")
        i = 0
        # new_char_id = char_id
        loss_word_list = []
        current_word_id = caption_list[i]
        target_index = caption_list[i+1]
        while True:
            
            if i == 0:
                current_word_id = caption_list[0]
            target_index = caption_list[i+1]
            if idxtochar[target_index] == '<end>':
                break
            out_embedding = self.embedding_matrix(current_word_id)  # embedding_shape = (1,50)
            # print(out_embedding.shape)
            # print(out_embedding, out_embedding.requires_grad)
            # print(f"The out embedding shape is {out_embedding.shape}")
            # calculate the hidden state activation from the embedding and the previous hidden state and get the next hidden state
            hidden = self.rnn(out_embedding, hidden)   # hidden_shape = (1,1024)

            # Now dish out the output_logits of the current time step
            output_logits = self.output_layer(hidden)   # output_logits shape (1,len(vocab))
            # print(f"The shape of output logits is {output_logits.shape}")

            # # calculate the probabilities of the characters
            probs =  torch.softmax(output_logits, dim = 0)         # probs_shape = (1,len(vocab))

            # print(f"The shape of probs is {probs.shape}")
            # print(f"The shape of output probabilities is {probs.shape}")
            # # Now get the index of the larget probability.
            next_word_id = torch.argmax(probs, dim =0)
            # print(f"the next word id  is {idxtochar[next_word_id.item()]}, and id is {next_word_id.item()}")

            # Now we will take the next_word that the cell thinks will be and do two things.
            # First we use it and the target to calculate the loss
            current_word_loss = self.cross_entopy_loss(probs, caption_list[i+1])
            # append the current_word_loss to loss_word_list
            loss_word_list.append(current_word_loss)
            
            # Second we would use the index of the current most probable output word as an input of the next time-step
            current_word_id = next_word_id
            # increase the value of i by 1 so that we can get the target word for the next input word
            i+=1
            
        final_loss = torch.sum(torch.tensor(loss_word_list, requires_grad=True))
        return final_loss

In [50]:
res_50 = torchvision.models.resnet50(weights =ResNet50_Weights.DEFAULT)
cust_res = CustomResNet(pretrained_model=res_50)

In [51]:
print(cust_res)

CustomResNet(
  (backbone): Sequential(
    (0): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inpla

In [52]:
summary(cust_res)

Layer (type:depth-idx)                             Param #
CustomResNet                                       --
├─Sequential: 1-1                                  --
│    └─ResNet: 2-1                                 --
│    │    └─Conv2d: 3-1                            (9,408)
│    │    └─BatchNorm2d: 3-2                       (128)
│    │    └─ReLU: 3-3                              --
│    │    └─MaxPool2d: 3-4                         --
│    │    └─Sequential: 3-5                        (215,808)
│    │    └─Sequential: 3-6                        (1,219,584)
│    │    └─Sequential: 3-7                        (7,098,368)
│    │    └─Sequential: 3-8                        (14,964,736)
│    │    └─AdaptiveAvgPool2d: 3-9                 --
│    │    └─Linear: 3-10                           524,544
├─Embedding: 1-2                                   1,116,600
├─RNNCell: 1-3                                     78,848
├─Linear: 1-4                                      5,739,324
├─CrossEntr

In [53]:
# from tqdm.notebook import tqdm_notebook
# for i in range(1):
#     for image_batch, caption_batch, _ in tqdm_notebook(test_coco_dataloader):
#         image_batch.shape, caption_batch.shape
#     break

In [54]:
# Now lets create the training loop which would take in the image and the caption train on it word by word
def training_loop(epochs, train_dataloader, val_dataloader, model, optimizer):
    for epoch in range(epochs):
        loss_list = []
        for batch in tqdm(train_dataloader):
            image, caption, _ = batch
            # print(image.shape)
            # image = torch.unsqueeze(image, dim=0)
            # print(caption.shape)
            # image = image.to(device = torch.device('cuda')),.,
            # caption = caption.to(device = torch.device('cuda'))
            optimizer.zero_grad()
            output =  model(image, caption[0])
            # print(output, grad_info)
            # break
            loss_list.append(output.item())
            output.backward()
            # print(grad_b)
            optimizer.step()
            # print(output.item())
        print(f"The loss is {sum(loss_list)/len(loss_list)}")


In [55]:
adam_optimizer = optim.Adam(cust_res.parameters(), lr  = 3e-2)
training_loop(10, train_dataloader=test_coco_dataloader, val_dataloader=None, model = cust_res, optimizer=adam_optimizer)

100%|██████████| 1334/1334 [06:41<00:00,  3.32it/s]


The loss is 112.8016164513721


100%|██████████| 1334/1334 [05:05<00:00,  4.36it/s]


The loss is 112.8016164513721


100%|██████████| 1334/1334 [04:59<00:00,  4.45it/s]


The loss is 112.8016164513721


  1%|          | 9/1334 [00:02<06:16,  3.51it/s]


KeyboardInterrupt: 

In [None]:
p = torch.tensor([[1,2,3,4]])
x = torch.argmax(p, dim = 1)
x

tensor([3])

In [71]:
p = torch.tensor([1.0,2.0,2.0,4.0], requires_grad=True)
q = torch.tensor([3.0,4.0,5.0,6.0], requires_grad=True)
s  = torch.sum(p+q, dim = 0)
print(s)
s.backward()


tensor(27., grad_fn=<SumBackward1>)


In [60]:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)
output = loss(input, target)

In [62]:
print(input)

tensor([[-0.7649, -0.7517, -0.6849, -0.0564, -1.0063],
        [-0.8070,  0.1536, -0.1430,  0.6772,  0.7898],
        [ 0.4916, -0.5202, -0.5325,  0.3402, -0.7753]], requires_grad=True)


In [61]:
print(target)

tensor([0, 2, 3])


In [63]:
print(output)

tensor(1.6728, grad_fn=<NllLossBackward0>)


In [1]:
import pandas as pd
image_captions_df = pd.read_parquet('F://coco/captions/image_captions_df_new.parquet')
image_captions_df

Unnamed: 0,image_id,file_name,captions,caption_token_set,idx_captions,tokens
0,57870,COCO_train2014_000000057870.jpg,"[[a, restaurant, has, modern, wooden, tables, ...","[arrangement, is, flower, on, table, blue, top...","[[27, 13, 25, 23, 9, 30, 12, 19], [27, 22, 13,...","[[a, restaurant, has, modern, wooden, tables, ..."
1,384029,COCO_train2014_000000384029.jpg,"[[a, man, preparing, desserts, in, a, kitchen,...","[types, desserts, person, is, kitchen, many, u...","[[27, 53, 56, 33, 29, 27, 35, 42, 29, 50], [27...","[[a, man, preparing, desserts, in, a, kitchen,..."
2,222016,COCO_train2014_000000222016.jpg,"[[a, big, red, telephone, booth, that, a, man,...","[inside, person, image, standing, is, an, that...","[[27, 66, 62, 67, 65, 61, 27, 53, 1, 59, 29], ...","[[a, big, red, telephone, booth, that, a, man,..."
3,520950,COCO_train2014_000000520950.jpg,"[[the, kitchen, is, full, of, spices, on, the,...","[is, on, that, kitchen, all, utilizes, with, i...","[[20, 35, 1, 73, 28, 80, 3, 20, 75], [27, 35, ...","[[the, kitchen, is, full, of, spices, on, the,..."
4,69675,COCO_train2014_000000069675.jpg,"[[a, child, and, woman, are, cooking, in, the,...","[child, woman, person, on, women, an, together...","[[27, 88, 12, 89, 94, 98, 29, 20, 99], [27, 89...","[[a, child, and, woman, are, cooking, in, the,..."
...,...,...,...,...,...,...
66658,360271,COCO_train2014_000000360271.jpg,"[[women, sitting, at, a, dinner, table, with, ...","[woman, another, dinner., sitting, is, watchin...","[[90, 215, 102, 27, 311, 4, 7, 140, 145, 102, ...","[[women, sitting, at, a, dinner, table, with, ..."
66659,471345,COCO_train2014_000000471345.jpg,"[[a, table, topped, with, two, white, plates, ...","[lobster,, rice., sitting, is, on, table, sit,...","[[27, 4, 283, 7, 103, 118, 427, 28, 97, 163, 1...","[[a, table, topped, with, two, white, plates, ..."
66660,444010,COCO_train2014_000000444010.jpg,"[[a, group, of, friends, sitting, down, at, a,...","[woman, jovial, sitting, table, group, togethe...","[[27, 362, 28, 2829, 215, 144, 102, 27, 4, 330...","[[a, group, of, friends, sitting, down, at, a,..."
66661,565004,COCO_train2014_000000565004.jpg,"[[wine, being, poured, into, a, glass, over, a...","[another, is, on, table, wine, with, red, glas...","[[418, 319, 6776, 161, 27, 686, 139, 27, 4], [...","[[wine, being, poured, into, a, glass, over, a..."


In [2]:
image_captions_df['tokens'][0]

array([array(['a', 'restaurant', 'has', 'modern', 'wooden', 'tables', 'and',
              'chairs', '.'], dtype=object)                                 ,
       array(['a', 'long', 'restaurant', 'table', 'with', 'rattan', 'rounded',
              'back', 'chairs', '.'], dtype=object)                           ,
       array(['a', 'long', 'table', 'with', 'a', 'plant', 'on', 'top', 'of',
              'it', 'surrounded', 'with', 'wooden', 'chairs'], dtype=object),
       array(['a', 'long', 'table', 'with', 'a', 'flower', 'arrangement', 'in',
              'the', 'middle', 'for', 'meetings'], dtype=object)               ,
       array(['a', 'table', 'is', 'adorned', 'with', 'wooden', 'chairs', 'with',
              'blue', 'accents', '.'], dtype=object)                            ],
      dtype=object)

In [8]:
def find_shortest_tokens(token_list):
    min_len = 1000
    shortest_token = None
    for token in token_list:
        if len(token) <min_len:
            shortest_token = token
    return shortest_token

image_captions_df['shortest_token'] = image_captions_df['tokens'].apply(lambda x : find_shortest_tokens(x))
image_captions_df['shortest_token_len'] = image_captions_df['shortest_token'].apply(lambda x : len(x))
image_captions_df



Unnamed: 0,image_id,file_name,captions,caption_token_set,idx_captions,tokens,shortest_token,shortest_token_len
0,57870,COCO_train2014_000000057870.jpg,"[[a, restaurant, has, modern, wooden, tables, ...","[arrangement, is, flower, on, table, blue, top...","[[27, 13, 25, 23, 9, 30, 12, 19], [27, 22, 13,...","[[a, restaurant, has, modern, wooden, tables, ...","[a, table, is, adorned, with, wooden, chairs, ...",11
1,384029,COCO_train2014_000000384029.jpg,"[[a, man, preparing, desserts, in, a, kitchen,...","[types, desserts, person, is, kitchen, many, u...","[[27, 53, 56, 33, 29, 27, 35, 42, 29, 50], [27...","[[a, man, preparing, desserts, in, a, kitchen,...","[close, up, of, a, hand, touching, various, pa...",9
2,222016,COCO_train2014_000000222016.jpg,"[[a, big, red, telephone, booth, that, a, man,...","[inside, person, image, standing, is, an, that...","[[27, 66, 62, 67, 65, 61, 27, 53, 1, 59, 29], ...","[[a, big, red, telephone, booth, that, a, man,...","[a, man, using, a, phone, in, a, phone, booth, .]",10
3,520950,COCO_train2014_000000520950.jpg,"[[the, kitchen, is, full, of, spices, on, the,...","[is, on, that, kitchen, all, utilizes, with, i...","[[20, 35, 1, 73, 28, 80, 3, 20, 75], [27, 35, ...","[[the, kitchen, is, full, of, spices, on, the,...","[a, very, small, kitchen, with, a, stove, and,...",12
4,69675,COCO_train2014_000000069675.jpg,"[[a, child, and, woman, are, cooking, in, the,...","[child, woman, person, on, women, an, together...","[[27, 88, 12, 89, 94, 98, 29, 20, 99], [27, 89...","[[a, child, and, woman, are, cooking, in, the,...","[two, women, cooking, on, stove, in, a, kitche...",10
...,...,...,...,...,...,...,...,...
66658,360271,COCO_train2014_000000360271.jpg,"[[women, sitting, at, a, dinner, table, with, ...","[woman, another, dinner., sitting, is, watchin...","[[90, 215, 102, 27, 311, 4, 7, 140, 145, 102, ...","[[women, sitting, at, a, dinner, table, with, ...","[a, bunch, of, women, are, eating, at, a, table]",9
66659,471345,COCO_train2014_000000471345.jpg,"[[a, table, topped, with, two, white, plates, ...","[lobster,, rice., sitting, is, on, table, sit,...","[[27, 4, 283, 7, 103, 118, 427, 28, 97, 163, 1...","[[a, table, topped, with, two, white, plates, ...","[there, is, lobster, ,, rice, ,, and, a, salad...",16
66660,444010,COCO_train2014_000000444010.jpg,"[[a, group, of, friends, sitting, down, at, a,...","[woman, jovial, sitting, table, group, togethe...","[[27, 362, 28, 2829, 215, 144, 102, 27, 4, 330...","[[a, group, of, friends, sitting, down, at, a,...","[a, jovial, older, couple, and, a, young, woma...",14
66661,565004,COCO_train2014_000000565004.jpg,"[[wine, being, poured, into, a, glass, over, a...","[another, is, on, table, wine, with, red, glas...","[[418, 319, 6776, 161, 27, 686, 139, 27, 4], [...","[[wine, being, poured, into, a, glass, over, a...","[a, wine, glass, being, filled, with, red, win...",9


In [12]:
less_than_15_image_caption_df = image_captions_df[image_captions_df['shortest_token_len'] <= 15]

# This is a new

In [57]:
import pandas as pd
image_captions_df = pd.read_parquet('F://coco/captions/image_captions_df_less_than_15.parquet')
image_captions_df

Unnamed: 0,image_id,file_name,captions,caption_token_set,idx_captions,tokens,shortest_token,shortest_token_len
0,57870,COCO_train2014_000000057870.jpg,"[[a, restaurant, has, modern, wooden, tables, ...","[arrangement, is, flower, on, table, blue, top...","[[27, 13, 25, 23, 9, 30, 12, 19], [27, 22, 13,...","[[a, restaurant, has, modern, wooden, tables, ...","[a, table, is, adorned, with, wooden, chairs, ...",11
1,384029,COCO_train2014_000000384029.jpg,"[[a, man, preparing, desserts, in, a, kitchen,...","[types, desserts, person, is, kitchen, many, u...","[[27, 53, 56, 33, 29, 27, 35, 42, 29, 50], [27...","[[a, man, preparing, desserts, in, a, kitchen,...","[close, up, of, a, hand, touching, various, pa...",9
2,222016,COCO_train2014_000000222016.jpg,"[[a, big, red, telephone, booth, that, a, man,...","[inside, person, image, standing, is, an, that...","[[27, 66, 62, 67, 65, 61, 27, 53, 1, 59, 29], ...","[[a, big, red, telephone, booth, that, a, man,...","[a, man, using, a, phone, in, a, phone, booth, .]",10
3,520950,COCO_train2014_000000520950.jpg,"[[the, kitchen, is, full, of, spices, on, the,...","[is, on, that, kitchen, all, utilizes, with, i...","[[20, 35, 1, 73, 28, 80, 3, 20, 75], [27, 35, ...","[[the, kitchen, is, full, of, spices, on, the,...","[a, very, small, kitchen, with, a, stove, and,...",12
4,69675,COCO_train2014_000000069675.jpg,"[[a, child, and, woman, are, cooking, in, the,...","[child, woman, person, on, women, an, together...","[[27, 88, 12, 89, 94, 98, 29, 20, 99], [27, 89...","[[a, child, and, woman, are, cooking, in, the,...","[two, women, cooking, on, stove, in, a, kitche...",10
...,...,...,...,...,...,...,...,...
66657,53136,COCO_train2014_000000053136.jpg,"[[vegetable, and, rice, dish, served, in, a, w...","[bed, vegetables, on, white, top, with, pasta,...","[[2753, 12, 2778, 339, 995, 29, 27, 118, 358],...","[[vegetable, and, rice, dish, served, in, a, w...","[a, white, bowl, contains, shredded, cabbage, ...",9
66658,360271,COCO_train2014_000000360271.jpg,"[[women, sitting, at, a, dinner, table, with, ...","[woman, another, dinner., sitting, is, watchin...","[[90, 215, 102, 27, 311, 4, 7, 140, 145, 102, ...","[[women, sitting, at, a, dinner, table, with, ...","[a, bunch, of, women, are, eating, at, a, table]",9
66660,444010,COCO_train2014_000000444010.jpg,"[[a, group, of, friends, sitting, down, at, a,...","[woman, jovial, sitting, table, group, togethe...","[[27, 362, 28, 2829, 215, 144, 102, 27, 4, 330...","[[a, group, of, friends, sitting, down, at, a,...","[a, jovial, older, couple, and, a, young, woma...",14
66661,565004,COCO_train2014_000000565004.jpg,"[[wine, being, poured, into, a, glass, over, a...","[another, is, on, table, wine, with, red, glas...","[[418, 319, 6776, 161, 27, 686, 139, 27, 4], [...","[[wine, being, poured, into, a, glass, over, a...","[a, wine, glass, being, filled, with, red, win...",9


Now first what we need to do is to create a vocabulary and convert the tokens to integers so that we can get the embeddings.

In [58]:
def make_vocab(token_list):
    """ This function takes in a token list and adds the unique tokens to the vocabulary."""
    for tokens in  token_list:
        for token in tokens:
            if token in vocab_dict:
                vocab_dict[token] +=1
            else:
                vocab_dict[token] = 1
    return vocab_dict


def stoi(captions_list, chartoidx):
    """ Function to convert the string tokens to integers.
        Also adds the <start>, <end> & <unk> tokens to form the final list to be fed in the RNN
        INPUT --> list of captions with 
        OUPUT --> """
    caption_token_list = []
    for caption in captions_list:
        tokens_list = [chartoidx['<start>']] + []
        for token in caption:
            try:
                tokens_list.append(chartoidx[token])
            except:
                tokens_list.append(chartoidx['<unk>'])
        tokens_list = tokens_list + [chartoidx['<end>']]
        caption_token_list.append(tokens_list)
    return caption_token_list

In [59]:
vocab_dict = {}
image_captions_df['tokens'].apply(lambda x : make_vocab(x))

chartoidx = {}    # create the character to index dictionary for tokens and their integer representations.
idxtochar = []    # index to char list. Useful for decoding the output from the neural net. 
for i,word in enumerate(vocab_dict.keys()):
    chartoidx[word] = i
    idxtochar.append(word)

idxtochar = idxtochar + ['<unk>','<start>','<end>','<pad>']    # add additional tokens
chartoidx.update({'<unk>':21573,                       # add addtional tokens and their corresponding numbers 
                        '<start>':21574,
                        '<end>':21575,
                        '<pad>':21576})

# Add an idx_tokens column to the dataframe to house the integer representation of the token strings.
image_captions_df['idx_tokens'] = image_captions_df['tokens'].apply(lambda x : stoi(x, chartoidx))
image_captions_df

Unnamed: 0,image_id,file_name,captions,caption_token_set,idx_captions,tokens,shortest_token,shortest_token_len,idx_tokens
0,57870,COCO_train2014_000000057870.jpg,"[[a, restaurant, has, modern, wooden, tables, ...","[arrangement, is, flower, on, table, blue, top...","[[27, 13, 25, 23, 9, 30, 12, 19], [27, 22, 13,...","[[a, restaurant, has, modern, wooden, tables, ...","[a, table, is, adorned, with, wooden, chairs, ...",11,"[[21574, 0, 1, 2, 3, 4, 5, 6, 7, 8, 21575], [2..."
1,384029,COCO_train2014_000000384029.jpg,"[[a, man, preparing, desserts, in, a, kitchen,...","[types, desserts, person, is, kitchen, many, u...","[[27, 53, 56, 33, 29, 27, 35, 42, 29, 50], [27...","[[a, man, preparing, desserts, in, a, kitchen,...","[close, up, of, a, hand, touching, various, pa...",9,"[[21574, 0, 32, 33, 34, 23, 0, 35, 36, 23, 37,..."
2,222016,COCO_train2014_000000222016.jpg,"[[a, big, red, telephone, booth, that, a, man,...","[inside, person, image, standing, is, an, that...","[[27, 66, 62, 67, 65, 61, 27, 53, 1, 59, 29], ...","[[a, big, red, telephone, booth, that, a, man,...","[a, man, using, a, phone, in, a, phone, booth, .]",10,"[[21574, 0, 57, 58, 59, 60, 61, 0, 32, 28, 62,..."
3,520950,COCO_train2014_000000520950.jpg,"[[the, kitchen, is, full, of, spices, on, the,...","[is, on, that, kitchen, all, utilizes, with, i...","[[20, 35, 1, 73, 28, 80, 3, 20, 75], [27, 35, ...","[[the, kitchen, is, full, of, spices, on, the,...","[a, very, small, kitchen, with, a, stove, and,...",12,"[[21574, 24, 35, 28, 69, 18, 70, 16, 24, 71, 2..."
4,69675,COCO_train2014_000000069675.jpg,"[[a, child, and, woman, are, cooking, in, the,...","[child, woman, person, on, women, an, together...","[[27, 88, 12, 89, 94, 98, 29, 20, 99], [27, 89...","[[a, child, and, woman, are, cooking, in, the,...","[two, women, cooking, on, stove, in, a, kitche...",10,"[[21574, 0, 87, 6, 88, 89, 90, 23, 24, 35, 8, ..."
...,...,...,...,...,...,...,...,...,...
66657,53136,COCO_train2014_000000053136.jpg,"[[vegetable, and, rice, dish, served, in, a, w...","[bed, vegetables, on, white, top, with, pasta,...","[[2753, 12, 2778, 339, 995, 29, 27, 118, 358],...","[[vegetable, and, rice, dish, served, in, a, w...","[a, white, bowl, contains, shredded, cabbage, ...",9,"[[21574, 2147, 6, 2167, 297, 813, 23, 0, 118, ..."
66658,360271,COCO_train2014_000000360271.jpg,"[[women, sitting, at, a, dinner, table, with, ...","[woman, another, dinner., sitting, is, watchin...","[[90, 215, 102, 27, 311, 4, 7, 140, 145, 102, ...","[[women, sitting, at, a, dinner, table, with, ...","[a, bunch, of, women, are, eating, at, a, table]",9,"[[21574, 100, 198, 92, 0, 278, 10, 11, 125, 12..."
66660,444010,COCO_train2014_000000444010.jpg,"[[a, group, of, friends, sitting, down, at, a,...","[woman, jovial, sitting, table, group, togethe...","[[27, 362, 28, 2829, 215, 144, 102, 27, 4, 330...","[[a, group, of, friends, sitting, down, at, a,...","[a, jovial, older, couple, and, a, young, woma...",14,"[[21574, 0, 330, 18, 2209, 198, 133, 92, 0, 10..."
66661,565004,COCO_train2014_000000565004.jpg,"[[wine, being, poured, into, a, glass, over, a...","[another, is, on, table, wine, with, red, glas...","[[418, 319, 6776, 161, 27, 686, 139, 27, 4], [...","[[wine, being, poured, into, a, glass, over, a...","[a, wine, glass, being, filled, with, red, win...",9,"[[21574, 343, 285, 4815, 144, 0, 550, 131, 0, ..."


In [64]:
def convert_shortest_token(token_list,chartoidx):
    new_token_list = [chartoidx['<start>']]
    for token in token_list:
        new_token_list = new_token_list  + [chartoidx[token]]
    new_token_list = new_token_list + [chartoidx['<end>']]
    return new_token_list

In [65]:
image_captions_df['shortest_idx_tokens'] = image_captions_df['shortest_token'].apply(lambda x : convert_shortest_token(x, chartoidx))
image_captions_df


Unnamed: 0,image_id,file_name,captions,caption_token_set,idx_captions,tokens,shortest_token,shortest_token_len,idx_tokens,shortest_idx_tokens
0,57870,COCO_train2014_000000057870.jpg,"[[a, restaurant, has, modern, wooden, tables, ...","[arrangement, is, flower, on, table, blue, top...","[[27, 13, 25, 23, 9, 30, 12, 19], [27, 22, 13,...","[[a, restaurant, has, modern, wooden, tables, ...","[a, table, is, adorned, with, wooden, chairs, ...",11,"[[21574, 0, 1, 2, 3, 4, 5, 6, 7, 8, 21575], [2...","[21574, 0, 10, 28, 29, 11, 4, 7, 11, 30, 31, 8..."
1,384029,COCO_train2014_000000384029.jpg,"[[a, man, preparing, desserts, in, a, kitchen,...","[types, desserts, person, is, kitchen, many, u...","[[27, 53, 56, 33, 29, 27, 35, 42, 29, 50], [27...","[[a, man, preparing, desserts, in, a, kitchen,...","[close, up, of, a, hand, touching, various, pa...",9,"[[21574, 0, 32, 33, 34, 23, 0, 35, 36, 23, 37,...","[21574, 49, 50, 18, 0, 55, 56, 45, 42, 8, 21575]"
2,222016,COCO_train2014_000000222016.jpg,"[[a, big, red, telephone, booth, that, a, man,...","[inside, person, image, standing, is, an, that...","[[27, 66, 62, 67, 65, 61, 27, 53, 1, 59, 29], ...","[[a, big, red, telephone, booth, that, a, man,...","[a, man, using, a, phone, in, a, phone, booth, .]",10,"[[21574, 0, 57, 58, 59, 60, 61, 0, 32, 28, 62,...","[21574, 0, 32, 68, 0, 64, 23, 0, 64, 60, 8, 21..."
3,520950,COCO_train2014_000000520950.jpg,"[[the, kitchen, is, full, of, spices, on, the,...","[is, on, that, kitchen, all, utilizes, with, i...","[[20, 35, 1, 73, 28, 80, 3, 20, 75], [27, 35, ...","[[the, kitchen, is, full, of, spices, on, the,...","[a, very, small, kitchen, with, a, stove, and,...",12,"[[21574, 24, 35, 28, 69, 18, 70, 16, 24, 71, 2...","[21574, 0, 84, 41, 35, 11, 0, 85, 6, 0, 86, 18..."
4,69675,COCO_train2014_000000069675.jpg,"[[a, child, and, woman, are, cooking, in, the,...","[child, woman, person, on, women, an, together...","[[27, 88, 12, 89, 94, 98, 29, 20, 99], [27, 89...","[[a, child, and, woman, are, cooking, in, the,...","[two, women, cooking, on, stove, in, a, kitche...",10,"[[21574, 0, 87, 6, 88, 89, 90, 23, 24, 35, 8, ...","[21574, 99, 100, 90, 16, 85, 23, 0, 35, 101, 8..."
...,...,...,...,...,...,...,...,...,...,...
66657,53136,COCO_train2014_000000053136.jpg,"[[vegetable, and, rice, dish, served, in, a, w...","[bed, vegetables, on, white, top, with, pasta,...","[[2753, 12, 2778, 339, 995, 29, 27, 118, 358],...","[[vegetable, and, rice, dish, served, in, a, w...","[a, white, bowl, contains, shredded, cabbage, ...",9,"[[21574, 2147, 6, 2167, 297, 813, 23, 0, 118, ...","[21574, 0, 118, 472, 2315, 2299, 2146, 6, 2607..."
66658,360271,COCO_train2014_000000360271.jpg,"[[women, sitting, at, a, dinner, table, with, ...","[woman, another, dinner., sitting, is, watchin...","[[90, 215, 102, 27, 311, 4, 7, 140, 145, 102, ...","[[women, sitting, at, a, dinner, table, with, ...","[a, bunch, of, women, are, eating, at, a, table]",9,"[[21574, 100, 198, 92, 0, 278, 10, 11, 125, 12...","[21574, 0, 276, 18, 100, 89, 617, 92, 0, 10, 2..."
66660,444010,COCO_train2014_000000444010.jpg,"[[a, group, of, friends, sitting, down, at, a,...","[woman, jovial, sitting, table, group, togethe...","[[27, 362, 28, 2829, 215, 144, 102, 27, 4, 330...","[[a, group, of, friends, sitting, down, at, a,...","[a, jovial, older, couple, and, a, young, woma...",14,"[[21574, 0, 330, 18, 2209, 198, 133, 92, 0, 10...","[21574, 0, 5737, 98, 275, 6, 0, 93, 88, 1902, ..."
66661,565004,COCO_train2014_000000565004.jpg,"[[wine, being, poured, into, a, glass, over, a...","[another, is, on, table, wine, with, red, glas...","[[418, 319, 6776, 161, 27, 686, 139, 27, 4], [...","[[wine, being, poured, into, a, glass, over, a...","[a, wine, glass, being, filled, with, red, win...",9,"[[21574, 343, 285, 4815, 144, 0, 550, 131, 0, ...","[21574, 0, 343, 550, 285, 171, 11, 58, 343, 8,..."


First and foremost we would have to change the way that we are making the dataset because previously we were giving all the five caption and then taking the first from them. Now we would have to give only one caption.

In [None]:
class CocoDataset(Dataset):
    def __init__(self,image_label_df, val_stride = 10, is_val_set_bool = False, test_data_set =False, test_stride = 5, transforms = None):
        
        self.transfrom = transforms
        self.is_val_set_bool = is_val_set_bool
        self.val_stride = val_stride
        self.image_label_df = image_label_df.copy()
        self.test_data_set = test_data_set
        self.test_stride = test_stride

        if self.test_data_set: #If we need only a small subset of the data to work with
            self.image_label_df = self.image_label_df[::test_stride].reset_index(drop = True)

        elif self.is_val_set_bool: # If we need only the validation data then return the validation data which is a subset of total data
            assert self.val_stride > 0
            self.image_label_df = self.image_label_df[::val_stride]

        elif self.val_stride > 0:  # Else if val_stride is greater than zero then return the remaining dataframe after removing 10% of data
            self.image_label_df = self.image_label_df.drop(index = list(range(0,len(self.image_label_df),self.val_stride)))
            
        else: # else train on the full dataset
            self.image_label_df = self.image_label_df
        
    def __len__(self):
        """ This method calculates the length of your data."""
        return len(self.image_label_df)

    # Now we will introduce function which gives us an image and its corresponding labels by id
    def get_image_by_id(self, img_id = None):
        """ This function returns an image by its id and the corresponding multiple labesl in a present/no_present binary format"""
        

        folder_path = 'F://coco/train2014/train2014/'

        if img_id == None:
            raise ValueError('Must provide IMAGE ID')

        else:
            row = self.image_label_df[self.image_label_df['image_id']==img_id]
            file_name = row['file_name'].values[0]

            captions_array = torch.tensor(row['idx_tokens'][0][0])

            # Get the image-data from file_name
            image_array = torchvision.io.read_image(folder_path + file_name)
            image_array = (image_array/255.0).to(torch.float32)
            image_array = self.transfrom(image_array).to(torch.float32)
        
            return image_array, captions_array


    # Now we would write the code for returning the image from index and its corresponding labels. Which would be used by the train/val dataloaders
    def __getitem__(self, ndx):

        """ This function takes in an index and returns the image and the labels of the image at that index"""
        folder_path = 'F://coco/train2014/train2014/'
        
        row = self.image_label_df.iloc[ndx]
        # print(row)

        # Now get the image_id and the file_name of the image
        image_id = row['image_id']
        file_name = row['file_name']
        
        # Now get the captions array from the dataframe
        captions_array = torch.tensor(row['idx_tokens'][0])
        # Now get the image_data from storage
        image_array = torchvision.io.read_image(folder_path + file_name)
        image_array = (image_array/255.0).to(torch.float32)
        image_array = self.transfrom(image_array).to(torch.float32)
        # print(captions_array)
        return (image_array, captions_array, image_id)