In [1]:
####################
### REQUIREMENTS ###
####################

from __future__ import unicode_literals, print_function, division
from io import open
from itertools import islice # For reading only a part of the data file
from collections import OrderedDict # For defining a variable-length nn.Sequential()
from collections import defaultdict # Used in readData()
import unicodedata
import string
import re
import random
import math # For math.ceil() in readLine()

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

#from google.colab import drive
#drive.mount('/content/gdrive/')

In [2]:
#########################
### CONTROL CONSTANTS ###
#########################

# Major.
TRAIN_DATA = 25*(10**6) # The number of bytes to train on.
PATH = '.\\AML\\compr\\'
PRED_IN = 100 # Section 5. The number of characters to predict from.
PRED_OUT = 400 # Section 5. The number of characters to predict.
LOAD_DATASET_FROM_FILE = True
LOAD_MODEL_FROM_FILE = True
torch.backends.cudnn.enabled = True
#DEVICE = torch.device("cuda")
DEVICE = torch.device("cpu")
INSPECT_VISUALLY = True
LIMIT_OUTPUT_PER_BATCH = 1

# Minor.
LOG = open(PATH + 'log', 'w', encoding='utf-8')
LOGGING = False
CUDA_LAUNCH_BLOCKING = 1
START_ITER = 0

In [None]:
##########################
### DATA-PREPROCESSING ###
##########################

BATCH_SIZE = 2

# Reading here is much easier. No cleaning, no sequences of different length.

def bucketToBatch(bucket):
    # Recall that we drop the last character in source.
    source_tensor = torch.zeros(len(bucket), PRED_IN + PRED_OUT - 1, dtype=torch.int64, device=DEVICE)
    target_tensor = torch.zeros(len(bucket), PRED_OUT, dtype=torch.int64, device=DEVICE)
    for byte_str in range(len(bucket)):
        for byte_pos, byte in enumerate(bucket[byte_str][0]):
            source_tensor[byte_str][byte_pos] = byte
        for byte_pos, byte in enumerate(bucket[byte_str][1]):
            target_tensor[byte_str][byte_pos] = byte

    return (source_tensor, target_tensor)

