## Beam Search

In [1]:
import torch 
import torch.nn.functional as F

In [2]:
def beam_search(decoder, hidden, context, beam_width=3, max_length=10):
    sequences = [[list(), 1.0, hidden]]  # (sequence, score, hidden_state)
 
    for _ in range(max_length):
        all_candidates = []
        for seq, score, hidden  in sequences:
            decoder_input = torch.tensor([seq[-1] if seq else 0])
            output, hidden = decoder(decoder_input, hidden, context)
            top_probs, top_indexes = torch.topk(F.log_softmax(output, dim=1), beam_width)
            
            for i in range(beam_width):
                candidate = [seq + [top_indexes[0][i].item()], score * -top_probs[0][i].item(), hidden]
                all_candidates.append(candidate)
                
        sequences = sorted(all_candidates, key=lambda tup:tup[1], reverse=True)[:beam_width]
    return sequences[0][0]  # return the best sequence