#### d_model: the number of expected features in the input (required).
#### nhead: the number of heads in the multiheadattention models (required).
#### dropout: the dropout value (default=0.1).
#### activation: the activation function of the intermediate layer, can be a string ("relu" or "gelu") or a unary callable. Default: relu

In [None]:
import torch
import math
from torchvision import models
from torch.nn import MultiheadAttention
from transformers.models.deprecated.transfo_xl.modeling_transfo_xl import PositionwiseFF

In [None]:
# Encoder architecture for image features
class Encoder(torch.nn.Module):
    def __init__(self, embed_size):
        """
        :param embed_size: size of the embedding vector is used to represent the image features in vector form
        """
        super(Encoder, self).__init__() # Load the pretrained model
        image_encoder = models.vgg16(pretrained=True)
        # freeze all VGG parameters so we don't backprop through them since they are pre-trained
        for param in image_encoder.parameters():
            param.requires_grad_(False) 
        
        # The last few layers in VGG16 are fully connected layers used for classification, but in an encoder, we're interested in extracting high-level features from the image
        modules = list(image_encoder.features)[:-1] 
        
        # Since we have updated the model architecture, we are updating the image_encoder model
        self.image_encoder = torch.nn.Sequential(*modules) 
        
        # Note: ensure the output shape of self.image_encoder matches images_encoder_classifier[0].in_features
        
        # Add a linear layer to transform the features into the desired dimension
        self.embed = torch.nn.Linear(image_encoder.classifier[0].in_features, embed_size) # Add a linear layer to transform the features into the desired dimension
        
    def forward(self, image):
        """
        :param image: Tensor, shape [batch_size, 3, 224, 224]
        :return: Tensor, shape [batch_size, embed_size]
        """
        with torch.no_grad():
            features = self.image_encoder(image)
            
        # Resize the features to have the same size as the input to the decoder
        features = features.view(features.size(0), -1)
        features = self.embed(features)
        
        return features
                

In [None]:
# Positional Encoding [credit: https://medium.com/@hunter-j-phillips/positional-encoding-7a93db4109e6]
class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0) # [1, max_len, d_model]
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        """
        :param x: Tensor, shape [seq_len, batch_size, d_model or embed_size]
        """
        x = x + self.pe[:, :x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [None]:
# Decoder Architecture
class TransformerDecoderLayer(torch.nn.Module):
    def __init__(self, d_model: int, n_head: int, d_ffn: int, dropout: float=0.1):
        """
        Decoder Layer Transformer Architecture
        :param d_model: dimension of the vectors flowing through the model (generally it is d_model = embed_size = 512)
        :param n_head: number of heads in the multi-head attention models (generally n_head = 8) 
        :param d_ffn: dimension of the feed forward neural network (generally d_ffn = 2048)
        :param dropout: probability that a neuron will be turned off during training (generally dropout = 0.1)
        """
        super(TransformerDecoderLayer, self).__init__() 

        # Masked Multi-head Self-Attention
        self.self_attn = MultiheadAttention(d_model, n_head, dropout=dropout)
        # Masked Multi-head Self-Attention Layer Normalization
        self.self_attn_norm = torch.nn.LayerNorm(d_model)
        
        # Multi-head Cross-Attention
        self.cross_attn = MultiheadAttention(d_model, n_head, dropout=dropout)
        # Multi-head Cross-Attention Layer Normalization
        self.cross_attn_norm = torch.nn.LayerNorm(d_model)
        
        
        # Position-wise Feed Forward Neural Network
        self.ffn = PositionwiseFF(d_model, d_ffn, dropout=dropout)
        # Position-wise Feed Forward Neural Network Layer Normalization
        self.ffn_norm = torch.nn.LayerNorm(d_model)
        
        # Dropout
        self.dropout = torch.nn.Dropout(dropout)
        
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        """
        Forward pass of the Decoder Layer
        :param tgt: Embedding of the target sequence (shape: [seq_len, batch_size, d_model])
        :param memory: Embedding of the source sequence (shape: [seq_len, batch_size, d_model])
        :param tgt_mask: mask for the target sequence (shape: [seq_len, seq_len])
        :param memory_mask: mask for the source sequence (shape: [seq_len, seq_len])
        :return: Embedding of the target sequence (shape: [seq_len, batch_size, d_model])
        """
        # Masked Multi-head Self-Attention
        tgt2, masked_self_attn_weights = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask)
        tgt = tgt + self.dropout(tgt2) # Adding
        tgt = self.self_attn_norm(tgt) # Normalizing
        
        # Multi-head Cross-Attention
        tgt2, cross_attn_weights = self.cross_attn(tgt, memory, memory, attn_mask=memory_mask)
        tgt = tgt + self.dropout(tgt2) # Adding
        tgt = self.cross_attn_norm(tgt) # Normalizing
        
        # Position-wise Feed Forward Neural Network
        tgt2 = self.ffn(tgt)
        tgt = tgt + self.dropout(tgt2)
        tgt = self.ffn_norm(tgt)
        
        return tgt, masked_self_attn_weights
    
