# IMPORTS

In [None]:
import wget, os, gzip, pickle, random, re, sys
import numpy as np
import math
import copy

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import random_split
import torch.distributions as dist

import matplotlib.pyplot as plt
%matplotlib inline

# HELPER FUNCTIONS

In [None]:
IMDB_URL = 'http://dlvu.github.io/data/imdb.{}.pkl.gz'
IMDB_FILE = 'imdb.{}.pkl.gz'

PAD, START, END, UNK = '.pad', '.start', '.end', '.unk'

def load_imdb(final=False, val=5000, seed=0, voc=None, char=False):

    cst = 'char' if char else 'word'

    imdb_url = IMDB_URL.format(cst)
    imdb_file = IMDB_FILE.format(cst)

    if not os.path.exists(imdb_file):
        wget.download(imdb_url)

    with gzip.open(imdb_file) as file:
        sequences, labels, i2w, w2i = pickle.load(file)

    if voc is not None and voc < len(i2w):
        nw_sequences = {}

        i2w = i2w[:voc]
        w2i = {w: i for i, w in enumerate(i2w)}

        mx, unk = voc, w2i['.unk']
        for key, seqs in sequences.items():
            nw_sequences[key] = []
            for seq in seqs:
                seq = [s if s < mx else unk for s in seq]
                nw_sequences[key].append(seq)

        sequences = nw_sequences

    if final:
        return (sequences['train'], labels['train']), (sequences['test'], labels['test']), (i2w, w2i), 2

    # Make a validation split
    random.seed(seed)

    x_train, y_train = [], []
    x_val, y_val = [], []

    val_ind = set( random.sample(range(len(sequences['train'])), k=val) )
    for i, (s, l) in enumerate(zip(sequences['train'], labels['train'])):
        if i in val_ind:
            x_val.append(s)
            y_val.append(l)
        else:
            x_train.append(s)
            y_train.append(l)

    return (x_train, y_train), \
           (x_val, y_val), \
           (i2w, w2i), 2


def gen_sentence(sent, g):

    symb = '_[a-z]*'

    while True:

        match = re.search(symb, sent)
        if match is None:
            return sent

        s = match.span()
        sent = sent[:s[0]] + random.choice(g[sent[s[0]:s[1]]]) + sent[s[1]:]

def gen_dyck(p):
    open = 1
    sent = '('
    while open > 0:
        if random.random() < p:
            sent += '('
            open += 1
        else:
            sent += ')'
            open -= 1

    return sent

def gen_ndfa(p):

    word = random.choice(['abc!', 'uvw!', 'klm!'])

    s = ''
    while True:
        if random.random() < p:
            return 's' + s + 's'
        else:
            s+= word

def load_brackets(n=50_000, seed=0):
    return load_toy(n, char=True, seed=seed, name='dyck')

def load_ndfa(n=50_000, seed=0):
    return load_toy(n, char=True, seed=seed, name='ndfa')

def load_toy(n=50_000, char=True, seed=0, name='lang'):

    random.seed(0)

    if name == 'lang':
        sent = '_s'

        toy = {
            '_s': ['_s _adv', '_np _vp', '_np _vp _prep _np', '_np _vp ( _prep _np )', '_np _vp _con _s' , '_np _vp ( _con _s )'],
            '_adv': ['briefly', 'quickly', 'impatiently'],
            '_np': ['a _noun', 'the _noun', 'a _adj _noun', 'the _adj _noun'],
            '_prep': ['on', 'with', 'to'],
            '_con' : ['while', 'but'],
            '_noun': ['mouse', 'bunny', 'cat', 'dog', 'man', 'woman', 'person'],
            '_vp': ['walked', 'walks', 'ran', 'runs', 'goes', 'went'],
            '_adj': ['short', 'quick', 'busy', 'nice', 'gorgeous']
        }

        sentences = [ gen_sentence(sent, toy) for _ in range(n)]
        sentences.sort(key=lambda s : len(s))

    elif name == 'dyck':

        sentences = [gen_dyck(7./16.) for _ in range(n)]
        sentences.sort(key=lambda s: len(s))

    elif name == 'ndfa':

        sentences = [gen_ndfa(1./4.) for _ in range(n)]
        sentences.sort(key=lambda s: len(s))

    else:
        raise Exception(name)

    tokens = set()
    for s in sentences:

        if char:
            for c in s:
                tokens.add(c)
        else:
            for w in s.split():
                tokens.add(w)

    i2t = [PAD, START, END, UNK] + list(tokens)
    t2i = {t:i for i, t in enumerate(i2t)}

    sequences = []
    for s in sentences:
        if char:
            tok = list(s)
        else:
            tok = s.split()
        sequences.append([t2i[t] for t in tok])

    return sequences, (i2t, t2i)

# MAIN CLASSES

