- [Decoder for Text Generation](https://machinelearningmastery.com/beam-search-decoder-natural-language-processing/)
- [Beam search using heap](https://geekyisawesome.blogspot.com/2016/10/using-beam-search-to-generate-most.html)
- [BeamSearch](https://yashk2810.github.io/Image-Captioning-using-InceptionV3-and-Beam-Search/)

# MobileNetV2

In [48]:
"""
Creates a MobileNetV2 model as defined in the paper: M. Sandler, 
A. Howard, M. Zhu, A. Zhmoginov, L.-C. Chen. "MobileNetV2: Inverted 
Residuals and Linear Bottlenecks.", arXiv:1801.04381, 2018."

Code reference: https://github.com/tonylins/pytorch-mobilenet-v2
ImageNet pretrained weights: https://drive.google.com/file/d/1jlto6HRVD3ipNkAl1lNhDbkBp7HylaqR
"""
import math
import torch
import torch.nn as nn



def conv_bn(inp, oup, stride):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


def conv_1x1_bn(inp, oup):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.BatchNorm2d(oup),
        nn.ReLU6(inplace=True)
    )


class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio):
        super(InvertedResidual, self).__init__()
        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )
        else:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.ReLU6(inplace=True),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
            )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)


class MobileNetV2(nn.Module):
    def __init__(self, n_class=1000, input_size=224, width_mult=1.):
        super(MobileNetV2, self).__init__()
        block = InvertedResidual
        input_channel = 32
        last_channel = 1280
        interverted_residual_setting = [
            # t, c, n, s
            [1, 16, 1, 1],
            [6, 24, 2, 2],
            [6, 32, 3, 2],
            [6, 64, 4, 2],
            [6, 96, 3, 1],
            [6, 160, 3, 2],
            [6, 320, 1, 1],
        ]

        # building first layer
        assert input_size % 32 == 0
        input_channel = int(input_channel * width_mult)
        self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
        self.features = [conv_bn(3, input_channel, 2)]
        # building inverted residual blocks
        for t, c, n, s in interverted_residual_setting:
            output_channel = int(c * width_mult)
            for i in range(n):
                if i == 0:
                    self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
                else:
                    self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
                input_channel = output_channel
        # building last several layers
        self.features.append(conv_1x1_bn(input_channel, self.last_channel))
        # make it nn.Sequential
        self.features = nn.Sequential(*self.features)

        # building classifier
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.last_channel, n_class),
        )

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = x.mean(3).mean(2)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()
                
def MobileNet(pretrained=True, **kwargs):
    """
    Constructs a MobileNet V2 model.
    
    Parameters
    ----------
    pretrained: bool, use ImageNet pretrained model or not.
    n_class: int, 1000 classes in ImageNet data.
    weight_file: str, path to pretrained weights
    """
    weight_file = kwargs.pop('weight_file', '')
    model = MobileNetV2(**kwargs)
    if pretrained:
        state_dict = torch.load(weight_file)
        model.load_state_dict(state_dict)
    return model

# Load weights pretrained on ImageNet data using function
model = MobileNet(pretrained=True, n_class=1000, weight_file='./mobilenet_v2.pth.tar')

# Encoder

In [49]:
class EncoderCNN(nn.Module):
    """
    Convolutional Neural Network (MobileNetV2) that encodes input image 
    into encoded feature representations.
    """
    def __init__(self, weight_file, feature_size=14, tune_layer=None, finetune=False):
        """
        Parameters
        ----------
        weight_file: str, path to MobileNetV2 pretrained weights.
        feature_size: int, encoded feature map size to be used.
        tune_layer: int, tune layers from this layer onwards. For
            MobileNetV2 select integer from 0 (early) to 18 (final)
        finetune: bool, fine tune layers
        """
        super(EncoderCNN, self).__init__()
        self.weight_file = weight_file
        self.feature_size = feature_size
        self.tune_layer = tune_layer
        self.finetune = finetune
        self.pretrained = True
        
        # MobileNetV2 pretrained on ImageNet
        cnn = MobileNet(pretrained=self.pretrained, weight_file=self.weight_file)
        
        # Remove classification layer
        modules = list(cnn.children())[:-1]
        self.cnn = nn.Sequential(*modules)
        
        # Resize feature maps to fixed size to allow input images of variable size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((self.feature_size, self.feature_size))
        
        # Fine-tune
        self.fine_tune()
        
    def forward(self, images):
        """
        Parameters
        ----------
        images: PyTorch tensor, size: [M, 3, H, W]
        """
        features = self.cnn(images) # size: [M, 1280, H/32, W/32]
        features = self.adaptive_pool(features) # size: [M, 1280, fs, fs]
        features = features.permute(0, 2, 3, 1) # size: [M, fs, fs, 1280]
        return features
    
    def fine_tune(self):
        """
        Fine-tuning CNN.
        """
        # Disable gradient computation
        for param in self.cnn.parameters():
            param.requires_grad = False
            
        # Enable gradient computation for few layers
        for child in list(self.cnn.children())[0][self.tune_layer:]:
            for param in child.parameters():
                param.requires_grad = self.finetune

# Attention Mechanism

