## Attention Seq2Seq with Batching + Stacking + Residual Links

* **Task**: toy "translation" task --- translating a list of letters (from A to H) to the next-letter-list (e.g. ['A', 'B', 'C'] translates as ['B', 'C', 'D']. 
* **Type**: Luong et al. (2016). No bidirection, but has stacking and residual links (following Prakash et al. (2016)). Clear-to-the-boot step-by-step demo.
* **PyTorch Version**: 0.3.1
* **Rant**: showy people on Github write convoluted tutorial code (although efficient, sophisticated and all). Doesn't help for beginners at all! This tutorial tells you all you need to know!!

In [1]:
from __future__ import division

import unicodedata
import string
import re
import random
import time
import math
import numpy as np

from io import open

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F

### Data Prep

In [2]:
class Indexer:
    """Token-Index mapping."""
    
    def __init__(self, name):
        """
        Args:
            name: name of the indexer.
        """
        self.name = name
        self.word2index = {"SOS": 0, "EOS": 1} # str -> int
        self.index2word = {0: "SOS", 1: "EOS"}
        self.word2count = {"SOS": 0, "EOS": 0} # str -> int
        self.nWords = 2  # Count SOS and EOS
    
    def add_sentence(self, sentence):
        """Add a sentence to the dictionary.
        
        Args:
            sentence: a list of tokens (in string).
        """
        for word in sentence:
            self.add_word(word)

    def add_word(self, word):
        """Add a word to the dictionary.
        
        Args:
            word: a token (in string).
        """
        if word not in self.word2index:
            self.word2index[word] = self.nWords
            self.word2count[word] = 1
            self.index2word[self.nWords] = word
            self.nWords += 1
        else:
            self.word2count[word] += 1 
            
    def get_index(self, word):
        """Word->Index lookup.
        
        Args:
            word: a token (string).
        Returns:
            The index of the word.
        """
        return self.word2index[word] if word in self.word2index else -1
    
    def get_word(self, index):
        """Index->Word lookup.
        
        Args:
            index: index of a token.
        Returns:
            The token under the index. -1 if the index is out of bound.
        """
        return self.index2word[index] if index<self.nWords else -1
    
    def get_sentence_index(self, sentence):
        """Words->Indexs lookup.
        
        Args:
            sentence: a list of token (string).
        Returns:
            A list of indices.
        """
        return [self.get_index(word) for word in sentence]
    
    def get_sentence_word(self, indexSentence):
        """Indexs->Words lookup.
        
        Args:
            indexSentence: a list of indices.
        Returns:
            A list of tokens (string).
        """
        return [self.get_word(index) for index in indexSentence]

In [3]:
# Toy data generation
#   vocab -> A to I
#   length -> 3 to 8
#   task -> translate for the next letter (e.g. A -> B)

VOCAB = [chr(i) for i in range(65,74)] # 'A' -> 'I'
FROM_LEN, TO_LEN = 3, 8
MAX_LENGTH = TO_LEN + 2
SOS, EOS = 'SOS', 'EOS'
INDEXER = Indexer('LetterTranslator')
DATA_SIZE = 3000

def translate_word(word):
    """Find the next letter.
    
    Args:
        word: a letter word (e.g. 'A').
    Returns:
        The next letter to word.
    """
    return VOCAB[VOCAB.index(word)+1]

def translate_sent(sent):
    """Find the next-letter translation of a sentence.
    
    Args:
        sent: a list of letter words.
    Returns:
        The next letters.
    """
    return [translate_word(word) for word in sent]

def generate_pair():
    """Randomly generate a pair of sentences (arg1 translates to arg2).
    
    Returns:
        randInput: a list of letter words.
        randTarget: a list of translation letter words of randInput.
        randInputLen, randTargetLen: lengths of the lists above.
    """
    randInput = list(np.random.choice(VOCAB[:-1], size=random.randint(FROM_LEN,TO_LEN)))
    randTarget = translate_sent(randInput)
    randInputLen, randTargetLen = len(randInput), len(randTarget)
    return randInput, randTarget+[str('EOS')], \
           randInputLen, randTargetLen+1
        # str(): default is utf-8

def generate_data():
    """Randomly generate a set of pairs of sentences (arg1 translates to arg2).
    
    Returns:
        pairs: a pair of lists, where each is a list of token indices.
        lengths: lengths of the corresponding lists in pairs.
    """
    pairs, lengths = [], []
    for _ in range(DATA_SIZE):
        randInput,randTarget,randInputLen,randTargetLen = generate_pair()
        INDEXER.add_sentence(randInput)
        INDEXER.add_sentence(randTarget)
        pairs.append([INDEXER.get_sentence_index(randInput),
                      INDEXER.get_sentence_index(randTarget)])
        
            # convert sentences to <mt,bc> shape.
            # here bc=1.
        lengths.append([randInputLen,randTargetLen])
    return pairs, lengths

In [4]:
pairs, lengths = generate_data()

In [5]:
BATCH_SIZE = 5

class DataIterator:
    """Data feeder by batch."""
    
    def __init__(self, pairs, lengths):
        self.pairs = pairs
        self.lengths = lengths
        self.size = len(pairs)
        self.indices = range(self.size)
        
    def _get_padded_sentence(self, index, maxSentLen, maxTargetLen):
        """Pad a sentence pair by EOS (pad both to the largest length of respective batch).
        
        Args:
            index: index of a sentence & length pair in self.pairs, self.lengths.
            maxSentLen: the length of the longest source sentence.
            maxTargetLen: the length of the longest target sentence.
        Returns:
            padded source sentence (list), its length (int), 
            padded target sentence (list), its length (int).
        """
        sent1,sent2 = self.pairs[index][0], self.pairs[index][1]
        length1,length2 = self.lengths[index][0], self.lengths[index][1]
        paddedSent1 = sent1[:maxSentLen] if length1>maxSentLen else sent1+[INDEXER.get_index('EOS')]*(maxSentLen-length1)
        paddedSent2 = sent2[:maxTargetLen] if length2>maxTargetLen else sent2+[INDEXER.get_index('EOS')]*(maxTargetLen-length2)
        return paddedSent1,length1,paddedSent2,length2
    
    def random_batch(self, batchSize=BATCH_SIZE):
        """Random batching.
        
        Args:
            batchSize: size of a batch of sentence pairs and respective lengths.
        Returns:
            the batch of source sentence (Variable(torch.LongTensor())),
            the lengths of source sentences (numpy.array())
            and the same for target sentences and lengths.
        """
        batchIndices = np.random.choice(self.indices, size=batchSize, replace=False)
        batchSents,batchTargets,batchSentLens,batchTargetLens = [], [], [], []
        maxSentLen, maxTargetLen = np.array([self.lengths[index] for index in batchIndices]).max(axis=0)
        for index in batchIndices:
            paddedSent1,length1,paddedSent2,length2 = self._get_padded_sentence(index, maxSentLen, maxTargetLen)
            batchSents.append(paddedSent1)
            batchTargets.append(paddedSent2)
            batchSentLens.append(length1)
            batchTargetLens.append(length2)
        batchIndices = range(batchSize) # reindex from 0 for sorting.
        batchIndices = [i for i,l in sorted(zip(batchIndices,batchSentLens),key=lambda p:p[1],reverse=True)]
        batchSents = Variable(torch.LongTensor(np.array(batchSents)[batchIndices])).transpose(0,1) # <bc,mt> -> <mt,bc>
        batchTargets = Variable(torch.LongTensor(np.array(batchTargets)[batchIndices])).transpose(0,1)
        batchSentLens = np.array(batchSentLens)[batchIndices]
        batchTargetLens = np.array(batchTargetLens)[batchIndices]
        return batchSents, batchSentLens, batchTargets, batchTargetLens


In [6]:
dataIter = DataIterator(pairs, lengths)
a1,a2,b1,b2 = dataIter.random_batch(2)
print a1, a2, a1.size(0)
print b1, b2, b1.size(0)

Variable containing:
  3   8
 10   8
  4   9
  4   8
  2   2
  9   9
[torch.LongTensor of size 6x2]
 [6 6] 6
Variable containing:
    5     9
    2     9
    6    10
    6     9
    4     4
   10    10
    1     1
[torch.LongTensor of size 7x2]
 [7 7] 7


### Model

In [134]:
# Seq2Seq (batch) with attention, similar to Luong et al. (2016)
#   Comment notation: mt = max-time; bc = batch-size; h = hidden-size.

HIDDEN_SIZE = 20
N_LAYERS = 2

class EncoderRNN(nn.Module):
    """Simple GRU encoder."""
    
    def __init__(self, inputSize, hiddenSize, nLayers=N_LAYERS, dropout=0.1):
        # inputSize: vocabulary size.
        # hiddenSize: size for both embedding and GRU hidden.
        super(EncoderRNN, self).__init__()
        self.inputSize = inputSize
        self.hiddenSize = hiddenSize
        self.nLayers = nLayers
        self.dropout = dropout
        self.embedding = nn.Embedding(inputSize, hiddenSize)
        self.gru = nn.GRU(hiddenSize, hiddenSize, nLayers, dropout)
    
    def forward(self, inputs, inputsLen, hidden=None):
        # inputs: <mt,bc>
        # inputsLen: <bc,> (a list).
        # hidden: <n_layer*n_direction,bc,h>
        embedded = self.embedding(inputs) # <mt,bc,h>
        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, inputsLen)
            # 'packed' has a 'data' and a 'batch_sizes' field.
            #   'data' is a <sum(len),h> matrix (len is real lengths, not padded).
            #   'batch_sizes' has the number of non-zero batches at each time-step.
            # e.g. for this 'inputs'
            #    2     1     3     0     2
            #    6     8     1     6     2
            #    0     7     0     8     8
            #    6     4     2     1     1
            #    1     8     1     1     1
            #    6     1     1     1     1
            #    0     1     1     1     1
            #    1     1     1     1     1
            #    1     1     1     1     1
            #    1     1     1     1     1  
            # 'data' = 22 = 7+5+4+3+3 (1's are pads corresponding to 'EOS').
            # 'batch_sizes' = [5, 5, 5, 3, 2, 1, 1].
        outputs,hidden = self.gru(packed, hidden)
            # outputs: same format as 'packed'.
            # hidden: <n_layer*n_direction,bc,h>
        outputs, outputsLen = torch.nn.utils.rnn.pad_packed_sequence(outputs)
            # outputs: <mt,bc,h>
            # outputsLen: same as the 'batch_sizes' field of 'packed'.   
        return outputs, hidden


