In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from collections import Counter
from skimage import io, transform

from torchsummary import summary

import matplotlib.pyplot as plt # for plotting
import numpy as np

In [2]:
VOCAB = {}

### Image Transforms

In [3]:
class Rescale(object):
    """Rescale the image in a sample to a given size.

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """

    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size

    def __call__(self, image):
        h, w = image.shape[:2]
        if isinstance(self.output_size, int):
            if h > w:
                new_h, new_w = self.output_size * h / w, self.output_size
            else:
                new_h, new_w = self.output_size, self.output_size * w / h
        else:
            new_h, new_w = self.output_size

        new_h, new_w = int(new_h), int(new_w)
        img = transform.resize(image, (new_h, new_w))
        return img


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, image):
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        image = image.transpose((2, 0, 1))
        return image


IMAGE_RESIZE = (256, 256)
# Sequentially compose the transforms
img_transform = transforms.Compose([Rescale(IMAGE_RESIZE), ToTensor()])


### Captions Preprocessing

In [4]:
class CaptionsPreprocessing:
    """Preprocess the captions, generate vocabulary and convert words to tensor tokens
    Args:
        captions_file_path (string): captions tsv file path
    """
    def __init__(self, captions_file_path):
        self.captions_file_path = captions_file_path

        # Read raw captions
        self.raw_captions_dict = self.read_raw_captions()

        # Preprocess captions
        self.captions_dict = self.process_captions()

        # Create vocabulary
        self.start = "<start>"
        self.end = "<end>"
        self.oov = "<unk>"
        self.pad = "<pad>"
        self.vocab = self.generate_vocabulary()
        self.word2index = self.convert_word2index()        
        self.index2word = self.convert_index2word()
        self.embed = nn.Embedding(len(self.vocab), embedding_dim)
        self.max_len_caption = 50

    def read_raw_captions(self):
        """
        Returns:
            Dictionary with raw captions list keyed by image ids (integers)
        """
        captions_dict = {}
        with open(self.captions_file_path, 'r', encoding='utf-8') as f:
            for img_caption_line in f.readlines():
                img_captions = img_caption_line.strip().split('\t')
                captions_dict[int(img_captions[0])] = img_captions[1:]

        return captions_dict 

    def process_captions(self):
        """
        Use this function to generate dictionary and other preprocessing on captions
        """

        raw_captions_dict = self.raw_captions_dict 
        
        # Do the preprocessing here                
        captions_dict = raw_captions_dict

        return captions_dict

 

    def generate_vocabulary(self):
        """
        Use this function to generate dictionary and other preprocessing on captions
        """
        captions_dict = self.captions_dict

        # Generate the vocabulary
        
        all_captions = ""        
        for cap_lists in captions_dict.values():
            all_captions += " ".join(cap_lists)
        all_captions = all_captions.lower().replace(".", "").split(" ")
        
        vocab = {self.pad :1, self.oov :1, self.start :1, self.end :1}
        vocab_update = Counter(all_captions) 
        vocab_update = {k:v for k,v in vocab_update.items() if v >= freq_threshold}
        vocab.update(vocab_update)
        vocab_size = len(vocab)   
        VOCAB.update(vocab)
        print("VOCAB SIZE =", vocab_size)
        return vocab
    
    def convert_word2index(self):
        word2index = {}
        vocab = self.vocab
        idx = 0
        for k, v in vocab.items():
            word2index[k] = idx
            idx +=1
        
        return word2index
    
    def convert_index2word(self):
        index2word = {}
        vocab = self.vocab
        idx = 0
        
        for k, v in vocab.items():
            index2word[idx] = k
            idx +=1
        
        return index2word

 

    def captions_transform(self, img_caption_list):
        """
        Use this function to generate tensor tokens for the text captions
        Args:
            img_caption_list: List of captions for a particular image
        """
        word2index = self.word2index
        vocab = self.vocab
        #index2word = self.index2word        
        embed = self.embed
        start = self.start
        end = self.end
        oov = self.oov
        max_len_caption = self.max_len_caption
                
        processed_list = list(map(lambda x: start + " "+ x + " " + end, img_caption_list))
        processed_list = list(map(lambda x: x.lower().replace(".", "").split(" "), processed_list))
        processed_list = list(map(lambda x: list(map(lambda y: word2index[y] if y in vocab else word2index[oov],x)),
                                  processed_list))
        processed_list = list(map(lambda x: x + ( [0] * int(max_len_caption - len(x)) ),processed_list))
        
        # Generate tensors
        processed_list = torch.LongTensor(processed_list)
        processed_captions = embed(processed_list)   
        #print(processed_captions)    
        
        #return torch.zeros(len(img_caption_list), 10)
        return processed_captions