def readData(enwik_file):
    num_of_batch_files = 0
    batches = []

    enwik = open(PATH + '%s' % enwik_file, 'rb')
    
    tmp_bucket = []

    # We read the binary file in chunks of chunk_size bytes.
    chunk_size = PRED_IN + PRED_OUT # 500 in the paper
    to_read = TRAIN_DATA // chunk_size

    for seq in range(to_read):
        if to_read >= 100 and seq % ((to_read // 100) * 5) == 0:
            print("%d%% of the datafile was read." % (seq / (to_read // 100)))
        tmp1 = enwik.read(PRED_IN)
        tmp2 = enwik.read(PRED_OUT)
        if len(tmp2) >= PRED_OUT: # If the enwik file hasn't ended yet.
            # Thanks to a hint from Candidate #1041040: it is reasonable to approach this task similarly to translation.
            # Here we take PRED_IN initial characters (instead of 1 SOS_token in translation), we need to predict 
            # PRED_OUT characters, and we use teacher forcing. Therefore, (numbering from 0) we are interested in outputs,
            # corresponding to characters PRED_IN - 1 (the last character of the source) till PRED_IN + PRED_OUT - 2 
            # (the 2nd to last character in the target; the character PRED_IN + PRED_OUT - 1 is the last one and isn't thus
            # used in prediction).
            tmp_bucket.append((tmp1+tmp2[:-1],tmp2))
        if len(tmp_bucket) >= BATCH_SIZE:
            batches.append(bucketToBatch(tmp_bucket))
            del tmp_bucket[:]

    if tmp_bucket:
        batches.append(bucketToBatch(tmp_bucket))

    torch.save(batches, PATH + 'enwik_batches_0')

    return batches

if not LOAD_DATASET_FROM_FILE:
    readData('enwik9')
batches = torch.load(PATH + 'enwik_batches_0', map_location=DEVICE)

In [None]:
########################
### MODEL DEFINITION ###
########################

import numpy as np
from torch.autograd import Variable
import copy

# CONVOLUTION PARAMETER DEFINITIONS

conv1x1 = {'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 1}
conv1xk = {'kernel_size': 3, 'stride': 1, 'padding': 1} # Keep wdt odd and pad to (wdt+1)/2. On changing wdt MAKE SURE TO TEST masking in ResBlock::forward!!!

CHANNEL_FACTOR = 2 # Do NOT touch this value!
EMBED_DIM = 256

# ENCODER & DECODER

class ChannelLayerNorm(nn.Module):
    def __init__(self, channels):
        super(ChannelLayerNorm, self).__init__()
        self.channels = channels
        self.trueBN = nn.Sequential(
            nn.LayerNorm(channels)
        ) # We trick LayerNorm into thinking that channels are actually characters. Thus, the normalisation is going over the channels. 
        # A similar trick for BatchNorm itself would require knowing the number of characters ahead of time, which is feasible for compression, but infeasible for translation.
    def forward(self, batch):
        batch = torch.transpose(batch,1,2) # Make 1 the dimension of characters and 2 the dimension of channels.
        return torch.transpose(self.trueBN(batch),1,2)

class MultiplicativeUnit(nn.Module):
    def __init__(self, conv, type, channels):
        super(MultiplicativeUnit, self).__init__()
        self.type = type

        self.MU_Path = nn.ModuleList([
            nn.Sequential(OrderedDict([
                ('conv_1', nn.Conv1d(channels, channels, **conv)),
                ('norm_2', ChannelLayerNorm(channels)),
                ('relu_1', nn.Sigmoid())
            ])) for path in range(3)
        ])

        self.MU_Path.append(
            nn.Sequential(OrderedDict([
                ('conv_1', nn.Conv1d(channels, channels, **conv)),
                ('norm_2', ChannelLayerNorm(channels)),
                ('relu_1', nn.Tanh())
            ]))
        )

    def forward(self, input):
        if self.type != '1x1':
            for path in range(4):
                c_out, c_in, wdt = self.MU_Path[path]._modules['conv_1'].weight.data.size()
                self.MU_Path[path]._modules['conv_1'].weight.data[:,:,wdt//2+1:wdt] = torch.nn.Parameter(torch.zeros(c_out, c_in, wdt-wdt//2-1, device=DEVICE))

        # Variable names correspond to the names in https://arxiv.org/pdf/1610.00527.pdf.
        g1 = self.MU_Path[0](input)
        g2 = self.MU_Path[1](input)
        g3 = self.MU_Path[2](input)
        u  = self.MU_Path[3](input)
        return g1 * nn.Tanh()(g2 * input + g3 * u)

class ResBlock(nn.Module):
    def __init__(self, channels, dilation, type):
        super(ResBlock, self).__init__()
        self.type = type
        conv1xk_dil = copy.deepcopy(conv1xk)
        conv1xk_dil['padding'] *= dilation
        conv1xk_dil['dilation'] = dilation

        # Note that the size of padding in conv1xk is chosen so that the size of the output of the block is equal to the size of the input.
        # Be careful about changing the order! Assumed to be so in forward()!!!
        self.resblock = nn.Sequential(OrderedDict([
            ('norm_1', ChannelLayerNorm(CHANNEL_FACTOR * channels)),
            ('relu_1', nn.ReLU()),
            ('conv1x1_1', nn.Conv1d(CHANNEL_FACTOR * channels, channels, **conv1x1)),
            ('norm_2', ChannelLayerNorm(channels)),
            ('relu_2', nn.ReLU()),
            ('MU_1', MultiplicativeUnit(conv1xk_dil, '1xk', channels)),
            ('MU_2', MultiplicativeUnit(conv1x1, '1x1', channels)),
            ('conv1x1_2', nn.Conv1d(channels, CHANNEL_FACTOR * channels, **conv1x1))
         ]))
    
    def forward(self, input):        
        if LOGGING:
            print('def ResBlock::forward(self, input)', file=LOG)
            print(input.size(), file=LOG)
            print(step1.size(), file=LOG)
            print(step2.size(), file=LOG)
            print(step3.size(), file=LOG, end='\n')
        
        output = self.resblock(input)
        return output + input

class CNN(nn.Module):
    def __init__(self, input_dim, channels = EMBED_DIM):
        super(CNN, self).__init__()
        self.type = type
        self.embed = nn.Embedding(input_dim, CHANNEL_FACTOR * channels)
        self.channels = channels

        layers = OrderedDict()
        for res_set in range(RES_SETS):
            for res_block in range(RES_BLOCKS):
                layers['res_set_' + str(res_set) + '|res_block_' + str(res_block)] = ResBlock(channels, 2 ** res_block, type)

        layers['fin|conv1x1'] = nn.Conv1d(CHANNEL_FACTOR * channels, channels, **conv1x1)
        layers['fin|ReLU'] = nn.ReLU()
        layers['fin|conv1xd'] = nn.Conv1d(channels, input_dim, **conv1xk)
        layers['fin|logsoftmax'] = nn.LogSoftmax(dim=1)
        
        self.CNN = nn.Sequential(layers)

    def forward(self, source):
        if LOGGING:
            print('def CNN::forward(self, source, target, encoder_output) self.type == enc', file=LOG)
            print('source', file=LOG)
            print(source.size(), file=LOG, end='\n')
            #print(source, file=LOG)
            
        tmp = self.embed(source)
        emb = torch.transpose(tmp, 1, 2)

        c_out, c_in, wdt = self.CNN._modules['fin|conv1xd'].weight.data.size()
        self.CNN._modules['fin|conv1xd'].weight.data[:,:,wdt//2+1:wdt] = torch.nn.Parameter(torch.zeros(c_out, c_in, wdt-wdt//2-1, device=DEVICE))
        return self.CNN(emb)

In [None]:
################
### TRAINING ###
################

# TIMING [COPY-PASTED FROM THE PYTORCH TUTORIAL]

import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    return '%s (- %s)' % (asMinutes(s), asMinutes(s / percent - s))

# Constants from the paper.

ADAM_LEARNING_RATE = 0.00005
RES_SETS = 6        # Number of sets of residual blocks
RES_BLOCKS = 5      # Number of residual blocks per set

# Define 1 GD step. Teacher Forcing
def train_TF(input_tensor, target_tensor, decoder, decoder_optimizer, criterion):
    decoder_optimizer.zero_grad()

    decoder_output = decoder(source=input_tensor)

    loss = criterion(decoder_output[:,:,PRED_IN-1:], target_tensor)
    loss.backward()

    decoder_optimizer.step()

    return loss.detach() 

# Define the actual training.
def trainIters(decoder, n_iters, print_every=1000, save_every=10000):
    start = time.time()
    print_loss_total = 0  # Reset every print_every
    start_iter = START_ITER

    decoder_optimizer = optim.Adam(decoder.parameters(), ADAM_LEARNING_RATE)

    if LOAD_MODEL_FROM_FILE:
        checkpoint = torch.load(PATH + 'bytenet_dec.train.tar', map_location=DEVICE)
        start_iter = checkpoint['iter']
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        decoder_optimizer.load_state_dict(checkpoint['decoder_optimizer_state_dict'])

        for param_group in decoder_optimizer.param_groups:
            param_group['lr'] = ADAM_LEARNING_RATE

        del checkpoint # Significantly reduces memory consumption 
        decoder.to(DEVICE)
        decoder.train()

        print('Loaded the model from %sbytenet_dec.train.tar, which has already been trained for %d iterations.' % (PATH, start_iter))

    criterion = nn.NLLLoss()

    for iter in range(start_iter + 1, start_iter + n_iters + 1):
        input_tensor, target_tensor = batches[iter % len(batches)]

        loss = train_TF(input_tensor, target_tensor, decoder, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if iter % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            real_iter = iter - start_iter
            print('%s (%d %d%%) %.4f' % (timeSince(start, real_iter / n_iters), iter, real_iter / n_iters * 100, print_loss_avg),end=('\n' if iter % save_every != 0 else ' '))

        if iter % save_every == 0:
            torch.save({
                'iter': iter,
                'decoder_state_dict': decoder.state_dict(),
                'decoder_optimizer_state_dict': decoder_optimizer.state_dict(),
            }, PATH + 'bytenet_dec.train.tar')
            print ('Saved!')

decoder = CNN(256).to(DEVICE)
trainIters(decoder, 1000000, print_every=600, save_every=2000)

In [None]:
# Assumes Windows OS. Change to "copy /y" to "cp -f" on Linux.
!copy /y {PATH + 'bytenet_dec.train.tar'} {PATH + 'bytenet_dec.eval.tar'}

In [None]:
##################
### EVALUATION ###
##################

from operator import itemgetter

# TIMING

import time
import math

def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    return '%s (- %s)' % (asMinutes(s), asMinutes(s / percent - s))

# EVALUATE

EMBED_DIM = 200

# Constants from the paper.

RES_SETS = 6        # Number of sets of residual blocks
RES_BLOCKS = 5      # Number of residual blocks per set

# Define 1 evaluation step.
def evaluate(input_tensor, target_tensor, decoder, criterion):
    batch_size = target_tensor.size(0)
    target_length = target_tensor.size(1)
    loss = 0

    output = torch.zeros(batch_size, PRED_IN + PRED_OUT - 1, dtype=torch.int64, device=DEVICE, requires_grad=False)
    output[:,:PRED_IN] = input_tensor[:,:PRED_IN]

    for char_index in range(PRED_OUT - 1):
        decoder_output = decoder(source=output)
        
        loss_per_char = criterion(decoder_output[:,:,PRED_IN-1+char_index], target_tensor[:,char_index])
        loss += loss_per_char.detach()

        # We can't halt on EOS_token, because we have many sentences in a batch, and those are not obliged to be translated
        # into sequences of exactly the same length. However, we want to stop printing the output after EOS_token.
        output[:,PRED_IN+char_index] = torch.argmax(decoder_output[:,:,PRED_IN-1+char_index], dim=1)

    return output, loss.detach() / target_length

# Define the actual evaluation.
def evaluateIters(decoder, print_every, process_every = 1):
    start = time.time()
    print_loss_total = 0
    num_of_pairs = 0
    DS_SIZE = TRAIN_DATA / (PRED_IN + PRED_OUT)
    
    checkpoint = torch.load(PATH + 'bytenet_dec.eval.tar', map_location=DEVICE)
    start_iter = checkpoint['iter']
    decoder.load_state_dict(checkpoint['decoder_state_dict'])

    decoder.to(DEVICE)

    decoder.eval()

    print('Loaded the model from %sbytenet.eval.tar, which has already been trained for %d iterations.' % (PATH, start_iter))
    criterion = nn.NLLLoss(reduction='sum')
        
    for batch, (input_tensor, target_tensor) in enumerate(batches):
        if batch % process_every != 0: continue
        
        if INSPECT_VISUALLY:
            input_tensor  = input_tensor[:LIMIT_OUTPUT_PER_BATCH,:]
            target_tensor = target_tensor[:LIMIT_OUTPUT_PER_BATCH,:]
        
        decoder_output = decoder(source=input_tensor)
        target_length = target_tensor.size(1)
        print_loss_total += criterion(decoder_output[:,:,PRED_IN-1:], target_tensor) / target_length
        num_of_pairs += input_tensor.size(0)
        if batch % (print_every * process_every) == 0:
            print('%s (%d %d%%) %d pairs processed so far; ' % (timeSince(start, num_of_pairs / DS_SIZE), batch, batch / len(batches) * 100, num_of_pairs),end='')
            print('the average loss so far: %.4f.' % (print_loss_total / num_of_pairs))
        
        if not INSPECT_VISUALLY: continue
            
        print('Bucket #{0}'.format(batch))
        output, loss = evaluate(input_tensor, target_tensor, decoder, criterion)
        output_TF = torch.argmax(decoder_output[:,:,PRED_IN-1:].detach(), dim=1)
        
        for source,target,result,result_TF in zip(input_tensor,target_tensor,output,output_TF):
            print('> ',end='')
            print(''.join([chr(x) for x in input_tensor[0,:PRED_IN]]))
            #print(list(map(int,source)))
            
            print('= ',end='')
            print(''.join([chr(x) for x in target]))
            #print(list(map(int,target)))
            
            print('<-TF ',end='')
            print(''.join([chr(x) for x in result]))
            #print(list(map(int,result)),end='\n\n')
            
            print('<+TF ',end='')
            print(''.join([chr(x) for x in result_TF]))
            #print(list(map(int,output_TF)),end='\n\n')
            
            print('\n###\n###\n###\n')
        
    return print_loss_total, num_of_pairs

In [None]:
with torch.no_grad():
    decoder = CNN(256).to(DEVICE)
    print_loss_total, num_of_pairs = evaluateIters(decoder, 1)

In [None]:
sum(p.numel() for p in decoder.parameters())