class LinearAttention(nn.Module):
    """Basic linear attention layer."""
    
    def __init__(self, hiddenSize):
        super(LinearAttention, self).__init__()
        self.hiddenSize = hiddenSize
        self.attention = nn.Linear(hiddenSize, hiddenSize)
    
    def forward(self, hidden, encoderOutput):
        # hidden: <1,bc,h>
        # encoderOutput: <mt,bc,h>
        encoderOutputLen, batchSize = encoderOutput.size(0), encoderOutput.size(1)
        encoderOutputLen = len(encoderOutput)
        attentionEnergies = Variable(torch.zeros(batchSize, encoderOutputLen)) # <bc,mt>
        for b in range(batchSize):
            for i in range(encoderOutputLen):
                attentionEnergies[b,i] = self.score(hidden[:,b],encoderOutput[i,b].unsqueeze(0))
                    # hidden[:,b] selects a <1,h> from <1,bc,h>
                    # encoderOutput[i,b] selects a <h,> from <mt,bc,h>
                    #   then unsqueeze(0) to add a first dimension to make <1,h>
                    # score thus takes <1,h> and <1,h>
        return F.softmax(attentionEnergies, dim=-1).unsqueeze(1)
            # first softmax along the mt dimension of <bc,mt>,
            # then unsqueeze(1) to make <bc,1,mt>, technical convenience.
        
    def score(self, hidden, encoderOutput):
            # hidden: <bc=1,h>
            # encoderOutput: <bc=1,h> (1 time step).
        energy = self.attention(encoderOutput)
            # linear attention: <bc,h> * <h,h> -> <bc,h>   
        energy = hidden.dot(energy)
            # dot: <bc,h> * <bc,h> -> <bc,h>
            # .dot smartly find fitting dimensions.
        return energy
    
