In [25]:
import json
import torch, torch.nn as nn
from torch import nn, optim, tensor
from torch.autograd import Variable
import random
from torch.nn.utils.rnn import pad_sequence

filepath = "/Users/willi/OneDrive/Documents/GitHub/LING-111-Project/preprocessed_data/bnc_spoken_output_flattened.json"

with open(filepath) as infile:
    input_data = json.load(infile)
    text = input_data[:1000]

seq_len = len(text)
batch_size = 1
embedding_size = 50
hidden_size = 100

vocab = set()

for sentence in text:
    vocab = vocab.union(sentence)

vocab.update(["<unk>", "<pad>"])

vocab = list(vocab)
vocab_size = len(vocab)
output_size = vocab_size
word_ids = {word: id_ for id_, word in enumerate(vocab)}
ids_word = {id_: word for word, id_ in word_ids.items()}
padding_idx = word_ids.get("<pad>")
max_len = max(len(sentence) for sentence in text)

indexed_sentences = [
    torch.LongTensor([word_ids.get(word, word_ids["<unk>"]) for word in sentence])
    for sentence in text 
]

text_tensor = pad_sequence(indexed_sentences, batch_first=True, padding_value=padding_idx)

In [26]:
class biRNNLM(nn.Module):

    def __init__(self, vocab, hidden_size, freeze_embeddings=True,
                 recurrent_activation="tanh", recurrent_layers=1, recurrent_bidirectional=False):
        super(biRNNLM, self).__init__()
        
        self.vocab = vocab
        
        self.embedding = nn.Embedding(vocab_size, 10)
        
        # The embeddings go into an RNN layer with recurrent_dim units
        self.bi_rnn = torch.nn.RNN(input_size=10, hidden_size=hidden_size, num_layers=1, batch_first=False, bidirectional=True)

        self.linear = nn.Linear(hidden_size * 2, output_size)
        
        self.loss_function = nn.CrossEntropyLoss()
        
    def forward(self, text=text_tensor, seq_lengths=max_len):
        unk_index = word_ids.get('<unk>')  # fallback index
        indices = torch.LongTensor([
        word_ids[word] if word in word_ids else unk_index
        for word in text
        ]).unsqueeze(1)
        word_embeddings = self.embedding(indices)
        
        bi_output, bi_hidden = self.bi_rnn(word_embeddings)

        # stagger
        forward_output, backward_output = bi_output[:-2, :, :hidden_size], bi_output[2:, :, hidden_size:]
        staggered_output = torch.cat((forward_output, backward_output), dim=-1)

        # only predict on words
        labels = torch.LongTensor(indices[1:-1])

        logits = self.linear(staggered_output.squeeze(1))  # expected shape [T, vocab_size]       
        target = labels.squeeze(1).long()
        loss = self.loss_function(logits, target)
        return loss

In [None]:
model = biRNNLM(vocab=vocab, hidden_size=hidden_size)
optimizer = optim.Adam(model.parameters())

def train(model=model, optimizer=optimizer, epochs=10, print_every=1,
          validation_data=None):
    current_loss = 0.0
    minibatches_per_log = len(text)
    for epoch in range(epochs):
        # Within each epoch, iterate over the data in mini-batches
        # Note the use of *datapoint_list for generality, whether or not there are offsets
        for sentence in text:
            
            model.zero_grad()

            loss = model(sentence)
            loss.backward()
            optimizer.step()

            if print_every is not None:
                current_loss += loss.item()

            # Log performance
        if print_every is not None and (epoch + 1) % print_every == 0:
            log_message = ('| epoch {:3d} | train loss {:6.3f} |'
                             .format(epoch + 1, current_loss/minibatches_per_log))
            if validation_data is not None:
                validation_performance = test(model, validation_data)
                log_message += 'valid loss {loss:6.3f} |'.format(**validation_performance)
            print(log_message)

            # Reset trackers after logging
            current_loss = 0.0
            model.train()
train(model, optimizer, 10)

| epoch   1 | train loss  5.797 |
| epoch   2 | train loss  4.864 |
| epoch   3 | train loss  4.332 |
| epoch   4 | train loss  3.863 |