In [50]:
class AttentionMechanism(nn.Module):
    """
    Attention Mechanism.
    """
    def __init__(self, encoder_size, decoder_size, attention_size):
        """
        Parameters
        ----------
        encoder_size: int, number of channels in encoder CNN output feature
            map (for MobileNetV2 it is 1280)
        decoder_size: int, number of features in the hidden state, i.e. LSTM 
            output size
        attention_size: int, size of MLP used to compute attention scores
        """
        super(AttentionMechanism, self).__init__()
        self.encoder_size = encoder_size
        self.decoder_size = decoder_size
        self.attention_size = attention_size
        
        # Linear layer to transform encoded features to attention size
        self.encoder_attn = nn.Linear(in_features=self.encoder_size, 
                                      out_features=self.attention_size)
        
        # Linear layer to transform decoders (hidden state) output to attention size
        self.decoder_attn = nn.Linear(in_features=self.decoder_size, 
                                      out_features=self.attention_size)
        
        # ReLU non-linearity
        self.relu = nn.ReLU()
        
        # Linear layer to compute attention scores at time t for L locations
        self.fc_attn = nn.Linear(in_features=self.attention_size, out_features=1)
        
        # Softmax layer to compute attention weights
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, encoder_out, decoder_out):
        """
        Parameters
        ----------
        encoder_out: PyTorch tensor, size: [M, L, D] where, L is feature
            map locations, and D is channels of encoded CNN feature map.
        decoder_out: PyTorch tensor, size: [M, h], where h is hidden
            dimension of the previous step output from decoder
            
        NOTE: M is batch size. k is attention size (see comments)
        
        Returns
        -------
        attn_weighted_encoding: PyTorch tensor, size: [M, D], attention weighted 
            annotation vector
        alpha: PyTorch tensor, size: [M, L], attention weights 
        """
        enc_attn = self.encoder_attn(encoder_out)  # size: [M, L, k]
        dec_attn = self.decoder_attn(decoder_out)  # size: [M, k]
        
        enc_dec_sum = enc_attn + dec_attn.unsqueeze(1)  # size: [M, L, k]
        
        # Compute attention scores for L locations at time t (Paper eq: 4)
        attn_scores = self.fc_attn(self.relu(enc_dec_sum))  # size: [M, L]
        
        # Compute for each location the probability that location i is the right 
        # place to focus for generating next word (Paper eq: 5)
        alpha = self.softmax(attn_scores.squeeze(2))  # size: [M, L]
        
        # Compute attention weighted annotation vector (Paper eq: 6)
        attn_weighted_encoding = torch.sum(encoder_out * alpha.unsqueeze(2), dim=1) # size: [M, D]
        
        return attn_weighted_encoding, alpha

# Decoder

In [51]:
class DecoderAttentionRNN(nn.Module):
    """
    RNN (LSTM) decoder to decode encoded images and generate sequences.
    """
    def __init__(self, encoder_size, decoder_size, attention_size, embedding_size, vocab_size, dropout_prob=0.5):
        """
        encoder_size: int, number of channels in encoder CNN output feature
            map (for MobileNetV2 it is 1280)
        decoder_size: int, number of features in the hidden state, i.e. LSTM 
            output size
        attention_size: int, size of MLP used to compute attention scores
        embedding_size: int, size of embedding
        vocab_size: int, vocabulary size
        dropout: float, dropout probability
        """
        super(DecoderAttentionRNN, self).__init__()
        self.encoder_size = encoder_size
        self.decoder_size = decoder_size
        self.attention_size = attention_size
        self.embedding_size = embedding_size
        self.vocab_size = vocab_size
        self.dropout_prob = dropout_prob
        
        # Create attention mechanism
        self.attention = AttentionMechanism(self.encoder_size, self.decoder_size, self.attention_size)
        
        # Create embedding layer
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)  # size: [V, E]
        
        # Create dropout module
        self.dropout = nn.Dropout(p=self.dropout_prob)
        
        # Create LSTM cell (uses for loop) for decoding
        self.rnn = nn.LSTMCell(input_size=self.embedding_size + self.encoder_size, 
                               hidden_size=self.decoder_size, bias=True)
        
        # MLPs for LSTM cell's initial states
        self.init_h = nn.Linear(self.encoder_size, self.decoder_size)
        self.init_c = nn.Linear(self.encoder_size, self.decoder_size)
        
        # MLP to compute beta (gating scalar, paper section 4.2.1)
        self.f_beta = nn.Linear(self.decoder_size, 1) # scalar
        
        # Sigmoid to compute beta
        self.sigmoid = nn.Sigmoid()
        
        # FC layer to compute scores over vocabulary
        self.fc = nn.Linear(self.decoder_size, self.vocab_size)
        
    def init_lstm_states(self, encoder_out):
        """
        Initialize LSTM's initial hidden and cell memory states based on encoded
        feature representation. NOTE: Encoded feature map locations mean is used.
        """
        # Compute mean of encoder output locations
        mean_encoder_out = torch.mean(encoder_out, dim=1)  # size: [M, L, D] -> [M, D]
        
        # Initialize LSTMs hidden state
        h0 = self.init_h(mean_encoder_out)  # size: [M, h]
        
        # Initialize LSTMs cell memory state
        c0 = self.init_c(mean_encoder_out)  # size: [M, h]
        
        return h0, c0
    
    def forward(self, encoder_out, encoded_captions, caption_lengths):
        """
        Parameters
        ----------
        encoder_out: PyTorch tensor, size: [M, fs, fs, D] where, fs is feature
            map size, and D is channels of encoded CNN feature map.
        encoded_captions: PyTorch long tensor
        caption_lengths: PyTorch tensor
        """
        batch_size = encoder_out.size(0)
        
        # Flatten encoded feature maps from size [M, fs, fs, D] to [M, L, D]
        encoder_out = encoder_out.view(batch_size, -1, self.encoder_size)
        num_locations = encoder_out.size(1)
        
        # Sort caption lengths in descending order
        caption_lengths, sorted_idx = torch.sort(caption_lengths.squeeze(1), dim=0, 
                                                 descending=True)
        
        # Compute decode lengths to decode. Sequence generation ends when <END> token
        # is generated. A typical caption is [<START>, ..., <END>, <PAD>, <PAD>], caption
        # lengths only considers [<START>, ..., <END>], so when <END> is generated there
        # is no need to decode further. Decode lengths = caption lengths - 1
        decode_lengths = (caption_lengths - 1).tolist()
        
        # Sort encoded feature maps and captions as per caption lengths. REASON: Since a 
        # batch contains different caption lengths (and decode lengths). At each time step 
        # up to max decode length T in a batch we need to apply attention mechanism to only 
        # those images in batch whose decode length is greater than current time step
        encoder_out = encoder_out[sorted_idx]
        encoded_captions = encoded_captions[sorted_idx]
        
        # Get embeddings for encoded captions
        embeddings = self.embedding(encoded_captions) # size: [M, T, E]
        
        # Initialize LSTM's states
        h, c = self.init_lstm_states(encoder_out) # sizes: [M, h], [M, h]
        
        # Compute max decode length
        T = int(max(decode_lengths))
        
        # Create placeholders to store predicted scores and alphas (alphas for doubly stochastic attention)
        pred_scores = torch.zeros(batch_size, T, self.vocab_size) # size: [M, T, V]
        alphas = torch.zeros(batch_size, T, num_locations) # size: [M, T, L]
        
        # Decoding step: (1) Compute attention weighted encoding and attention weights
        # using encoder output, and initial hidden state; (2) Generate a new encoded word
        for t in range(T):
            # Compute batch size at step t (At step t how many decoding lengths are greater than t)
            batch_size_t = sum([dl > t for dl in decode_lengths])
            
            # Encoder output and encoded captions are already sorted by caption lengths
            # in descending order. So based on the number of decoding lengths that are 
            # greater than current t, extract data from encoded output and initial hidden state
            # as input to attention mechanism. 
            attn_weighted_encoding, alpha = self.attention(encoder_out[:batch_size_t],
                                                           h[:batch_size_t])
                        
            # Compute gating scalar beta (paper section: 4.2.1)
            beta_t = self.sigmoid(self.f_beta(h[:batch_size_t])) # size: [M, 1]
                        
            # Multiply gating scalar beta to attention weighted encoding
            context_vector = beta_t * attn_weighted_encoding  # size: [M, D]
                        
            # Concatenate embeddings and context vector, size: [M, E] and [M, D] -> [M, E+D]
            concat_input = torch.cat([embeddings[:batch_size_t, t, :], context_vector], dim=1) # size: [M, E+D]
                        
            # LSTM input states from time step t-1
            previous_states = (h[:batch_size_t], c[:batch_size_t])
                        
            # Generate decoded word
            h, c = self.rnn(concat_input, previous_states)
            
            # Compute scores over vocabulary
            scores = self.fc(self.dropout(h)) # size: [M, V]
            
            # Populate placeholders for predicted scores and alphas
            pred_scores[:batch_size_t, t, :] = scores
            alphas[:batch_size_t, t, :] = alpha # alpha size: [M, L]
            
        return pred_scores, encoded_captions, decode_lengths, alphas, sorted_idx