CAPTIONS_FILE_PATH = '../data/train_cap64.tsv'
embedding_dim = 256
freq_threshold = 5
captions_preprocessing_obj = CaptionsPreprocessing(CAPTIONS_FILE_PATH)

VOCAB SIZE = 127


In [5]:
print(len(captions_preprocessing_obj.captions_dict))

64


### Dataset Class

In [6]:
class ImageCaptionsDataset(Dataset):

    def __init__(self, img_dir, captions_dict, img_transform=None, captions_transform=None):
        """
        Args:
            img_dir (string): Directory with all the images.
            captions_dict: Dictionary with captions list keyed by image ids (integers)
            img_transform (callable, optional): Optional transform to be applied
                on the image sample.

            captions_transform: (callable, optional): Optional transform to be applied
                on the caption sample (list).
        """
        self.img_dir = img_dir
        self.captions_dict = captions_dict
        self.img_transform = img_transform
        self.captions_transform = captions_transform

        self.image_ids = list(captions_dict.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, 'image_{}.jpg'.format(self.image_ids[idx]))
        image = io.imread(img_name)
        captions = self.captions_dict[self.image_ids[idx]]

        if self.img_transform:
            image = self.img_transform(image)

        if self.captions_transform:
            captions = self.captions_transform(captions)

        sample = {'image': image, 'captions': captions}

        return sample

In [7]:
#ENCODER

