In [None]:
'''This is a Simple Recurrent Network'''

from __future__ import print_function
import torch
import torch.nn as nn
import numpy as np
import matplotlib
%matplotlib inline
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram, linkage

### Load Data ###
def sentenceToTensor(tokens_list):
    # Convert list of strings to tensor of token indices (integers)
    #
    # Input
    #  tokens_list : list of strings, e.g. ['<SOS>','lion','eat','man','<EOS>']
    # Output
    #  1D tensor of the same length (integers), e.g., tensor([ 2, 18, 13, 19,  0])
    assert(isinstance(tokens_list,list))
    tokens_index = [token_to_index[token] for token in tokens_list]
    return torch.tensor(tokens_index)

# load and process the set of simple sentences
with open('data/elman_sentences.txt','r') as fid:
    lines = fid.readlines()
sentences_str = [l.strip() for l in lines]
sentences_tokens = [s.split() for s in sentences_str]
sentences_tokens = [['<SOS>']+s+['<EOS>'] for s in sentences_tokens]
unique_tokens = sorted(set(sum(sentences_tokens,[])))
n_tokens = len(unique_tokens) # all words and special tokens
token_to_index = {t : i for i,t in enumerate(unique_tokens)}
index_to_token = {i : t for i,t in enumerate(unique_tokens)}
training_pats = [sentenceToTensor(s) for s in sentences_tokens] # python list of 1D sentence tensors
ntrain = len(training_pats)
print('mapping unique tokens to integers: %s \n' % token_to_index)
print('example sentence as string: %s \n' % ' '.join(sentences_tokens[0]))
print('example sentence as tensor: %s \n' % training_pats[0])

### Build Neural Network ###
# Word level SRN, only take single word at a time
class SRN(nn.Module):
    
    def __init__(self, vocab_size, hidden_size):
        # vocab_size : number of tokens in vocabulary including special tokens <SOS> and <EOS>
        # hidden_size : dim of input embeddings and hidden layer
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.embed = nn.Embedding(vocab_size,hidden_size)
        self.fc1 = nn.Linear(2*hidden_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, vocab_size)

    def forward(self, input_token_index, hidden_prev):
        # Input
        #    input_token_index: [integer] index of current input token
        #    hidden_prev: [length hidden_size 1D tensor] hidden state from previous step
        # Output
        #    output: [length vocab_size 1D tensor] log-probability of emitting each output token
        #    hidden_curr : [length hidden_size 1D tensor] hidden state for current step
        input_embed = self.embed(input_token_index) # hidden_size 1D tensor
        hidden_curr = nn.functional.sigmoid(self.fc1(torch.cat((input_embed, hidden_prev))))
        output = nn.functional.log_softmax(self.fc2(hidden_curr))
        return output, hidden_curr

    def initHidden(self):
        # Returns length hidden_size 1D tensor of zeros
        return torch.zeros(self.hidden_size)
    
    def get_embeddings(self):
        # Returns [vocab_size x hidden_size] numpy array of input embeddings
        return self.embed(torch.arange(self.vocab_size)).detach().numpy()
    
### Train Neural Network ###
def train(seq_tensor, rnn):
    # Process a sentence and update the SRN weights. With <SOS> as the input at step 0,
    # predict every subsequent word given the past words.
    # Return the mean loss across each symbol prediction.
    #
    # Input
    #   seq_tensor: [1D tensor] sentence as token indices
    #   rnn : instance of SRN class
    # Output
    #   loss : [scalar] average NLL loss across prediction steps
    rnn.train()
    hidden_prev = rnn.initHidden()
    output = torch.zeros(len(seq_tensor)-1, rnn.vocab_size)
    rnn.zero_grad()
    for i in range(len(seq_tensor)-1):
        output[i], hidden_prev = rnn(seq_tensor[i], hidden_prev)
        
    loss = criterion(output, seq_tensor[1:])
    return loss

# Main training loop
nepochs = 10 # number of passes through the entire training set 
nhidden = 20 # number of hidden units in the SRN
rnn = SRN(n_tokens,nhidden)
optimizer = torch.optim.AdamW(rnn.parameters(), weight_decay=0.04) # w/ default learning rate 0.001
criterion = nn.NLLLoss()
n = len(training_pats)
for i in range(nepochs):
    perm = np.random.permutation(n)
    train_error = 0
    for p in perm:
        loss = train(training_pats[p], rnn)
        loss.backward()
        optimizer.step()
        train_error += loss.item()
    print(train_error/float(n))
    
### Analyze the SRN Internal Representation ###
def plot_dendo(X, names, exclude=['<SOS>','<EOS>']):
    #  Show hierarchical clustering of vectors 
    #
    # Input
    #  X : numpy tensor [nitem x dim] such that each row is a vector to be clustered
    #  names : [length nitem] list of item names
    #  exclude: list of names we want to exclude       
    nitem = len(names)
    names  = np.array(names)
    include = np.array([myname not in exclude for myname in names], dtype=bool)
    linked = linkage(X[include],'single', optimal_ordering=True)
    plt.figure(1, figsize=(20,6))
    dendrogram(linked, labels=names[include], color_threshold=0, leaf_font_size=18)
    plt.show()

plot_dendo(rnn.get_embeddings(), unique_tokens)

### Generate Sentences ###
def generate(rnn, maxlen=4):
    hidden_prev = rnn.initHidden()
    sentence = []
    word_index = torch.tensor(1)
    for i in range(maxlen):
        output_tens, hidden_prev = rnn.forward(word_index, hidden_prev)
        cat = torch.distributions.categorical.Categorical(output_tens)
        word_index = cat.sample()
        sentence.append(index_to_token[word_index.item()])
    return print(' '.join(sentence))
    
for i in range(10):
    generate(rnn)