In [1]:
####################
### REQUIREMENTS ###
####################

from __future__ import unicode_literals, print_function, division
from io import open
from itertools import islice # For reading only a part of the data file
from collections import OrderedDict # For defining a variable-length nn.Sequential()
from collections import defaultdict # Used in readData()
import unicodedata
import string
import re
import random
import math # For math.ceil() in readLine()

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

#from google.colab import drive
#drive.mount('/content/gdrive/')
#!mkdir -p /content/MyDrive
#!mount --bind /content/gdrive/My\ Drive /content/MyDrive

In [8]:
#########################
### CONTROL CONSTANTS ###
#########################

# Major.
DS_SIZE = 100000 # The number of sentences in the dataset to be fetched. After changing, set LOAD_DATASET_FROM_FILE = False.
PATH = '.\\AML\\' # WINDOWS STYLE. The working directory with all files (above all, bytenet.train.tar & bucketised_batches_0).
LOAD_DATASET_FROM_FILE = True # Read the training pairs either from bucketised_batches_0 (True), or from train.(en|de).
LOAD_MODEL_FROM_FILE = False # Training only. Read the model from bytenet.train.tar (True) or start training from scratch.
torch.backends.cudnn.enabled = True # cuDNN is NVIDIA's library of primitives, primarily for CNNs. 9x speed up.
TEACHER_FORCING = True # Both can be used. No TEACHER_FORCING results in a massive performance hit ().
DEVICE = torch.device("cuda")
DEVICE = torch.device("cpu")
INSPECT_VISUALLY = False
LIMIT_OUTPUT_PER_BATCH = 1

# Minor.
LOG = open(PATH + 'log', mode='w') # For debugging purposes. Mainly outputs sizes of tensors in intermediate computations.
LOGGING = False # But has not been used in practice since very early stages.
CUDA_LAUNCH_BLOCKING = 1 # Used for more meaningful GPU error messages. 
START_ITER = 0 # If not LOAD_MODEL_FROM_FILE, you can start training it from batch START_ITER. Not particularly useful.

In [9]:
##########################
### DATA-PREPROCESSING ###
##########################

# BATCHING

PAD_TO = 50 # During training the input and target sequences will be padded to the nearest multiple of PAD_TO for efficient batching.

# We will accumulate (source, target) pairs with (m * PAD_TO, n * PAD_TO) characters in the bucketised_tensor_pairs
# dictionary with key (m,n). Whenever k pairs are accumulated in the entry corresponding to the pair (m,n) with 
# k * max(m,n) * PAD_TO >= MAX_BATCH_SIZE, the entry is converted into a batch tensor and flushed. 
# After all this steps, bucketised_tensor_pairs is likely still non-empty. All buckets satisfying the same inequality 
# with MIN_BATCH_SIZE instead are also turned into batches. All others are discarded.

MIN_BATCH_SIZE = 2000
MAX_BATCH_SIZE = 3000

# VOCABULARY 

SOS_token = '\2'
EOS_token = '\3'
PSC_token = '\4' # Padding Sequence Character (for Dynamic Unfolding)
NAC_token = '\5' # Not A Character (for batching: pad the source and target to the multiple of PAD_TO)
UCF_token = '\7' # Unknown Character Found (replace all unknown characters with this token)

# Below, the last unicode symbols are 8 German letters with umlauts and 2 Eszetts.
all_letters = SOS_token + EOS_token + UCF_token + PSC_token + NAC_token + string.printable + "£€°\u00E4\u00EF\u00F6\u00FC\u00DF\u00C4\u00CF\u00D6\u00DC\u1E9E"
n_letters = len(all_letters)

# TEXT PROCESSING