class ResidualBlock(nn.Module):
    def __init__(self, channels, kernel_size, filters, stride=1):
        """
        Args:
            channels: Int: Number of Input channels to 1st convolutional layer
            kernel_size: integer, Symmetric Conv Window = (kernel_size, kernel_size)
            filters: python list of integers, defining the number of filters in the CONV layers of the main path
            stride: Tuple: (stride, stride)
        """
        super(ResidualBlock, self).__init__()
        F1, F2, F3 = filters
        #N, in_channels , H, W = shape
        kernel_size = (kernel_size, kernel_size)
        padding = (1,1)
        stride = (stride, stride)
        self.conv1 = nn.Conv2d(in_channels = channels, out_channels = F1, kernel_size=(1,1), stride=stride, padding=0)
        self.bn1 = nn.BatchNorm2d(F1)
        self.relu = nn.ReLU(inplace=True) 
        self.conv2 = nn.Conv2d(in_channels = F1, out_channels = F2, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn2 = nn.BatchNorm2d(F2)
        self.conv3 = nn.Conv2d(in_channels = F2, out_channels = F3, kernel_size=(1,1), stride=stride, padding=0)
        self.bn3 = nn.BatchNorm2d(F3)
        
    def forward(self, x):
        x_residual = x #backup x for residual connection
        
        #stage 1 main path
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        #print("RESI:", x.shape)
        
        #stage 2 main path
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        #print("RESI:", x.shape)
        
        #stage 3 main path
        x = self.conv3(x)
        x = self.bn3(x)
        #print("RESI:", x.shape)
        
        x += x_residual #add output with residual connection
        x = self.relu(x)
        return x
    
class ConvolutionalBlock(nn.Module):
    def __init__(self, channels, kernel_size, filters, stride=1):
        """
        Args:
            channels: Int: Number of Input channels to 1st convolutional layer
            kernel_size: integer, Symmetric Conv Window = (kernel_size, kernel_size)
            filters: python list of integers, defining the number of filters in the CONV layers of the main path
            stride: Tuple: (stride, stride)
        """
        super(ConvolutionalBlock, self).__init__()
        F1, F2, F3 = filters
        kernel_size = (kernel_size, kernel_size)
        padding = (1,1)
        stride = (stride, stride)
        
        self.conv1 = nn.Conv2d(in_channels = channels, out_channels = F1, kernel_size=(1,1), stride=stride, padding=0)
        self.bn1 = nn.BatchNorm2d(F1)
        self.relu = nn.ReLU(inplace=True) 
        self.conv2 = nn.Conv2d(in_channels = F1, out_channels = F2, kernel_size=kernel_size, stride=(1,1), padding=padding)
        self.bn2 = nn.BatchNorm2d(F2)
        self.conv3 = nn.Conv2d(in_channels = F2, out_channels = F3, kernel_size=(1,1), stride=(1,1), padding=0)
        self.bn3 = nn.BatchNorm2d(F3)
        self.conv4 = nn.Conv2d(in_channels = channels, out_channels = F3, kernel_size=(1,1), stride=stride, padding=0)
        self.bn4 = nn.BatchNorm2d(F3)
        
    def forward(self,x):
        x_residual = x #backup x for residual connection
        
        #stage 1 main path
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        #print("CONV:", x.shape)
        
        #stage 2 main path
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        #print("CONV:", x.shape)
        
        #stage 3 main path
        x = self.conv3(x)
        x = self.bn3(x)
        #print("CONV:", x.shape)
        
        #residual connection
        x_residual = self.conv4(x_residual)
        x_residual = self.bn4(x_residual)
        x += x_residual #add output with residual connection
        x = self.relu(x)
        return x
    
class ResNet50(nn.Module):
    def __init__(self, input_shape = (256, 256, 3), classes = 5):
        """
        It Implements Famous Resnet50 Architecture
        Args:
            input_shape(tuple):(callable, optional): dimensions of image sample
            classes(int):(callable, optional): Final output classes of softmax layer.
        """
        super(ResNet50, self).__init__()
        
        self.pad = nn.ZeroPad2d((1, 1, 3, 3))        
        ###STAGE1
        self.conv1 = nn.Conv2d(in_channels = 3, out_channels=64, kernel_size=(7,7), stride = (2,2), padding=1) # convolve each of our 3-channel images with 6 different 5x5 kernels, giving us 6 feature maps
        self.batch_norm1 = nn.BatchNorm2d(64) #BatchNorm
        self.pool1 = nn.MaxPool2d((3,3), stride=(2,2), padding=1, dilation=1)
        
        ###STAGE2 channels, kernel_size=3, filters, stride=1, stage
        self.conv_block1 = ConvolutionalBlock(channels = 64, kernel_size = 3, filters = [64, 64, 256],stride = 1)
        self.residual_block1 = ResidualBlock(channels = 256, kernel_size = 3, filters = [64, 64, 256])
        
        ###STAGE3 
        self.conv_block2 = ConvolutionalBlock(channels = 256, kernel_size = 3, filters = [128, 128, 512],stride = 2)
        self.residual_block2 = ResidualBlock(channels = 512, kernel_size = 3, filters = [128, 128, 512],)
        
        ###STAGE4 
        self.conv_block3 = ConvolutionalBlock(channels = 512, kernel_size = 3, filters = [256, 256, 1024], stride = 2)
        self.residual_block3 = ResidualBlock(channels = 1024, kernel_size = 3, filters = [256, 256, 1024])
        
        ###STAGE5 
        self.conv_block4 = ConvolutionalBlock(channels = 1024, kernel_size = 3, filters = [512, 512, 2048], stride = 2)
        self.residual_block4 = ResidualBlock(channels = 2048, kernel_size = 3, filters = [512, 512, 2048])
        
        self.adaptive_pool = nn.AdaptiveAvgPool2d(output_size = (1,1))
        self.fc1 = nn.Linear(in_features=2048, out_features=classes, bias = True)
        
        
    def forward(self, x):
        print("IP_SIZE:", x.shape)
        
        ###STAGE1        
        #print("\n STAGE1")
        x = self.conv1(x)
        x = self.batch_norm1(x)
        x = self.pool1(x)
        print("OP_STAGE1_SIZE:", x.shape)
        
        ###STAGE2 
        #print("\n STAGE2")
        x = self.conv_block1(x)
        x = self.residual_block1(x)
        x = self.residual_block1(x)
        print("OP_STAGE2_SIZE:", x.shape)
        
        ###STAGE3 
        #print("\n STAGE3")
        x = self.conv_block2(x)
        x = self.residual_block2(x)
        x = self.residual_block2(x)
        x = self.residual_block2(x)
        print("OP_STAGE3_SIZE:", x.shape)
        
        ###STAGE4  
        #print("\n STAGE4")
        x = self.conv_block3(x)
        x = self.residual_block3(x)
        x = self.residual_block3(x)
        x = self.residual_block3(x)
        x = self.residual_block3(x)
        x = self.residual_block3(x)
        print("OP_STAGE4_SIZE:", x.shape)
        
        ###STAGE5  
        #print("\n STAGE5")
        x = self.conv_block4(x)
        x = self.residual_block4(x)
        x = self.residual_block4(x)
        print("OP_STAGE5_SIZE:", x.shape)
        
        x = self.adaptive_pool(x)
        print("OP_ADAPTIVEPOOL_SHAPE", x.shape)
        
        x = x.view(x.size(0), -1) # Flatten Vector
        x = self.fc1(x)
        print("OP_FC1_SIZE:", x.shape)
        return x
        
        
class Encoder(nn.Module):    
    def __init__(self, embed_dim):
        """
        CNN ENCODER
        Args:
            embed_dim(int): embedding dimension ie output dimension of last FC Layer
        Returns:
            x: Feature vector of size(BatchSize, embed_dim)
        """
        super(Encoder, self).__init__()
        self.resnet50 = ResNet50(classes = embed_dim)
        
    def forward(self, x):
        return self.resnet50(x)
    
        
        

In [11]:
#DECODER

class Decoder(nn.Module):
    def __init__(self, embed_dim, lstm_hidden_size, lstm_layers = 1):
        """
        It Implements Famous Resnet50 Architecture
        Args:
            embed_dim(int): embedding dimension ie output dimension of last FC Layer
            lstm_hidden_size(int): size of hidden units of LSTM Cell
            lstm_layers(int, optional): Number of recurrent layers
        """
        super(Decoder, self).__init__()
        self.vocab_size = len(VOCAB)
        self.lstm = nn.LSTM(input_size = embed_dim, hidden_size = lstm_hidden_size,
                            num_layers = lstm_layers, batch_first = True)
        
        self.linear = nn.Linear(lstm_hidden_size, self.vocab_size)
        
    def forward(self, image_features, embedded_captions_list):
        N, I, L, E = embedded_captions_list.shape
        time_instances = L - 1
        print("Decoder Shape 1",image_features.shape, embedded_captions_list.shape)
        
        image_features = torch.Tensor.repeat_interleave(image_features, repeats=5 , dim=0)
        image_features = image_features.unsqueeze(1)
        embedded_captions_list = torch.reshape(embedded_captions_list, (N*I,L,E))
        print("Decoder Shape Processed",image_features.shape, embedded_captions_list.shape)
        
        input_lstm = torch.cat((image_features, embedded_captions_list), dim = 1)
        
        lstm_outputs, _ = self.lstm(input_lstm)        
        print("LSTM OP SHAPE", lstm_outputs.shape)
        lstm_outputs = self.linear(lstm_outputs)
        print("LSTM OP POST LINEAR SHAPE", lstm_outputs.shape)
        return lstm_outputs, input_lstm

In [12]:
class ImageCaptionsNet(nn.Module):
    def __init__(self):
        super(ImageCaptionsNet, self).__init__()

        # Define your architecture here
        
        ##CNN ENCODER RESNET-50
        
        self.Encoder = Encoder(embed_dim = embedding_dim)
        self.Decoder = Decoder(embedding_dim, units, 1)
        #self.Decoder = DecoderRNN(256, 512, len(captions_preprocessing_obj.vocab), 1)
        
        

    def forward(self, x):
        x = image_batch, captions_batch

        # Forward Propogation
        x = self.Encoder(image_batch)
        #print(x.shape)
        x = self.Decoder(x, captions_batch)
        return x
units = 512
net = ImageCaptionsNet()
net = net.double()
# If GPU training is required
# net = net.cuda()

### Training Loop

In [10]:
IMAGE_DIR = '../data/train/'

# Creating the Dataset
train_dataset = ImageCaptionsDataset(
    IMAGE_DIR, captions_preprocessing_obj.captions_dict, img_transform=img_transform,
    captions_transform=captions_preprocessing_obj.captions_transform
)

# Define your hyperparameters
NUMBER_OF_EPOCHS = 3
LEARNING_RATE = 1e-1
BATCH_SIZE = 8
NUM_WORKERS = 0 # Parallel threads for dataloading
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=LEARNING_RATE)