# Helpers

In [52]:
import json

def read_json(json_path):
    with open(json_path, 'r') as j:
        json_data = json.load(j)
    return json_data

# Inference w/ Beam Search Version-1

In [53]:
import os
from queue import PriorityQueue

import numpy as np

import imageio
from skimage.transform import resize

import torch
import torch.nn.functional as F
import torchvision.transforms as transforms

import operator
import matplotlib.pyplot as plt

- https://github.com/DeepRNN/image_captioning/blob/master/base_model.py
- https://github.com/budzianowski/PyTorch-Beam-Search-Decoding

In [20]:
class BeamSearchNode(object):
    """
    Modified from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding 
    to include attention weights.
    """
    def __init__(self, lstm_states, alpha, prev_node, word_idx, log_prob, length):
        """
        lstm_states: tuple of PyTorch tensor output from LSTM decoder
        alpha: PyTorch tensor, attention weights
        prev_node: PriorityQueue object
        word_idx: int, word index in word2idx dictionary
        log_prob: float, log of softmax value
        length: int, length of sequence so far
        """
        self.lstm_states = lstm_states
        self.alpha = alpha
        self.prev_node = prev_node
        self.word_idx = word_idx
        self.log_prob = log_prob
        self.length = length
        
    def eval(self, gamma=1.0):
        """
        Compute priority number for the node
        """
        reward = 0
        return self.log_prob / float(self.length - 1 + 1e-6) + gamma * reward


