## "Show and Tell" to perform image captioning for the visually impaired

In [1]:
## Importing necessary libraries

import torch as tor
import torch.nn as nn
import torchvision as tv

In [2]:
class image_encoder(nn.Module):
    
    """
        Class to encode image to an information rich feature map.
    """
    
    def __init__(self,feature_map_size = 14):
        
        """
            Constructor of class: image_encoder.
        """
        
        super().__init__()
        
        self.feature_map_size = feature_map_size
        
        ## load a pretrained resnet-101 model  
        resnet101 = tv.models.resnet101(pretrained = True)
        
        ## exclude the classification FC layers
        encoder_modules = list(resnet101.children())[:-2]
        
        self.res_encoder = nn.Sequential(*encoder_modules)
        
        ## perform adaptive average pooling on the feature map
        self.adapt_pool = nn.AdaptiveAvgPool2d((feature_map_size,feature_map_size))
        
    def forward(self,x):
        """
            Method to implement the forward propogation of the encoder system.
        """
        out = self.res_encoder(x)
        out = self.adapt_pool(out)
        out = out.permute(0,3,1,2)
        
        return out

In [3]:
class att_block(nn.Module):
    
    """
        Class to implement the attention block with the encoder framework.
    """
    
    def __init__(self,encoder_dim, attention_dim,decoder_dim):
        
        """
            Constructor for the attention block.
        """
        
        super().__init__()
        
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  
        self.full_att = nn.Linear(attention_dim, 1)  

    def forward(self,encoder_out, decoder_hidden):
        
        """
            Method to implement the forward propogation of the attention block.
        """
        
        att_e = self.encoder_att(encoder_out)  
        att_d = self.decoder_att(decoder_hidden)  
        att_full = self.full_att(self.relu(att_e + att_d.unsqueeze(1))).squeeze(2)  
        alpha = self.softmax(att_full)  
        attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  

        return attention_weighted_encoding, alpha

In [None]:
class decoder_attention(nn.Module):
    
    """
    Class to implement the complete network with the attention based encoder framework and the decoder pipeline.
    """
    
    def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, encoder_dim=2048, dropout=0.5):
        
        """
            Constructor for decoder_attention.
        """
        
        super().__init__()

        self.encoder_dim = encoder_dim
        self.attention_dim = attention_dim
        self.embed_dim = embed_dim
        self.decoder_dim = decoder_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = att_block(encoder_dim,attention_dim,decoder_dim) 

        self.embedding = nn.Embedding(vocab_size, embed_dim)  
        self.dropout = nn.Dropout(p=self.dropout)
        self.decode_step = nn.LSTMCell(embed_dim + encoder_dim, decoder_dim, bias=True)  
        self.init_h = nn.Linear(encoder_dim, decoder_dim)  
        self.init_c = nn.Linear(encoder_dim, decoder_dim)  
        self.f_beta = nn.Linear(decoder_dim, encoder_dim)  
        self.sigmoid = nn.Sigmoid()
        self.fc = nn.Linear(decoder_dim, vocab_size)  
        self.init_weights()  

    def init_weights(self):
        
        """
            Initialization of weights.
        """
        
        self.embedding.weight.data.uniform_(-0.1, 0.1)
        self.fc.bias.data.fill_(0)
        self.fc.weight.data.uniform_(-0.1, 0.1)

    def load_pretrained_embeddings(self, embeddings):
        
        """
            Loading the pre-trained embeddings.
        """
        
        self.embedding.weight = nn.Parameter(embeddings)

    def init_hidden_state(self, encoder_out):
        
        """
            Initialization of the hidden state.
        """
        
        mean_encoder_out = encoder_out.mean(dim=1)
        h = self.init_h(mean_encoder_out)  # (batch_size, decoder_dim)
        c = self.init_c(mean_encoder_out)
        return h, c

    def forward(self, encoder_out, encoded_captions, caption_lengths):

        """
            Method to implement the forward propogration of the entire network.
        """
        
        batch_size = encoder_out.size(0)
        encoder_dim = encoder_out.size(-1)
        vocab_size = self.vocab_size

        # Flatten image
        encoder_out = encoder_out.view(batch_size, -1, encoder_dim)  
        num_pixels = encoder_out.size(1)

        # Sort input data by decreasing lengths
        caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
        encoder_out = encoder_out[sort_ind]
        encoded_captions = encoded_captions[sort_ind]

        # Embedding
        embeddings = self.embedding(encoded_captions)  # (batch_size, max_caption_length, embed_dim)

        # Initialize LSTM state
        h, c = self.init_hidden_state(encoder_out)  # (batch_size, decoder_dim)

        decode_lengths = (caption_lengths - 1).tolist()

        # Create tensors to hold word predicion scores and alphas
        predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
        alphas = torch.zeros(batch_size, max(decode_lengths), num_pixels).to(device)

        for t in range(max(decode_lengths)):
            batch_size_t = sum([l > t for l in decode_lengths])
            attention_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                                h[:batch_size_t])
            gate = self.sigmoid(self.f_beta(h[:batch_size_t]))  # gating scalar, (batch_size_t, encoder_dim)
            attention_weighted_encoding = gate * attention_weighted_encoding
            h, c = self.decode_step(
                torch.cat([embeddings[:batch_size_t, t, :], attention_weighted_encoding], dim=1),
                (h[:batch_size_t], c[:batch_size_t]))  # (batch_size_t, decoder_dim)
            preds = self.fc(self.dropout(h))  # (batch_size_t, vocab_size)
            predictions[:batch_size_t, t, :] = preds
            alphas[:batch_size_t, t, :] = alpha

        return predictions, encoded_captions, decode_lengths, alphas, sort_ind