# Creating the DataLoader for batching purposes
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
import os
for epoch in range(NUMBER_OF_EPOCHS):
    for batch_idx, sample in enumerate(train_loader):
        net.zero_grad()

        image_batch, captions_batch = sample['image'], sample['captions']
        
        print("image_shape", image_batch.shape)
        print("batch_shape", captions_batch.shape)
        # If GPU training required
        # image_batch, captions_batch = image_batch.cuda(), captions_batch.cuda()

        output_captions, captions_batch = net((image_batch, captions_batch))
        print("size for loss", output_captions.shape, captions_batch.shape)
        loss = loss_function(output_captions, captions_batch)
        loss.backward()
        optimizer.step()
    print("Iteration: " + str(epoch + 1))

image_shape torch.Size([8, 3, 256, 256])
batch_shape torch.Size([8, 5, 50, 256])
IP_SIZE: torch.Size([8, 3, 256, 256])
OP_STAGE1_SIZE: torch.Size([8, 64, 63, 63])
OP_STAGE2_SIZE: torch.Size([8, 256, 63, 63])
OP_STAGE3_SIZE: torch.Size([8, 512, 32, 32])
OP_STAGE4_SIZE: torch.Size([8, 1024, 16, 16])
OP_STAGE5_SIZE: torch.Size([8, 2048, 8, 8])
OP_ADAPTIVEPOOL_SHAPE torch.Size([8, 2048, 1, 1])
OP_FC1_SIZE: torch.Size([8, 256])
Decoder Shape 1 torch.Size([8, 256]) torch.Size([8, 5, 50, 256])
Decoder Shape Processed torch.Size([40, 1, 256]) torch.Size([40, 50, 256])
LSTM OP SHAPE torch.Size([40, 51, 512])
LSTM OP POST LINEAR SHAPE torch.Size([40, 51, 256])
size for loss torch.Size([40, 51, 256]) torch.Size([40, 51, 256])