class GenerateCaption(object):
    
    def __init__(self, config, beam_width=3):
        self.config = config
        self.beam_width = beam_width
        self.word2idx = self.read_json(self.config.word2idx_file)
        self.idx2word = {idx:word for word, idx in self.word2idx.items()}
        
        # Encoder
        self.encoder = EncoderCNN(weight_file=self.config.encoder_path)
        
        # Decoder encoder_size, decoder_size, attention_size, embedding_size, vocab_size
        decoder = DecoderAttentionRNN(encoder_size=self.config.encoder_size, 
                                      decoder_size=self.config.decoder_size, 
                                      attention_size=self.config.attention_size, 
                                      embedding_size=self.config.embedding_size, 
                                      vocab_size=len(self.word2idx))
        
        decoder.load_state_dict(torch.load(self.config.decoder_path))
        self.decoder = decoder
        
    # Helper Methods        
    def read_json(self, file):
        with open(file, 'r') as f:
            data = json.load(f)
        return data
    
    def read_preprocess_image(self, img_path):
        img = imageio.imread(img_path)
        
        # If image is gray scale then add channels
        if len(img.shape) == 2:
            img = img[:, :, np.newaxis]
            img = np.concatenate([img, img, img], axis=2)
            
        # Resize image
        img_resize = resize(img, (224, 224), mode='constant', anti_aliasing=True)
        img_resize = img_resize.transpose(2, 0, 1)  # PyTorch: [C, W, H]
        
        # Image tensor
        img = torch.FloatTensor(img_resize)

        # Normalize image
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
        
        transform = transforms.Compose([normalize])
        img = transform(img)  # size: [3, 224, 224]
        
        img = img.unsqueeze(0)  # size: [1, 3, 224, 224]
        
        return img
    
    def decode(self, encoder_output, h_prev, c_prev, embedding_t):
        
        # Attention weighted encoding and alpha
        attn_wtd_encoding, alpha = self.decoder.attention(encoder_output, h_prev)
        
        # Gating scalar beta
        beta_t = self.decoder.sigmoid(self.decoder.f_beta(h_prev))
        
        # Context vector
        context_vector = beta_t * attn_wtd_encoding
        
        # Concatenate init embedding with context vector
        concat_input = torch.cat([embedding_t, context_vector], dim=1) # size: [1, 1536]
        
        # Run RNN and compute scores
        h, c = self.decoder.rnn(concat_input, (h_prev, c_prev))
        scores = self.decoder.fc(h)  # size: [1, 9490]
        
        # Compute Log Softmax of scores
        log_probs = F.log_softmax(scores, dim=1)
        
        return h, c, log_probs, alpha
        
    def beam_search(self, img_path):
        """
        Modified from: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding 
        to include attention weights for Show, Attend and Tell.
        """
        # Read input image
        img = self.read_preprocess_image(img_path)
        
        # Encode input image
        encoder_output = self.encoder(img)  # size: [1, 14, 14, 1280]
        encoder_output = encoder_output.view(1, -1, self.config.encoder_size)  # size: [1, 196, 1280]
        num_locations = encoder_output.size(1)  # 196
        
        # LSTM initial hidden states
        h, c = self.decoder.init_lstm_states(encoder_output)
        lstm_states = (h, c)
        
        # Decoder input: start with <START> token
        start_idx = torch.LongTensor([self.word2idx['<START>']])
        embedding_t = self.decoder.embedding(start_idx) # size: [1, 256]
        
        # 
        top_k = 1  # Number of sentences to generate
        end_nodes = []
        num_generate = min((top_k + 1), top_k - len(end_nodes))
        
        # Create starting node
        node = BeamSearchNode(lstm_states=lstm_states, alpha=None, prev_node=None, 
                              word_idx=start_idx, log_prob=0, length=1)
        
        # Create Queue that retrieves open entries in priority order (lowest first)
        nodes = PriorityQueue()
        
        # Start the queue
        nodes.put((-node.eval(), node))
        q_size = 1
        
        # Start beam search
        while True:
            
            # Give up when decoding takes too long
            if q_size > 100:
                break
                
            # Fetch the best node
            priority_number, best_node = nodes.get()
            
            h, c = best_node.lstm_states
            word_idx = best_node.word_idx
            embedding_t = self.decoder.embedding(word_idx)
            
            if best_node.word_idx.item() == self.word2idx['<END>'] and best_node.prev_node != None:
                end_nodes.append(priority_number, best_node)
                
                # If we reach maximum number of sentences to generate
                if len(end_nodes) >= num_generate:
                    break
                else:
                    continue
                    
            # Decode step
            h, c, log_probs, alpha = self.decode(encoder_output, h, c, embedding_t)
            k_log_probs, indices = torch.topk(log_probs, k=self.beam_width)
            
            next_nodes = []
            for b in range(self.beam_width):
                decoded_t = indices[0][b].view(1)
                log_prob = k_log_probs[0][b].item()
                node = BeamSearchNode(lstm_states=(h, c), alpha=alpha, prev_node=best_node, 
                                      word_idx=decoded_t, log_prob=best_node.log_prob + log_prob, 
                                      length=best_node.length + 1)
                priority_number = -node.eval()
                next_nodes.append((priority_number, node))
                
            # Put next nodes into queue
            for n in range(len(next_nodes)):
                priority_number, next_node = next_nodes[n]
                nodes.put((priority_number, next_node))
                
            # Increase queue size
            q_size += len(next_nodes) - 1
            
        # Choose beam width best paths and back trace
        if len(end_nodes) == 0:
            end_nodes = [nodes.get() for _ in range(top_k)]
            
        captions = []
        alphas = []
        for priority_number, best_node in sorted(end_nodes, key=operator.itemgetter(0)):
            _caption = []
            _alpha = []
            _caption.append(best_node.word_idx)
            _alpha.append(best_node.alpha)
            # Back trace
            while best_node.prev_node != None:
                best_node = best_node.prev_node
                _caption.append(best_node.word_idx)
                _alpha.append(best_node.alpha)
            _caption = _caption[::-1]
            _alpha = _alpha[::-1]
            captions.append(_caption)
            alphas.append(_alpha)
            
        caption = [c.item() for c in captions[0]]
        caption = [self.idx2word[idx] for idx in caption]
            
        return caption, [a for a in alphas[0]]
    
class Config(object):
    def __init__(self):
        # Encoder parameters
        # ------------------
        self.encoder_path = '/home/ankoor/caption/mobilenet_v2.pth.tar'

        # Decoder parameters
        # ------------------
        self.decoder_path = '/home/ankoor/caption/checkpoints/DecoderAttentionLSTM.pth'
        self.encoder_size = 1280  # MobileNetV2 output channels (do not change!) 2048 for ResNet
        self.decoder_size = 512  # LSTM output size (hidden state vector size)
        self.attention_size = 512  # Size of MLP used to compute attention scores
        self.embedding_size = 256  # Word embedding size
        self.dropout_prob = 0.5

        # Word to index mapping
        # ---------------------
        self.word2idx_file = './WORD2IDX_COCO.json'

In [21]:
config = Config()
captioner = GenerateCaption(config)

caption, alphas = captioner.beam_search('./test.jpg')
print(' '.join(caption))

<START> credit furred furred furred streak streak stemware googles furred furred furred furred streak streak kisses furred furred streak streak kisses furred furred streak streak kisses furred furred streak streak kisses furred furred streak streak kisses furred furred furred streak streak kisses furred furred furred streak streak


### Scratch

In [None]:
class BeamSearchNode(object):
    """
    Ref: https://github.com/budzianowski/PyTorch-Beam-Search-Decoding
    """
    def __init__(self, hidden_state, prev_node, word_idx, log_prob, length):
        """
        hidden_state: tuple of PyTorch tensor or PyTorch tensor, output from 
            LSTM decoder
        prev_node: PriorityQueue object
        word_idx: int, word index in word2idx dictionary
        log_prob: float, log of softmax value
        length: int, length of sequence?
        """
        self.h = hidden_state
        self.prev_node = prev_node
        self.word_idx = word_idx
        self.log_prob = log_prob
        self.length = length
        
    def eval(self, alpha=1.0):
        """
        Length normalization?
        """
        reward = 0
        return self.log_prob / float(self.length - 1 + 1e-6) + alpha * reward

In [None]:
beam_width = 3  # Beam width

# Read word-index mapping
word2idx = read_json('./WORD2IDX_COCO.json')
idx2word = {idx: word for word, idx in word2idx.items()}

