# Sequence Models and LSTM Networks - Probablistic Model with Pyro and Penn Treebank
Sequence Tagger: https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html<br>
Bayesian NN: https://github.com/paraschopra/bayesian-neural-network-mnist/blob/master/bnn.ipynb<br>
Penn Treebank: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.9.8216&rep=rep1&type=pdf

https://gist.github.com/williamFalcon/f27c7b90e34b4ba88ced042d9ef33edd <br>
https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel

In [None]:
import numpy as np
import nltk
from nltk.corpus import treebank
import os
import codecs

In [None]:
nltk.download('treebank')
nltk.download('universal_tagset')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)

In [None]:
from IPython.display import clear_output

In [None]:
import pyro
from pyro.distributions import Normal, Categorical
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

pyro.enable_validation(True)
pyro.clear_param_store()

# An LSTM for Part-of-Speech Tagging

### Load Data

In [None]:
# Penn tree bank
sentences = treebank.tagged_sents(tagset='universal')

In [None]:
samples = 5000
sentences = sentences[:samples]

In [None]:
def format_sequence(seq):
    """
    Formats penn treebank POS format into tuple ([tokens], [POS])
    """
    tokens = [x[0] for x in seq]
    tags = [x[1] for x in seq]
    return (tokens, tags)

In [None]:
sentences = [format_sequence(sentence) for sentence in sentences]

In [None]:
def data_vocab(sentences):
    """Builds vocab based on input data"""
    vocab = dict()
    for sentence in sentences:
        for word in sentence[0]:
            if word not in vocab:
                vocab[word] = len(vocab)
    return vocab

In [None]:
# Vocab of input data (this will likely be a subset of any word embedding array)
data_vocab = data_vocab(sentences)

## Prepare data

In [None]:
def prepare_sequence(seq, to_ix):
    """Encodes sentence tokens as ids from word_to_ix dictionary"""
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

In [None]:
# Train/Test split
split_ratio = 0.80
training_data = sentences[:int(len(sentences)*split_ratio)]
test_data = sentences[len(training_data):]

In [None]:
print(f'Dataset Size: {len(sentences)} | Training Set Size: {len(training_data)} | Test Set Size: {len(test_data)}')

In [None]:
word_to_ix = {}
for sent, tags in sentences:   # training_data
#     print(sent, tags)
    for word in sent:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
# print(word_to_ix)

In [None]:
# Create tag-index lookups
tag_to_ix = {}
for _, tags in sentences:
    for tag in tags:
        if tag not in tag_to_ix:
            tag_to_ix[tag] = len(tag_to_ix)

ix_to_tag = {v:k for k, v in tag_to_ix.items()}

In [None]:
print(f'Word dictionary size: {len(word_to_ix)}')
print(f'Tag dictionary size: {len(tag_to_ix)}')

### Create LSTM model

In [None]:
class LSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, batch_size=32, pretrained_embeddings=None):
        super(LSTMTagger, self).__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        
        
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        if pretrained_embeddings is not None:
            self.word_embeddings.weight.data.copy_(torch.from_numpy(pretrained_embeddings))
            self.word_embeddings.weight.requires_grad = False
        
        self.hidden = self.init_hidden()

        # The LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim.
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)    # , batch_first=True

        # The linear layer that maps from hidden state space to tag space
        self.out = nn.Linear(hidden_dim, tagset_size)

    def forward(self, sentence):
        
        self.hidden = self.init_hidden()
        
        embeds = self.word_embeddings(sentence)
        lstm_out, _ = self.lstm(embeds.view(len(sentence), 1, -1))
        tag_space = self.out(lstm_out.view(len(sentence), -1))
        tag_scores = F.log_softmax(tag_space, dim=1)
        return tag_scores
    
    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_dim))

### Initialise the NN model

In [None]:
path_to_embeddings = './data/embeddings/glove.6B.300d.txt'
path_to_trimmed_embeddings = './data/embeddings/trimmed_emb.npz'

In [None]:
# Load trimmed embeddings from disk
pretrained_embeddings = np.load(path_to_trimmed_embeddings)

In [None]:
# Check embedding shape
pretrained_embeddings['embeddings'].shape