ValueError: Expected target size (40, 256), got torch.Size([40, 51, 256])

In [None]:
summary(net, (32,3,256,256))

In [None]:
[["abc"],["bcd"],["asdf"],["fsdfsd"],["fsdfd"]]

### Model Architecture


#### 3 - Building  first ResNet model (50 layers)
You now have the necessary blocks to build a very deep ResNet. The following figure describes in detail the architecture of this neural network. "ID BLOCK" in the diagram stands for "Identity block," and "ID BLOCK x3" means you should stack 3 identity blocks together.


The details of this ResNet-50 model are:
- Zero-padding pads the input with a pad of (3,3)
- Stage 1:
    - The 2D Convolution has 64 filters of shape (7,7) and uses a stride of (2,2). Its name is "conv1".
    - BatchNorm is applied to the 'channels' axis of the input.
    - MaxPooling uses a (3,3) window and a (2,2) stride.
- Stage 2:
    - The convolutional block uses three sets of filters of size [64,64,256], "f" is 3, "s" is 1 and the block is "a".
    - The 2 identity blocks use three sets of filters of size [64,64,256], "f" is 3 and the blocks are "b" and "c".
- Stage 3:
    - The convolutional block uses three sets of filters of size [128,128,512], "f" is 3, "s" is 2 and the block is "a".
    - The 3 identity blocks use three sets of filters of size [128,128,512], "f" is 3 and the blocks are "b", "c" and "d".
