# Prerequisites

In [0]:
#@title Imports { form-width: "200px" }

# from __future__ import unicode_literals, print_function, division


from io import open
from itertools import islice

import os
import re
import math
import time
import numpy as np
import string
import random
import logging
import unicodedata


import torch
from torch import nn
from torch import optim


logging.basicConfig(
    format='%(asctime)s %(levelname)s: %(message)s',
    datefmt='%Y-%m-%d %I:%M:%S',
    level=logging.INFO)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [0]:
#@title Paths { form-width: "200px" }

GDRIVE_PATH = "/content/drive"
WORK_DIR = os.path.join(
    GDRIVE_PATH, "path/to/project/directory/in/gdrive")
DATA_DIR = os.path.join(WORK_DIR, "data")


In [0]:
#@title Language { form-width: "200px" }

# Start of sentence marker.
SOS_token = 0
SOS_char = '^'

# End of sentence marker.
EOS_token = 1
EOS_char = '$'

# Padding character for dynamic unfolding.
PAD_token = 2
PAD_char = '@'


# Describes a language. The avialable characters in the language are gathered
# and indexed by this class.
class Lang:
    def __init__(self, name):
        self.name = name
        self.char2index = {PAD_char: PAD_token}
        self.char2count = {}
        self.index2char = {
            SOS_token: SOS_char,
            EOS_token: EOS_char,
            PAD_token: PAD_char,
        }
        # The counts includes the special characters.
        self.n_chars = 3

    def addSentence(self, sentence):
        for char in sentence:
            self.addChar(char)

    def addChar(self, char):
        if char not in self.char2index:
            self.char2index[char] = self.n_chars
            self.char2count[char] = 1
            self.index2char[self.n_chars] = char
            self.n_chars += 1
        else:
            self.char2count[char] += 1


In [0]:
#@title Language Encoding Functions { form-width: "200px" }

# Converts a sentence to a vector of indices matching its characters.
def indexesFromSentence(lang, sentence):
    return [lang.char2index[char] for char in sentence]

# Converts a sentence to a tensor of indices matching its characters.
# Depending on use-case, the indices can be tpyed as integers or floats.
def tensorFromSentence(lang, sentence, is_float=False):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    dtype = torch.float if is_float else torch.long
    return torch.tensor(indexes, dtype=dtype, device=device).view(-1, 1)

# Converts a pair of sentences (the first in the input language, the second in
# the output language) to a pair of tensors of indices matching their
# characters.
def tensorsFromPair(experiment, pair):
    input_tensor = tensorFromSentence(experiment.input_lang, pair[0])
    target_tensor = tensorFromSentence(experiment.output_lang, pair[1])
    return (input_tensor, target_tensor)

# Converts a vector of character indices to a sentence.
def senteceFromTensor(lang, tensor):
    sentence = [lang.index2char[i.item()] for i in tensor]
    return ''.join(sentence).strip(SOS_char + EOS_char)


# Input Processing

In [0]:
#@title Parsing { form-width: "200px" }

# Turns a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

# Lowercase, trim, and remove non-letter characters
def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s


In [0]:
#@title Filtering { form-width: "200px" }

BAD_CHARS = set([SOS_char, EOS_char, PAD_char, '#', '&'])

def noSpecialChars(string):
    chars = set(string)
    return not (chars & BAD_CHARS)

def filterPair(piar):
    return noSpecialChars(piar[0]) and noSpecialChars(piar[1])

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]


In [0]:
#@title Read Tatoeba Data { form-width: "200px" }

# Reads the en-fr data used in the tutorial.
def readTatoebaLangs(num=0):
    logging.info("Reading lines...")

    # Read the file and split into lines
    lines = open(os.path.join(DATA_DIR, "Tatoeba/eng-fra.txt"),
                 encoding='utf-8').read().strip().split('\n')

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    if num:
        # We don't want to take the first num pairs, because the data might be
        # somewhat sorted. But we do want experiments to be reproducible.
        # So we use a constant seed.
        random.seed(1234)
        pairs = random.sample(pairs, num)

    input_lang = Lang("en")
    output_lang = Lang("fr")

    return input_lang, output_lang, pairs


In [0]:
#@title Read WMT Data { form-width: "200px" }

def readWMTLangs(num=0):
    logging.info("Reading lines...")

    # Read the files and split into lines
    l1_lines, l2_lines = [], []
    # The first lines in WMT are messed up.
    start = 1000
    getLines = lambda lines: [normalizeString(l.strip()) for l in lines]
    with open(os.path.join(DATA_DIR, "WMT/train.en.txt"),
              encoding='utf-8') as fd:
        l1_lines = getLines(islice(fd, start, start+num))
    with open(os.path.join(DATA_DIR, "WMT/train.de.txt"), 
              encoding='utf-8') as fd:
        l2_lines = getLines(islice(fd, start, start+num))

    # Split every line into pairs and normalize
    pairs = list(zip(l1_lines, l2_lines))

    input_lang = Lang("en")
    output_lang = Lang("de")

    return input_lang, output_lang, pairs


