In [0]:
# Imports in encoder file
import torch
import torch.nn as nn
import torchvision.models as models

# Imports in Utils
import os
from skimage import io
import torchvision.transforms as transforms
import torch.utils.data as datautil
import nltk
import pandas as pd
from PIL import Image


In [12]:
class VGGNetEncoder(nn.Module):
    def __init__(self,encoded_img_size=14):
        super(VGGNetEncoder, self).__init__()
        # feature extraction model (VGG16)
        self.net = models.vgg16(pretrained=True)
        print("Original VGG16 summary")
        print(summary(self.net,input_size=(3,224,224)))
        # Removing the last max pool and fully connected layer used for classification
        # As the paper mentions the use of the output from a lower convolutional layer
        self.net = nn.Sequential(*list(self.net.features.children())[:-1])
        print("Modified VGG summary")
        print(summary(self.net, input_size=(3,224,224)))

    def forward(self, x):
        feature = self.net(x)
 
        return feature
      
VGGNetEncoder()

Original VGG16 summary
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReL

VGGNetEncoder(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): C

In [0]:
# Goes into utils.py
def process_image(resize, crop_size, split):
    image_mean = (0.485, 0.456, 0.406)
    image_std =  (0.229, 0.224, 0.225)
    if split == "Train":
      transform = transforms.Compose([transforms.Resize(resize),
                                      transforms.RandomCrop(crop_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize(image_mean, image_std)])
    elif split == "Val":
      transform = transforms.Compose([transforms.Scale(256),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(image_mean, image_std)])
     
    return transform

In [0]:
# Goes into utils.py
# Load the input for the CNN from the csv
class Flickr8KDataset(datautil.Dataset):
  def __init__(self, csv_file, root_dir, vocabulary, max_caption_len, transform=None):
          """
          Args:
              csv_file (string): Path to the csv file with annotations.
              root_dir (string): Directory with all the images.
              vocabulary (vocab object): Object with the word_to_ind mappings.
              transform (callable, optional): Optional transform to be applied
                  on a sample.
          """
          self.input_frame = pd.read_csv(csv_file)
          self.root_dir = root_dir
          self.transform = transform
          self.vocab = vocabulary

  def __len__(self):
      return len(self.input_frame)

  def __getitem__(self, idx):
      img_path = os.path.join(self.root_dir, self.input_frame.iloc[idx, 0])
      image= Image.open(img_path).convert("RGB")
      
      caption = self.input_frame.iloc[idx, 1]
      # Tokenize the word in the captions
      caption_tokens = nltk.tokenize.word_tokenize(str(caption).lower())

      # Convert the captions to the corresponding word ids from the built vocabulary.
      captions = []
      captions.append(self.vocab['<start>'])
      for tokens in caption_tokens:
        captions.append(self.vocab[token] if token in self.vocab else self.vocab['<unk>'] for token in tokens)
      captions.append(self.vocab['<end>'] + word_dict['<pad>'] * (max_caption_len - len(tokens)))
      target = torch.Tensor(captions)
      return image, target

    
def load_dataset(input_csv, img_dir, vocab,max_caption_len, batch_size, shuffle):
  flickr_data = Flickr8KDataset(input_csv, img_dir, vocab, max_caption_len)
  data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle)
  
  return data_loader