In [None]:
# Prepare input image
img_path = './example.jpg'
img = imageio.imread(img_path)

# If image is gray scale then add channels
if len(img.shape) == 2:
    img = img[:, :, np.newaxis]
    img = np.concatenate([img, img, img], axis=2)
    
# Resize image and return it
img_raw = resize(img, (224, 224), mode='constant', anti_aliasing=True)
img = img_raw.transpose(2, 0, 1)  # PyTorch: [C, W, H]

# Image tensor
img = torch.FloatTensor(img)

# Normalize image
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])

transform = transforms.Compose([normalize])
IMG = transform(img)  # size: [3, 224, 224]

IMG = IMG.unsqueeze(0)  # size: [1, 3, 224, 224]
plt.imshow(img_raw)

In [None]:
# Load Encoder
weight_path = './mobilenet_v2.pth.tar'
ENCODER = EncoderCNN(weight_file=weight_path)

In [None]:
# Load Decoder
encoder_size = 1280  
decoder_size = 512
attention_size = 512
embedding_size = 256
vocab_size = len(word2idx)

weight_file = './checkpoints/DecoderAttentionLSTM.pth'
DECODER = DecoderAttentionRNN(encoder_size, decoder_size, attention_size, embedding_size, vocab_size)
decoder_state_dict = torch.load(weight_file)
DECODER.load_state_dict(decoder_state_dict)

In [None]:
# Encode input image
encoder_out = ENCODER(IMG)  # size: [1, 14, 14, 1280]
encoder_size = encoder_out.size(-1)
encoder_out = encoder_out.view(1, -1, encoder_size)  # size: [1, 196, 1280]
num_locations = encoder_out.size(1)  # 196

In [None]:
# LSTM initial hidden states
h, c = DECODER.init_lstm_states(encoder_out)
print('h shape: ', h.shape)
print('c shape: ', c.shape)

In [None]:
# Attention weighted encoding and alpha
attn_wtd_encoding, alpha = DECODER.attention(encoder_out, h)
print('attn_wtd_encoding shape: ', attn_wtd_encoding.shape)
print('alpha shape: ', alpha.shape)

In [None]:
# Gating scalar beta
beta_t = DECODER.sigmoid(DECODER.f_beta(h))
print('beta_t: ', beta_t)

# Context vector
context_vector = beta_t * attn_wtd_encoding
print('context vector shape: ', context_vector.shape)

In [None]:
# Decoder input: start with <START> token
start_idx = torch.LongTensor([word2idx['<START>']])
init_embedding = DECODER.embedding(start_idx) # size: [1, 256]

# Concatenate init embedding with context vector
concat_input = torch.cat([init_embedding, context_vector], dim=1) # size: [1, 1536]

In [None]:
# Run RNN and compute scores
h, c = DECODER.rnn(concat_input, (h, c))
scores = DECODER.fc(h)  # size: [1, 9490]

In [None]:
# Compute Log Softmax of scores
log_probs = F.log_softmax(scores, dim=1)

In [None]:
decoder_hidden = (h, c)

node = BeamSearchNode(hidden_state=decoder_hidden, 
                      prev_node=None,
                      word_idx=start_idx,
                      log_prob=0, 
                      length=1)

nodes = PriorityQueue()

# Start the queue
nodes.put((-node.eval(), node))
q_size = 1

In [None]:
# Fetch the best node
s, n = nodes.get()

In [None]:
print(s, n)

In [None]:
end_nodes = []
topK = 1
num_required = min((topK + 1), topK - len(end_nodes))
print('num_required: ', num_required)
init_idx = n.word_idx
init_h = n.h[0]

if n.word_idx.item() == word2idx['<START>'] and n.prev_node != None:
    end_nodes.append((s, n))
    
#     if len(end_nodes) >= num_required:
#         break
#     else:
#         continue


In [None]:
h, c = DECODER.rnn(concat_input, (h, c))
scores = DECODER.fc(h)
log_probs = F.log_softmax(scores, dim=1)

In [None]:
log_prob, indices = torch.topk(log_probs, k=beam_width)
next_nodes = []
print(indices.shape)
print(log_prob.shape)

In [None]:
indices[0], log_prob[0]

In [None]:
for i in range(beam_width):
    decoded_t = indices[0][i].view(1)
    print(decoded_t)
    log_p = log_prob[0][i].item()
    node = BeamSearchNode((h, c), n, decoded_t, n.log_prob + log_p, n.length + 1)
    s = -node.eval()
    next_nodes.append((s, node))

In [None]:
indices[0][i].view(1, -1).shape

In [None]:
start_idx.shape

In [None]:
# Put nodes in queue
for i in range(len(next_nodes)):
    s, nn = next_nodes[i]
    nodes.put((s, nn))

In [None]:
q_size += len(next_nodes) - 1

In [None]:
if len(end_nodes) == 0:
    end_nodes = [nodes.get() for _ in range(topK)]

In [None]:
caption = []

In [None]:
for lp, n in sorted(end_nodes, key=operator.itemgetter(0)):
    c = []
    c.append(n.word_idx)
    
    while n.prev_node != None:
        n = n.prev_node
        c.append(n.word_idx)
    
    c = c[::-1]
    caption.append(c)

In [None]:
caption

In [None]:
[t.item() for t in caption[0]]

In [None]:
[idx2word[idx] for idx in [t.item() for t in caption[0]]]

In [None]:
alpha.shape

# Inference w/ Beam Search Version-2 (not sure if works correctly)

In [None]:
import heapq

class Beam(object):
    """
    Modified to include alpha, original source: 
    https://geekyisawesome.blogspot.com/2016/10/using-beam-search-to-generate-most.html
    
    For comparison of prefixes, the tuple (log_prob, complete_seq) is used. This is so 
    that if 2 prefixes have equal log probabilities then a complete sequence is preferred
    over an incomplete one since (-1.5, False) < (-1.5, True)
    """
    def __init__(self, beam_width):
        self.heap = list()
        self.beam_width = beam_width
        
    def add(self, log_prob, complete, prefix, alpha, lstm_states):
        heapq.heappush(self.heap, (log_prob, complete, prefix, alpha, lstm_states))
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)
            
    def __iter__(self):
        return iter(self.heap)
    