In [0]:
#@title Prepare Data { form-width: "200px" }

def prepareData(use_WMT, num=0):
    prep_data_func = readWMTLangs if use_WMT else readTatoebaLangs
    input_lang, output_lang, pairs = prep_data_func(num)

    logging.info("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    logging.info("Trimmed to %s sentence pairs" % len(pairs))
    logging.info("Counting chars...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    logging.info("Counted chars:")
    logging.info("%s %s", input_lang.name, input_lang.n_chars)
    logging.info("%s %s", output_lang.name, output_lang.n_chars)
    logging.info(random.choice(pairs))
    return input_lang, output_lang, pairs


# ByteNet

## Layers

In [0]:
#@title Layer Norm { form-width: "200px" }

class LayerNorm(nn.Module):
    def __init__(self, channels):
        # Takes an input of the shape (batch, channels, seq_len) and performs
        # LayerNorm across the channels per datapoint.
        super(LayerNorm, self).__init__()
        self.channels = channels
        self.layernorm = nn.LayerNorm(channels)
    
    def forward(self, inputs):
        # The PyTorch Layer Norm only works on the last dimensions, so the
        # cannels must be moved to the end before normalisation, and then moved
        # back.
        out = inputs.transpose(1,2)
        out = self.layernorm(out)
        out = out.transpose(1,2)
        return out


In [0]:
#@title Residual Block { form-width: "200px" }

class ResidualBlock(nn.Module):
    def __init__(self, d, k, dialation, masked=False):
        super(ResidualBlock, self).__init__()
        self.d = d # Embedding size
        self.k = k # Kernel size
        self.dialation = dialation
        self.masked = masked

        self.lnorm1 = LayerNorm(2*d)
        self.relu1 = nn.ReLU()
        self.conv_2d_to_d = nn.Conv1d(2*d, d, kernel_size=1)
        self.lnorm2 = LayerNorm(d)
        self.relu2 = nn.ReLU()
        self.pad = nn.ConstantPad1d(self.padding(), 0.)
        self.conv = nn.Conv1d(d, d, k, dilation=dialation)
        self.lnorm3 = LayerNorm(d)
        self.relu3 = nn.ReLU()
        self.conv_d_to_2d = nn.Conv1d(d, 2*d, kernel_size=1)
    
    def padding(self):
        # We use padding to preserve the length - the length of the output to
        # the convolutional layer should be the same as that of the input.
        p = self.dialation*(self.k-1)
        if self.masked:
            # For a masked convolution (i.e. a causual convolution) all the
            # padding must be placed on the left.
            return (p, 0)
        return (p // 2 + p % 2, p // 2)

    def forward(self, inputs):
        out = inputs
        out = self.lnorm1(out)
        out = self.relu1(out)
        # We use d kernels of size 1x2d, to reduce the number of channels from
        # 2d to d.
        out = self.conv_2d_to_d(out)
        out = self.lnorm2(out)
        out = self.relu2(out)

        logging.debug("before main conv: %s" % (out.size(),))
        out = self.pad(out)
        out = self.conv(out)
        logging.debug("after main conv: %s" % (out.size(),))

        out = self.lnorm3(out)
        out = self.relu3(out)
        # We use 2d kernels of size 1xd, to expand the number of channels from d
        # to 2d.
        out= self.conv_d_to_2d(out)
        # The residual connection.
        out += inputs
        return out


In [0]:
#@title Residual Series { form-width: "200px" }

# A series of residual blocks, with dialtion increasing from block to
# block by a factor of 2.
# Example: if num_blocks=3 we will have dilations 1, 2, 4.
class ResidualSeries(nn.Module):
    def __init__(self, d, k, num_blocks, masked=False):
        super(ResidualSeries, self).__init__()
        self.d = d
        self.k = k
        self.num_blocks = num_blocks
        self.masked = masked

        blocks = []
        for i in range(num_blocks):
            dialation = 1 << i
            blocks.append(ResidualBlock(d, k, dialation, masked))
        self.blocks = nn.Sequential(*blocks)
    
    def forward(self, inputs):
        out = inputs
        out = self.blocks(out)
        return out


## CNN Variant

In [0]:
#@title Encoder { form-width: "200px" }

class EncoderByteNet(nn.Module):
    def __init__(self, params, input_lang):
        super(EncoderByteNet, self).__init__()
        logging.info("--Starting CNN encoder--")
        self.d = params.d # Embedding dimension
        self.k = params.k # Kernel size
        # Number of residual blocks in a series
        self.num_blocks = params.num_blocks
        # Number of series of residual blocks
        self.num_series = params.num_series
         # The rate at which the input is padded for dynamic unfolding.
        self.unfold_rate = params.unfold_rate
        self.input_lang = input_lang

        self.conv_in = nn.Conv1d(1, 2*self.d, 1)
        blocks = [ResidualSeries(self.d, self.k, self.num_blocks)\
                  for i in range(self.num_series)]
        self.blocks = nn.Sequential(*blocks)
        self.conv_out = nn.Conv1d(2*self.d, self.d, 1)
        # Note: According to the paper, the are additional layers here:
        # a ReLU, another convolution and a softmax.
        # Since no information is given about the convolution (like kernel size
        # or dilation rate), I decided not to add it. Also, a softmax seems
        # somewhat illogical here, which makes me believe the additional layers
        # specifed might be relevant only for the decoder). Adding just the ReLU
        # resulted in a much less successful model*. Hence, these layers were
        # skipped.
        # * When training with teacher forcing, the ReLU resulted in the encoder
        #   outputs being zeros. Without teacher forcing, the convergence rate
        #   was simply quicker without the ReLU.

    def unfold(self, input):
        length = input.size(0)
        pad_len = math.floor((self.unfold_rate - 1.0) * length)
        padding = PAD_char * pad_len
        return torch.cat(
            (input.type(torch.float),
             tensorFromSentence(self.input_lang, padding, True)),
             0)

    def forward(self, input):
        logging.debug("--Encoding--")
        logging.debug("input: %s" % (input.size(),))

        out = input
        logging.debug("out before unfold: %s" % (out.size(),))
        out = self.unfold(out).permute(1,0).unsqueeze(0)
        logging.debug("out before conv: %s" % (out.size(),))

        # The residual block expects the number of incoming channels to be
        # 2d, so we use 2d filters of size 1x1 to extend the number of
        # channels to 2d.
        out = self.conv_in(out)
        logging.debug("initial out: %s" % (out.size(),))

        # Goes through a series of residual blocks (only 4 atm).
        out = self.blocks(out)

        # Reduces the number of channels to d.
        out = self.conv_out(out)
        logging.debug("final: %s" % (out.size(),))
        logging.debug("--Done Encoding--")
        return out


In [0]:
#@title Decoder { form-width: "200px" }

class DecoderByteNet(nn.Module):
    def __init__(self, params, output_lang):
        super(DecoderByteNet, self).__init__()
        logging.info("--Starting CNN decoder--")
        self.d = params.d # Embedding dimension
        self.k = params.k # Kernel size
        # Number of residual blocks in a series
        self.num_blocks = params.num_blocks
        # Number of series of residual blocks
        self.num_series = params.num_series

        self.embedding = nn.Embedding(output_lang.n_chars, self.d)

        blocks = [ResidualSeries(self.d, self.k, self.num_blocks, masked=True)\
                  for i in range(self.num_series)]
        self.blocks = nn.Sequential(*blocks)

        # After the residual blocks, the number of channels is 2d. If m is the
        # number of characters in the output language, the number of channels
        # should be changed to m, so that a softmax can give us the probability
        # of each of them being the next character.
        self.conv_out = nn.Conv1d(2*self.d, output_lang.n_chars, 1)
        # Note: According to the paper, the are additional layers before the,
        # softmax, but they were skipped. See reasoning in the documentation of 
        # EncoderByteNet, where they were also skipped for similar reasons.
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, context):
        # Inputs: the previously decoded characters.
        # Context: the respective outputs from the encoder.
        logging.debug("--Decoding--")
        logging.debug("inputs: %s" % (inputs.size(),))
        logging.debug("context: %s" % (context.size(),))

        # If we are still decoding but there are no more outputs from the
        # encoder, pad accordingly on the right.
        if context.size(2) < inputs.size(0):
            padding = inputs.size(0) - context.size(2)
            context = torch.cat(
                (context, torch.zeros(1, self.d, padding, device=device)), 2)
            logging.debug("padded context: %s" % (context.size(),))

        # Gets the embeddings of the current inputs.
        embedding_tensor = self.embedding(inputs).transpose(0,1).unsqueeze(0)
        logging.debug("embedding_tensor  before   cats: %s" %
                        (embedding_tensor.size(),))
        # Concatenates the outputs of the encoder to the embeddings.
        embedding_tensor = torch.cat((embedding_tensor, context), 1)
        logging.debug("embedding_tensor  with  context: %s" %
                        (embedding_tensor.size(),))       
        out = embedding_tensor
        
        out = self.blocks(out)
        out = self.conv_out(out)
        logging.debug("after conv_out: %s" % (out.size(),))
        # Gets the probability for each character in the output language.
        out = self.softmax(out)
        logging.debug("grad after: %s" % (out.grad_fn,))

        logging.debug("--Done Decoding--")
        return out


## RNN Variant

In [0]:
#@title Recurrent ByteNet Encoder { form-width: "200px" }

class EncoderRNNByteNet(nn.Module):
    def __init__(self, params, input_lang):
        super(EncoderRNNByteNet, self).__init__()
        logging.info("--Starting RNN encoder--")
        self.d = params.d # Embedding dimension
        self.lstm_layers = params.lstm_layers
        self.unfold_rate = params.unfold_rate
        self.input_lang = input_lang

        self.embedding = nn.Embedding(input_lang.n_chars, self.d)
        self.lstm = nn.LSTM(self.d, self.d, self.lstm_layers,
                            dropout=params.dropout, bidirectional=True)
        self.lin_out = nn.Linear(2*self.d, self.d)
    
    def unfold(self, inputs):
        length = inputs.size(0)
        pad_len = math.floor((self.unfold_rate - 1.0) * length)
        padding = PAD_char * pad_len
        return torch.cat(
            (inputs,
             tensorFromSentence(self.input_lang, padding, False)),
             0)


    def forward(self, inputs):
        logging.debug("--Encoding--")
        logging.debug("input: %s" % (inputs.size(),))
    
        # Gets the embeddings of the current input.
        out = inputs
        logging.debug("out before unfold: %s" % (out.size(),))
        out = self.unfold(out)
        logging.debug("out before embedding: %s" % (out.size(),))
        out = self.embedding(out)

        logging.debug("before LSTM: %s" % (out.size(),))
        out, hidden = self.lstm(out)
        logging.debug("after LSTM: %s" % (out.size(),))
        logging.debug("hidden: %s" % (hidden[0].size(),))
        logging.debug("cell: %s" % (hidden[1].size(),))
        out = self.lin_out(out)
        logging.debug("after lin_out: %s" % (out.size(),))
        out = out.transpose(0,1).transpose(1,2)
        logging.debug("after transpose: %s" % (out.size(),))

        logging.debug("--Done Encoding--")
        return out, hidden

In [0]:
#@title Recurrent ByteNet Decoder { form-width: "200px" }

class DecoderRNNByteNet(nn.Module):
    def __init__(self, params, output_lang):
        super(DecoderRNNByteNet, self).__init__()
        logging.info("--Starting RNN decoder--")
        self.d = params.d # Embedding dimension
        # If an RNN encoder uses l LSTM layers, which are bidirectional, then it
        # produces 2l hidden states. These need to be fed to the decoder's LSTM
        # layers, hence it has 2l layers.
        self.lstm_layers = 2*params.lstm_layers

        self.embedding = nn.Embedding(output_lang.n_chars, self.d)
        self.lstm = nn.LSTM(2*self.d, self.d, self.lstm_layers,
                            dropout=params.dropout)
        self.lin_out = nn.Linear(self.d, output_lang.n_chars)
        self.softmax = nn.LogSoftmax(dim=1)


    def forward(self, input, encoder_context, decoder_context=None):
        logging.debug("--Decoding--")
        logging.debug("input: %s" % (input.size(),))
        logging.debug("encoder_context: %s" % (encoder_context.size(),))
        if decoder_context:
            logging.debug("hidden: %s" % (decoder_context[0].size(),))
            logging.debug("cell: %s" % (decoder_context[1].size(),))

        # If we are still decoding but there are no more outputs from the
        # encoder, pad accordingly on the right.
        if encoder_context.size(2) == 0:
            encoder_context = torch.zeros(1, self.d, 1, device=device)
            logging.debug("padded encdoer_context: %s" %
                          (encoder_context.size(),))
        # Gets the embeddings of the current input.
        embedding_tensor = self.embedding(input).unsqueeze(2)
        logging.debug("embedding_tensor before cats: %s" %
                        (embedding_tensor.size(),))
        # Concatenates the outputs of the encoder to the embeddings.
        embedding_tensor = torch.cat((embedding_tensor, encoder_context), 1)
        logging.debug("after concatenation, embedding_tensor size = %s" %
                      (embedding_tensor.size(),))
        
        # LSTM expects the dimenstions (seq_len, batch, input_size).
        out = embedding_tensor.transpose(1,2)
        logging.debug("before LSTM: %s" % (out.size(),))
        out, hidden = self.lstm(out, decoder_context)
        logging.debug("after LSTM: %s" % (out.size(),))       

        out = self.lin_out(out)
        logging.debug("after lin_out: %s" % (out.size(),))
        # softmax expects (batch, input_size).
        out = out.transpose(1,2).squeeze(2)
        logging.debug("before softmax: %s" % (out.size(),))
        # Gets the probability for each character in the output language.
        out = self.softmax(out)

        logging.debug("--Done Decoding--")
        return out , hidden

## Infrastructure

In [0]:
#@title Wrapper Functions { form-width: "200px" }

# These are meant to wrap operations used in training and testing, so that all
# variants of ByteNet may be used by the same function.

def getEncoderOutput(experiment, input_tensor):
    if experiment.params.is_rnn_enc: return experiment.encoder(input_tensor)
    return experiment.encoder(input_tensor), None


def getDecoderOutput(experiment, inputs, context, hidden, parallel=False):
    output, hidden = None, None
    if experiment.params.is_rnn_dec:
        output, hidden = experiment.decoder(inputs, context, hidden)
    else:
        output = experiment.decoder(inputs, context)
    if not parallel and not experiment.params.is_rnn_dec:
        # Parallel means the decoder predicts all the characters in parallel,
        # hence all outputs are needed (possible only for a CNN decoder).
        # Otherwise prediction is char-by-char, hence only the final output is
        # relevant.
        output = output[:,:,-1]
    
    return output, hidden


def getDecoderContext(experiment, idx, encoder_output):
    start_idx = idx if experiment.params.is_rnn_dec else 0
    return encoder_output[:,:,start_idx:idx+1]


def getDecoderInputs(experiment, idx, decoder_input):
    start_idx = idx if experiment.params.is_rnn_dec else 0
    return decoder_input.flatten()[start_idx:idx+1]


In [0]:
#@title Parameters Container { form-width: "200px" }

class ModelParameters(object):
    def __init__(self, is_rnn_enc, is_rnn_dec, embedding_dim, kernel=3,
                 num_blocks=5, num_series=1, lstm_layers=4, unfold_rate=1.2,
                 dropout=0.1):
        self.is_rnn_enc = is_rnn_enc
        # Based on the paper, we assume that a decoder cannot be a CNN when the
        # encoder is an RNN.
        self.is_rnn_dec = True if is_rnn_enc else is_rnn_dec
        self.d = embedding_dim

        self.k = kernel
        # The number of residual blocks in a series (relevant only for CNNs)
        self.num_blocks = num_blocks
        # The number of residual block series (relevant only for CNNs)
        self.num_series = num_series
        
        # The number of LSTM layers in the encoder, the decoder will have twice
        # as many (relevant only for RNNs).
        self.lstm_layers = lstm_layers
        self.dropout = dropout

        self.unfold_rate = unfold_rate



In [0]:
#@title Experiment Container { form-width: "200px" }

class Experiment(object):
    def __init__(self, name, data_size, use_WMT, params, criterion=nn.NLLLoss(),
                 learning_rate=0.0003, max_len=1000, beam_width=2,
                 beam_candidates=4):
        self.name = name
        # Number of source-target pairs to use.
        self.data_size = data_size
        # Whether to use the WMT dataset (otherwise Tatoeba is used).
        self.use_WMT = use_WMT
        # The parameters of the model.
        self.params = params

        self.criterion = criterion
        self.learning_rate = learning_rate
        # An upper bound on the length of decoder predictions. This is a
        # safeguard against infinite sequences.
        self.max_len = max_len

        # Beam serach parameters.
        self.beam_width = beam_width
        self.beam_candidates = beam_candidates

        self.input_lang, self.output_lang, self.pairs = self.getData()
        self.encoder = None
        self.decoder = None
    
    def getPath(self):
        # Returns the path to the experiment.
        return os.path.join(WORK_DIR, "results", self.name)
    
    def getData(self):
        # Reads the data and returns the input and output languages, along with
        # the pairs of intput-output sequences.
        return prepareData(self.use_WMT, self.data_size)
    
    def createModel(self):
        rnn_enc, rnn_dec = self.params.is_rnn_enc, self.params.is_rnn_dec
        encoder = EncoderRNNByteNet if rnn_enc else EncoderByteNet
        decoder = DecoderRNNByteNet if rnn_dec else DecoderByteNet

        self.encoder = encoder(self.params, self.input_lang).to(device)
        self.decoder = decoder(self.params, self.output_lang).to(device)

    def loadModel(self):
        path = self.getPath()
        self.createModel()
        self.encoder.load_state_dict(torch.load(os.path.join(path, "encoder"),
                                                map_location=device))
        self.decoder.load_state_dict(torch.load(os.path.join(path, "decoder"),
                                                map_location=device))
    
    def saveModel(self):
        path = self.getPath()
        torch.save(self.encoder.state_dict(), os.path.join(path, "encoder"))
        torch.save(self.decoder.state_dict(), os.path.join(path, "decoder"))


# Training

In [0]:
#@title Single Parallel Training Iteration { form-width: "200px" }

# Parallel training is only possible for the CNN variant.
# This training method matches the description in the paper: the entire target
# sequence is fed to the decoder (except the last character) and the decoder
# produces exactly <target length> predictions simultenously.
# This training method is quicker than predicting char-by-char, however, if the
# model fails to reach a perfect NLL of 0.0000, the model then perdorms very
# badly even on training samples during test time.
# This is because even a single error (which is bound to happen eventually) can
# completely "throw off" the decoder, since the decoder is then fed this mistake
# as a previous target character, and it is not equipped to handle this
# situation.
# This is probably a consequence of the small datasets we use. To conclude, if
# the model is expected to reach a NLL of 0, then this method is recommended,
# since it is considerable quicker. Otherwise, to reach a more effective model,
# the non-parallel version is recommended.

def trainParallel(experiment, input_tensor, target_tensor,
                  encoder_optimizer, decoder_optimizer):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    target_length = target_tensor.size(0)

    encoder_output, hidden = getEncoderOutput(experiment, input_tensor)
    logging.debug("encoder_output: %s" % (encoder_output.size(),)) 

    logging.debug("target_tensor: %s" % (target_tensor.size(),))
    target_with_sos = torch.cat((
        torch.tensor([[SOS_token]], device=device), target_tensor), 0)
    inputs = getDecoderInputs(experiment, target_length, target_with_sos)
    logging.debug("inputs: %s" % (inputs.size(),))
    
    decoder_output, hidden = getDecoderOutput(
            experiment, inputs, encoder_output[:,:,:target_length], hidden,
            True)
    logging.debug("decoder_output: %s" % (decoder_output.size(),))
    loss = 0
    for i in range(target_length):
        loss += experiment.criterion(decoder_output[:,:,i], target_tensor[i])

    loss.backward()

    nn.utils.clip_grad_norm_(experiment.encoder.parameters(), 2.0)
    nn.utils.clip_grad_norm_(experiment.decoder.parameters(), 2.0)

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length


In [0]:
#@title Single Training Iteration { form-width: "200px" }

# Controls how often we use real target values for training.
teacher_forcing_ratio = 0.5


def train(experiment, input_tensor, target_tensor,
          encoder_optimizer, decoder_optimizer):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    input_length = input_tensor.size(0)
    target_length = target_tensor.size(0)

    encoder_output, hidden = getEncoderOutput(experiment, input_tensor)
    logging.debug("encoder_output: %s" % (encoder_output.size(),))
    if hidden: logging.debug("hidden[0]: %s" % (hidden[0].size(),))
    if hidden: logging.debug("hidden[1]: %s" % (hidden[1].size(),))
    decoder_input = torch.tensor([[SOS_token]], device=device)
    teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    loss = 0
    
    for i in range(target_length):
        inputs = getDecoderInputs(experiment, i, decoder_input)
        context = getDecoderContext(experiment, i, encoder_output)

        decoder_output, hidden = getDecoderOutput(
            experiment, inputs, context, hidden)
        logging.debug("decoder_output: %s" % (decoder_output.size(),))
        loss += experiment.criterion(decoder_output, target_tensor[i])

        if teacher_forcing:
            decoder_input = torch.cat(
                (decoder_input, target_tensor[i].unsqueeze(0)))
        else:
            topv, topi = decoder_output.topk(1)
            decoder_input = torch.cat((decoder_input, topi.detach()))

    loss.backward()

    nn.utils.clip_grad_norm_(experiment.encoder.parameters(), 2.0)
    nn.utils.clip_grad_norm_(experiment.decoder.parameters(), 2.0)

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.item() / target_length


## Multiple Iterations

In [0]:
#@title Track Training Progress { form-width: "200px" }

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))


