In [None]:
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dist
from copy import deepcopy
from typing import List, Tuple
import time
import numpy as np
from data import load_ndfa, load_brackets
import logging
import random
import pandas as pd
import matplotlib.pyplot as plt 

logging.basicConfig(level='DEBUG', format='%(asctime)s - %(message)s')
RANDOM_SEED = 10

MAX_CHARS_PER_BATCH = 10000
DATASET = 'ndfa'
TEMPERATURE = 1.0

EVAL_N_SAMPLES = 10
EVAL_SEQ_MAX_LEN = 20

class HParams:
    EMB_DIM = 32
    N_HIDDEN = 16
    N_LAYERS = 1
    EPOCHS = 3
    LEARNING_RATE = 0.01


dataset_loaders = {
    'ndfa' : load_ndfa,
    'brackets': load_brackets
}
    
np.random.seed(RANDOM_SEED)

def prepare_batches(x_train, i2w, w2i, max_chars_per_batch):
    dict_size = len(i2w)
    
    # w2i['.pad'] = 0
    pad_val = w2i['.pad']
    # w2i['.start'] = 1
    start_val = w2i['.start']
    # w2i['.end'] = 2
    end_val = w2i['.end']

    for x in x_train:
        x.insert(0, start_val)
        x.append(end_val)

    sizes = defaultdict(list)
    for x in x_train:
        sizes[len(x)].append(x)

    t_sizes = dict()
    for k, v in sizes.items():
        t_sizes[k] = torch.tensor(v, dtype=torch.long)
    
    batches = []
    for _, x_tensor in t_sizes.items():
        x_tensor_len, n_chars = x_tensor.shape
        
        # Shift input left to create output tensor
        shifted_input = x_tensor[:, 2:]
        
        # Make column with padding value
        start_pad = start_val * torch.ones(x_tensor_len, dtype=torch.long)

        empty_pad = pad_val * torch.ones(x_tensor_len, dtype=torch.long)

        # Append padding to output tensor
        y_tensor = torch.column_stack([start_pad, shifted_input, empty_pad.T])

        assert x_tensor.shape == y_tensor.shape

        # Split into batches
        batch_size = max_chars_per_batch // n_chars
        x_batches = torch.split(x_tensor, batch_size)
        y_batches = torch.split(y_tensor, batch_size)
        
        # Create One-Hot Encodings of the output
        # TODO probably there is a smarter way to do one hots over the whole dict
        y_oh_batches = list()
        for y in y_batches:
            b, chrs = y.shape
            one_hots = torch.zeros(b,chrs, dict_size, dtype=torch.long)
            for bi in range(y.shape[0]):
                y_one_hot = torch.zeros(chrs, dict_size, dtype=torch.long)
                for el in range(chrs):
                    y_one_hot[el][y[bi,el]] = 1
                one_hots[bi, :, :] = y_one_hot
            y_oh_batches.append(one_hots)
        
        assert len(x_batches) == len(y_oh_batches)
        batches.extend(list(zip(x_batches, y_oh_batches)))
    
    np.random.shuffle(batches)
    return batches

class recurNet(nn.Module):

    def __init__(self, embedding_dim = 32, hidden_size = 16, vocab_size = 15, num_layers=1):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.layer1 = nn.Embedding(num_embeddings = vocab_size, embedding_dim = embedding_dim)

        self.layer2 = nn.LSTM(input_size = embedding_dim, hidden_size = hidden_size, num_layers=num_layers, batch_first=True)
        
        self.layer3 = nn.Linear(hidden_size, vocab_size)

    def forward(self, input):
        # b, n_chrs = input.shape
        emb = self.layer1(input)
        # assert emb.shape == (b, n_chrs, self.embedding_dim)
                    
        lstm, (_, _) = self.layer2(emb)

        # assert lstm.shape == (b, n_chrs, self.hidden_size)
        output = self.layer3(lstm)
        # assert output.shape == (b, n_chrs, self.vocab_size)

        return output
        
def print_some_batches(batches, i2w):
    for x,y in batches:
        i = random.choice(list(range(len(x))))
        print(f'Input | {" ".join(decode(x[i], i2w))}')
        print(f'Output | {" ".join(decode(y[i].argmax(1), i2w))}')

