In [86]:
# In this notebook, you learn:
#
# 1) How to implement Beam Search for Neural Machine Translation?

In [87]:
# Resources to understand Beam Search:
#
# 1) https://www.youtube.com/watch?v=RLWuzLLSIgw
#       -- Intuitive and simple explanation of Beam Search algorithm.
# 2) https://www.youtube.com/watch?v=gb__z7LlN_4
#       -- Improvements to Beam Search.
# 3) https://www.youtube.com/watch?v=ZGUZwk7xIwk
#       -- Error analysis of Beam Search.

In [88]:
# Assuming you already know about Beam Search, this is an implementation of the basic version of Beam Search for 
# Machine Translation. Lets go through the implementation logic at a high level before we dive into the code.
# 
# The input to beam search is a batch of source sequences -- [batch_size, source_sequence_length].
# Each 1D tensor in the batch represents a source sequence where each element is a token (from source vocabulary) 
# in the sequence.
#
# The output of beam search is a batch of target sequences (or translated sequences) -- [batch_size, target_sequence_length].
# Each 1D tensor in the batch represents a target sequence where each element is a token (from target vocabulary) 
# in the sequence.
#
# In Beam Search, at every position, instead of just predicting the token with the highest probability, we keep 
# track of top 'beam_width' number of tokens with the largest probabilities. So, we will have 'beam_width' 
# number of potential target sequences at each time step. The 'beam_width' is a hyperparameter that we need to 
# set before running the beam search algorithm.
#
# Now going into the details, at each time step, each source sequence can have multiple potential target 
# sequences (or tokens) since we keep track 'beam_width' number of potential target sequences at each time step.
# So, along with the target sequences, we need to keep track of the index of the source sequence this target 
# sequence (prediction) belongs to in the batch. The state of each target sequence is stored in the 
# 'SequenceState' (below in the code) object.
#
# During the beam search, some of the target sequences may reach the end of the sequence (<eos> token predicted 
# for that target sequence) before others. We need to keep track of the target sequences that have reached the 
# end of the sequence. We store the target sequences that have reached the end of the sequence in the 
# 'complete_state' dictionary (below in the code). The key of the dictionary is the index of the source sequence 
# in the source batch and the value is the 'SequenceState' object that has reached the end of the sequence (tgt) 
# and has the highest probability among all the sequences that reached the end. The list of target sequences 
# that haven't reached the end of the sequence are stored in the 'running_state' list.
#
# The beam search algorithm is iterated 'TGT_SEQ_LIMIT' number of times. At each iteration, we predict the next 
# token for each target sequence in the 'running_state' list. The 'running_state' and 'complete_state' are 
# updated based on the predicted tokens.
#
# The state update logic is as follows:
# 1) If the target sequence has reached the end of the sequence, then we store the target sequence in the 
#    'complete_state' if this target sequence has the highest probability among all the target sequences that 
#    have reached the end of the sequence.
# 2) If the target sequence hasn't reached the end of the sequence, then this target sequence is a potential 
#    target sequence to be considered for the next iteration. For each (source sequence, target sequence) pair, 
#    we retrieve 'beam_width' number of potential tokens with the highest probability from the model. We will 
#    now have a list of target sequences for a specific source sequence. We will sort this list based on the 
#    probability of the target sequences and keep only the 'beam_width' number of target sequences with the 
#    highest probability for the next iteration.
# 3) If all the target sequences for a source sequence have reached the end of the sequence, then this source 
#    sequence will not be considered for the next iteration.
# 
# Once we have iterated 'TGT_SEQ_LIMIT' number of times, we will have the 'complete_state' dictionary with the 
# target sequences that have the highest probability for most (hopefully) of the source sequence in the batch. 
# However, some of the source sequences may not have any target sequences in the 'complete_state' dictionary. 
# In this case, we will consider the target sequence with the highest probability from the 'running_state' 
# list that for these specific source sequence and add them to the 'complete_state' dictionary.
#
# Finally, we will have the target sequences with the highest probability for each source sequence in the 
# batch stored in the 'complete_state' dictionary. We will return these target sequences as the output of the 
# beam search algorithm.
#
# Back tracking a bit, lets expand on the part where we have a 'running_state' and we need to update it based 
# on the predicted tokens. As noted above, the 'running_state' might contain multiple target sequences for a 
# source sequence. We are passing these (source sequence, target sequence) pairs to the decoder part of the 
# model to get the probability distribution over the target vocabulary. So, we will have to construct the 
# input to the decoder from the 'running_state' list. The input to the decoder is a batch of source sequences 
# and target sequences. So, we bascially copy the source sequence as many times as it is used in the 
# 'running_state' list and form the source sequence batch.
# -- This might sound a bit confusing, but you will understand (hopefully) it better when you see the code. 
# src_mask is also constructed in a similar way. The target sequence batch is constructed by taking the 
# target sequences from the 'running_state' list.

In [1]:
import random
import torch

from dataclasses import dataclass
from torch import Tensor
from typing import Dict, List

