In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from Data.Library import Library
import matplotlib.pyplot as plt
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class RNN(nn.Module):
    def __init__(self, vocab_size, d_model, d_internal, device = torch.device('cpu')):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.d_internal = d_internal
        self.device = device

        # Define model shape
        self.embeddings = nn.Embedding(self.vocab_size, self.d_model)
        self.fc1 = nn.Linear(self.d_model, self.d_internal).to(self.device)
        self.fc2 = nn.Linear(self.d_internal, self.d_model).to(self.device)
        self.output_layer = nn.Linear(self.d_model, self.vocab_size).to(self.device)
        self.log_softmax = nn.LogSoftmax(dim=-1).to(self.device)

    
    def forward(self, sequence):
        batch_size, seq_length = sequence.shape
        x = torch.zeros([batch_size, self.d_model])
        xs = torch.zeros([batch_size, seq_length, self.d_model]).to(self.device)
        for idx in range(seq_length):
            # Recurrent step
            x = self.embeddings(sequence[:, idx]) + x # Add current token embedding to previous recurrent output            
            x = self.fc1(x.to(self.device)) # Apply recurrent layer
            x = F.relu(x)
            x = self.fc2(x)
            x = F.relu(x)
            xs[:, idx, :] = x
            x = x.to(torch.device('cpu'))
        # Construct output
        xs = self.output_layer(xs)
        xs = self.log_softmax(xs)
        return xs.permute(0, 2, 1).to(torch.device('cpu'))    
            

In [None]:
# Hyperparams
epochs = 128
lr = .001
seq_length=1024
batch_size= 32
d_model = 256
d_internal = 512
train_size = 2**20
encoding=1000
torch.manual_seed(0)

# Setup
device = torch.device('cuda')
library = Library(encoding = encoding, train_size = train_size, streaming=False)

model = RNN(
    vocab_size=library.encoding.max_token_value,
    d_model = d_model,
    d_internal = d_internal,
    device=device
    )
loss_fn = nn.NLLLoss()
optim = torch.optim.Adam(model.parameters(), lr=lr)

x_batch = torch.zeros([batch_size, seq_length-1])
y_batch = torch.zeros([batch_size, seq_length-1])
losses = torch.zeros(epochs)
perplexities = torch.zeros(epochs)
print('Training')
# Training
for epoch in range(epochs):
    dataloader = library.get_train_dataloader(seq_length)
    for idx, data in enumerate(dataloader):
        mod_idx = idx % batch_size
        if data.shape[0] != seq_length:
            break # End of usable dataloader
        x_batch[mod_idx] = data[:-1]
        y_batch[mod_idx] = data[1:]
        if mod_idx == batch_size-1:
            # Update weights
            optim.zero_grad()
            y_pred = model(x_batch.long())
            loss = loss_fn(y_pred, y_batch.long())
            losses[epoch] += loss
            print(f'{epoch}:{idx+1}:{losses[epoch]}', end='\r')
            loss.backward()
            optim.step()
    # Test
    perplexities[epoch] = library.calc_perplexity(model)
    torch.save(model.state_dict(), f'Models/{encoding}.pkl')
    print(f'\n{epoch}:{perplexities[epoch]:.4f}')

Training
0:22688:2.3323722016905337e+20
0:293.0223
1:22688:3770.5834960937555
1:257.8623
3:22688:3674.4191894531255
3:236.0574
4:22688:3654.1682128906255
4:231.1693
5:22688:3639.7399902343755
5:227.8131
6:22688:3628.6945800781255
6:225.1880
7:22688:3619.5678710937555
7:223.1555
8:22688:3612.0295410156255
8:221.6118
9:1568:248.16328430175786

In [None]:
def shannon(model, encoder, length=100):
    current_sentence = torch.LongTensor(encoder.encode('[')).unsqueeze(0)
    for i in range(length):
        output = torch.exp(model(current_sentence))[0,:,-1]
        new_char=torch.distributions.categorical.Categorical(probs=output).sample()
        current_sentence = torch.cat((current_sentence, new_char.unsqueeze(0).unsqueeze(0)))
    return encoder.decode(current_sentence)
library = Library(encoding = '50k', train_size = train_size, streaming=False)
model.load_state_dict(torch.load( f'Models/{encoding}.pkl', weights_only=True))
shannon(model, library.encoding)

In [8]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp
get_n_params(model)

388712

In [16]:
import numpy as np
np.exp(4.337/127)

1.0347393986809896