In [0]:
#@title Train by Iteration { form-width: "200px" }

def trainIters(experiment, n_iters, print_every, pairs=[], with_dups=True,
               parallel=False):
    if parallel and experiment.params.is_rnn_dec:
        logging.error("Parallel training is not possible for an RNN decoder.")
        return 0
    train_func = trainParallel if parallel else train
    
    if not pairs: pairs = experiment.pairs
    # Chooses a sampling function that allows or disallows duplicates (i.e.
    # replacement) according to with_dups.
    sample = random.choices if with_dups else random.sample

    start = time.time()
    plot_losses = []
    print_loss = 0  # Resets every print_every
    total_loss = 0

    lr = experiment.learning_rate
    encoder_optimizer = optim.Adam(
        experiment.encoder.parameters(), lr=lr, eps=1e-6)
    decoder_optimizer = optim.Adam(
        experiment.decoder.parameters(), lr=lr, eps=1e-6)

    training_pairs = [tensorsFromPair(experiment, pair)
        for pair in sample(pairs, k=n_iters)]

    for iter in range(1, n_iters + 1):
        training_pair = training_pairs[iter - 1]
        input_tensor = training_pair[0]
        target_tensor = training_pair[1]

        loss = train_func(experiment, input_tensor, target_tensor,
                          encoder_optimizer, decoder_optimizer)
        print_loss += loss
        total_loss += loss
        logging.debug("loss: %d" % loss)

        if not (iter % print_every) == 0: continue
        print_loss_avg = print_loss / print_every
        print_loss = 0
        logging.info('%s (%d %d%%) %.4f',
                     timeSince(start, iter/n_iters), iter,
                     (iter/n_iters)*100, print_loss_avg)
    
    return total_loss