In [2]:
# Size of the embeddings and other intermediate vector representations.
D_MODEL = 6
# Number of sequences in the batch.
BATCH_SIZE  = 2
# Number of tokens in the source sequences.
SRC_SEQ_LEN = 3
# Number of potential tokens to be considered for every position in the target sequences.
BEAM_WIDTH = 3
# Start of the sequence token id.
SOS_TOKEN_ID = 0
# End of the sequence token id.
EOS_TOKEN_ID = 3
# Maximum number of tokens allowed in the target sequences.
TGT_SEQ_LIMIT = 6
# Size of the target vocabulary.
TGT_VOCAB_SIZE = 10

#### Items useful for implementing Beam Search.

In [8]:
# Useful for storing the state of the beam search.

# Contains the state of the sequence that is predicted by the beam search.
@dataclass
class SequenceState:
    # Index of the source sequence for which the current target sequence has been predicted.
    index: int
    # Sequence of tokens in the target prediction. This is a 1D tensor.
    tokens: Tensor
    # Log of the probability that this is the translation for the source sequence.
    log_prob: float


# Holds the state of Beam search.
class SearchState:
    def __init__(self):
        # Holds the final translations (predictions) for the source sequences keyed by the source 
        # sequence index in the batch.
        self.complete_state: Dict[int, SequenceState] = {}
        # Holds the tgt sequences that are not complete (<eos> token not predicted yet) and are 
        # being predicted.
        self.running_state: List[SequenceState] = []

    def __repr__(self) -> str:
        return f"Complete State: {self.complete_state}\nRunning State: {self.running_state}"


In [9]:
# Generates a random mask.
def generate_random_mask(batch_size: int, seq_len: int) -> Tensor:
    src_mask = torch.randn(seq_len, seq_len) > 0 
    src_mask = src_mask.unsqueeze(0).unsqueeze(0)
    src_mask = src_mask.repeat(batch_size, 1, 1, 1)
    print("shape of src_mask: ", src_mask.shape)
    print("src_mask: \n", src_mask)
    print("-" * 150)
    return src_mask

def construct_look_ahead_mask(batch_size, seq_len: int) -> Tensor:
    attention_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.uint8), diagonal=1)
    return (attention_mask == 0).unsqueeze(0).unsqueeze(0).repeat(batch_size, 1, 1, 1)

In [11]:
# Create an example tensor that serves as encoded source and corresponding mask for experimentation.
encoded_src = torch.arange(36, dtype=torch.float32).reshape(BATCH_SIZE, SRC_SEQ_LEN, D_MODEL)
print("shape of encoded_src:", encoded_src.shape)
print("encoded_src:\n", encoded_src)
print("-" * 150)
# Generating a random mask for experimentation. This might not be a valid mask in real scenarios.
src_mask = generate_random_mask(BATCH_SIZE, SRC_SEQ_LEN)

shape of encoded_src: torch.Size([2, 3, 6])
encoded_src:
 tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.]],

        [[18., 19., 20., 21., 22., 23.],
         [24., 25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34., 35.]]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape of src_mask:  torch.Size([2, 1, 3, 3])
src_mask: 
 tensor([[[[False,  True, False],
          [False,  True,  True],
          [False, False, False]]],


        [[[False,  True, False],
          [False,  True,  True],
          [False, False, False]]]])
------------------------------------------------------------------------------------------------------------------------------------------------------


In [12]:
# Create 'SearchState' object to save the state of the beam search algorithm.
beam_state = SearchState()
# At the start of the algorithm, every src sequence will only have a single tgt sequence with the <sos> token.
# The log probability of this sequence will be 0.0 since the probability of the <sos> token is 1.0 i.e., 
# log(1.0) = 0.0. The batch size (BATCH_SIZE defined above) is 2 that means there are 2 source sequences. So, 
# we will have 2 target sequences in the running_state list.
beam_state.running_state = [SequenceState(index=idx, tokens=torch.tensor(data=[SOS_TOKEN_ID], dtype=torch.int32), log_prob=0.0) for idx in range(BATCH_SIZE)]
print("len(beam_state.running_state): ", len(beam_state.running_state))
print("beam_state.running_state: \n", beam_state.running_state)

len(beam_state.running_state):  2
beam_state.running_state: 
 [SequenceState(index=0, tokens=tensor([0], dtype=torch.int32), log_prob=0.0), SequenceState(index=1, tokens=tensor([0], dtype=torch.int32), log_prob=0.0)]


Now, we will predict the next token for each sequence in the 'running_state' list. This step will be <br>
repeated until the \<eos\> token is predicted for all the sequences in the 'running_state' list or we <br>
predict the 'TGT_SEQ_LIMIT' number of tokens for each sequence in the 'running_state' list. Lets do <br>
2 iterations of the beam search algorithm and understand how components work.

#### 1st Iteration of Beam Search starts here

In [13]:
# Prepare the tgt_batch and tgt_mask for the decoder. The tgt_batch will just be the token sequences stacked
# together for all the sequences in the running_state list. The tgt_mask will be a random mask for now. We 
# will not be using these tgt tensors in this notebook since we are not actually using the 'Decoder' but just 
# creating random data that appears as the same as Decoder output. This cell is created just to show how the 
# target_batch and target_mask input for Decoder looks like and how they are created from the running_state.
tgt_for_inference = torch.stack(tensors=[state.tokens for state in beam_state.running_state], dim=0) 
print("shape of tgt_for_inference: ", tgt_for_inference.shape)
print("tgt_for_inference: \n", tgt_for_inference)
print("-" * 150)
tgt_mask_for_inference = construct_look_ahead_mask(batch_size=len(beam_state.running_state), seq_len=tgt_for_inference.size(1))
print("shape of tgt_mask_for_inference: ", tgt_mask_for_inference.shape)
print("tgt_mask_for_inference: \n", tgt_mask_for_inference)