class LuongDecoderRNN(nn.Module):
    """Luong attention."""
    
    def __init__(self, hiddenSize, outputSize, nLayers=N_LAYERS, dropout=0.1, residual=True):
        super(LuongDecoderRNN, self).__init__()
        self.hiddenSize = hiddenSize
        self.outputSize = outputSize
        self.nLayers = nLayers
        self.dropout = dropout
        self.residual = residual
        self.embedding = nn.Embedding(outputSize, hiddenSize)
        self.gru = nn.GRU(2*hiddenSize, hiddenSize, nLayers, dropout) 
        self.out = nn.Linear(2*hiddenSize, outputSize)
            # inputSize doubles because concatted context of same hiddenSize.
        self.linearAttention = LinearAttention(hiddenSize)
        
    def forward(self, inputs, hidden, context, encoderOutput):
            # inputs: <bc,>
            # hidden: <n_layer*n_direction,bc,h>
            # context: <bc,h>
            # encoderOutput: <mt,bc,h>  
        batchSize = inputs.size(0)
        embedded = self.embedding(inputs).view(1,batchSize,self.hiddenSize) # <mt=1,bc,h>
        inputs = torch.cat((embedded,context.unsqueeze(0)),2)
            # unsqueeze: <bc,h> -> <mt=1,bc,h>
            # concat: <mt,bc,h> & <mt,bc,h> @2 -> <mt,bc,2h>
        output, hidden = self.gru(inputs, hidden)
            # IN: <mt=1,bc,2h>, <n_layer*n_direction,bc,h>
            # OUT: <mt=1,bc,h>, <n_layer*n_direction,bc,h> 
        hidden = hidden+embedded if self.residual else hidden
            # broachcast addition: <n_layer*n_direction,bc,h> + <mt=1,bc,h>
            #   = <n_layer*n_direction,bc,h>.
        attentionWeights = self.linearAttention(output,
                                                encoderOutput)
            # squeeze: <mt=1,bc,h> -> <bc,h>
            # attentionWeights: <bc=1,1,mt>
        context = attentionWeights.bmm(encoderOutput.transpose(0,1))
            # transpose: <mt,bc,h> -> <bc,mt,h>
            # bmm (batched matrix multiplication): 
            #   <bc,1,mt> & <bc,mt,h> -> <bc,1,h>
        output = output.squeeze(0)
        context = context.squeeze(1)
            # output squeeze: <mt=1,bc=1,h> -> <bc,h>
            # context squeeze: <bc=1,1,h> -> <bc,h>
        output = F.log_softmax(F.tanh(self.out(torch.cat((output,context),1))),dim=-1)
            # concat: <bc,h> & <bc,h> @1 -> <bc,2h>
            # linear->tahn/out: <bc,2h> * <2h,vocab> -> <bc,vocab>
            # softmax: along dim=-1, i.e. vocab.
        return output, hidden, context, attentionWeights
            # full output for visualization:
            #   output: <bc,vocab>
            #   hidden: <n_layer*n_direction,bc,h>
            #   context: <bc,h>
            #   attentionWeights: <bc,1,mt>