In [0]:
#@title Train by Epoch { form-width: "200px" }

def trainEpochs(experiment, n_epochs, print_every, pairs=[], parallel=False):
    if parallel and experiment.params.is_rnn_dec:
        logging.error("Parallel training is not possible for an RNN decoder.")
        return 0
    
    start = time.time()
    num_pairs = experiment.data_size
    iters_in_cycle = print_every*num_pairs

    print_loss = 0  # Resets every print_every
    total_loss = 0

    for epoch in range(1, n_epochs+1):
        # An epoch is equivalent to iterations where each pair is chosen exactly
        # once (so n_iters=num_pairs and we disallow duplicates).
        # We set print_every=num_pairs+1 to prevent printing.
        loss = trainIters(
            experiment, num_pairs, num_pairs+1, pairs, False, parallel)
        print_loss += loss
        total_loss += loss
        
        if not (epoch % print_every) == 0: continue
        print_loss_avg = print_loss / iters_in_cycle
        print_loss = 0
        
        logging.info('%s (%d %d%%) %.4f',
                     timeSince(start, epoch/n_epochs), epoch,
                     (epoch/n_epochs)*100, print_loss_avg)


# Evaluation and Testing

In [0]:
#@title Beam Node { form-width: "200px" }

class BeamSearchNode(object):
    def __init__(self, prev_node, decoded, outputs, log_prob, length):
        self.prev_node = prev_node
        # All the chars (in index form) decoded along the path to this node.
        self.decoded = decoded 
        # The probability vectors output by the decoder along this path.
        self.outputs = outputs
        # The log probability of this path.
        self.log_prob = log_prob
        # The length of the path.
        self.length = length

        self.eps = 1e-6

    def eval(self):
        return self.log_prob / float(self.length - 1 + self.eps)