shape of tgt_for_inference:  torch.Size([2, 1])
tgt_for_inference: 
 tensor([[0],
        [0]], dtype=torch.int32)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape of tgt_mask_for_inference:  torch.Size([2, 1, 1, 1])
tgt_mask_for_inference: 
 tensor([[[[True]]],


        [[[True]]]])


In [15]:
# Create a random tensor for the predicted probabilities. This will be the output of the 'TokenPredictor' module 
# in the transformer and will be used to identify the next token for each sequence in the 'running_state' list.
predicted_log_probabilities = torch.log_softmax(torch.rand(tgt_for_inference.size(0), tgt_for_inference.size(1), TGT_VOCAB_SIZE), dim=-1)
print("shape of predicted_log_probabilities: ", predicted_log_probabilities.shape)
print("predicted_log_probabilities: \n", predicted_log_probabilities)
print("-" * 150)
# We only need the probabilities for the last token in the sequence. So, we will only consider the last token
# probabilities for the calculation of the log probabilities for any sequence. Please notice that we get a 2D 
# tensor after slicing the predicted_log_probabilities tensor which is originally a 3D tensor.
predicted_log_probabilities = predicted_log_probabilities[:, -1, :]
print("shape of predicted_log_probabilities: ", predicted_log_probabilities.shape)
print("predicted_log_probabilities: \n", predicted_log_probabilities)

shape of predicted_log_probabilities:  torch.Size([2, 1, 10])
predicted_log_probabilities: 
 tensor([[[-2.9145, -2.5996, -2.9133, -1.9518, -2.5989, -1.9350, -2.1859,
          -2.1319, -2.1800, -2.1863]],

        [[-2.4938, -2.3863, -1.9947, -2.7851, -2.5315, -2.2796, -2.3758,
          -1.9841, -2.4892, -2.0220]]])
------------------------------------------------------------------------------------------------------------------------------------------------------
shape of predicted_log_probabilities:  torch.Size([2, 10])
predicted_log_probabilities: 
 tensor([[-2.9145, -2.5996, -2.9133, -1.9518, -2.5989, -1.9350, -2.1859, -2.1319,
         -2.1800, -2.1863],
        [-2.4938, -2.3863, -1.9947, -2.7851, -2.5315, -2.2796, -2.3758, -1.9841,
         -2.4892, -2.0220]])


In [16]:
# Convert the probabilities tensor into an iterator (tuple here) so that we can iterate over probabilities and 
# corresponding 'SequenceState' objects in the 'running_state' list simultaneously using 'zip' function. This 
# will basically separate each 1D tensor in the 2D tensor into a separate tensor and store them in a tuple.
new_tgt_probs = torch.unbind(predicted_log_probabilities, dim=0)
print("type(new_tgt_probs): ", type(new_tgt_probs))
print("len(new_tgt_probs): ", len(new_tgt_probs))
print("new_tgt_probs: \n", new_tgt_probs)

type(new_tgt_probs):  <class 'tuple'>
len(new_tgt_probs):  2
new_tgt_probs: 
 (tensor([-2.9145, -2.5996, -2.9133, -1.9518, -2.5989, -1.9350, -2.1859, -2.1319,
        -2.1800, -2.1863]), tensor([-2.4938, -2.3863, -1.9947, -2.7851, -2.5315, -2.2796, -2.3758, -1.9841,
        -2.4892, -2.0220]))


In [17]:
# Update the beam_state object with the new target token predictions of the beam search algorithm.
new_running_state: List[SequenceState] = []
for new_tgt_prob, old_seq_state in zip(new_tgt_probs, beam_state.running_state):
    # Extract the top 3 tokens (BEAM_WIDTH) to be considered as the next token via beam search.
    top_probs, top_tokens = new_tgt_prob.topk(k=BEAM_WIDTH, dim=-1)
    print("top_probs: ", top_probs)
    print("top_tokens: ", top_tokens)
    print("-" * 150)
    # Iterate on each predicted token, create the sequence with this token appended and calculate
    # the probability of the new sequence (with token appended).
    for pred_prob, pred_token in zip(top_probs, top_tokens):
        print("pred_token: ", pred_token)
        print("old_seq_state tokens: ", old_seq_state.tokens)
        # Append the newly predicted token to the existing sequence of tokens.
        updated_token_seq = torch.cat(tensors=[old_seq_state.tokens, pred_token.unsqueeze(0).to(torch.int32)], dim=0)
        print("updated_token_seq: ", updated_token_seq)
        # The log probability of the extended sequence is the probability of the old sequence added
        # to the probability associated with the newly predicted token.
        updated_seq_prob = old_seq_state.log_prob + pred_prob.item()
        print("updated_seq_prob: ", updated_seq_prob)
        print("-" * 150)
        # Creates a new SequenceState object associated with the extended tgt sequence.
        new_state = SequenceState(index=old_seq_state.index, tokens=updated_token_seq, log_prob=updated_seq_prob)
        # If the newly predicted token is not <eos>, then this tgt sequence is not complete and
        # can be extended by predicting further tokens.
        new_running_state.append(new_state)