In [None]:
EMBEDDING_DIM = 300   # Glove 300
HIDDEN_DIM = 32

In [None]:
lstm_net = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix), pretrained_embeddings=pretrained_embeddings['embeddings'])
loss_function = nn.NLLLoss()
optimizer = optim.SGD(lstm_net.parameters(), lr=0.1)

In [None]:
print(lstm_net)

### Train standard NN model

In [None]:
for epoch in range(5):  # again, normally you would NOT do 300 epochs, it is toy data
    for sentence, tags in training_data:
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        lstm_net.zero_grad()

        # Step 2. Get our inputs ready for the network, that is, turn them into
        # Tensors of word indices.
        sentence_in = prepare_sequence(sentence, word_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)

        # Step 3. Run our forward pass.
        tag_scores = lstm_net(sentence_in)

        # Step 4. Compute the loss, gradients, and update the parameters by
        #  calling optimizer.step()
        loss = loss_function(tag_scores, targets)
        loss.backward()
        optimizer.step()
    
    if epoch % 1 == 0:
        print(f'Epoch: {epoch} - Loss: {loss}')

In [None]:
# helper function for deterministic nn inference
def tag_score_to_tag_name(tag_score, ix_to_tag):
    """Converts tag score to tag names"""
    return ix_to_tag.get(torch.argmax(tag_score).item())

In [None]:
# Single test example
test_data_sm = test_data[:1]

In [None]:
# Inference
with torch.no_grad():
    inputs = prepare_sequence(test_data_sm[0][0], word_to_ix)
    tag_scores = lstm_net(inputs)
    
#     print(f'Tag Scores:\n{tag_scores}\n')
    print(f'{"Token":<20} {"Pred":<10} {"Actual":<10}')
    print(f'{"-----":<20} {"----":<10} {"------":<10}')
    for i, token in enumerate(training_data[0][0]):
        print(f'{token:<20} {tag_score_to_tag_name(tag_scores[i], ix_to_tag):<10} {test_data_sm[0][1][i]:<10}')

### Initialise Probabilistic Pyro Model

Ref:<br>
- https://forum.pyro.ai/t/bayesian-rnn-nan-loss-issue/254

- Loc = mean, Scale = standard deviation
- mu = 0, sigma = 1 -> Unit Gaussian distribution

Refs:<br>
- https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html
- https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html

See fix for issues relating to lognorm: https://github.com/pyro-ppl/pyro/issues/1409 

In [None]:
# Re-initialise lstm_net (in case it was being used previously)
lstm_net = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix), pretrained_embeddings=pretrained_embeddings['embeddings'])

In [None]:
def model(input, target):
    
    # Embeddings (current model uses pre-trained embeddings)
#     word_embeddings_w_prior = Normal(loc=torch.zeros_like(lstm_net.word_embeddings.weight),
#                                      scale=torch.ones_like(lstm_net.word_embeddings.weight)).independent(2)
    
    # LSTM
    lstm_w_ih_l0_prior = Normal(loc=torch.zeros_like(lstm_net.lstm.weight_ih_l0),
                          scale=torch.ones_like(lstm_net.lstm.weight_ih_l0)).independent(2)
    lstm_w_hh_l0_prior = Normal(loc=torch.zeros_like(lstm_net.lstm.weight_hh_l0),
                          scale=torch.ones_like(lstm_net.lstm.weight_hh_l0)).independent(2)
    lstm_b_ih_l0_prior = Normal(loc=torch.zeros_like(lstm_net.lstm.bias_ih_l0),
                          scale=torch.ones_like(lstm_net.lstm.bias_ih_l0)).independent(1)
    lstm_b_hh_l0_prior = Normal(loc=torch.zeros_like(lstm_net.lstm.bias_hh_l0),
                          scale=torch.ones_like(lstm_net.lstm.bias_hh_l0)).independent(1)
    
    # Output
    out_w_prior = Normal(loc=torch.zeros_like(lstm_net.out.weight),
                          scale=torch.ones_like(lstm_net.out.weight)).independent(2)
    out_b_prior = Normal(loc=torch.zeros_like(lstm_net.out.bias),
                          scale=torch.ones_like(lstm_net.out.bias)).independent(1)
    
    
    priors = {'lstm.weight_ih_l0': lstm_w_ih_l0_prior,
              'lstm.weight_hh_l0': lstm_w_hh_l0_prior,
              'lstm.bias_ih_l0': lstm_b_ih_l0_prior,
              'lstm.bias_hh_l0': lstm_b_hh_l0_prior,
              'out.weight': out_w_prior,
              'out.bias': out_b_prior}   # 'word_embeddings.weight': word_embeddings_w_prior,
    
    # Lift module parameters to random variables sampled from the priors
    lifted_module = pyro.random_module("module", lstm_net, priors)
    
    # Sample a regressor (which also samples w and b)
    lifted_reg_model = lifted_module()
    