In [0]:
#@title Loss Measuring Function { form-width: "200px" }

def getLoss(prediction, target, criterion):
    loss = 0
    trgt_length = len(target)
    pred_length = len(prediction)

    # The predicted length could be different from the true target length.
    # We measure the error up to the maximum of the two, suing EOS padding.
    # When the prediction is shorter, this means taking the distribution for the
    # last character, which, by definition, means it predicted an EOS.
    for i in range(max(trgt_length, pred_length)):
        pred_dist = prediction[min(i, pred_length-1)]
        trgt_char = target[min(i, trgt_length-1)]
        logging.debug("pred_dist: %s" % (pred_dist,))
        logging.debug("trgt_char: %s" % (trgt_char.item(),))
        loss += criterion(pred_dist, trgt_char)
    return loss / trgt_length


In [0]:
#@title Test { form-width: "200px" }

def test(experiment, input_tensor, target_tensor):
    with torch.no_grad():
        input_length = input_tensor.size(0)
        target_length = target_tensor.size(0)

        encoder_output, hidden = getEncoderOutput(experiment, input_tensor)
        decoder_input = torch.tensor([[SOS_token]], device=device)
        prediction = []
        loss = 0

        for i in range(experiment.max_len):
            inputs = getDecoderInputs(experiment, i, decoder_input)
            context = getDecoderContext(experiment, i, encoder_output)

            decoder_output, hidden = getDecoderOutput(
                experiment, inputs, context, hidden)
            prediction.append(decoder_output)

            topv, topi = decoder_output.data.topk(1)
            if topi.item() == EOS_token:
                break

            decoder_input = torch.cat((decoder_input, topi.detach()))

        sentence = senteceFromTensor(experiment.output_lang, decoder_input)
        return sentence, getLoss(prediction, target_tensor,
                                 experiment.criterion)