print("len(new_running_state): ", len(new_running_state))
print("new_running_state: \n", new_running_state)

top_probs:  tensor([-1.9350, -1.9518, -2.1319])
top_tokens:  tensor([5, 3, 7])
------------------------------------------------------------------------------------------------------------------------------------------------------
pred_token:  tensor(5)
old_seq_state tokens:  tensor([0], dtype=torch.int32)
updated_token_seq:  tensor([0, 5], dtype=torch.int32)
updated_seq_prob:  -1.9349650144577026
------------------------------------------------------------------------------------------------------------------------------------------------------
pred_token:  tensor(3)
old_seq_state tokens:  tensor([0], dtype=torch.int32)
updated_token_seq:  tensor([0, 3], dtype=torch.int32)
updated_seq_prob:  -1.951758861541748
------------------------------------------------------------------------------------------------------------------------------------------------------
pred_token:  tensor(7)
old_seq_state tokens:  tensor([0], dtype=torch.int32)
updated_token_seq:  tensor([0, 7], dtype=torch.int32

In [19]:
# Update the beam state with the new running state.
beam_state.running_state = new_running_state
print(beam_state)

Complete State: {}
Running State: [SequenceState(index=0, tokens=tensor([0, 5], dtype=torch.int32), log_prob=-1.9349650144577026), SequenceState(index=0, tokens=tensor([0, 3], dtype=torch.int32), log_prob=-1.951758861541748), SequenceState(index=0, tokens=tensor([0, 7], dtype=torch.int32), log_prob=-2.1318984031677246), SequenceState(index=1, tokens=tensor([0, 7], dtype=torch.int32), log_prob=-1.9841415882110596), SequenceState(index=1, tokens=tensor([0, 2], dtype=torch.int32), log_prob=-1.994722604751587), SequenceState(index=1, tokens=tensor([0, 9], dtype=torch.int32), log_prob=-2.0220141410827637)]


#### 2nd Iteration of Beam Search starts here

In [20]:
# In the 1st iteration, we did not handle a lot of cases that we usually run into. Lets go through each of the
# cases now:
#
# 1) We can have multiple potential target sequences in the 'running_state' for the same source sequence for which 
#    we predict the next tokens.
#       -- This will effect the src_batch, src_mask since these are passed as inputs to the Decoder. 
#       -- The src_batch in this case needs to contain the same source sequence repeated as many as times as the 
#          corresponding tgt sequences. Similarly the src_mask needs to be updated.
# 2) Some of the newly predicted tokens can be '<eos>' token and these should not be added back to the new 
#    'running_state'.
#       -- If the newly predicted token is '<eos>' token, this token should be appended to the tgt sequence and 
#          added to the list of complete state objects.
#       -- Only if the predicted token is not '<eos>' token, it needs to be added back to the 'running_state' for 
#          next iteration of predictions.
# 3) The complete_state only need to hold the sequence with the maximum probability at any point of time. So, we 
#    need to find the complete tgt sequence with max log probability and save it to the complete state. 
# 4) If we have more than 3 potential tgt sequences, we need to find out the top 3 sequences sorted by probability, 
#    add these 3 sequences to the 'running_state' and discard the additional sequences.  
#
# We will now see how to handle all these cases here as we go through the 2nd iteration of Beam Search.

In [21]:
# Handling case 1:
#
# We need to create the 'encoded_src' that can be passed as input to the Decoder. In the beam search 
# 'running_state', we have 3 SequenceState objects for which the index is 0 i.e., we have 3 potential target 
# sequences for the 0th source sequence. So, we need to repeat the encoded tokens for the 0th source 
# sequence 3 times in the 'itr_2_src_for_inference' (variable below).
#
# Similarly we have 3 SequenceState objects for which the index is 1 i.e., we have 3 potential target sequences
# for the 1st source sequence. So, we again need to repeat the encoded tokens for the 1st sequence 3 times in 
# the 'itr_2_src_for_inference' (variable below).
#
# Please note that this number 3 being the same for both source sequence 1 and source sequence 2 is just a 
# coincidence and could be different. We have to find out how many tgt sequences are in the running_state for a
# source sequence and repeat the encoded src based on that.
#
# This gives all the source sequence indices from the running_state. 
src_indices = torch.tensor(data=[state.index for state in beam_state.running_state], dtype=torch.int32) 
print("src_indices shape: ", src_indices.shape)
print("src_indices: ", src_indices)
print("-" * 150)
# Now, we just create a new src based on the original encoded_src and the src_indices. This will basically
# take the tensors from original encoded_src and copy it as many times as specified by the src_indices tensor.
itr_2_src_for_inference = torch.index_select(input=encoded_src, dim=0, index=src_indices)
print("itr_2_src_for_inference shape: ", itr_2_src_for_inference.shape)
print("itr_2_src_for_inference: ", itr_2_src_for_inference)
print("-" * 150)
# Similarly we create the source mask for inference.
itr_2_src_mask_for_inference = torch.index_select(input=src_mask, dim=0, index=src_indices)
print("itr_2_src_mask_for_inference shape: ", itr_2_src_mask_for_inference.shape)
print("itr_2_src_mask_for_inference: ", itr_2_src_mask_for_inference)

src_indices shape:  torch.Size([6])
src_indices:  tensor([0, 0, 0, 1, 1, 1], dtype=torch.int32)
------------------------------------------------------------------------------------------------------------------------------------------------------
itr_2_src_for_inference shape:  torch.Size([6, 3, 6])
itr_2_src_for_inference:  tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.]],

        [[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.]],

        [[ 0.,  1.,  2.,  3.,  4.,  5.],
         [ 6.,  7.,  8.,  9., 10., 11.],
         [12., 13., 14., 15., 16., 17.]],

        [[18., 19., 20., 21., 22., 23.],
         [24., 25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34., 35.]],

        [[18., 19., 20., 21., 22., 23.],
         [24., 25., 26., 27., 28., 29.],
         [30., 31., 32., 33., 34., 35.]],

        [[18., 19., 20., 21., 22., 23.],
       

In [22]:
# Prepare the tgt_batch and tgt_mask for the 2nd iteration of the decoder. The tgt_batch will be just 
# the token sequences stacked together for all the sequences in the running_state list. The tgt_mask 
# will be a random mask for now. We will not be using these tgt tensors since we are not actually
# using the Decoder but just creating random data that appears as the same as Decoder output. This cell
# is created just to show how the tgts are created for inference.
itr_2_tgt_for_inference = torch.stack(tensors=[state.tokens for state in beam_state.running_state], dim=0) 
print("shape of itr_2_tgt_for_inference: ", itr_2_tgt_for_inference.shape)
print("itr_2_tgt_for_inference: \n", itr_2_tgt_for_inference)
print("-" * 150)
itr_2_tgt_mask_for_inference = construct_look_ahead_mask(batch_size=len(beam_state.running_state), seq_len=itr_2_tgt_for_inference.size(1))
print("shape of itr_2_tgt_mask_for_inference: ", itr_2_tgt_mask_for_inference.shape)
print("itr_2_tgt_mask_for_inference: \n", itr_2_tgt_mask_for_inference)

shape of itr_2_tgt_for_inference:  torch.Size([6, 2])
itr_2_tgt_for_inference: 
 tensor([[0, 5],
        [0, 3],
        [0, 7],
        [0, 7],
        [0, 2],
        [0, 9]], dtype=torch.int32)
------------------------------------------------------------------------------------------------------------------------------------------------------
shape of itr_2_tgt_mask_for_inference:  torch.Size([6, 1, 2, 2])
itr_2_tgt_mask_for_inference: 
 tensor([[[[ True, False],
          [ True,  True]]],


        [[[ True, False],
          [ True,  True]]],


        [[[ True, False],
          [ True,  True]]],


        [[[ True, False],
          [ True,  True]]],


        [[[ True, False],
          [ True,  True]]],


        [[[ True, False],
          [ True,  True]]]])


In [23]:
# Create a random tensor for the predicted probabilities. This will be the output of the 'TokenPredictor' module 
# in the transformer and will be used to identify the next token for each sequence in the running_state list.
itr_2_predicted_log_probabilities = torch.log_softmax(torch.rand(itr_2_tgt_for_inference.size(0), itr_2_tgt_for_inference.size(1), TGT_VOCAB_SIZE), dim=-1)
print("shape of itr_2_predicted_log_probabilities: ", itr_2_predicted_log_probabilities.shape)
print("itr_2_predicted_log_probabilities: \n", itr_2_predicted_log_probabilities)
print("-" * 150)
# We only need the probabilities for the last token in the sequence. So, we will only consider the last token
# probabilities for the calculation of the log probabilities of the sequences. Please notice that we get a 
# 2D tensor after slicing the predicted_log_probabilities tensor which is originally a 3D tensor.
itr_2_predicted_log_probabilities = itr_2_predicted_log_probabilities[:, -1, :]
print("shape of itr_2_predicted_log_probabilities: ", itr_2_predicted_log_probabilities.shape)
print("itr_2_predicted_log_probabilities: \n", itr_2_predicted_log_probabilities)
print("-" * 150)

shape of itr_2_predicted_log_probabilities:  torch.Size([6, 2, 10])
itr_2_predicted_log_probabilities: 
 tensor([[[-2.6758, -2.1868, -2.6052, -2.3803, -2.0308, -1.9903, -2.3367,
          -2.5929, -2.5734, -1.9889],
         [-1.9587, -2.6052, -1.9610, -2.4634, -2.6421, -2.6777, -2.7243,
          -2.0137, -2.2799, -2.1379]],

        [[-1.9284, -2.5868, -2.1620, -2.2617, -2.6320, -2.7865, -2.2264,
          -2.2296, -2.4123, -2.1087],
         [-2.8518, -2.3136, -2.1585, -2.4839, -2.4208, -2.6184, -2.1139,
          -1.9845, -2.0909, -2.2941]],

        [[-2.5583, -2.0368, -2.6204, -1.8999, -2.4938, -2.2499, -2.4966,
          -2.5997, -2.4901, -1.9506],
         [-2.6331, -2.6461, -2.3672, -2.1944, -2.5626, -1.9010, -2.3785,
          -2.6150, -1.9536, -2.1378]],

        [[-1.9151, -2.1589, -2.2824, -2.7962, -1.9311, -2.6187, -2.6744,
          -2.1058, -2.5357, -2.4492],
         [-2.3691, -2.0651, -1.9704, -2.3472, -2.6806, -2.9178, -2.9207,
          -2.1407, -2.1183, -2.0379]],


In [24]:
# A given source sentence can have multiple potential target sequences if beam_width > 1. Here, we find the number 
# of target sequences currently being used (in running_state) to predict the next token for each of the source 
# sequences. We have 3 tgt sequences for source sequence 1 and 3 tgt sequences for source sequence 2. So, the 
# counts output should be tensor([3, 3]).
_, itr_2_tgt_group_counts = torch.unique(input=torch.tensor(data=[seq_state.index for seq_state in beam_state.running_state], dtype=torch.int16), return_counts=True)
print("itr_2_tgt_group_counts shape: ", itr_2_tgt_group_counts.shape)
print("itr_2_tgt_group_counts: ", itr_2_tgt_group_counts)

itr_2_tgt_group_counts shape:  torch.Size([2])
itr_2_tgt_group_counts:  tensor([3, 3])


In [25]:
# Holds the running_state i.e., the SequenceState for the incomplete token sequences at the end of iteration 2.
itr_2_new_running_state: List[SequenceState] = []

In [26]:
# Handling cases 2, 3 and 4:
#
# For each group of tgt sequences that correspond to the same src sequence, we need to find the number of 
# running sequences and the complete sequences.
#
# Index to keep track of the start of the group of tgt sequences for a single source sequence.
start_index = 0
# Iterate on each group independently to process the group.
for tgt_group_size in itr_2_tgt_group_counts:
    print("start_index: ", start_index)
    print("tgt_group_size: ", tgt_group_size)
    print("-" * 150)
    # Extract the group of SequenceState objects corresponding to a single source sequence.
    old_tgt_state_group = beam_state.running_state[start_index: start_index + tgt_group_size.item()]
    # Extract the group of probabilities for the corresponding tgt sequence token predictions and convert the 
    # tensor into a iterator (tuple here) to be used with 'zip' function below.
    new_tgt_prob_group = torch.unbind(itr_2_predicted_log_probabilities[start_index: start_index + tgt_group_size.item()], dim=0)
    print("old_tgt_state_group: ", old_tgt_state_group)
    print("new_tgt_prob_group: ", new_tgt_prob_group)
    print("-" * 150)
    # Holds all the new sequences formed after appending the token for the previous group of tgt sequences.
    running_beams: List[SequenceState] = []
    # Holds all the sequences for which the <eos> token has been predicted.
    complete_beams: List[SequenceState] = []
    # Iterate on the old SequenceState object and the corresponding next prediction probabilities for this 
    # tgt sequence.
    for old_seq_state, new_tgt_pred_probs in (zip(old_tgt_state_group, new_tgt_prob_group)):
        # Index of the source sentence for which the translations are being calculated.
        src_seq_idx = old_seq_state.index
        # Extract the top few tokens to be considered as the next token via beam search.
        top_probs, top_tokens = new_tgt_pred_probs.topk(k=BEAM_WIDTH, dim=-1)
        print("top_probs: ", top_probs)
        print("top_tokens: ", top_tokens)
        print("-" * 150)
        # Iterate on each predicted token, create the sequence with this token appended and calculate
        # the probability of the new sequence (with token appended).
        for pred_prob, pred_token in zip(top_probs, top_tokens):
            print("pred_prob: ", pred_prob)
            print("pred_token: ", pred_token)
            # Append the newly predicted token to the existing sequence of tokens.
            updated_token_seq = torch.cat(tensors=[old_seq_state.tokens, pred_token.unsqueeze(0).to(torch.int32)])
            print("updated_token_seq: ", updated_token_seq)
            # The log probability of the extended sequence is the probability of the old sequence added
            # to the probability associated with the newly predicted token.
            updated_seq_prob = old_seq_state.log_prob + pred_prob.item()
            print("updated_seq_prob: ", updated_seq_prob)
            # Creates a new SequenceState object associated with the extended tgt sequence.
            new_state = SequenceState(index=src_seq_idx, tokens=updated_token_seq, log_prob=updated_seq_prob)
            # THIS CONDITIONAL IF-ELSE BLOCK TOGETHER HANDLE CASE 2,
            if pred_token.item() == EOS_TOKEN_ID:
                # If the newly predicted token is <eos>, then this tgt sequence is complete and we add it
                # to the list of complete sequences for this specific src sequence.
                complete_beams.append(new_state)
            else:
                # If the newly predicted token is not <eos>, then this tgt sequence is not complete and
                # can be extended by predicting further tokens.
                running_beams.append(new_state)
            print("-" * 150)
    # If the newly predicted token is an <eos> token, then we remove these tgt sequences from the 
    # beam search and update the complete state for the corresponding src sequence accordingly.
    # THIS CONDITIONAL IF BLOCK HANDLES CASE 3.
    if len(complete_beams) > 0:
        # Index of the source sentence for which the translations are being calculated.
        src_seq_idx = complete_beams[0].index
        # sort the completed sequences according to their probabilities in descending order. 
        complete_beams.sort(key=lambda seq_state: seq_state.log_prob, reverse=True)
        print("complete_beams: ", complete_beams)
        if src_seq_idx in beam_state.complete_state:
            # If we found complete sequences before for this specific source sequence, we only store
            # the complete sequence for which the probability of occurence is the highest.
            if beam_state.complete_state[src_seq_idx].log_prob < complete_beams[0].log_prob:
                beam_state.complete_state[src_seq_idx] = complete_beams[0]
        else:
            # If this is the first complete sequence we found, we just store this specific sequence.
            beam_state.complete_state[src_seq_idx] = complete_beams[0]        
        print("beam_state.complete_state: ", beam_state.complete_state)
        print("-" * 150)   
    #THIS CONDITIONAL IF BLOCK HANDLES CASE 4.
    if len(running_beams) > 0:              
        # sort the running sequences according to their probabilities in descending order.
        running_beams.sort(key=lambda seq_state: seq_state.log_prob, reverse=True)
        print("running_beams: ", running_beams)
        # Add the running sequences for further predictions. Only add the first 'beam_width' number
        # of sequences.
        itr_2_new_running_state.extend(running_beams[:min(BEAM_WIDTH, len(running_beams))])  
        print("new_running_state: ", itr_2_new_running_state)
        print("-" * 150)
    # Update the start_index to the start of the next group of tgt sequences.
    start_index += tgt_group_size.item()
    print("complete_beams: ", complete_beams)
    print("running_beams: ", running_beams)
    print("-" * 150)

beam_state.running_state = itr_2_new_running_state

start_index:  0
tgt_group_size:  tensor(3)
------------------------------------------------------------------------------------------------------------------------------------------------------
old_tgt_state_group:  [SequenceState(index=0, tokens=tensor([0, 5], dtype=torch.int32), log_prob=-1.9349650144577026), SequenceState(index=0, tokens=tensor([0, 3], dtype=torch.int32), log_prob=-1.951758861541748), SequenceState(index=0, tokens=tensor([0, 7], dtype=torch.int32), log_prob=-2.1318984031677246)]
new_tgt_prob_group:  (tensor([-1.9587, -2.6052, -1.9610, -2.4634, -2.6421, -2.6777, -2.7243, -2.0137,
        -2.2799, -2.1379]), tensor([-2.8518, -2.3136, -2.1585, -2.4839, -2.4208, -2.6184, -2.1139, -1.9845,
        -2.0909, -2.2941]), tensor([-2.6331, -2.6461, -2.3672, -2.1944, -2.5626, -1.9010, -2.3785, -2.6150,
        -1.9536, -2.1378]))
------------------------------------------------------------------------------------------------------------------------------------------------------

In [27]:
# The above process (iterations) are repeated until the <eos> token is predicted for all the sequences in the
# running_state list or we predict the 'TGT_SEQ_LIMIT' number of tokens for each sequence in the running_state 
# list.

#### Preparing the final predictions

In [28]:
# Once we complete 'TGT_SEQ_LIMIT' number of iterations for token predictions for each running sequence, we end
# the beam search algorithm. The complete_state will contain the tgt sequences that end in <eos> token with the 
# maximum probability for a particular source sequence. If the <eos> token is not predicted for a tgt sequence, 
# we can just take the incomplete sequence from the running_state list with the maximum probability and add it 
# to the complete_state list.

In [29]:
# Lets manually create the complete_state and running_state for the beam_state object to show how the beam search
# algorithm ends.
beam_state_end = SearchState()

In [30]:
# We add the complete sequences for the source sequences 0, 2 and 4.
for idx in range(0, 6, 2):
    # Creating a random token sequence for the source sequence idx. Token sequence needs to start with <sos> (0) token
    # and end with <eos> (3) token.
    random_token_sequence = [0] + random.sample(range(4, 10), 4) + [3]
    beam_state_end.complete_state[idx] = SequenceState(index=idx, tokens=torch.tensor(data=random_token_sequence, dtype=torch.int32), log_prob=-random.uniform(0, 5))
    
print(beam_state_end.complete_state)

{0: SequenceState(index=0, tokens=tensor([0, 4, 5, 8, 7, 3], dtype=torch.int32), log_prob=-3.1855079746338193), 2: SequenceState(index=2, tokens=tensor([0, 8, 6, 4, 7, 3], dtype=torch.int32), log_prob=-3.9063676123287365), 4: SequenceState(index=4, tokens=tensor([0, 4, 6, 8, 7, 3], dtype=torch.int32), log_prob=-0.7895624838182663)}


In [31]:
# We add the running sequences for the source sequences 1, 3 and 5.
for idx in range(1, 6, 2):
    # Creating a random token sequence for the source sequence idx. Token sequence needs to start with <sos> (0) token
    # and ends with a token other than <eos> (3) token.
    random_token_sequence = [0] + random.sample(list(range(4, 10)) + list(range(4, 10)), 7)
    beam_state_end.running_state.append(SequenceState(index=idx, tokens=torch.tensor(data=random_token_sequence, dtype=torch.int32), log_prob=-random.uniform(0, 5)))
    random_token_sequence = [0] + random.sample(list(range(4, 10)) + list(range(4, 10)), 7)
    beam_state_end.running_state.append(SequenceState(index=idx, tokens=torch.tensor(data=random_token_sequence, dtype=torch.int32), log_prob=-random.uniform(0, 5)))
    random_token_sequence = [0] + random.sample(list(range(4, 10)) + list(range(4, 10)), 7)
    beam_state_end.running_state.append(SequenceState(index=idx, tokens=torch.tensor(data=random_token_sequence, dtype=torch.int32), log_prob=-random.uniform(0, 5)))

# We also add a running tgt sequence for the source sequence 0. This tgt sequence is incomplete and has a higher log 
# probability than the complete sequence for the source sequence 0. However, this sequence should still not be used as a 
# final translation for the source sequence 0 since it is incomplete and we have other complete tgt sequences for 
# source sequence 0.
beam_state_end.running_state.append(SequenceState(index=0, tokens=torch.tensor(data=[0, 4, 5, 6, 7, 8, 9], dtype=torch.int32), log_prob=-0.001))
print(beam_state_end.running_state)

[SequenceState(index=1, tokens=tensor([0, 6, 7, 5, 4, 8, 9, 7], dtype=torch.int32), log_prob=-0.18177964969088856), SequenceState(index=1, tokens=tensor([0, 6, 7, 5, 8, 6, 4, 9], dtype=torch.int32), log_prob=-3.11712461127309), SequenceState(index=1, tokens=tensor([0, 6, 8, 9, 8, 4, 5, 7], dtype=torch.int32), log_prob=-1.617822470253099), SequenceState(index=3, tokens=tensor([0, 8, 9, 7, 4, 5, 6, 5], dtype=torch.int32), log_prob=-0.07893848570126316), SequenceState(index=3, tokens=tensor([0, 6, 8, 5, 9, 6, 9, 7], dtype=torch.int32), log_prob=-2.4689468108769446), SequenceState(index=3, tokens=tensor([0, 6, 9, 6, 9, 5, 5, 7], dtype=torch.int32), log_prob=-3.562795768979456), SequenceState(index=5, tokens=tensor([0, 7, 6, 8, 8, 9, 7, 5], dtype=torch.int32), log_prob=-4.363238021561321), SequenceState(index=5, tokens=tensor([0, 9, 8, 8, 9, 4, 5, 6], dtype=torch.int32), log_prob=-3.1369504026578054), SequenceState(index=5, tokens=tensor([0, 8, 8, 4, 7, 5, 6, 9], dtype=torch.int32), log_pro

In [32]:
# Now, lets go through the running_state to find the source sequences for which the <eos> token has not been predicted
# in any of the corresponding target sequences. We will find the tgt sequence with the maximum probability from the 
# running_state and add it to the complete_state list.

In [33]:
dummy_tensor = torch.tensor(data=[SOS_TOKEN_ID], dtype=torch.int32)
print(dummy_tensor)

tensor([0], dtype=torch.int32)


In [34]:
# Holds the sequence with the maximum probability for a given source sequence.
max_prob_state = SequenceState(index=-1, tokens=dummy_tensor, log_prob=float('-inf'))
for state in beam_state_end.running_state:
    print("state: ", state)
    print("-" * 150)
    # If the source sequence of the current running sequence is not already in the complete sequences, 
    # that means we haven't found a complete sequence for this (identified by the index) source sequence 
    # yet. So, we hold this sequence in the max_prob_state as a potential complete sequence for the
    # source sequence identified by the index.
    if state.index not in beam_state_end.complete_state:
        # If we haven't found any complete sequence yet, we just store this sequence as the potential
        # complete sequence in the max_prob_state.
        if max_prob_state.index == -1:
            max_prob_state = state
        elif max_prob_state.index != state.index:
            # If all the running sequences have been looked at for the previous source sequence (identified 
            # by max_prob_state.index), then we just add the potential complete sequence to the list of
            # complete sequences.
            beam_state_end.complete_state[state.index] = max_prob_state
            # We also store the new running sequence as a new potential running sequence for the new src
            # sequence (identified by the index state.index).
            max_prob_state = state
        else:
            # If a new potential complete sequence is found, we pick the one with maximum probability and
            # store it as the potential complete sequence.
            if max_prob_state.log_prob < state.log_prob:
                max_prob_state = state
        print("max_prob_state: ", max_prob_state)
# Add the left out potential sequence to the list of complete sequences.
if max_prob_state.index != -1:
    beam_state_end.complete_state[max_prob_state.index] = max_prob_state
print("-" * 150)
print("beam_state_end.complete_state: ", beam_state_end.complete_state)

state:  SequenceState(index=1, tokens=tensor([0, 6, 7, 5, 4, 8, 9, 7], dtype=torch.int32), log_prob=-0.18177964969088856)
------------------------------------------------------------------------------------------------------------------------------------------------------
max_prob_state:  SequenceState(index=1, tokens=tensor([0, 6, 7, 5, 4, 8, 9, 7], dtype=torch.int32), log_prob=-0.18177964969088856)
state:  SequenceState(index=1, tokens=tensor([0, 6, 7, 5, 8, 6, 4, 9], dtype=torch.int32), log_prob=-3.11712461127309)
------------------------------------------------------------------------------------------------------------------------------------------------------
max_prob_state:  SequenceState(index=1, tokens=tensor([0, 6, 7, 5, 4, 8, 9, 7], dtype=torch.int32), log_prob=-0.18177964969088856)
state:  SequenceState(index=1, tokens=tensor([0, 6, 8, 9, 8, 4, 5, 7], dtype=torch.int32), log_prob=-1.617822470253099)
---------------------------------------------------------------------------