class Config(object):
    def __init__(self):
        # Encoder parameters
        # ------------------
        self.encoder_path = './mobilenet_v2.pth.tar'

        # Decoder parameters
        # ------------------
        self.decoder_path = './checkpoints/DecoderAttentionLSTM.pth'
        self.encoder_size = 1280  # MobileNetV2 output channels (do not change!) 2048 for ResNet
        self.decoder_size = 512  # LSTM output size (hidden state vector size)
        self.attention_size = 512  # Size of MLP used to compute attention scores
        self.embedding_size = 256  # Word embedding size
        self.dropout_prob = 0.5

        # Word to index mapping
        # ---------------------
        self.word2idx_file = './WORD2IDX_COCO.json'

In [None]:
class GenerateCaption(object):
    
    def __init__(self, config, beam_width=3):
        self.config = config
        self.beam_width = beam_width
        self.word2idx = self.read_json(self.config.word2idx_file)
        self.idx2word = {idx:word for word, idx in self.word2idx.items()}
        
        # Encoder
        self.encoder = EncoderCNN(weight_file=self.config.encoder_path)
        
        # Decoder encoder_size, decoder_size, attention_size, embedding_size, vocab_size
        decoder = DecoderAttentionRNN(encoder_size=self.config.encoder_size, 
                                      decoder_size=self.config.decoder_size, 
                                      attention_size=self.config.attention_size, 
                                      embedding_size=self.config.embedding_size, 
                                      vocab_size=len(self.word2idx))
        
        decoder.load_state_dict(torch.load(self.config.decoder_path))
        self.decoder = decoder
        
    # Helper Methods        
    def read_json(self, file):
        with open(file, 'r') as f:
            data = json.load(f)
        return data
    
    def read_preprocess_image(self, img_path):
        img = imageio.imread(img_path)
        
        # If image is gray scale then add channels
        if len(img.shape) == 2:
            img = img[:, :, np.newaxis]
            img = np.concatenate([img, img, img], axis=2)
            
        # Resize image
        img_resize = resize(img, (224, 224), mode='constant', anti_aliasing=True)
        img_resize = img_resize.transpose(2, 0, 1)  # PyTorch: [C, W, H]
        
        # Image tensor
        img = torch.FloatTensor(img_resize)

        # Normalize image
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
        
        transform = transforms.Compose([normalize])
        img = transform(img)  # size: [3, 224, 224]
        
        img = img.unsqueeze(0)  # size: [1, 3, 224, 224]
        
        return img
    
    def decode(self, encoder_output, h_prev, c_prev, embedding_t):
        
        # Attention weighted encoding and alpha
        attn_wtd_encoding, alpha = self.decoder.attention(encoder_out, h_prev)
        
        # Gating scalar beta
        beta_t = self.decoder.sigmoid(self.decoder.f_beta(h_prev))
        
        # Context vector
        context_vector = beta_t * attn_wtd_encoding
        
        # Concatenate init embedding with context vector
        concat_input = torch.cat([embedding_t, context_vector], dim=1) # size: [1, 1536]
        
        # Run RNN and compute scores
        h, c = self.decoder.rnn(concat_input, (h_prev, c_prev))
        scores = self.decoder.fc(h)  # size: [1, 9490]
        
        # Compute Log Softmax of scores
        log_probs = F.log_softmax(scores, dim=1)
        
        return h, c, log_probs, alpha
        
    def beam_search(self, img_path, clip_len=16):
        # Read and preprocess input image
        img = self.read_preprocess_image(img_path)
        
        # Encode input image
        encoder_output = self.encoder(img)  # size: [1, 14, 14, 1280]
        encoder_output = encoder_output.view(1, -1, self.config.encoder_size)  # size: [1, 196, 1280]
        num_locations = encoder_output.size(1) 
        
        # LSTM initial hidden states
        init_h, init_c = self.decoder.init_lstm_states(encoder_out)
        lstm_states = (init_h, init_c)
        
        # Decoder input: start with <START> token
        prefix = torch.LongTensor([self.word2idx['<START>']])
        alpha = torch.ones(1, num_locations)
        
        prev_beam = Beam(beam_width=self.beam_width)
        prev_beam.add(0.0, False, [prefix], [alpha], [lstm_states])
        
        while True:
            curr_beam = Beam(beam_width=self.beam_width)
            
            for prefix_log_prob, complete, prefix, alpha, lstm_states in prev_beam:
                if complete == True:
                    curr_beam.add(prefix_log_prob, True, prefix, alpha, lstm_states)
                else:
                    # Decode
                    embedding_t = self.decoder.embedding(prefix[-1]) # size: [1, 256]
                    h_t, c_t, log_probs_t, alpha_t = self.decode(encoder_output, *lstm_states[-1], embedding_t)
                    k_log_probs, k_word_idx = torch.topk(log_probs_t, k=self.beam_width)
                    
                    for i in range(self.beam_width):
                        if k_word_idx[0][i] == self.word2idx['<END>']:
                            curr_beam.add(prefix_log_prob + k_log_probs[0][i].item(), True, prefix, alpha, lstm_states)
                        else:
                            new_prefix = torch.LongTensor([k_word_idx[0][i].item()])
                            curr_beam.add(prefix_log_prob + k_log_probs[0][i].item(), False, prefix + [new_prefix],
                                          alpha + [alpha_t], lstm_states + [(h_t, c_t)])
                            
            best_log_prob, best_complete, best_prefix, best_alpha, _ = max(curr_beam)
            if best_complete == True or len(best_prefix)-1 == clip_len:
                return best_prefix[1:], best_alpha[1:]
            
            prev_beam = curr_beam

In [58]:
import heapq

class Beam(object):
    """
    Modified to include alpha, original source: 
    https://geekyisawesome.blogspot.com/2016/10/using-beam-search-to-generate-most.html
    
    For comparison of prefixes, the tuple (log_prob, complete_seq) is used. This is so 
    that if 2 prefixes have equal log probabilities then a complete sequence is preferred
    over an incomplete one since (-1.5, False) < (-1.5, True)
    """
    def __init__(self, beam_width):
        self.heap = list()
        self.beam_width = beam_width
        
    def add(self, prob, complete, prefix, alpha, lstm_states):
        heapq.heappush(self.heap, (prob, complete, prefix, alpha, lstm_states))
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)
            
    def __iter__(self):
        return iter(self.heap)