In [0]:
#@title Test with Beam Search { form-width: "200px" }

def addNodes(node, decoder_output, candidates, next_nodes, beam_width):
    log_prob, chars = decoder_output.topk(beam_width)
    logging.debug("considering: %s" % chars)
    logging.debug("with log probs: %s" % log_prob)
    for i in range(beam_width):
        decoded = torch.cat((node.decoded, chars[:,i]))
        new_node = BeamSearchNode(
            node, decoded, 
            node.outputs + [decoder_output],
            node.log_prob+log_prob[0][i], node.length+1)
        if chars[0,i] == EOS_token:
            candidates.append((new_node.eval(), new_node))
            logging.debug("added candidate: %s" % decoded)
        else:
            next_nodes.append((new_node.eval(), new_node))
            logging.debug("added next node: %s" % decoded)

def chooseNodes(next_nodes, beam_width):
    next_nodes.sort()
    prev_nodes = []
    logging.debug("possible next nodes: %s" % (next_nodes,))
    for i in range(min(beam_width, len(next_nodes))):
        logging.debug("promoted node %d" % i)
        prev_nodes.append(next_nodes[-i])
    return prev_nodes


def beamTest(experiment, input_tensor, target_tensor):
    with torch.no_grad():
        input_length = input_tensor.size(0)
        target_length = target_tensor.size(0)

        encoder_output, hidden = getEncoderOutput(experiment, input_tensor)
        decoder_input = torch.tensor([SOS_token], device=device)

        beam_width = experiment.beam_width
        candidates_num = experiment.beam_candidates
        candidates, prev_nodes, next_nodes = [], [], []
        
        initial_node = BeamSearchNode(None, decoder_input, [], 0, 1)
        prev_nodes.append((initial_node.eval(), initial_node))

        idx = 0
        while (prev_nodes
               and len(candidates) <= candidates_num
               and idx < experiment.max_len):
            logging.debug("have %d prev_nodes" % len(prev_nodes))
            for _, node in prev_nodes:
                inputs = getDecoderInputs(experiment, idx, node.decoded)
                context = getDecoderContext(experiment, idx, encoder_output)
                logging.debug("inputs: %s" % (inputs,))

                decoder_output, hidden = getDecoderOutput(
                    experiment, inputs, context, hidden)
                logging.debug("decoder_output: %s" % decoder_output)
                addNodes(node, decoder_output, candidates, next_nodes,
                         beam_width)
            
            prev_nodes = chooseNodes(next_nodes, beam_width)
            next_nodes = []
            idx += 1

        if not candidates:
            # It's possible all predictions are longer than the maximum allowed
            # length, so no true candidates exist.
            # In this case, we use the incomplete predictions instead.
            logging.debug("no candidates found.")
            candidates = prev_nodes
        
        prediction = sorted(candidates)[-1][1]
        sentence = senteceFromTensor(experiment.output_lang, prediction.decoded)
        loss = getLoss(prediction.outputs, target_tensor, experiment.criterion)
        return sentence, loss


