# 1. Load Data

In [80]:
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import random
import math
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from nltk.tokenize import word_tokenize

In [81]:
hyperparameters = {
    'batch_size': 16,
    'embedding_dim': 256,
    'lstm_out_dim': 512,
    'hidden_size': 100,
    'epochs': 150,
    'learning_rate': 0.001
}

PADDING_TOKEN = '<PAD>'
UNKNOWN_TOKEN = '<UNK>'

In [3]:
BASE_PATH = '/scratch/lt2316-h18-resources/coco/'

In [4]:
coco_captions = COCO(BASE_PATH + 'annotations/captions_train2017.json')
coco_instances = COCO(BASE_PATH + 'annotations/instances_train2017.json')

loading annotations into memory...
Done (t=1.19s)
creating index...
index created!
loading annotations into memory...
Done (t=21.80s)
creating index...
index created!


In [66]:
class Sampler():
    def __init__(self, coco_captions, number_of_samples=100, train_split=0.8, val_split=0.05, test_split=0.15) -> None:
        samples = []
        transform = transforms.ToTensor()
        
        random_images = random.sample(list(coco_captions.imgs.values()), number_of_samples)
        for image_info in random_images:
            image = Image.open(BASE_PATH + 'train2017/' + image_info['file_name']).resize((100,100)).convert('RGB')
            samples.extend([{
                'image': transform(image),
                'caption': annotation['caption']          
            } for annotation in coco_captions.imgToAnns[image_info['id']]])

        train_border = int(train_split * number_of_samples)
        val_border = int((train_split + val_split) * number_of_samples)

        self.train_samples = samples[:train_border]
        self.val_samples = samples[train_border:val_border]
        self.test_samples = samples[val_border:]


In [73]:
class COCO_Dataset(Dataset):
    def __init__(self, samples) -> None:
        super().__init__()
        
        self.max_length_context = -1
        vocab = {PADDING_TOKEN, UNKNOWN_TOKEN}
        for sample in samples:
            split_caption = word_tokenize(sample['caption'])
            vocab.update(split_caption)
            self.max_length_context = max(self.max_length_context, len(split_caption))

        self.vocab = {word: index for index, word in enumerate(list(vocab))}
        self.samples = []
        for sample in samples:
            split_caption = word_tokenize(sample['caption'])
            padded_context = [self.get_encoded_word(word) for word in split_caption]
            padded_context.extend([self.get_encoded_word(PADDING_TOKEN)] * (self.max_length_context - len(split_caption)))
            self.samples.append({
                'image': sample['image'],
                'caption': sample['caption'],
                'encoded_caption': torch.tensor(padded_context)
            })

    def __getitem__(self, index):
        return self.samples[index]
    
    def __len__(self):
        return len(self.samples)

    def get_encoded_word(self, word):
        if word in self.vocab:
            return self.vocab[word]
        else:
            return self.vocab[UNKNOWN_TOKEN]

    def get_vocab_size(self):
        return len(self.vocab)

In [74]:
sampler = Sampler(coco_captions)

train_dataloader = DataLoader(COCO_Dataset(sampler.train_samples),
                              batch_size=hyperparameters['batch_size'],
                              shuffle=True)
val_dataloader = DataLoader(COCO_Dataset(sampler.val_samples),
                            batch_size=hyperparameters['batch_size'],
                            shuffle=True)
test_dataloader = DataLoader(COCO_Dataset(sampler.test_samples),
                             batch_size=hyperparameters['batch_size'],
                             shuffle=True)


In [75]:
train_dataloader.dataset[0]