class GenerateCaption(object):
    
    def __init__(self, config, beam_width=3):
        self.config = config
        self.beam_width = beam_width
        self.word2idx = self.read_json(self.config.word2idx_file)
        self.idx2word = {idx:word for word, idx in self.word2idx.items()}
        
        # Encoder
        self.encoder = EncoderCNN(weight_file=self.config.encoder_path)
        
        # Decoder encoder_size, decoder_size, attention_size, embedding_size, vocab_size
        decoder = DecoderAttentionRNN(encoder_size=self.config.encoder_size, 
                                      decoder_size=self.config.decoder_size, 
                                      attention_size=self.config.attention_size, 
                                      embedding_size=self.config.embedding_size, 
                                      vocab_size=len(self.word2idx))
        
        decoder.load_state_dict(torch.load(self.config.decoder_path))
        self.decoder = decoder
        
    # Helper Methods        
    def read_json(self, file):
        with open(file, 'r') as f:
            data = json.load(f)
        return data
    
    def read_preprocess_image(self, img_path):
        img = imageio.imread(img_path)
        
        # If image is gray scale then add channels
        if len(img.shape) == 2:
            img = img[:, :, np.newaxis]
            img = np.concatenate([img, img, img], axis=2)
            
        # Resize image
        img_resize = resize(img, (224, 224), mode='constant', anti_aliasing=True)
        img_resize = img_resize.transpose(2, 0, 1)  # PyTorch: [C, W, H]
        
        # Image tensor
        img = torch.FloatTensor(img_resize)

        # Normalize image
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                         std=[0.229, 0.224, 0.225])
        
        transform = transforms.Compose([normalize])
        img = transform(img)  # size: [3, 224, 224]
        
        img = img.unsqueeze(0)  # size: [1, 3, 224, 224]
        
        return img
    
    def decode(self, encoder_output, h_prev, c_prev, embedding_t):
        
        # Attention weighted encoding and alpha
        attn_wtd_encoding, alpha = self.decoder.attention(encoder_output, h_prev)
        
        # Gating scalar beta
        beta_t = self.decoder.sigmoid(self.decoder.f_beta(h_prev))
        
        # Context vector
        context_vector = beta_t * attn_wtd_encoding
        
        # Concatenate init embedding with context vector
        concat_input = torch.cat([embedding_t, context_vector], dim=1) # size: [1, 1536]
        
        # Run RNN and compute scores
        h, c = self.decoder.rnn(concat_input, (h_prev, c_prev))
        scores = self.decoder.fc(h)  # size: [1, 9490]
        
        # Compute Log Softmax of scores
        probs = F.softmax(scores, dim=1)
        
        return h, c, probs, alpha
        
    def beam_search(self, img_path, clip_len=20):
        # Read and preprocess input image
        img = self.read_preprocess_image(img_path)
        
        # Encode input image
        encoder_output = self.encoder(img)  # size: [1, 14, 14, 1280]
        encoder_output = encoder_output.view(1, -1, self.config.encoder_size)  # size: [1, 196, 1280]
        num_locations = encoder_output.size(1) 
        
        # LSTM initial hidden states
        init_h, init_c = self.decoder.init_lstm_states(encoder_output)
        lstm_states = (init_h, init_c)
        
        # Decoder input: start with <START> token
        prefix = torch.LongTensor([self.word2idx['<START>']])
        alpha = torch.ones(1, num_locations)
        
        prev_beam = Beam(beam_width=self.beam_width)
        prev_beam.add(1.0, False, [prefix], [alpha], [lstm_states])
        
        while True:
            curr_beam = Beam(beam_width=self.beam_width)
            
            for prefix_prob, complete, prefix, alpha, lstm_states in prev_beam:
                if complete == True:
                    curr_beam.add(prefix_prob, True, prefix, alpha, lstm_states)
                else:
                    # Decode
                    embedding_t = self.decoder.embedding(prefix[-1]) # size: [1, 256]
                    h_t, c_t, probs_t, alpha_t = self.decode(encoder_output, *lstm_states[-1], embedding_t)
                    k_probs, k_word_idx = torch.topk(probs_t, k=self.beam_width)
                    
                    for i in range(self.beam_width):
                        if k_word_idx[0][i] == self.word2idx['<END>']:
                            curr_beam.add(prefix_prob * k_probs[0][i].item(), True, prefix, alpha, lstm_states)
                        else:
                            new_prefix = torch.LongTensor([k_word_idx[0][i].item()])
                            curr_beam.add(prefix_prob * k_probs[0][i].item(), False, prefix + [new_prefix],
                                          alpha + [alpha_t], lstm_states + [(h_t, c_t)])
                            
            best_prob, best_complete, best_prefix, best_alpha, _ = max(curr_beam)
            if best_complete == True or len(best_prefix)-1 == clip_len:
                return best_prefix[1:], best_alpha[1:]
            
            prev_beam = curr_beam
            
    def greedy_search(self, img_path, max_length=20):
        # Read input image
        img = self.read_preprocess_image(img_path)
        
        # Encode input image
        encoder_output = self.encoder(img)  # size: [1, 14, 14, 1280]
        
        # Flatten the encoded feature map
        encoder_output = encoder_output.view(1, -1, self.config.encoder_size)  # size: [1, 196, 1280]
        
        # LSTM initial hidden states
        h, c = self.decoder.init_lstm_states(encoder_output)
        
        # Decoder input: start with <START> token
        word_idx = torch.LongTensor([self.word2idx['<START>']])
        embedding_t = self.decoder.embedding(word_idx) # size: [1, 256]
        
        # Decode
        encoded_caption = []
        caption_alphas = []
        for t in range(max_length):
            h, c, log_probs, alpha = self.decode(encoder_output, h, c, embedding_t)
            top_log_prob, top_idx = torch.topk(log_probs, k=1, dim=1)
            encoded_caption.append(top_idx.item())
            caption_alphas.append(alpha)
            embedding_t = self.decoder.embedding(torch.LongTensor([top_idx.item()]))
            
        caption = [self.idx2word[i] for i in encoded_caption]
            
        return caption, caption_alphas

