In [1]:
import heapq
import hyperparams as hp
from datasets import load_dataset
from models import Encoder, Decoder, Seq2Seq
import torch.nn.functional as F
import torch

train_iter, val_iter, test_iter, DE, EN = load_dataset(batch_size=1, device=-1)

encoder = Encoder(source_vocab_size=len(DE.vocab),
                  embed_dim=hp.embed_dim, hidden_dim=hp.hidden_dim,
                  n_layers=hp.n_layers, dropout=hp.dropout)
decoder = Decoder(target_vocab_size=len(EN.vocab),
                  embed_dim=hp.embed_dim, hidden_dim=hp.hidden_dim,
                  n_layers=hp.n_layers, dropout=hp.dropout)

In [2]:
batch = next(iter(train_iter))
start_token = batch.trg[:1]

In [3]:
encoder_out, encoder_hidden = encoder(batch.src)
decoder_hidden = encoder_hidden[-decoder.n_layers:]  # take what we need from encoder
decoder.eval()

Decoder(
  (embed): Embedding(10839, 256, padding_idx=1)
  (attention): LuongAttention(
    (W): Linear(in_features=512, out_features=512, bias=False)
  )
  (gru): GRU(768, 512, num_layers=2, dropout=0.2)
  (out): Linear(in_features=1024, out_features=10839, bias=True)
)

In [4]:
class Beam:
    def __init__(self, beam_width):
        self.heap = list()
        self.beam_width = beam_width

    def add(self, score, sequence, hidden_state):
        """
        maintains a heap of size(beam_width), always removes lowest scoring nodes.
        """
        heapq.heappush(self.heap, (score, sequence, hidden_state))
        if len(self.heap) > self.beam_width:
            heapq.heappop(self.heap)

    def __iter__(self):
        return iter(self.heap)

    def __len__(self):
        return len(self.heap)

    def __getitem__(self, idx):
        return self.heap[idx]


def beamsearch(decoder_topk, beam_size=2, maxlen=20):
    beam = Beam(beam_size)  # starting layer in search tree
    beam.add(1.0, batch.trg[0:1], decoder_hidden)  # initialize root
    for _ in range(maxlen):
        # expand next layer up to maxlen times
        next_beam = Beam(beam_size)
        # Add complete sentences that do not yet have the best probability to the current beam, the rest prepare to add more words to them.
        # for node in previous layer
        for node in beam:  # each layer will only have (beam_width) nodes
            # Get probability of each possible next word for the incomplete prefix.
            score, sequence, hidden_state = node
            next_probs, next_words, hidden_state = get_next(sequence[-1:],
                                                            hidden_state,
                                                            beam_size)
            for i in range(beam_size):
                score = score * next_probs[i]
                # add next word to sequence
                sequence = torch.cat([sequence, next_words[i]])
                next_beam.add(score, sequence, hidden_state)

        # move down one layer (to the next word in sequence up to maxlen)
        beam = next_beam
    best_score, best_sequence, _ = max(beam)  # get highest scoring sequence
    return best_score, best_sequence

In [5]:
def get_next(last_word, hidden_state, k=3):
    """
    Given the last item in a sequence and the hidden state used to generate the sequence
    return the top3 most likely words and their scores
    """
    output, hidden_state, _ = decoder(last_word, encoder_out, hidden_state)
    nex_word_probs = F.softmax(output, dim=2)
    probabilites, next_words = nex_word_probs.topk(k)
    return probabilites.squeeze().data, next_words.view(k, 1, 1), hidden_state

In [7]:
best_score, best_seq = beamsearch(decoder_topk, beam_size=3)
best_score, best_seq

## As a decoding helper

In [23]:
from beamsearch import BeamHelper

In [28]:
seq2seq = Seq2Seq(encoder, decoder)
beam_helper = BeamHelper(beam_size=3, maxlen=20)

In [29]:
%time seq2seq(batch.src, beam_helper)

CPU times: user 1.3 s, sys: 0 ns, total: 1.3 s
Wall time: 327 ms


(3.087069415258347e-79, Variable containing:
      2
    662
   6344
   2850
   4007
   2850
   6170
   6170
   6170
   7784
   8470
   1751
   1751
   3824
    487
  10693
    959
   2850
   2850
   3897
   3897
 [torch.LongTensor of size 21x1])

## Increased beam size comes at a cost
with beam_size == output vocab beam search becomes breath first search. with beam_size == 1 we get greedy search.

In [30]:
beam_helper = BeamHelper(beam_size=5, maxlen=20)
%time seq2seq(batch.src, beam_helper)

CPU times: user 1.97 s, sys: 10.6 ms, total: 1.98 s
Wall time: 499 ms


(3.087069415258347e-79, Variable containing:
      2
    662
   6344
   2850
   4007
   2850
   6170
   6170
   6170
   7784
   8470
   1751
   1751
   3824
    487
  10693
    959
   2850
   2850
   3897
   3897
 [torch.LongTensor of size 21x1])

In [31]:
beam_helper = BeamHelper(beam_size=20, maxlen=20)
%time seq2seq(batch.src, beam_helper)

CPU times: user 7.58 s, sys: 51.2 ms, total: 7.63 s
Wall time: 1.92 s


(3.087069415258347e-79, Variable containing:
      2
    662
   6344
   2850
   4007
   2850
   6170
   6170
   6170
   7784
   8470
   1751
   1751
   3824
    487
  10693
    959
   2850
   2850
   3897
   3897
 [torch.LongTensor of size 21x1])