### Trainer

In [131]:
def batch_cross_entropy(decoderOutputAll, targets, targetsLen, batchSize=BATCH_SIZE):
    # decoderOutputAll: <bc,mt,vocab> (transposed in train function).
    # targets: <bc,mt>
    # targetsLen: <bc,> (a list).
    logitsFlat = decoderOutputAll.view(-1, decoderOutputAll.size(-1))
        # <bc,mt,vocab> -> <bc*mt,vocab>
    logProbsFlat = F.log_softmax(logitsFlat,dim=-1)
        # <bc,mt,vocab>, with dim vocab has log probs.
    targetsFlat = targets.view(-1,1)
        # <bc,mt> -> <bc*mt,1>
    lossesFlat = -torch.gather(logProbsFlat,dim=1,index=targetsFlat)
        # <bc,mt,vocab> -> <bc*mt,1>
    losses = lossesFlat.view(*targets.size())
        # reshape: <bc*mt,1> -> <bc,mt>
    # Make a mask
    #   requires: lengths, maxLen
    maxLen = max(targetsLen)
    seqRange = torch.arange(maxLen).long()
        # generate a maxLen tensor of long type, <max-len,>
    seqRangeExpand = Variable(seqRange.unsqueeze(0).expand(batchSize,maxLen))
        # unsqueeze: <1,max-len>
        # expand: copy BATCH_SIZE times along first dimension
        #         second dim won't change as they are of the same length.
        #   e.g. for expand:
        #     >>> x = torch.Tensor([[1], [2], [3]])
        #     >>> x.size()
        #     torch.Size([3, 1])
        #     >>> x.expand(3, 4)
        #      1  1  1  1
        #      2  2  2  2
        #      3  3  3  3
        #     [torch.FloatTensor of size 3x4]
        #   finally we got <bc,max-len>.
    seqLenExpand = (Variable(torch.LongTensor(targetsLen)).unsqueeze(1).expand_as(seqRangeExpand))
        # unsqueeze: <bc,> -> <bc,1>
        # expand_as: <bc,1> -> <bc,max-len> 
    mask = seqRangeExpand < seqLenExpand
        # e.g. batch=2 case:
        #   seqRangeExpand is
        #     0 1 2 3
        #     0 1 2 3
        #   seqLenExpand is
        #     3 3 3 3 <= length of this sentence is 3
        #     2 2 2 2
        #   then we got a matrix that's elementwise results from the comparison.
        #     1 1 1 0
        #     1 1 0 0
        #   which means an elem=1 if it doesn't correspond to a padder.
    # Compute final loss
    losses = losses * mask.float() # zeroify all 0 elem in the mask.
    loss = losses.sum() / sum(targetsLen)
    return loss