# Turn a Unicode string to plain ASCII, thanks to https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
    html_subs = {'&quot;': '"', '&apos;': '\'', '&amp;': '&', '&#91;': '[', '&#93;': ']', '&lt;': '<',  '&gt;': '>', '&#124;': '|'}

    # NFD = Combine characters that "have the same meaning"
    # unicodedata.category(c) == 'Mn' <==> Unicode Non-spacing Mark
    for sub in html_subs:
        s = s.replace(sub, html_subs[sub])
    s = s.replace(' ##AT##-##AT## ', '-').replace('##STAR##','*').replace('##UNDERSCORE##','_').replace('##AT##','@').replace('\n', '')

    NFD =  [c for c in unicodedata.normalize('NFD', s)
            if (unicodedata.category(c) != 'Mn' or c == '\u0308')] # Leave the umlaut!!!

    # But if the umlaut is not over 'a', 'o' or 'u', German doesn't have such a letter. Drop the umlaut.
    NFD_filter_umlauts = [NFD[0]]
    for i in range(1, len(NFD)):
        if NFD[i] != '\u0308' or NFD[i - 1] in 'AIOUaiou': 
            NFD_filter_umlauts.append(NFD[i])   

    return unicodedata.normalize('NFC', ''.join([c if c in all_letters or c == '\u0308'
                    else UCF_token # Insert a UCF_token in place of c otherwise
                    for c in NFD_filter_umlauts]))

def letterToIndex(letter):
    assert all_letters.find(letter) != -1
    return all_letters.find(letter)

# This function can be a startling experiment. I decided to keep 2 different paddings to convey different ideas and 
# hopefully aid the model in transation. The input to the encoder does not get a SOS_token, but gets the padding 
# immediately afterwards (20% of its original length, as specified by the ByteNet paper), THEN the EOS_token (hoping that  
# putting EOS AFTER padding will help the network to understand where approximately it should stop predicting the output), 
# and then NAC_tokens to pad the sequence to a multiple of PAD_TO.
# Although this means that I use 2 different paddings: 1 for Dynamic Unfolding from the paper, and 1 for batching, those 
# paddings are inherently different: the 1st one means there is no information supplied, but an encoder output is expected
# in order to help the decoder in translation, whereas the 2nd one bears no meaning: there is no new information, and
# no output is expected.
def readLine(line, type): 
    if type == 'enc':
        tmp = unicodeToAscii(line.strip()) + PSC_token * int(0.2 * len(line)) + EOS_token
    else:
        tmp = (SOS_token if TEACHER_FORCING else '') + unicodeToAscii(line.strip()) + EOS_token
    tmp += NAC_token * (math.ceil(len(tmp) / PAD_TO) * PAD_TO - len(tmp))
    return tmp

# Turn a line into a <BATCH_SIZE x line_length x n_letters> array of letter indices.
def bucketToBatch(bucket):
    source_tensor = torch.zeros(len(bucket), len(bucket[0][0]), dtype=torch.int64, device=DEVICE)
    target_tensor = torch.zeros(len(bucket), len(bucket[0][1]), dtype=torch.int64, device=DEVICE)
    for i in range(len(bucket)):
        for char_pos, char in enumerate(bucket[i][0]):
            source_tensor[i][char_pos] = letterToIndex(char)
        for char_pos, char in enumerate(bucket[i][1]):
            target_tensor[i][char_pos] = letterToIndex(char)

    return (source_tensor, target_tensor)