In [59]:
config = Config()
captioner = GenerateCaption(config)

caption, alphas = captioner.greedy_search('./test.jpg')

print(' '.join(caption))

credit furred furred furred streak streak stemware googles furred furred furred furred streak streak kisses furred furred streak streak kisses


In [62]:
caption, alphas = captioner.beam_search('./test.jpg')
captioner.idx2word[6427]

'elements'

### Scratch

In [None]:
def decode(decoder, encoder_output, h_prev, c_prev, embedding_t):
    # Attention weighted encoding and alpha
    attn_wtd_encoding, alpha = decoder.attention(encoder_output, h_prev)

    # Gating scalar beta
    beta_t = decoder.sigmoid(decoder.f_beta(h_prev))

    # Context vector
    context_vector = beta_t * attn_wtd_encoding
    print(context_vector.shape, embedding_t.shape)

    # Concatenate init embedding with context vector
    concat_input = torch.cat([embedding_t, context_vector], dim=1) # size: [1, 1536]

    # Run RNN and compute scores
    h, c = decoder.rnn(concat_input, (h_prev, c_prev))
    scores = decoder.fc(h)  # size: [1, 9490]

    # Compute Log Softmax of scores
    log_probs = F.log_softmax(scores, dim=1)

    return h, c, log_probs, alpha

In [None]:
beam_width = 3  # Beam width

# Read word-index mapping
word2idx = read_json('./WORD2IDX_COCO.json')
idx2word = {idx: word for word, idx in word2idx.items()}

# Prepare input image
img_path = './example.jpg'
img = imageio.imread(img_path)

# If image is gray scale then add channels
if len(img.shape) == 2:
    img = img[:, :, np.newaxis]
    img = np.concatenate([img, img, img], axis=2)
    
# Resize image and return it
img_raw = resize(img, (224, 224), mode='constant', anti_aliasing=True)
img = img_raw.transpose(2, 0, 1)  # PyTorch: [C, W, H]

# Image tensor
img = torch.FloatTensor(img)

# Normalize image
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                 std=[0.229, 0.224, 0.225])

transform = transforms.Compose([normalize])
IMG = transform(img)  # size: [3, 224, 224]

IMG = IMG.unsqueeze(0)  # size: [1, 3, 224, 224]
plt.imshow(img_raw)

In [None]:
# Load Encoder
weight_path = './mobilenet_v2.pth.tar'
ENCODER = EncoderCNN(weight_file=weight_path)

# Load Decoder
encoder_size = 1280  
decoder_size = 512
attention_size = 512
embedding_size = 256
vocab_size = len(word2idx)

weight_file = './checkpoints/DecoderAttentionLSTM.pth'
DECODER = DecoderAttentionRNN(encoder_size, decoder_size, attention_size, embedding_size, vocab_size)
decoder_state_dict = torch.load(weight_file)
DECODER.load_state_dict(decoder_state_dict)

# Encode input image
encoder_out = ENCODER(IMG)  # size: [1, 14, 14, 1280]
encoder_size = encoder_out.size(-1)
encoder_out = encoder_out.view(1, -1, encoder_size)  # size: [1, 196, 1280]
num_locations = encoder_out.size(1)  # 196

In [None]:
# LSTM initial hidden states 
init_h, init_c = DECODER.init_lstm_states(encoder_out)
lstm_states = (init_h, init_c)

# Decoder input: start with <START> token
prefix = torch.LongTensor([word2idx['<START>']])
alpha = torch.ones(1, num_locations)

In [None]:
# Decoder input: start with <START> token
prefix = torch.LongTensor([word2idx['<START>']])
alpha = torch.ones(1, num_locations)

In [None]:
# Create beam
prev_beam = Beam(beam_width=beam_width)

In [None]:
# Add initial data to beam
prev_beam.add(0.0, False, [prefix], [alpha], [lstm_states])

In [None]:
for prefix_log_prob, complete, prefix, alpha, lstm_states in prev_beam:
    print(prefix_log_prob)
    print(complete)
    print(prefix)
    print(prefix[-1].shape)
    print(alpha[-1].shape)
    print(lstm_states[-1][0].shape)
    print(lstm_states[-1][1].shape)
    print('----' * 10)

In [None]:
# While loop
curr_beam = Beam(beam_width=beam_width)

for prefix_log_prob, complete, prefix, alpha, lstm_states in prev_beam:
    if complete == True:
        curr_beam.add(prefix_log_prob, True, prefix, alpha, lstm_states)
    else:
        embedding_t = DECODER.embedding(prefix[-1])
        h_t, c_t, log_probs_t, alpha_t = decode(DECODER, encoder_out, *lstm_states[-1], embedding_t)
        k_log_probs, k_word_idx = torch.topk(log_probs_t, k=beam_width)
        
        for i in range(beam_width):
            if k_word_idx[0][i] == word2idx['<END>']:
                curr_beam.add(prefix_log_prob + k_log_probs[0][i].item(), True, prefix, alpha, lstm_states)
            else:
                curr_beam.add(prefix_log_prob + k_log_probs[0][i].item(), False, prefix + [k_word_idx[0][i].view(1)],
                              alpha + [alpha_t], lstm_states + [(h_t, c_t)])
            
best_log_prob, best_complete, best_prefix, best_alpha, _ = max(curr_beam)

prev_beam = curr_beam

In [None]:
for prefix_log_prob, complete, prefix, alpha, lstm_states in prev_beam:
    print(prefix_log_prob)
    print(complete)
    print(prefix)
    print(prefix[-1])
    print(alpha[-1].shape)
    print(lstm_states[-1][0].shape)
    print(lstm_states[-1][1].shape)
    print('----' * 10)

In [None]:
' '.join([idx2word[i.item()] for i in best_prefix])