{'image': tensor([[[0.4745, 0.4706, 0.4078,  ..., 0.4471, 0.4353, 0.4588],
          [0.4588, 0.4667, 0.4275,  ..., 0.4471, 0.4392, 0.4588],
          [0.4667, 0.4706, 0.4118,  ..., 0.4588, 0.4431, 0.4588],
          ...,
          [0.5176, 0.6392, 0.7255,  ..., 0.7529, 0.7529, 0.7373],
          [0.5137, 0.6157, 0.7255,  ..., 0.7529, 0.7529, 0.7490],
          [0.5255, 0.6314, 0.7255,  ..., 0.7412, 0.7373, 0.7412]],
 
         [[0.4549, 0.4549, 0.4039,  ..., 0.4471, 0.4392, 0.4549],
          [0.4471, 0.4549, 0.4157,  ..., 0.4431, 0.4392, 0.4588],
          [0.4549, 0.4588, 0.3961,  ..., 0.4588, 0.4471, 0.4588],
          ...,
          [0.5490, 0.5765, 0.6039,  ..., 0.5490, 0.5412, 0.5255],
          [0.5490, 0.5647, 0.6118,  ..., 0.5451, 0.5373, 0.5333],
          [0.5608, 0.5725, 0.5961,  ..., 0.5373, 0.5333, 0.5294]],
 
         [[0.4706, 0.4627, 0.4078,  ..., 0.4471, 0.4471, 0.4627],
          [0.4471, 0.4549, 0.4275,  ..., 0.4510, 0.4471, 0.4667],
          [0.4588, 0.4627, 0.40

# 2. Model

In [127]:
class CaptionEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, out_dim, padding_idx) -> None:
        super(CaptionEncoder, self).__init__()

        self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.rnn = nn.LSTM(embedding_dim, out_dim, num_layers=1, bidirectional=True, batch_first=True)

    def forward(self, caption_batch):
        embeddings = self.embeddings(caption_batch)
        _, (h_n, _) = self.rnn(embeddings)

        return h_n

In [128]:
class ImageEncoder(nn.Module):
    def __init__(self) -> None:
        super(ImageEncoder, self).__init__()
        
        self.image_encoder = nn.Sequential(
            nn.Conv2d(3, 3, 3),
            nn.BatchNorm2d(3),
            nn.MaxPool2d(2,2),
            nn.Tanh()
        )
        

    def forward(self, image_batch):
        return self.image_encoder(image_batch)


In [147]:
class CaptionEvaluator(nn.Module):
    def __init__(self, caption_encoder, image_encoder, hidden_size) -> None:
        super(CaptionEvaluator, self).__init__()

        self.caption_encoder = caption_encoder
        self.image_encoder = image_encoder

        self.classifier = nn.Sequential(
            nn.Linear(100 * 100 * 3, hidden_size),
            nn.Dropout(0.05),
            nn.Tanh(),
            nn.Linear(hidden_size, int(hidden_size/2)),
            nn.Tanh(),
            nn.Linear(int(hidden_size/2), 1),
            nn.Sigmoid()
        )

    def forward(self, image_batch, caption_batch):
        encoded_caption = self.caption_encoder(caption_batch)
        flattened_caption = torch.flatten(encoded_caption.permute(1,0,2), 1)

        encoded_image = self.image_encoder(image_batch)
        flattened_image = torch.flatten(encoded_image, 1)

        concatenated_encoding = torch.cat((flattened_caption, flattened_image), 1)
        print(concatenated_encoding.shape)
        


In [148]:
caption_encoder = CaptionEncoder(train_dataloader.dataset.get_vocab_size(),
                                 hyperparameters['embedding_dim'],
                                 hyperparameters['lstm_out_dim'],
                                 train_dataloader.dataset.get_encoded_word(PADDING_TOKEN))

image_encoder = ImageEncoder()
caption_evaluator = CaptionEvaluator(caption_encoder,
                                     image_encoder, 
                                     hyperparameters['hidden_size'])




In [149]:
for i, batch in enumerate(train_dataloader):
#    image = image_encoder(batch['image'])
    output = caption_evaluator(batch['image'], batch['encoded_caption'])
    break

torch.Size([16, 7203])
torch.Size([16, 1024])
torch.Size([16, 8227])