def train(model: nn.Module, epochs : int, batches: List[Tuple[torch.Tensor, torch.Tensor]], device: torch.device, lr: float, loss_print_freq: int=30):
    model.to(device)

    criterion = nn.CrossEntropyLoss(reduction='sum')
    optimizer = optim.Adam(model.parameters(), lr)
    
    # Capture training starting time
    ts_train = time.perf_counter()

    running_loss = 0
    
    # list for training progress capturing
    data = list()
    for epoch in range(epochs):
        
        # Capture epoch starting time
        ts = time.perf_counter()
        
        
        for i, (x,one_hots) in enumerate(batches):
            x, one_hots = x.to(device), one_hots.to(device)
            
            optimizer.zero_grad()

            out = model(x)
            
            loss = criterion(out, one_hots.type(torch.float32))
            
            # divide by batch and # of tokens
            loss /= one_hots.shape[0]* one_hots.shape[1]
            
            loss.backward()

            optimizer.step()
            running_loss += loss.item()
            
            if i % loss_print_freq == 0: #print every 1000 batches
                logging.info('[%d, %5d] loss: %.3f ' %
                    (epoch +1, i+1, running_loss / loss_print_freq))
                running_loss = 0.0
            data.append({'update' : i, 'epoch': epoch, 'loss': loss.item()})
        logging.info(f'Epoch took: {time.perf_counter()-ts:.2f}s')
    
    logging.info(f'Finished training. {epochs} epochs took: {time.perf_counter()-ts_train:.2f}s')
    return data, model

def sample(lnprobs, temperature=1.0):
    """
    Sample an element from a categorical distribution
    :param lnprobs: Outcome logits :param temperature: Sampling temperature. 1.0 follows the given
    distribution, 0.0 returns the maximum probability element.
    :return: The index of the sampled element. """ 
    if temperature == 0.0:
        return lnprobs.argmax()
    p = F.softmax(lnprobs / temperature, dim=0) 
    cd = dist.Categorical(p)
    return cd.sample()

    
def decode(seq, i2w):
    return [i2w[tok] for tok in seq]

def sample_model(model: nn.Module, seed_seq: List[int], end_token : int, device: torch.device, max_len:int=20, temperature: float =1.0):
    seed_seq = deepcopy(seed_seq)

    with torch.no_grad():
        while len(seed_seq) <= max_len and seed_seq[-1] != end_token:
            seed_tensor = torch.LongTensor(seed_seq).view(1,-1)
            seed_tensor = seed_tensor.to(device)
        
            logits = model(seed_tensor)
            
            logits = logits.cpu().squeeze(0)
            last_logits = logits[-1, :]

            out_tok = int(sample(last_logits, temperature))
            seed_seq.append(out_tok)
            
    return seed_seq

def delete_by_indices(lst, indices):
    indices_as_set = set(indices)
    return [ lst[i] for i in range(len(lst)) if i not in indices_as_set ]

def eval_ndfa(samples, w2i, i2w):
  words = (['abc!', 'uvw!', 'klm!'])
  abc = [w2i[x] for x in words[0]]
  uvw =  [w2i[x] for x in words[1]]
  klm =  [w2i[x] for x in words[2]]
  words = [abc, uvw, klm]
 
  correct = 0
  for sample in samples:
    
    # sample =sample[0].tolist()
    
    # Delete .start and .end

    sample = sample[1:-1]
    # sample.pop(0)
    print(" ".join(decode(sample, i2w)))
    
    # sample.pop(-1)

    if sample[0] != w2i['s'] or sample[-1] != w2i['s']:
      print('Not start / end with s')
      continue
    
    if w2i['.unk'] in sample or w2i['.start'] in sample:
      print('unk or start in middle')
      continue
    
    # First and last element MUST BE s at this point, delete them:
    
    if(len(sample) < 2): continue
    sample.pop(0)
    
    sample.pop(-1)

    if(len(sample) == 0): 
      correct += 1
      continue

    if w2i['s'] in sample:
      print('rogue s spotted')
      continue
    
    if len(sample) % 4 != 0:
      print('words not % 4')
      continue

    if sample[0:4] == words[0]:
      while len(sample) >= 4:
        sample = delete_by_indices(sample, [0,1,2,3])
      
        if len(sample) == 0: 
          correct += 1
          continue

        if sample[0:4] != words[0]:
          print('different word 0 or sth')
          continue 

    if sample[0:4] == words[1]:
      while len(sample) >= 4:
        sample = delete_by_indices(sample, [0,1,2,3])
        if len(sample) == 0:
           correct += 1
           continue

        if sample[0:4] != words[1]:
          print('different word 1 or sth')
          continue 

    if sample[0:4] == words[2]:
      while len(sample) >= 4:
        sample = delete_by_indices(sample, [0,1,2,3])
       
        if len(sample) == 0:
           correct += 1
           continue

        if sample[0:4] != words[2]:
          print('different word or sth')
          continue 

   

  accuracy = correct / len(samples[0])
  
  return accuracy