- Stage 4:
    - The convolutional block uses three sets of filters of size [256, 256, 1024], "f" is 3, "s" is 2 and the block is "a".
    - The 5 identity blocks use three sets of filters of size [256, 256, 1024], "f" is 3 and the blocks are "b", "c", "d", "e" and "f".
- Stage 5:
    - The convolutional block uses three sets of filters of size [512, 512, 2048], "f" is 3, "s" is 2 and the block is "a".
    - The 2 identity blocks use three sets of filters of size [512, 512, 2048], "f" is 3 and the blocks are "b" and "c".
- The 2D Average Pooling uses a window of shape (2,2) and its name is "avg_pool".
- The 'flatten' layer doesn't have any hyperparameters or name.
- The Fully Connected (Dense) layer reduces its input to the number of classes using a softmax activation. Its name should be `'fc' + str(classes)`.

**Exercise**: Implement the ResNet with 50 layers described in the figure above. We have implemented Stages 1 and 2. Please implement the rest. (The syntax for implementing Stages 3-5 should be quite similar to that of Stage 2.) Make sure you follow the naming convention in the text above. 

You'll need to use this function: 
- Average pooling [see reference](https://keras.io/layers/pooling/#averagepooling2d)

Here are some other functions we used in the code below:
- Conv2D: [See reference](https://keras.io/layers/convolutional/#conv2d)
- BatchNorm: [See reference](https://keras.io/layers/normalization/#batchnormalization) (axis: Integer, the axis that should be normalized (typically the features axis))
- Zero padding: [See reference](https://keras.io/layers/convolutional/#zeropadding2d)
- Max pooling: [See reference](https://keras.io/layers/pooling/#maxpooling2d)
- Fully connected layer: [See reference](https://keras.io/layers/core/#dense)
- Addition: [See reference](https://keras.io/layers/merge/#add)


In [None]:
a = [[1,2,3],[4,5,6]]
print(a.shape)
np.repeat(a,5)

In [17]:
t = torch.tensor([[[1, 2,3],[3, 4,3]],[[5, 6,4],[7, 8,7]]])
print(t)
print(t.shape)

print(t.unsqueeze(1))

t = torch.reshape(t, (4,3))
print(t.shape)
t.dim()

tensor([[[1, 2, 3],
         [3, 4, 3]],

        [[5, 6, 4],
         [7, 8, 7]]])
torch.Size([2, 2, 3])
tensor([[[[1, 2, 3],
          [3, 4, 3]]],


        [[[5, 6, 4],
          [7, 8, 7]]]])
torch.Size([4, 3])


2

In [13]:
x = torch.tensor([[1,2,3],[4,5,6]])
print(x.shape)
x = x.repeat_interleave(5, dim=0)
print(x, x.shape)

torch.Size([2, 3])
tensor([[1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [1, 2, 3],
        [4, 5, 6],
        [4, 5, 6],
        [4, 5, 6],
        [4, 5, 6],
        [4, 5, 6]]) torch.Size([10, 3])


In [None]:
x.reshape((5,5))

In [None]:
x.view(-1,5)

In [None]:
# DECODER
"""
class Attention(nn.Module):
    def __init__(self, embed_dim, units):
        super(Attention, self).__init__()        
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(1)
        self.embed_dim = embed_dim
        self.units = units
        self.U = nn.Linear(units, units)
        self.W = nn.Linear(embed_dim, units)
        self.V = nn.Linear(units, 1)
    
    def forward(self, features, hidden_state):        
        W_s = self.W(features)
        U_hidden = self.U(hidden_state).unsqueeze(1)
        attention = self.V(self.tanh(W_s + U_h)).squeeze(2)
        attention = self.softmax(e)
        context = (img_features * attention.unsqueeze(2)).sum(1)
        return context, attention
    
class Decoder(nn.Module):
    def __init__(self, embed_dim, lstm_hidden_size, lstm_layers = 1):
        """'''
        It Implements Famous Resnet50 Architecture
        Args:
            embed_dim(int): embedding dimension ie output dimension of last FC Layer
            lstm_hidden_size(int): size of hidden units of LSTM Cell
            lstm_layers(int, optional): Number of recurrent layers
        '''"""
        super(Decoder, self).__init__()
        self.tanh = nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        
        self.embed_dim = embed_dim
        self.lstm_hidden_size = lstm_hidden_size
        self.vocab_size = len(VOCAB)
        
        self.hidden_state = nn.Linear(embed_dim, lstm_hidden_size)
        self.cell_state = nn.Linear(embed_dim, lstm_hidden_size)
        
        self.attention = Attention(embed_dim, lstm_hidden_size)
        self.beta = nn.Linear(lstm_hidden_size, embed_dim)        
        
        # lstm cell
        self.lstm = nn.LSTMCell(input_size=embed_dim+lstm_hidden_size, hidden_size=lstm_hidden_size)
        self.model_output = nn.Linear(lstm_hidden_size, self.vocab_size)
        self.dropout = nn.Dropout()
        
    def forward(self, image_features, embedded_captions_list):
        N, I, L, E = embedded_captions_list.shape
        time_instances = L - 1
        print("Decoder Shape 1",image_features.shape, embedded_captions_list.shape)
        
        image_features = torch.Tensor.repeat_interleave(image_features, repeats=5 , dim=0)
        embedded_captions_list = torch.reshape(embedded_captions_list, (N*I,L,E))
        print("Decoder Shape Processed",image_features.shape, embedded_captions_list.shape)
        
        hidden_state, cell_state = self.initialize_lstm_state(image_features)
        
        alphas = torch.zeros(batch_size, max_timespan, image_features.size(1))
        predictions = torch.zeros(N, time_instances, self.vocab_size)
                
        for t in range(time_instances):
            #DO PROCESSING HERE WITH LSTM
            
            context, alpha = self.attention(image_features, hidden_state)
            gated_context = self.sigmoid(self.f_beta(hidden_state)) * context
            
            input_lstm = torch.cat((embedded_captions_list, gated_context), dim=1)
            
            hidden_state, cell_state = self.lstm(input_lstm, (hidden_state, cell_state))
            
            out = self.model_output(self.dropout(hidden_state))
            
            preds[:, t] = output
            alphas[:, t] = alpha
        
        return predictions, alphas, embedded_captions_list
    
    def initialize_lstm_state(self, img_features):
        avg_features = img_features.mean(dim=0)

        c = self.cell_state(avg_features)
        c = self.tanh(c)

        h = self.hidden_state(avg_features)
        h = self.tanh(h)

        return h, c"""

In [None]:
'''class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, num_layers=1):
        super(DecoderRNN, self).__init__()
        
        # define the properties
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        
        # lstm cell
        self.lstm_cell = nn.LSTMCell(input_size=embed_size, hidden_size=hidden_size)
    
        # output fully connected layer
        self.fc_out = nn.Linear(in_features=self.hidden_size, out_features=self.vocab_size)
    
        # embedding layer
        self.embed = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embed_size)
    
        # activations
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, features, captions):
        
        # batch size
        batch_size = features.size(0)
        
        # init the hidden and cell states to zeros
        hidden_state = torch.zeros((batch_size, self.hidden_size))
        cell_state = torch.zeros((batch_size, self.hidden_size))
    
        # define the output tensor placeholder
        outputs = torch.empty((batch_size, captions.size(1), self.vocab_size))

        # embed the captions
        #captions_embed = self.embed(captions)
        captions_embed = captions
        
        # pass the caption word by word
        for t in range(captions.size(1)):

            # for the first time step the input is the feature vector
            if t == 0:
                hidden_state, cell_state = self.lstm_cell(features, (hidden_state, cell_state))
                
            # for the 2nd+ time step, using teacher forcer
            else:
                hidden_state, cell_state = self.lstm_cell(captions_embed[:, t, :], (hidden_state, cell_state))
            
            # output of the attention mechanism
            out = self.fc_out(hidden_state)
            
            # build the output tensor
            outputs[:, t, :] = out
    
        return outputs'''