In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import albumentations as alb
from albumentations.pytorch import ToTensorV2
import timm
import cv2
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from collections import Counter


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [None]:
class ImageCaptioner(nn.Module):
    def __init__(self, context_length, vocab_size, num_blocks, model_dim, num_heads, dropout_prob):
        super().__init__()
        self.cnn_encoder = timm.create_model('efficientnet_b0', pretrained=True)
        test_image = torch.zeros(1,3,224,224)

        with torch.no_grad():
            cnn_output = self.cnn_encoder(test_image)
        in_features = cnn_output.shape[1]    
        self.project = nn.Linear(in_features, model_dim)

        self.word_embeddings = nn.Embedding(vocab_size, model_dim)
        self.pos_embeddings = nn.Embedding(context_length, model_dim)

        block = nn.TransformerDecoderLayer(model_dim, num_heads, 2*model_dim, dropout=dropout_prob, batch_first=True, norm_first =True)
        self.blocks = nn.TransformerDecoderLayer(block, num_blocks)

        self.vocab_projection = nn.Linear(model_dim, vocab_size)

        
    def forward(self, images, true_labels):
        tok_embedded = self.word_embeddings(true_labels)
        B,T = true_labels.shape
        positions = torch.arange(T).to(device)
        pos_embedded = self.pos_embeddings(positions)
        total_emebddings = tok_embedded + pos_embedded #input to blocks
        
        with torch.no_grad():
            encoded_image = self.project = (self.cnn_encoder(images).view(B,-1))
        
        img_for_attention = torch.unsqueeze(encoded_image, 1)

        #Causal/Subsequent Mask
        attention_mask = nn.Transformer.generate_square_subsequent_mask(T).to(device)
        block_output = self.blocks(total_emebddings, img_for_attention, tgt_mask=attention_mask)

        vocabulary_vector = self.vocab_projection(block_output) #B,T,V

        return vocabulary_vector

In [None]:
caption_filename = 'captions.txt'
missing = ''

with open(caption_filename) as captions:
    lines = captions.readlines()