In [None]:
class Prep_data():
    def __init__(self, dataset, n=50000, max_tokens=150):
        self.dataset = dataset
        self.n_instances = n
        self.max_tokens = max_tokens
        self.n_batches = math.ceil(self.n_instances/self.max_tokens)
        
    
    def load_data(self):
        
        if self.dataset == 'ndfa':
            self.x_train, (self.i2w, self.w2i) = load_ndfa(n=self.n_instances)
            self.seeds = [[self.w2i['.start'], self.w2i['s'],  self.w2i['a']],
                          [self.w2i['.start'], self.w2i['s'],  self.w2i['k']],
                          [self.w2i['.start'], self.w2i['s'],  self.w2i['u']],
                          [self.w2i['.start'], self.w2i['s']],
                          [self.w2i['.start']]]

        elif self.dataset == 'brackets':
            self.x_train, (self.i2w, self.w2i) = load_brackets(n=self.n_instances)
            self.seeds = [[self.w2i['.start'], self.w2i['('],  self.w2i['('], self.w2i['(']],
                          [self.w2i['.start'], self.w2i['('],  self.w2i['('], self.w2i[')']],
                          [self.w2i['.start'], self.w2i['('],  self.w2i['(']],
                          [self.w2i['.start'], self.w2i['(']],
                          [self.w2i['.start']]]
        
        elif self.dataset == 'toy':
            self.x_train, (self.i2w, self.w2i) = load_toy(n=self.n_instances)
            self.seeds = [self.w2i['.start']]
        
        elif self.dataset == 'imdb':
            (self.x_train, y_train), (x_val, y_val), (self.i2w, self.w2i), self.numcls = load_imdb(final=False, char=True)
            self.seeds = [self.w2i['.start']]
        
        np.random.shuffle(self.x_train)

            
        
    def tokens_in_batch(self, fr, to):
        n_tokens = 0

        for i in range(fr, to):
            n_tokens += len(self.x_train[i])

        return n_tokens
    
    
    def prep_batch(self, batch):
        max_length = max([len(i) for i in batch]) + 2
        new_batch = torch.empty((len(batch), max_length), dtype=torch.long)
        lengths = torch.empty(len(batch), dtype=torch.long)

        for i, instance in enumerate(batch):
            instance.insert(0, self.w2i['.start'])
            instance.append(self.w2i['.end'])
            lengths[i] = len(instance)
            
            while len(instance) < max_length:
                instance.append(self.w2i['.pad'])

            new_batch[i] = torch.LongTensor(instance)

        return new_batch, lengths
                
        
    def get_batches(self):
        self.load_data()
        batches, lengths, targets = [], [], []
        fr, to = 0, 0
        
        while to < self.n_instances:
            to = fr
            while self.tokens_in_batch(fr, to) < self.max_tokens:
                to = min(to + 1, self.n_instances)
                
                if to == self.n_instances:
                    break
                    
            batch, length = self.prep_batch(self.x_train[fr:to])
            batches.append(batch)
            lengths.append(length)
            
            target = np.c_[batch[:, 1:], np.zeros(batch.shape[0])]            
            targets.append(torch.LongTensor(target))
            
            fr = to
        
        return batches, lengths, targets

In [None]:
class RNN(nn.Module):
    def __init__(self, dataset, hidden_layer_size=16, embedding_dimension=32, num_layers=1, max_tokens=150, n=150000):
        super(RNN, self).__init__()
        self.data_loader = Prep_data(dataset, n=n, max_tokens=max_tokens)
        self.X, self.lengths, self.Y = self.data_loader.get_batches()
        self.criterion = nn.CrossEntropyLoss() 
        self.hidden_layer_size, self.num_layers = hidden_layer_size, num_layers

        self.emb = nn.Embedding(len(self.data_loader.i2w), embedding_dimension, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dimension, self.hidden_layer_size, num_layers=num_layers)
        self.fc = nn.Linear(self.hidden_layer_size, len(self.data_loader.i2w))

        
    def forward(self, x, length):
        hidden_cell = (torch.randn(self.num_layers, len(x), self.hidden_layer_size).to(device),
                       torch.randn(self.num_layers, len(x), self.hidden_layer_size).to(device))
        
        x = self.emb(x)
        x = nn.utils.rnn.pack_padded_sequence(x, length, batch_first=True, enforce_sorted=False)
        
        x, hidden_cell = self.lstm(x, hidden_cell)
        x, sizes = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
        
        x = self.fc(x)
        return x
    
       
    def train(self, optimizer, device, epochs=3, max_length=50):
        self.device = device

        trainloss = np.array([])
        dataset = list(zip(self.X, self.lengths, self.Y))

        for epoch in range(epochs):
            np.random.shuffle(dataset)

            for i, (x, length, y) in enumerate(dataset, 0):
                inputs, labels = x.to(self.device), y.to(self.device)
                optimizer.zero_grad()

                outputs = self.forward(inputs, length)
                loss = self.criterion(outputs.permute(0, 2, 1), labels)
                loss.backward()
                
                nn.utils.clip_grad_value_(self.parameters(), 1)
                optimizer.step()

            with torch.no_grad():
                epochloss = 0
                for i, (x, length, y) in enumerate(dataset, 0):
                    inputs, labels = x.to(self.device), y.to(self.device)
                    outputs = self.forward(inputs, length)
                    batchloss = self.criterion(outputs.permute(0, 2, 1), labels).to('cpu')/len(x)
                    trainloss = np.append(trainloss, batchloss)
                    epochloss += batchloss

            print(f'EPOCH {epoch} ||')
            print(f'\tTRAIN: {epochloss/len(dataset)} ||')
            
            for i in range(10):
                self.generate(max_length)
            print()

        return np.convolve(trainloss, np.ones(50), 'valid') / 50


    def sample(self, lnprobs, temperature=1.0):
        """
        Sample an element from a categorical distribution
        :param lnprobs: Outcome logits
        :param temperature: Sampling temperature. 1.0 follows the given
        distribution, 0.0 returns the maximum probability element.
        :return: The index of the sampled element.
        """
        if temperature == 0.0:
            return lnprobs.argmax()

        p = F.softmax(lnprobs / temperature, dim=0)
        cd = dist.Categorical(p)

        return cd.sample()


    def generate(self, max_length=50):
        if len(self.data_loader.seeds) == 1:
            seed = copy.deepcopy(self.data_loader.seeds)
        else:
            seed = copy.deepcopy(np.random.choice(self.data_loader.seeds))
        
        with torch.no_grad():

            for i in range(max_length):
                probs = self.forward(torch.LongTensor([seed]).to(self.device), [len(seed)])
                token = self.sample(probs[-1, -1, :])
                seed.append(token.item())

                if token == self.data_loader.w2i['.end']:
                    break
                    
        print(''.join([self.data_loader.i2w[i] for i in seed]) )
    