#     lhat = log_softmax(lifted_reg_model(input))
    output = lifted_reg_model(input)
    
    pyro.sample("obs", Categorical(logits=output).independent(1), obs=target)

In [None]:
softplus = torch.nn.Softplus()

def guide(input, target):
    
    # Embedding layer weight distribution priors (current model uses pre-trained embeddings)
#     word_embeddings_w_mu = torch.randn_like(lstm_net.word_embeddings.weight)
#     word_embeddings_w_sigma = torch.randn_like(lstm_net.word_embeddings.weight)
#     word_embeddings_w_mu_param = pyro.param("word_embeddings_w_mu", word_embeddings_w_mu)
#     word_embeddings_w_sigma_param = softplus(pyro.param("word_embeddings_w_sigma", word_embeddings_w_sigma))
#     word_embeddings_w_prior = Normal(loc=word_embeddings_w_mu_param, scale=word_embeddings_w_sigma_param).independent(2)
    
    # LSTM layer weight distribution priors
    lstm_w_ih_l0_mu = torch.randn_like(lstm_net.lstm.weight_ih_l0)
    lstm_w_ih_l0_sigma = torch.randn_like(lstm_net.lstm.weight_ih_l0)
    lstm_w_ih_l0_mu_param = pyro.param("lstm_w_ih_l0_mu", lstm_w_ih_l0_mu)
    lstm_w_ih_l0_sigma_param = softplus(pyro.param("lstm_w_ih_l0_sigma", lstm_w_ih_l0_sigma))
    lstm_w_ih_l0_prior = Normal(loc=lstm_w_ih_l0_mu_param, scale=lstm_w_ih_l0_sigma_param).independent(2)
    
    lstm_w_hh_l0_mu = torch.randn_like(lstm_net.lstm.weight_hh_l0)
    lstm_w_hh_l0_sigma = torch.randn_like(lstm_net.lstm.weight_hh_l0)
    lstm_w_hh_l0_mu_param = pyro.param("lstm_w_hh_l0_mu", lstm_w_hh_l0_mu)
    lstm_w_hh_l0_sigma_param = softplus(pyro.param("lstm_w_hh_l0_sigma", lstm_w_hh_l0_sigma))
    lstm_w_hh_l0_prior = Normal(loc=lstm_w_hh_l0_mu_param, scale=lstm_w_hh_l0_sigma_param).independent(2)
    
    # LSTM layer bias distribution priors
    lstm_b_ih_l0_mu = torch.randn_like(lstm_net.lstm.bias_ih_l0)
    lstm_b_ih_l0_sigma = torch.randn_like(lstm_net.lstm.bias_ih_l0)
    lstm_b_ih_l0_mu_param = pyro.param("lstm_b_ih_l0_mu", lstm_b_ih_l0_mu)
    lstm_b_ih_l0_sigma_param = softplus(pyro.param("lstm_b_ih_l0_sigma", lstm_b_ih_l0_sigma))
    lstm_b_ih_l0_prior = Normal(loc=lstm_b_ih_l0_mu_param, scale=lstm_b_ih_l0_sigma_param).independent(1)
    
    lstm_b_hh_l0_mu = torch.randn_like(lstm_net.lstm.bias_hh_l0)
    lstm_b_hh_l0_sigma = torch.randn_like(lstm_net.lstm.bias_hh_l0)
    lstm_b_hh_l0_mu_param = pyro.param("lstm_b_hh_l0_mu", lstm_b_hh_l0_mu)
    lstm_b_hh_l0_sigma_param = softplus(pyro.param("lstm_b_hh_l0_sigma", lstm_b_hh_l0_sigma))
    lstm_b_hh_l0_prior = Normal(loc=lstm_b_hh_l0_mu_param, scale=lstm_b_hh_l0_sigma_param).independent(1)
    
    # Output layer weight distribution priors
    out_w_mu = torch.randn_like(lstm_net.out.weight)
    out_w_sigma = torch.randn_like(lstm_net.out.weight)
    out_w_mu_param = pyro.param("out_w_mu", out_w_mu)
    out_w_sigma_param = softplus(pyro.param("out_w_sigma", out_w_sigma))
    out_w_prior = Normal(loc=out_w_mu_param, scale=out_w_sigma_param).independent(2)
    
    # Output layer bias distribution priors
    out_b_mu = torch.randn_like(lstm_net.out.bias)
    out_b_sigma = torch.randn_like(lstm_net.out.bias)
    out_b_mu_param = pyro.param("out_b_mu", out_b_mu)
    out_b_sigma_param = softplus(pyro.param("out_b_sigma", out_b_sigma))
    out_b_prior = Normal(loc=out_b_mu_param, scale=out_b_sigma_param).independent(1)
    
    priors = {'lstm.weight_ih_l0': lstm_w_ih_l0_prior,
              'lstm.weight_hh_l0': lstm_w_hh_l0_prior,
              'lstm.bias_ih_l0': lstm_b_ih_l0_prior,
              'lstm.bias_hh_l0': lstm_b_hh_l0_prior,
              'out.weight': out_w_prior,
              'out.bias': out_b_prior}    # 'word_embeddings.weight': word_embeddings_w_prior,
    
    lifted_module = pyro.random_module("module", lstm_net, priors)
    
    return lifted_module()