In [135]:
def train_step(inputs, inputsLen, targets, targetsLen,
               encoder, decoder, 
               encoderOptim, decoderOptim,
               enforcingRatio, clip,
               batchSize=BATCH_SIZE):
    """One training step (on a **batch** of pairs of sentences)."""
    # Clear previous grads
    # WHY: Since the backward() function accumulates gradients, 
    #      and you don’t want to mix up gradients between minibatches, 
    #      you have to zero them out at the start of a new minibatch. 
    #      This is exactly like how a general (additive) accumulator 
    #      variable is initialized to 0 in code.
    encoderOptim.zero_grad()
    decoderOptim.zero_grad()
    # Set up loss
    loss = 0
    # Run encoder
    encoderHidden = None
    encoderOutput, encoderHidden = encoder(inputs, inputsLen, encoderHidden)    
    # Run decoder
    decoderInput = Variable(torch.LongTensor([INDEXER.get_index('SOS')]*batchSize))
    decoderContext = Variable(torch.zeros(batchSize,decoder.hiddenSize))
    decoderHidden = encoderHidden
    enforce = random.random() < enforcingRatio
    maxTargetLen = max(targetsLen)
    decoderOutputAll = Variable(torch.zeros(maxTargetLen,batchSize,decoder.outputSize))
        # <mt-max,bc,vocab>
    for di in range(maxTargetLen):
        decoderOutput,decoderHidden,decoderContext,attentionWeights = decoder(decoderInput,
                                                                              decoderHidden,
                                                                              decoderContext, 
                                                                              encoderOutput)
        decoderOutputAll[di] = decoderOutput
        if enforce:
            decoderInput = targets[di] # <== targets is <mt,bc>
        else:
            topValues,topIndices = decoderOutput.data.topk(1) # <bc,1>
            decoderInput = Variable(topIndices.squeeze()) # <bc,1> -> <bc,>
    # Sequence cross entropy
    loss = batch_cross_entropy(decoderOutputAll.transpose(0,1).contiguous(), 
                               targets.transpose(0,1).contiguous(), 
                               targetsLen)
    # Backprop
    loss.backward()
    torch.nn.utils.clip_grad_norm(encoder.parameters(), clip)
    torch.nn.utils.clip_grad_norm(decoder.parameters(), clip)
    encoderOptim.step()
    decoderOptim.step()
    return loss.data[0] / targetsLen