# TRAIN A MODEL

In [None]:
mod = RNN('toy', hidden_layer_size=32, embedding_dimension=32, num_layers=2, max_tokens=400, n=1000)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mod.to(device)
optimizer = optim.Adam(mod.parameters(), lr = 0.00025)
len(mod.X)

In [None]:
loss = mod.train(optimizer, device, epochs=15, max_length=150)

# NDFA

In [None]:
ndfa_loss = np.zeros((5, 6000))
for i in range(5):
    print()
    print(f'REPETITION: {i}')
    mod = RNN('ndfa', hidden_layer_size=16, embedding_dimension=32, num_layers=1, max_tokens=500, n=15000)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    mod.to(device)
    optimizer = optim.Adam(mod.parameters(), lr = 0.0005)
    ndfa_loss[i, :] = mod.train(optimizer, device, epochs=15, max_length=150)[:6000]

    
t = np.arange(1, 6001)
plt.plot(t, ndfa_loss.mean(axis=0), lw=2, label='Sigmoid', color='blue')
plt.fill_between(t, ndfa_loss.mean(axis=0) + ndfa_loss.std(axis=0), 
                 ndfa_loss.mean(axis=0) - ndfa_loss.std(axis=0), 
                 facecolor='blue', alpha=0.5)

plt.xlabel('Batch')
plt.ylabel('Cross Entropy Loss')
plt.grid()

#plt.savefig('ndfa')
plt.show()

# BRACKETS

In [None]:
bracket_loss = np.zeros((5, 4650))
for i in range(5):
    print()
    print(f'REPETITION: {i}')
    mod = RNN('brackets', hidden_layer_size=16, embedding_dimension=32, num_layers=1, max_tokens=400, n=15000)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    mod.to(device)
    optimizer = optim.Adam(mod.parameters(), lr = 0.0005)
    bracket_loss[i, :] = mod.train(optimizer, device, epochs=15, max_length=150)[:4650]

    
t = np.arange(1, 4651)
plt.plot(t, bracket_loss.mean(axis=0), lw=2, label='Sigmoid', color='blue')
plt.fill_between(t, bracket_loss.mean(axis=0) + bracket_loss.std(axis=0), 
                 bracket_loss.mean(axis=0) - bracket_loss.std(axis=0), 
                 facecolor='blue', alpha=0.5)

plt.xlabel('Batch')
plt.ylabel('Cross Entropy Loss')
plt.grid()

#plt.savefig('bracket')
plt.show()

# TOY

In [None]:
toy_loss = np.zeros((3, 31800))
for i in range(3):
    print()
    print(f'REPETITION: {i}')
    mod = RNN('toy', hidden_layer_size=32, embedding_dimension=32, num_layers=2, max_tokens=400, n=20000)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    mod.to(device)
    optimizer = optim.Adam(mod.parameters(), lr = 0.00025)
    toy_loss[i, :] = mod.train(optimizer, device, epochs=15, max_length=150)[:31800]

    
t = np.arange(1, 31801)
plt.plot(t, toy_loss.mean(axis=0), lw=2, label='Sigmoid', color='blue')
plt.fill_between(t, toy_loss.mean(axis=0) + toy_loss.std(axis=0), 
                 toy_loss.mean(axis=0) - toy_loss.std(axis=0), 
                 facecolor='blue', alpha=0.5)

plt.xlabel('Batch')
plt.ylabel('Cross Entropy Loss')
plt.grid()

#plt.savefig('toy')
plt.show()