def main(dataset: str, max_chars_per_batch: int, net_hparams:HParams, n_samples: int, max_len: int, temperature: float, device, loss_fig='q456.png'):
    logging.info(f'Loading dataset: {dataset}')
    dataset_loader = dataset_loaders[dataset]
    x_train, (i2w, w2i) = dataset_loader(n=150_000)
    
    logging.info(f'Creating batches of max chars: {max_chars_per_batch}')
    batches = prepare_batches(x_train, i2w, w2i, max_chars_per_batch)


    logging.info(f'Training on: {device}')

    model = recurNet(
        embedding_dim=net_hparams.EMB_DIM, 
        num_layers=net_hparams.N_LAYERS,
        hidden_size=net_hparams.N_HIDDEN, 
        vocab_size=len(i2w)
    )
    
    data, model = train(
        model=model,
        epochs=net_hparams.EPOCHS, 
        batches=batches, 
        device=device,
        lr=net_hparams.LEARNING_RATE
    )

    df = pd.DataFrame(data)

    logging.info(f'saving progress to: {loss_fig}')
    df.groupby(by='epoch').mean()['loss'].plot()
    plt.savefig(loss_fig)

    return model, w2i, i2w

def generate_samples(seq, model, n_samples, device, max_len, temperature):

    logging.debug(f'Sampling from model with init seed: {", ".join(decode(seq, i2w))}')

    samples = []
    for i in range(n_samples):
        out_seq = sample_model(model=model, seed_seq=seq, end_token= w2i['.end'], device=device, max_len=max_len, temperature=temperature)
        out_seq_str = decode(out_seq, i2w)
        logging.debug(f'Output-{i}: {", ".join(out_seq_str)} [len={len(out_seq)}]')
        samples.append(out_seq)
    return samples


In [None]:

# %%
dataset_loader = dataset_loaders[DATASET]
x_train, (i2w, w2i) = dataset_loader(n=150_000)

logging.info(f'Creating batches of max chars: {MAX_CHARS_PER_BATCH}')
batches = prepare_batches(x_train, i2w, w2i, MAX_CHARS_PER_BATCH)

# %%
print_some_batches(batches, i2w)
# %%
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 
DATASET = 'ndfa'
model, w2i, i2w = main(
    dataset=DATASET,
    max_chars_per_batch=MAX_CHARS_PER_BATCH,
    net_hparams=HParams,
    n_samples=EVAL_N_SAMPLES,
    max_len=EVAL_SEQ_MAX_LEN,
    temperature=TEMPERATURE,
    device=device

)


In [None]:

def eval_ndfa_on_test_sequences():
    test_sequences = [
        [w2i['.start'], w2i['s'], w2i['k'], w2i['l']],
        [w2i[c] for c in '.start s a b c ! a '.split()]
    ]
    for seq in test_sequences:
        print(f'input | {" ".join(decode(seq,i2w))}')
        samples = generate_samples(seq, model, 10, device, EVAL_SEQ_MAX_LEN, TEMPERATURE)
        acc = eval_ndfa(samples, w2i, i2w)
        print(f'Accuracy: {acc}')

eval_ndfa_on_test_sequences()

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 
DATASET = 'brackets'
class HParams:
    EMB_DIM = 32
    N_HIDDEN = 16
    N_LAYERS = 1
    EPOCHS = 50
    LEARNING_RATE = 0.01

model, w2i, i2w = main(
    dataset=DATASET,
    max_chars_per_batch=MAX_CHARS_PER_BATCH,
    net_hparams=HParams,
    n_samples=EVAL_N_SAMPLES,
    max_len=EVAL_SEQ_MAX_LEN,
    temperature=TEMPERATURE,
    device=device,
    loss_fig='brackets.png'
)

In [None]:
def eval_brackets_on_test_sequences(model, device):
    test_sequences = [
        [w2i['.start'], w2i['('], w2i['('], w2i[')']],
    ]
    for seq in test_sequences:
        print(f'input | {" ".join(decode(seq,i2w))}')
        samples = generate_samples(seq, model, 10, device, EVAL_SEQ_MAX_LEN, TEMPERATURE)
        for s in samples:
            print(' '.join(decode(s,i2w)))
        # acc = eval_ndfa(samples, w2i, i2w)
        # print(f'Accuracy: {acc}')
eval_brackets_on_test_sequences(model, device)
# %%