def train(pairs, lengths,
          nEpochs=1, epochSize=100, lr=1e-4,
          enforcingRatio=0.5, clip=5.0, dropout=0.1, residual=True,
          printEvery=5):
    """Train multiple **batch** steps."""
    dataSize = len(pairs)
    encoder = EncoderRNN(INDEXER.nWords, HIDDEN_SIZE)
    decoder = LuongDecoderRNN(HIDDEN_SIZE, INDEXER.nWords, residual=residual)
    encoderOptim = optim.Adam(encoder.parameters(),lr)
    decoderOptim = optim.Adam(decoder.parameters(),lr)
    averageLoss = 0
    start = time.time()
    for e in range(nEpochs):
        epochLoss = 0
        for step in range(epochSize):
            inputs, inputsLen, targets, targetsLen = dataIter.random_batch()
            loss = train_step(inputs, inputsLen, targets, targetsLen,
                              encoder, decoder,
                              encoderOptim, decoderOptim,
                              enforcingRatio, clip) 
            if step!=0 and step%printEvery==0:
                print("Step %d average loss = %.4f (time: %.2f)" % (step, loss.mean(), # batch mean.
                                                                    time.time()-start)) 
                start = time.time()
            epochLoss += loss.mean()
        epochLoss /= epochSize
        averageLoss += epochLoss
        print("\nEpoch %d loss = %.4f\n" % (e+1,epochLoss))
    averageLoss /= nEpochs
    print("\nGrand average loss = %.4f\n" % averageLoss)
    return encoder, decoder
            # READ BATCH DATA

In [136]:
encoder, decoder = train(pairs, lengths, 
                         nEpochs=5, epochSize=1000,
                         printEvery=50)

Step 50 average loss = 0.3964 (time: 5.40)
Step 100 average loss = 0.3257 (time: 5.32)
Step 150 average loss = 0.4552 (time: 5.21)
Step 200 average loss = 0.3487 (time: 5.34)
Step 250 average loss = 0.4531 (time: 5.50)
Step 300 average loss = 0.3802 (time: 5.18)
Step 350 average loss = 0.3215 (time: 5.14)
Step 400 average loss = 0.4762 (time: 4.91)
Step 450 average loss = 0.3458 (time: 5.39)
Step 500 average loss = 0.3327 (time: 5.41)
Step 550 average loss = 0.3518 (time: 5.55)
Step 600 average loss = 0.3357 (time: 5.26)
Step 650 average loss = 0.3884 (time: 5.11)
Step 700 average loss = 0.4079 (time: 5.39)
Step 750 average loss = 0.3131 (time: 5.33)
Step 800 average loss = 0.3107 (time: 5.25)
Step 850 average loss = 0.3953 (time: 5.43)
Step 900 average loss = 0.2984 (time: 5.31)
Step 950 average loss = 0.3365 (time: 5.29)

Epoch 1 loss = 0.3734

Step 50 average loss = 0.3621 (time: 10.58)
Step 100 average loss = 0.3336 (time: 5.82)
Step 150 average loss = 0.3432 (time: 5.57)
Step 200 

### Evaluation