In [None]:
inference = SVI(model, guide, Adam({"lr": 0.01}), loss=Trace_ELBO())

In [None]:
# TODO: update to use batches; atm its single sample... very slow

num_iterations = 10
loss = 0
for j in range(num_iterations):
    loss = 0
    for sentence, tags in training_data:
        sentence_in = prepare_sequence(sentence, word_to_ix)
        targets = prepare_sequence(tags, tag_to_ix)
        
        # Calculate loss and take gradient step
        loss += inference.step(sentence_in, targets)
    
    total_epoch_loss_train = loss / len(training_data)
    
    if j % 1 == 0:
        clear_output(wait=True)
        print(f'Epoch {j} - Loss {total_epoch_loss_train:0.4f}')

In [None]:
def predict(x, num_samples):
    """
    
    """
    # Initialise set of probablistic models for inference
    sampled_models = [guide(None, None) for _ in range(num_samples)]
    
    yhats = [model(x).data for model in sampled_models]
    print(f'\nyhats:\n{yhats}')
    mean = torch.mean(torch.stack(yhats), 0)
    print(f'\nMean:\n{mean}')
    return np.argmax(mean.numpy(), axis=1)

In [None]:
# helper function
def tag_score_to_tag_name(tag_score, ix_to_tag):
    """
    Converts tag score to tag names
    """
    if type(tag_score).__module__ == np.__name__:
        return ix_to_tag.get(np.argmax(tag_score))
    if torch.is_tensor(tag_score):
        return ix_to_tag.get(torch.argmax(tag_score).item())

In [None]:
test_data_sm = test_data[:1]

In [None]:
# Predictions
num_samples = 10
correct = 0
total = 0
for j, data in enumerate(test_data_sm):
    sentence, tags = data
    sentence_in = prepare_sequence(sentence, word_to_ix)
    print(sentence, tags)
    
    # Convert tags into their indexes in tag dictionary
    tag_indices = np.array([tag_to_ix.get(tag) for tag in tags])
    
    predicted = predict(sentence_in, num_samples)
    total += len(tags)
    correct += (predicted == tag_indices).sum()
    
    for i, token in enumerate(sentence):
        print(f'{token:<10} {ix_to_tag.get(predicted[i])}')
    print('\n')

print(f'Accuracy: {correct/total * 100:0.1f}%')