In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from Data.Library import Library
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
class RNN(nn.Module):
    def __init__(self, vocab_size, d_model, d_internal, device = torch.device('cpu')):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.d_internal = d_internal
        self.device = device

        # Define model shape
        self.embeddings = nn.Embedding(self.vocab_size, self.d_model)
        self.recurrent_layer = nn.Linear(self.d_model, self.d_model).to(self.device)
        self.recurrent_activation = nn.ReLU().to(self.device)
        self.feed_forward = nn.Linear(self.d_model, self.d_internal).to(self.device)
        self.feed_forward_activation = nn.ReLU().to(self.device)
        self.output_layer = nn.Linear(self.d_internal, self.vocab_size).to(self.device)
        self.log_softmax = nn.LogSoftmax(dim=-1).to(self.device)

    
    def forward(self, sequence):
        batch_size, seq_length = sequence.shape
        x = torch.zeros([batch_size, self.d_model])
        xs = torch.zeros([batch_size, seq_length, self.d_model]).to(self.device)
        for idx in range(seq_length):
            # Recurrent step
            x = self.embeddings(sequence[:, idx]) + x # Add current token embedding to previous recurrent output
            x = self.recurrent_layer(x.to(self.device)) # Apply recurrent layer
            x = self.recurrent_activation(x) # Apply activation function
            xs[:, idx, :] = x
            x = x.to(torch.device('cpu'))
        # Construct output
        xs = self.feed_forward(xs)
        xs = self.feed_forward_activation(xs)
        xs = self.output_layer(xs)
        xs = self.log_softmax(xs)
        return xs.permute(0, 2, 1).to(torch.device('cpu'))

In [None]:
# Hyperparams
epochs = 16
lr = .0001
seq_length=512
batch_size= 512
d_model = 256
d_internal = 512
train_size = 2**12


# Setup
device = torch.device('mps')
library = Library(encoding = 27, train_size = train_size, streaming=False)
model = RNN(
    vocab_size=library.encoding.max_token_value,
    d_model = d_model,
    d_internal = d_internal,
    device=device
    )
loss_fn = nn.NLLLoss()
optim = torch.optim.Adam(model.parameters(), lr=lr)

x_batch = torch.zeros([batch_size, seq_length])
y_batch = torch.zeros([batch_size, seq_length])
losses = torch.zeros(epochs)
perplexities = torch.zeros(epochs)
# Training
for epoch in range(epochs):
    dataloader = library.get_train_dataloader(seq_length+1)
    for idx, data in enumerate(dataloader):
        mod_idx = idx % batch_size
        x_batch[mod_idx] = data[:-1]
        y_batch[mod_idx] = data[1:]
        if mod_idx == batch_size-1:
            # Update weights
            optim.zero_grad()
            y_pred = model(x_batch.long())
            loss = loss_fn(y_pred, y_batch.long())
            losses[epoch] += loss
            loss.backward()
            print(loss)
            optim.step()
    # Test
    perplexities[epoch] = library.calc_perplexity(model)
    print(f'perplexity:{perplexities[epoch]:.4f}')