In [137]:
class Evaluator:
    
    def __init__(self, encoder, decoder):
        self.encoder = encoder
        self.decoder = decoder
        self.encoder.dropout = 0.0
        self.decoder.dropout = 0.0

    def random_evaluation(self, pairs, lengths, batchSize=2):
        """Randomly pick batchSize sentences from a given corpus and translate.
        
        Args: 
            pairs, lengths: input data, elements in list(list(),list()).
            batchSize: size of pairs & lengths.
        """
        dataIter = DataIterator(pairs, lengths)
        inputs, inputsLen, targets, targetsLen = dataIter.random_batch(batchSize)
        # Run encoder
        encoderHidden = None
        encoderOutput, encoderHidden = self.encoder(inputs, inputsLen, encoderHidden)
        # Run decoder
        decoderInput = Variable(torch.LongTensor([INDEXER.get_index('SOS')]*batchSize))
        decoderContext = Variable(torch.zeros(batchSize,decoder.hiddenSize))
        decoderHidden = encoderHidden
        maxTargetLen = max(targetsLen)
        predictions = []
        for di in range(maxTargetLen):
            decoderOutput,decoderHidden,decoderContext,attentionWeights = self.decoder(decoderInput,
                                                                                       decoderHidden,
                                                                                       decoderContext, 
                                                                                       encoderOutput)
            topValues,topIndices = decoderOutput.data.topk(1) # <bc,1>
            decoderInput = Variable(topIndices.squeeze()) # <bc,1> -> <bc,>
            predictions.append(topIndices.view(-1).numpy())
        inputs = inputs.data.numpy().transpose()
        predictions = np.array(predictions).transpose() # <mt,bc> -> <bc,mt>
        targets = targets.data.numpy().transpose()
        for i,(input,pred,target) in enumerate(zip(inputs,predictions,targets)):
            print("Example %d" % (i+1))
            print("INPUT >> %s" % ' '.join(INDEXER.get_sentence_word(input)))
            print("PRED >> %s" % ' '.join(INDEXER.get_sentence_word(pred)))
            print("TRUE >> %s\n" % ' '.join(INDEXER.get_sentence_word(target))) 

    def evaluate_sentence(self, sent, maxLen=10):
        """Evaluate a given sentence.
        
        Args:
            sent: a sentence in string, where words are separated by whitespaces.
            maxLen: the threshold at which the decoder stops.
        """
        # Reformat data to the same as dataIter.random_batch(1)
        sent = sent.split()
        sentCode = INDEXER.get_sentence_index(sent)
        if any(i==-1 for i in sentCode):
            raise Exception("This sentence contains out of vocabulary words!")
        input = Variable(torch.LongTensor(sentCode)).view(-1,1)
        inputLen = np.array([len(sentCode)])
        # Run encoder
        encoderHidden = None
        encoderOutput, encoderHidden = self.encoder(input, inputLen, encoderHidden)
        # Run decoder
        decoderInput = Variable(torch.LongTensor([INDEXER.get_index('SOS')]*1))
        decoderContext = Variable(torch.zeros(1,decoder.hiddenSize))
        decoderHidden = encoderHidden
        pred = []
        for di in range(maxLen):
            decoderOutput,decoderHidden,decoderContext,attentionWeights = self.decoder(decoderInput,
                                                                                       decoderHidden,
                                                                                       decoderContext, 
                                                                                       encoderOutput)
            topValues,topIndices = decoderOutput.data.topk(1) # <bc,1>
            decoderInput = Variable(topIndices.squeeze()) # <bc,1> -> <bc,>
            predIndex = topIndices.view(-1).numpy()[0]
            if predIndex == INDEXER.get_index('EOS'):
                break
            pred.append(predIndex)
        print("INPUT >> %s" % ' '.join(sent))
        print("PRED >> %s\n" % ' '.join(INDEXER.get_sentence_word(pred)))

In [138]:
ev = Evaluator(encoder, decoder)

In [143]:
ev.evaluate_sentence('B B C E E')

INPUT >> B B C E E
PRED >> C C C F F

