In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import albumentations as alb
from albumentations.pytorch import ToTensorV2
import timm
import cv2
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from collections import Counter
import torchtext
from torchtext.data import get_tokenizer


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
class ImageCaptioner(nn.Module):
    def __init__(self, context_length, vocab_size, num_blocks, model_dim, num_heads, dropout_prob):
        super().__init__()
        self.cnn_encoder = timm.create_model('efficientnet_b0', pretrained=True)
        test_image = torch.zeros(1,3,224,224)

        with torch.no_grad():
            cnn_output = self.cnn_encoder(test_image)
        in_features = cnn_output.shape[1]    
        self.project = nn.Linear(in_features, model_dim)

        self.word_embeddings = nn.Embedding(vocab_size, model_dim)
        self.pos_embeddings = nn.Embedding(context_length, model_dim)

        block = nn.TransformerDecoderLayer(model_dim, num_heads, 2*model_dim, dropout=dropout_prob, batch_first=True, norm_first =True)
        self.blocks = nn.TransformerDecoderLayer(block, num_blocks)

        self.vocab_projection = nn.Linear(model_dim, vocab_size)

        
    def forward(self, images, true_labels):
        tok_embedded = self.word_embeddings(true_labels)
        B,T = true_labels.shape
        positions = torch.arange(T).to(device)
        pos_embedded = self.pos_embeddings(positions)
        total_emebddings = tok_embedded + pos_embedded #input to blocks
        
        with torch.no_grad():
            encoded_image = self.project = (self.cnn_encoder(images).view(B,-1))
        
        img_for_attention = torch.unsqueeze(encoded_image, 1)

        #Causal/Subsequent Mask
        attention_mask = nn.Transformer.generate_square_subsequent_mask(T).to(device)
        block_output = self.blocks(total_emebddings, img_for_attention, tgt_mask=attention_mask)

        vocabulary_vector = self.vocab_projection(block_output) #B,T,V

        return vocabulary_vector

In [None]:
caption_filename = 'Flickr8k/captions.txt'
missing = '2258277193_586949ec62'

with open(caption_filename) as captions:
    lines = captions.readlines()

get_captions = {}
all_captions = []

for i in range(1,len(lines)):
    data = lines[i].rstrip('\n').split('.jpg,')
    img_name = data[0] + '.jpg'
    if img_name == missing:
        continue

    caption_list = get_captions.get(img_name, [])
    caption_list.append(data[1])
    get_captions[img_name] = caption_list
    all_captions.append(data[1])

In [28]:
print(len(all_captions))

40455


In [None]:
df = pd.Dataframe(columns=['filename', 'caption'])
df['filename'] = get_captions.keys()
df['caption'] = df['filename'].map(lambda filename: get_captions(filename))

vocab_frequency = Counter()
word_tokeniser = get_tokenizer('basic english')

for c in all_captions:
    vocab_frequency.update(word_tokeniser(c))

vocabulary_mapping = torchtext.vocab.vocab(vocab_frequency)
vocabulary_mapping.insert_token('<UNKNOWN>',0)
vocabulary_mapping.insert_token('<PAD>',1)
vocabulary_mapping.insert_token('<START>',2)
vocabulary_mapping.insert_token('<END>',3)
vocabulary_mapping.set_default_index(4)

In [None]:
context_length = 20

class ImageCaptionDataset(Dataset):
    def __init__(self, split):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        image, captions = self.df.iloc[idx]
        encoded_captions = []
        for i,cap in enumerate(captions):
            splitted = word_tokeniser(cap)

            integers = [vocabulary_mapping[word] for word in splitted]
            integers = [2] + integers + [3]

            if len(integers) <= context_length:
                pads_to_add = context_length - len(integers)
                integers += [1] * pads_to_add
            else:
                integers = integers[:context_length - 1] + [3]
            
            encoded_captions.append(torch.tensor(integers, dtype=torch.long))

        random_idx = torch.randint(len(encoded_captions), (1,)).item()
        return image, encoded_captions[random_idx]
