In [1]:
from data_rnn import load_ndfa
import matplotlib.pyplot as plt
import numpy as np
from time import time

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributions as dist

In [None]:
n=150000
x_train, (i2w, w2i) = load_ndfa(n)

In [None]:
print(f'Dictionary:{w2i}')
print(f'Index:{i2w}')

In [None]:
def print_sequence(x_train, i):
    print(f"Sequence #{str(i).rjust(6, ' ')}: {''.join([i2w[i] for i in x_train[i]])}")

In [None]:
for i in np.random.randint(n, size=10):
    print_sequence(x_train, i)

In [None]:
x_train[74191]

In [None]:
def batch_length(batch):
    return max(len(seq) for seq in batch)

In [None]:
def add_padding(seq, amt=1):
    for _ in range(amt):
        seq.append(w2i['.pad'])
    return seq

In [None]:
def add_start(seq):
    seq.insert(0, w2i['.start'])
    return seq

In [None]:
def add_end(seq):
    seq.append(w2i['.end'])
    return seq

In [None]:
def preprocess_batch(batch):
    max_len = batch_length(batch)

    upd_batch = []
    for i, _ in enumerate(batch):
        seq = batch[i].copy()
        seq = add_start(seq)
        seq = add_end(seq)
        seq = add_padding(seq, amt=max_len + 2 - len(seq))
        upd_batch.append(seq)

    upd_batch = torch.tensor(upd_batch, dtype=torch.long)
    targets = torch.tensor(upd_batch, dtype=torch.long)[:, 1:]
    m = nn.ZeroPad2d((0, 1, 0, 0))
    targets = m(z)

    return upd_batch, targets

In [None]:
# def batch_generator(data, max_number_of_tokens=128):
#     total_tokens = 0
#     for i, seq in enumerate(data):
#         batch_start = i
#         if (total_tokens + len(data) < max_number_of_tokens):
#             total_tokens += len(seq)
#             batch_end = i + len(data)
#         batch = data[batch_start:batch_end]
#         yield preprocess_batch(batch)
#         yield(batch)

In [None]:
def batch_generator(data, batch_size=128):
    data = np.array(data)

    indx = np.random.permutation((len(data)))
    n_batches = int(len(data) / batch_size) + 1

    for i in range(n_batches):
        bound_l = batch_size*i
        bound_r = batch_size*(i+1) if i + 1 < n_batches else len(indx)

        batch_ind = indx[bound_l:bound_r]
        batch = data[batch_ind]

        yield preprocess_batch(batch)


In [None]:
batch_t, target_t = preprocess_batch(x_train[:5])
batch_t, target_t

In [None]:
for i, (x_batch, y_batch) in enumerate(batch_generator(x_train[:5])):
    print(i)
    print(x_batch)
    print(y_batch)

In [None]:
vocab_size=len(w2i)

In [None]:
embedding_size = 32
hidden_size = 16
lstm_num_layers = 1

In [None]:
class Net(nn.Module):
    def __init__(self,
                 vocab_size,
                 embedding_size,
                 hidden_size,
                 lstm_num_layers) -> None:
        super().__init__()

        self.embed = nn.Embedding(vocab_size, embedding_size)
        self.lstm = nn.LSTM(embedding_size, hidden_size, lstm_num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, vocab_size)

        def forward(self, x):
            input = self.embedding(x)
            lstm_output, (hn, cn) = self.lstm(input)
            output = self.linear(lstm_output)
            return output

In [None]:
net =  Net(vocab_size, embedding_size, hidden_size, lstm_num_layers)

In [None]:
net


In [None]:
def sample(lnprobs, temperature=1.0):
    """
     Sample an element from a categorical distribution
     :param lnprobs: Outcome logits
     :param temperature: Sampling temperature. 1.0 follows the given distribution, 0.0 returns the maximum probability element.
     :return: The index of the sampled element.
    """
    if temperature == 0.0:
             return lnprobs.argmax()

    p = F.softmax(lnprobs / temperature, dim=0)
    cd = dist.Categorical(p)
    return cd.sample()


In [None]:
def predict(dataset, model, seq, temperature=1.0, max_length=20):
    """
    :param dataset: need i2w and w2i
    :param model: the model we sample from
    :param seq: the sequence of tokens we want to complete
    :param max_length: we stop if we reach an end token, or after max_length tokens
    :return: the generated sequence of tokens
    """
    model.eval()
    pred = []
    for i in range(0, max_length):
        x = torch.tensor([[dataset.w2i[i] for w in seq[i:]]])
        y = model.forward(x)
        last_token_logits = y[0][-1]
        j = sample(last_token_logits, temperature)
        pred.append(seq.dataset.i2w[j])
        if seq.dataset.i2w[j] == '.end':
            return pred
    return pred

In [None]:
device = torch.device('mps' if torch.has_mps else 'cpu')
net.to(device)
print(f"Using {device} device")

In [None]:
def train(model, dataset, epochs=3, batch_size=128, learning_rate=0.001):

    # Loss function:
    # check whether the loss function applies softmax or whether we need to do it manually
    # loss function = cross entropy loss at every point in time, read doc to figure out
    # how to shuffle dimensions properly
    criterion = nn.CrossEntropyLoss()

    # Optimizer:
    optimizer = optim.Adam(net.parameters(), lr=0.001)

    metrics = {
        'loss_history': [],
        'loss_train': []
    }

    # Training loop
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        start_time = time()
        running_loss = 0.0
        total_loss = 0.0

        size = len(dataloader.dataset)
        model.train()
        for batch, (X, y) in enumerate(batch_generator(x_train)):
            X, y = X.to(device), y.to(device)

            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            finish_time = time()

            # print statistics
            running_loss += loss.item()
            total_loss += loss.item()
            if i % 20 == 19:    # print every 2000 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 20:.3f} time: {finish_time - start_time:.3f}')
                metrics['loss_history'].append(running_loss / 20)
                running_loss = 0.0
        metrics['loss_train'].append(total_loss / len(x_train))


        print("Predicting:")
        model.eval()
        seq = ['.start', 'a', 'b']
        predict(model, dataset, seq, max_length=20)