def readData(prefix, lang1, lang2, N = None):
    bucketised_batches = []

    lines1 = open(PATH + '%s.%s' % (prefix, lang1), encoding='utf-8')
    lines2 = open(PATH + '%s.%s' % (prefix, lang2), encoding='utf-8')

    # Read in only the first N lines of the train files. N = None means reading the entire file.
    bucketised_tensor_pairs = defaultdict(list)

    i = 0

    for line1, line2 in islice(zip(lines1, lines2), N):
        i += 1
        if DS_SIZE >= 100 and i % ((DS_SIZE // 100) * 5) == 0:
            print("%d%% of the datafile was read." % (i / (DS_SIZE // 100)))
        tmp1 = readLine(line1, 'enc')
        tmp2 = readLine(line2, 'dec')
        cur_bucket = bucketised_tensor_pairs[(len(tmp1),len(tmp2))]
        cur_bucket.append((tmp1, tmp2)) #.append((lineToTensor(tmp1), lineToTensor(tmp2)))
        if len(cur_bucket) * max(len(tmp1), len(tmp2)) >= MAX_BATCH_SIZE:
            bucketised_batches.append(bucketToBatch(cur_bucket))
            del bucketised_tensor_pairs[(len(tmp1),len(tmp2))]

    for key in bucketised_tensor_pairs.keys():
        if len(bucketised_tensor_pairs[key]) * max(key[0], key[1]) >= MIN_BATCH_SIZE:
            bucketised_batches.append(bucketToBatch(bucketised_tensor_pairs[key]))

    if bucketised_batches:
        torch.save(bucketised_batches, PATH + 'bucketised_batches_0')

if not LOAD_DATASET_FROM_FILE:
    readData('train', 'en','de', DS_SIZE)
bucketised_batches = torch.load(PATH + 'bucketised_batches_0', map_location=DEVICE)

In [10]:
########################
### MODEL DEFINITION ###
########################

import numpy as np
from torch.autograd import Variable
import copy

# CONVOLUTION PARAMETER DEFINITIONS

# wdt = width; str = stride; pad = padding; dil = dilation
conv1x1 = {'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 1}
conv1xk = {'kernel_size': 3, 'stride': 1, 'padding': 1} # Keep wdt odd and pad to (wdt+1)/2. On changing wdt MAKE SURE TO TEST masking in ResBlock::forward!!!



# ENCODER & DECODER

# The layer norm across the dimension of channels (we generally store inte tensors as: dim 0 -> batch; dim 1 -> channel;
# dim 2 -> character; since this order is assumed by convolutions). However, if we calculate layer norm along the dimension
# of characters, this will cause leakage from future characters to the previous ones (because for each channel of the former
# LayerNorm will assign its value based on the values of all characters in this channel indiscriminately).
# Therefore, we need a different LayerNorm that will first transpose the input before normalising them.
class ChannelLayerNorm(nn.Module):
    def __init__(self, channels):
        super(ChannelLayerNorm, self).__init__()
        self.channels = channels
        self.trueBN = nn.Sequential(
            nn.LayerNorm(channels)
        ) # We trick LayerNorm into thinking that channels are actually characters. Thus, the normalisation is going over the channels. 
        # A similar trick for BatchNorm itself would require knowing the number of characters ahead of time, which is feasible for compression, but infeasible for translation.
    def forward(self, batch):
        batch = torch.transpose(batch,1,2) # Make 1 the dimension of characters and 2 the dimension of channels.
        return torch.transpose(self.trueBN(batch),1,2)

class ResBlock(nn.Module):
    # [self.type == 'enc'] <=> Encoder; everything else is Decoder.
    
    def __init__(self, channels, dilation, type):
        super(ResBlock, self).__init__()
        self.type = type
        channel_factor = 2 if type != 'enc' else 1
        conv1xk_dil = copy.deepcopy(conv1xk)
        conv1xk_dil['padding'] *= dilation
        conv1xk_dil['dilation'] = dilation

        # Note that the size of padding in conv1xk is chosen so that the size of the output of the block is equal to the size of the input.
        self.resblock = nn.Sequential(OrderedDict([
            ('norm_1', ChannelLayerNorm(channel_factor * channels)),
            ('relu_1', nn.ReLU()),
            ('conv1x1_1', nn.Conv1d(channel_factor * channels, channels, **conv1x1)),
            ('norm_2', ChannelLayerNorm(channels)),
            ('relu_2', nn.ReLU()),
            ('conv1xd_2', nn.Conv1d(channels, channels, **conv1xk_dil)),
            ('norm_3', ChannelLayerNorm(channels)),
            ('relu_3', nn.ReLU()),
            ('conv1x1_3', nn.Conv1d(channels, channel_factor * channels, **conv1x1))
         ]))
    
    def forward(self, input):
        # In the decoder we have to mask future tokens. It would make sense to mask the future target tokens (otherwise 
        # we feed into the network what we want to get from it) ONLY, but leave the encoder output unmasked (like it is 
        # done, e.g., in Transformers).
        # If we mask both the target and the encoder output tokens, this creates a bottleneck: during testing, the output 
        # of the network is heavily dependent upon the first character it predicts (because it is used predicting all 
        # other characters). So we really want to predict it correctly. However, this character itself is predicted based
        # on SOS_token and just 1 (as opposed to r, the size of the receptive field) encoder outputs.
        # We probably need only 6-7 next characters to predict this 1st character, but they now all have to be encapsulated
        # in the 1st encoder output vector of size 1 x EMBED_DIM. We thus create a bottleneck less severe, but still
        # noticeable, than in seq2seq, where we had to encapsulated the entire sentence in such a vector.
        
        # The authors, however, do not provide any explanation as to how they solve this issue. The implementation would
        # be quite untrivial: we need to prevent the leakage from target tokens' channels into the channels of the encoder 
        # outputs (because those are would be allowed to leak to previous tokens). This requires modifying LayerNorms to serve
        # channels 0 ... d-1 and d ... 2d-1 separately, as well as convolutions (in PyTorch convolutions allow for grouping
        # which solves this issue). Sadly, I will have to ignore this issue in my implementation: it was not found until 
        # 26th April, after reviewing the architecture from Attention Is All You Need.
        if self.type != 'enc':
            c_out, c_in, wdt = self.resblock._modules['conv1xd_2'].weight.data.size()
            #self.resblock._modules['conv1xd_2'].weight.data[:,0:c_in//2,wdt//2+1:wdt] = torch.nn.Parameter(torch.zeros(c_out, c_in//2, wdt-wdt//2-1, device=DEVICE))
            self.resblock._modules['conv1xd_2'].weight.data[:,:,wdt//2+1:wdt] = torch.nn.Parameter(torch.zeros(c_out, c_in, wdt-wdt//2-1, device=DEVICE))
        
        if LOGGING:
            print('def ResBlock::forward(self, input)', file=LOG)
            print(input.size(), file=LOG)
            print(step1.size(), file=LOG)
            print(step2.size(), file=LOG)
            print(step3.size(), file=LOG, end='\n')
        
        output = self.resblock(input)
        return output + input

class CNN(nn.Module):
    def __init__(self, input_dim, channels, res_sets, res_blocks, type):
        super(CNN, self).__init__()
        channel_factor = 2 if type != 'enc' else 1
        self.type = type
        self.embed = nn.Embedding(input_dim, channels)
        self.channels = channels

        layers = OrderedDict()
        for res_set in range(res_sets):
            for res_block in range(res_blocks):
                layers['res_set_' + str(res_set) + '|res_block_' + str(res_block)] = ResBlock(channels, 2 ** res_block, type)

        if type != 'enc':
            layers['fin|conv1x1'] = nn.Conv1d(channel_factor * channels, channels, **conv1x1)
            layers['fin|relu'] = nn.ReLU()
            layers['fin|conv1xd'] = nn.Conv1d(channels, n_letters, **conv1xk)
            layers['fin|logsoftmax'] = nn.LogSoftmax(dim=1)
        self.CNN = nn.Sequential(layers)

    def forward(self, source=None, target=None, encoder_output=None):
        # After the embedding we transpose dimensions 1 and 2 to obtain: 
        # dim = 0 is the batch, dim = 1 is the channel, dim = 2 is the character in a sequence.
        
        #print('In forward():')
        if self.type == 'enc':
            if LOGGING:
                print('def CNN::forward(self, source, target, encoder_output) self.type == enc', file=LOG)
                print('source', file=LOG)
                print(source.size(), file=LOG, end='\n')
                #print(source, file=LOG)

            emb = torch.transpose(self.embed(source), 1, 2)
        else:
            if LOGGING:
                print('def CNN::forward(self, source, target, encoder_output) self.type == dec', file=LOG)
                print('target', file=LOG)
                print(target.size(), file=LOG)
                #print(target, file=LOG)
                print('encoder_output', file=LOG)
                print(encoder_output.size(), file=LOG)
                #print(encoder_output, file=LOG)
            
            tmp = torch.transpose(self.embed(target), 1, 2)
            
            if LOGGING:
                print('embedding of the target', file=LOG)
                print(tmp.size(), file=LOG, end='\n')
                #print(tmp)
            
            emb = torch.zeros(tmp.size(0), 2 * tmp.size(1), tmp.size(2), device=DEVICE)
            out_length = min(tmp.size(2), encoder_output.size(2))
            # The top self.channels channels will be the target sequence. The bottom self.channels ones will be the encoder output.
            emb[:,:self.channels,:] = tmp
            emb[:,self.channels:2*self.channels,0:out_length] = encoder_output[:,:,0:out_length]

            c_out, c_in, wdt = self.CNN._modules['fin|conv1xd'].weight.data.size()
            #self.CNN._modules['fin|conv1xd'].weight.data[:,0:c_in//2,wdt//2+1:wdt] = torch.nn.Parameter(torch.zeros(c_out, c_in//2, wdt-wdt//2-1, device=DEVICE))
            self.CNN._modules['fin|conv1xd'].weight.data[:,:,wdt//2+1:wdt] = torch.nn.Parameter(torch.zeros(c_out, c_in, wdt-wdt//2-1, device=DEVICE))
        return self.CNN(emb)

In [12]:
################
### TRAINING ###
################

# TIMING [COPY-PASTED FROM THE PYTORCH TUTORIAL]

import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    return '%s (- %s)' % (asMinutes(s), asMinutes(s / percent - s))

# TRAIN

EMBED_DIM = 200

# Constants from the paper.

ADAM_LEARNING_RATE = 0.0003
RES_SETS = 6        # Number of sets of residual blocks
RES_BLOCKS = 5      # Number of residual blocks per set

# Define 1 GD step. W/o Teacher Forcing
def train_no_TF(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    target_length = target_tensor.size(1)

    encoder_output = encoder(source=input_tensor)

    output = torch.zeros(target_tensor.size(0), target_length, dtype=torch.int64, device=DEVICE, requires_grad=False)
    for i in range(target_tensor.size(0)):
        output[i][0] = letterToIndex(SOS_token)
    loss = 0

    for char_index in range(1, target_length):
        decoder_output = decoder(target=output[:,:char_index], encoder_output=encoder_output[:,:,:char_index])
        loss_per_char = criterion(decoder_output[:,:,char_index-1], target_tensor[:,char_index-1])
        loss += loss_per_char.detach()

        if LOGGING:
            print('In def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)', file=LOG)
            print('char_index = %d' % char_index, file=LOG)
            print(output.size(), file=LOG)
            print(decoder_output.size(), file=LOG)

        loss_per_char.backward(retain_graph=True)
        output[:,char_index] = torch.tensor(np.argmax(decoder_output[:,:,char_index-1].detach().numpy(), axis=1), device=DEVICE)
                               

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.detach() / target_length

# Define 1 GD step. W/ Teacher Forcing
def train_TF(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()

    target_length = target_tensor.size(1)

    encoder_output = encoder(source=input_tensor)
    decoder_output = decoder(target=target_tensor, encoder_output=encoder_output) # Teacher forcing

    loss = criterion(decoder_output[:,:,:-1], target_tensor[:,1:])

    loss.backward()                               

    encoder_optimizer.step()
    decoder_optimizer.step()

    return loss.detach() 

# Define the actual training.
def trainIters(encoder, decoder, n_iters, print_every=1000, save_every=10000):
    start = time.time()
    print_loss_total = 0  # Reset every print_every
    start_iter = START_ITER

    # Choose whether to use teacher forcing or not.
    train = lambda it, tt, e, d, eo, do, c: train_TF(it, tt, e, d, eo, do, c) if TEACHER_FORCING else train_no_TF(it, tt, e, d, eo, do, c)

    encoder_optimizer = optim.Adam(encoder.parameters(), ADAM_LEARNING_RATE)
    decoder_optimizer = optim.Adam(decoder.parameters(), ADAM_LEARNING_RATE)

    if LOAD_MODEL_FROM_FILE:
        checkpoint = torch.load(PATH + 'bytenet.train.tar', map_location=DEVICE)
        start_iter = checkpoint['iter']
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        encoder_optimizer.load_state_dict(checkpoint['encoder_optimizer_state_dict'])
        decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer_state_dict'])

        for param_group in encoder_optimizer.param_groups:
            param_group['lr'] = ADAM_LEARNING_RATE
        for param_group in decoder_optimizer.param_groups:
            param_group['lr'] = ADAM_LEARNING_RATE

        encoder.to(DEVICE)
        decoder.to(DEVICE)

        encoder.train()
        decoder.train()

        print('Loaded the model from %sbytenet.train.tar, which has already been trained for %d iterations.' % (PATH, start_iter))

    criterion = nn.NLLLoss()

    for iter in range(start_iter + 1, start_iter + n_iters + 1):    
        input_tensor, target_tensor = bucketised_batches[iter % len(bucketised_batches)]
        if LOGGING:
            print('def trainIters(encoder, decoder, n_iters, print_every, plot_every, save_every, learning_rate)', file=LOG)
            print('iter = %d' % iter, file=LOG)
            print(input_tensor.size(), file=LOG)
            print(target_tensor.size(), file=LOG, end='\n')

        loss = train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            real_iter = iter - start_iter
            print('%s (%d %d%%) %.4f' % (timeSince(start, real_iter / n_iters), iter, real_iter / n_iters * 100, print_loss_avg),end=(' ' if iter % save_every == 0 else '\n'))

        if iter % save_every == 0:
            torch.save({
                'iter': iter,
                'encoder_state_dict': encoder.state_dict(),
                'decoder_state_dict': decoder.state_dict(),
                'encoder_optimizer_state_dict': encoder_optimizer.state_dict(),
                'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
            }, PATH + 'bytenet.train.tar')
            print('Saved!')
    
encoder = CNN(n_letters, EMBED_DIM, RES_SETS, RES_BLOCKS, 'enc').to(DEVICE)
decoder = CNN(n_letters, EMBED_DIM, RES_SETS, RES_BLOCKS, 'dec').to(DEVICE)

trainIters(encoder, decoder, 1000000, print_every=300, save_every=2000)

In [None]:
# Assumes Windows OS. Change to "copy /y" to "cp -f" on Linux.
!copy /y {PATH + 'bytenet.train.tar'} {PATH + 'bytenet.eval.tar'}

In [None]:
##################
### EVALUATION ###
##################

from operator import itemgetter

# TIMING

import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    return '%s (- %s)' % (asMinutes(s), asMinutes(s / percent - s))

# EVALUATE

OVER_LENGTH = 50

EMBED_DIM = 200

# Constants from the paper.

RES_SETS = 6        # Number of sets of residual blocks
RES_BLOCKS = 5      # Number of residual blocks per set

# Define 1 evaluation step.
def evaluate(input_tensor, target_tensor, encoder, decoder, criterion):
    batch_size = target_tensor.size(0)
    target_length = target_tensor.size(1)

    encoder_output = encoder(source=input_tensor)

    output = torch.zeros(batch_size, target_length + OVER_LENGTH, dtype=torch.int64, device=DEVICE, requires_grad=False)
    for i in range(batch_size):
        output[i][0] = letterToIndex(SOS_token)
    loss = 0

    for char_index in range(1, target_length + OVER_LENGTH):
        decoder_output = decoder(target=output[:,:char_index], encoder_output=encoder_output[:,:,:char_index])
        
        if char_index <= target_length:
            loss_per_char = criterion(decoder_output[:,:,char_index-1], target_tensor[:,char_index-1])
            loss += loss_per_char.detach()

        # We can't halt on EOS_token, because we have many sentences in a batch, and those are not obliged to be translated
        # into sequences of exactly the same length. However, we want to stop printing the output after EOS_token.
        output[:,char_index] = torch.argmax(decoder_output[:,:,char_index-1], dim=1)

    return output, loss.detach() / target_length

# Define the actual evaluation.
def evaluateIters(encoder, decoder, print_every, process_every = 1):
    start = time.time()
    print_loss_total = 0
    num_of_pairs = 0
    
    checkpoint = torch.load(PATH + 'bytenet.eval.tar', map_location=DEVICE)
    start_iter = checkpoint['iter']
    encoder.load_state_dict(checkpoint['encoder_state_dict'])
    decoder.load_state_dict(checkpoint['decoder_state_dict'])

    encoder.to(DEVICE)
    decoder.to(DEVICE)

    encoder.eval()
    decoder.eval()

    print('Loaded the model from %sbytenet.eval.tar, which has already been trained for %d iterations.' % (PATH, start_iter))
    criterion = nn.NLLLoss(reduction='sum')
        
    for batch, (input_tensor, target_tensor) in enumerate(bucketised_batches):
        if batch % process_every != 0: continue
        
        if INSPECT_VISUALLY:
            input_tensor  = input_tensor[:LIMIT_OUTPUT_PER_BATCH,:]
            target_tensor = target_tensor[:LIMIT_OUTPUT_PER_BATCH,:]
        
        encoder_output = encoder(source=input_tensor)
        decoder_output = decoder(target=target_tensor, encoder_output=encoder_output)
        target_length = target_tensor.size(1) - 1
        print_loss_total += criterion(decoder_output[:,:,:-1], target_tensor[:,1:]) / target_length
        num_of_pairs += input_tensor.size(0)
        
        if batch % (print_every * process_every) == 0:
            print('%s (%d %d%%) %d pairs processed so far; ' % (timeSince(start, num_of_pairs / DS_SIZE), batch, batch / len(bucketised_batches) * 100, num_of_pairs),end='')
            print('the average loss so far: %.4f.' % (print_loss_total / num_of_pairs))
        
        if not INSPECT_VISUALLY: continue
            
        print('Bucket #{0}'.format(batch))
        output, loss = evaluate(input_tensor, target_tensor, encoder, decoder, criterion)
        output_TF = torch.argmax(decoder_output[:,:,:-1].detach(), dim=1)
        
        print(loss)
        
        for source,target,result,result_TF in zip(input_tensor,target_tensor,output,output_TF):
            print('> ',end='')
            for char in itemgetter(*map(int,source))(all_letters): # Uses source as indices in the array all_letters.
                if char not in [SOS_token,EOS_token,PSC_token]:   
                    print(char,end='')
                else: 
                    if char == EOS_token: break
            print()
            #print(list(map(int,source)))
            
            print('= ',end='')
            for char in itemgetter(*map(int,target))(all_letters):
                if char not in [SOS_token,EOS_token]:   
                    print(char,end='')
                else: 
                    if char == EOS_token: break
            print()
            #print(list(map(int,target)))
            
            print('<-TF ',end='')
            for char in itemgetter(*map(int,result))(all_letters): # Uses source as indices in the array all_letters.
                if char not in [SOS_token,EOS_token]:   
                    print(char,end='')
                else: 
                    if char == EOS_token: break
            print()
            #print(list(map(int,result)),end='\n\n')
            
            print('<+TF ',end='')
            for char in itemgetter(*map(int,result_TF))(all_letters):
                if char not in [SOS_token,EOS_token]:   
                    print(char,end='')
                else: 
                    if char == EOS_token: break 
            print()
            #print(list(map(int,output_TF)),end='\n\n')
        print()
        
    return print_loss_total, num_of_pairs

# The computed training error implies teacher forcing (i.e., we only measure the loss in case at each step to predict the 
# next character the network is given the ground truth previous characters of the translation). Calculating the training
# error without teacher forcing would take around a month on my hardware.
with torch.no_grad():
    encoder = CNN(n_letters, EMBED_DIM, RES_SETS, RES_BLOCKS, 'enc').to(DEVICE)
    decoder = CNN(n_letters, EMBED_DIM, RES_SETS, RES_BLOCKS, 'dec').to(DEVICE)

    print_loss_total, num_of_pairs = evaluateIters(encoder, decoder, 100)

In [None]:
print(num_of_pairs)
print(print_loss_total/num_of_pairs)

In [None]:
sum(p.numel() for p in encoder.parameters() if p.requires_grad) + \
sum(p.numel() for p in decoder.parameters() if p.requires_grad)