class Decoder(torch.nn.Module):
    def __init__(self, vocab_size, d_model, n_head, d_ffn, num_layers, max_seq_length, dropout=0.1):
        """
        Decoder Transformer Architecture
        :param vocab_size: size of the vocabulary
        :param d_model: dimension of the vectors flowing through the model (generally it is d_model = embed_size = 512)
        :param n_head: number of heads in the multi-head attention models (generally n_head = 8) 
        :param d_ffn: dimension of the feed forward neural network (generally d_ffn = 2048)
        :param num_layers: number of decoder layers (generally num_layers = 6)
        :param max_seq_length: maximum sequence length
        :param dropout: probability that a neuron will be turned off during training (generally dropout = 0.1)
        """
        super(Decoder, self).__init__()
        self.layers = torch.nn.ModuleList([TransformerDecoderLayer(d_model, n_head, d_ffn, dropout) for _ in range(num_layers)])
        self.dropout = torch.nn.Dropout(dropout)
        
        # self.positional_encoding = PositionalEncoding(d_model, dropout=dropout, max_len=max_seq_length)
        self.output_projection = torch.nn.Linear(d_model, vocab_size)
        
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None):
        """
        Forward pass of the Decoder
        :param tgt: Embedding of the target sequence (shape: [seq_len, batch_size, d_model])
        :param memory: Embedding of the source sequence (shape: [seq_len, batch_size, d_model])
        :param tgt_mask: mask for the target sequence (shape: [seq_len, seq_len])
        :param memory_mask: mask for the source sequence (shape: [seq_len, seq_len])
        :return: Embedding of the target sequence (shape: [seq_len, batch_size, d_model])
        """
        # tgt = self.positional_encoding(tgt)
        for layer in self.layers:
            tgt, masked_self_attn_weights, cross_attn_weights = layer(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
        tgt = self.output_projection(tgt)
        return tgt

In [None]:
# Image Captioning Model
class ImageClassificationModel(torch.nn.Module):
    def __init__(self, encoder: Encoder, decoder: Decoder, image_size: int, vocab_size: int, max_seq_length: int,
                    d_model: int=512, n_head: int=8, d_ffn: int=2048, dropout: float=0.1):
        """
        Image Captioning Model
        :param encoder: Encoder model
        :param decoder: Decoder model
        :param vocab_size: size of the vocabulary
        """
        super(ImageClassificationModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.vocab_size = vocab_size
        self.max_seq_length = max_seq_length
        self.d_model = d_model
        self.n_head = n_head
        self.d_ffn = d_ffn
        self.dropout = dropout
        self.image_size = image_size
        
    def forward(self, image, tgt, tgt_mask=None):
        """
        Forward pass of the Image Captioning Model
        :param image: Tensor, shape [batch_size, 3, 224, 224]
        :param tgt: Tensor, shape [seq_len, batch_size]
        :param tgt_mask: mask for the target sequence (shape: [seq_len, seq_len])
        :return: Tensor, shape [seq_len, batch_size, vocab_size]
        """
        # Encode the image
        image_features = self.encoder(image)
        # Decode the image features
        output = self.decoder(tgt, image_features, tgt_mask=tgt_mask)
        return output
        
    def create_model(self, device, vocab_size, max_seq_length, d_model=512, n_head=8, d_ffn=2048, dropout=0.1):
        """
        Create the Image Captioning Model
        :param device: device to run the model on
        :param vocab_size: size of the source vocabulary
        :param max_seq_length: size of the target vocabulary
        :param d_model: dimension of the vectors flowing through the model (generally it is d_model = embed_size = 512)
        :param n_head: number of heads in the multi-head attention models (generally n_head = 8)
        :param d_ffn: dimension of the feed forward neural network (generally d_ffn = 2048)
        :param dropout: probability that a neuron will be turned off during training (generally dropout = 0.1)
        :return: Image Captioning Model
        """
        # Create encoder
        encoder = Encoder(embed_size=512)
        # Create decoder
        decoder = Decoder(vocab_size, d_model, n_head, d_ffn, num_layers=6, max_seq_length=max_seq_length, dropout=dropout)
        # Create Image Captioning Model
        model = ImageClassificationModel(encoder, decoder, image_size=224, vocab_size=vocab_size, max_seq_length=max_seq_length)
        # Move the model to the device
        model = model.to(device)
        return model
    
        
        