In [0]:
#@title Test and Evaluate Randomly { form-width: "200px" }

def testAndEval(experiment, use_beam, num=10, print_every=0, pairs=[]):
    if not pairs: pairs = experiment.pairs
    if not print_every: print_every = max(1, num // 10)
    eval_func = beamTest if use_beam else test
    test_pairs = random.sample(pairs, num)
    
    total_loss = 0
    for i, (src, dst) in enumerate(test_pairs):
        input_tensor = tensorFromSentence(experiment.input_lang, src)
        output_tensor = tensorFromSentence(experiment.output_lang, dst)
        prediction, loss = eval_func(experiment, input_tensor, output_tensor)
        total_loss += loss
        
        if (i % print_every) == 0:
            print('>', src)
            print('=', dst)
            print('<', prediction)
            print('Loss: %f\n' % loss)
    
    print('Average loss: %f' % (total_loss / num))


# Run Pipeline

In [0]:
#@title Mount Google Drive { form-width: "200px" }

from google.colab import drive
drive.mount(GDRIVE_PATH)

In [0]:
#@title Initialise Experiment { form-width: "200px" }

params = ModelParameters(
    is_rnn_enc=False,
    is_rnn_dec=False,
    embedding_dim=100,
    num_series=1,
    lstm_layers=8,
)

experiment = Experiment(
    name="experiment_name",
    data_size=1000,
    use_WMT=False,
    params=params,
    learning_rate=0.0003
)

In [0]:
#@title Train (by Iteration) { form-width: "200px" }

experiment.createModel()
# experiment.loadModel()
num_sets = 4

for i in range(1, num_sets+1):
    logging.info("--Training set %d--" % i)
    trainIters(experiment, 1000, print_every=100)
    experiment.saveModel()


In [0]:
#@title Train (by Epoch) { form-width: "200px" }

# experiment.createModel()
experiment.loadModel()
num_sets = 8

for i in range(1, num_sets+1):
    logging.info("--Training set %d--" % i)
    trainEpochs(experiment, 50, print_every=5, parallel=True)
    experiment.saveModel()



In [0]:
#@title Evaluate { form-width: "200px" }

experiment.loadModel()
testAndEval(experiment, False, num=1, print_every=1)