In [None]:
## Initialize data loaders

trainset = torch.data.utils.dataset(train_dir)
valset = torch.data.utils.dataset(val_dir)

trainloader = torch.data.utils.DataLoader(trainset,batch_size = 32,shuffle = True)
valloader = torch.data.utils.DataLoader(valset,batch_size = 32,shuffle = True)

In [None]:
decoder = decoder_attention(attention_dim=attention_dim,embed_dim=emb_dim,decoder_dim=decoder_dim,vocab_size=len(word_map),dropout=dropout)
decoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, decoder.parameters()),lr=decoder_lr)

encoder = image_encoder()
encoder.fine_tune(fine_tune_encoder)
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),lr=encoder_lr) if fine_tune_encoder else None

In [None]:
decoder = decoder.to(device)
encoder = encoder.to(device)

# Loss function
criterion = nn.CrossEntropyLoss().to(device)

# Custom dataloaders
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

train_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

val_loader = torch.utils.data.DataLoader(
    CaptionDataset(data_folder, data_name, 'VAL', transform=transforms.Compose([normalize])),
    batch_size=batch_size, shuffle=True, num_workers=workers, pin_memory=True)

# Epochs
for epoch in range(start_epoch, epochs):

    # Decay learning rate if there is no improvement for 8 consecutive epochs, and terminate training after 20
    if epochs_since_improvement == 20:
        break
    if epochs_since_improvement > 0 and epochs_since_improvement % 8 == 0:
        adjust_learning_rate(decoder_optimizer, 0.8)
        if fine_tune_encoder:
            adjust_learning_rate(encoder_optimizer, 0.8)

    # One epoch's training
    train(train_loader=train_loader,
          encoder=encoder,
          decoder=decoder,
          criterion=criterion,
          encoder_optimizer=encoder_optimizer,
          decoder_optimizer=decoder_optimizer,
          epoch=epoch)

    # One epoch's validation
    recent_bleu4 = validate(val_loader=val_loader,
                            encoder=encoder,
                            decoder=decoder,
                            criterion=criterion)