In [1]:
import numpy as np
import torch
from torch import nn
import functorch
import sys
import os
import math
sys.path.append(os.path.abspath("../../data"))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 8
SP_VOCAB_SIZE = 1000
VOCAB_SIZE=1002
CHUNK_LENGTH = 12

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from cnn_data import generate_data, decode_batch

train, valid, test = generate_data(train_length=1000, valid_length=250, vocab_size=SP_VOCAB_SIZE, batch_size=BATCH_SIZE, chunk_length=CHUNK_LENGTH, device=DEVICE)

Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
100%|██████████| 3/3 [00:00<00:00, 155.77it/s]
sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn --vocab_size=1000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: cnn
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_

In [46]:
class GRUEncoder(nn.Module):
    def __init__(self, input_units, hidden_units, output_units):
        super(GRUEncoder, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units

        k = 1/math.sqrt(hidden_units)
        self.input_weights = nn.Parameter(torch.rand(3, input_units, hidden_units) * 2 * k - k)

        self.hidden_weights = nn.Parameter(torch.rand(3, hidden_units, hidden_units) * 2 * k - k)
        self.hidden_biases = nn.Parameter(torch.rand(2, 1, hidden_units) * 2 * k - k)

        self.output_weight = nn.Parameter(torch.rand(hidden_units, output_units) * 2 * k - k)
        self.output_bias = nn.Parameter(torch.rand(1, output_units) * 2 * k - k)

    def forward(self, x, prev_hidden):
        # Compute the regular RNN forward pass

        # Compute update and reset gates for GRU
        update_gate = torch.sigmoid(x @ self.input_weights[0,] + prev_hidden @ self.hidden_weights[0,] + self.hidden_biases[0,])
        reset_gate = torch.sigmoid(x @ self.input_weights[1,] + prev_hidden @ self.hidden_weights[1,] + self.hidden_biases[1,])

        # This is a potential new state based on the reset gate
        proposed_state = torch.tanh(x @ self.input_weights[2,] + (prev_hidden * reset_gate) @ self.hidden_weights[2,])
        hidden_x = torch.tanh((1-update_gate) * prev_hidden + update_gate * proposed_state)
        # Compute output vector
        output_y = hidden_x @ self.output_weight + self.output_bias
        return hidden_x, output_y

In [47]:
class GRUDecoder(nn.Module):
    def __init__(self, input_units, context_units, hidden_units, output_units):
        super(GRUDecoder, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.output_units = output_units

        k = 1/math.sqrt(hidden_units)

        self.input_weights = nn.Parameter(torch.rand(3, input_units, hidden_units) * 2 * k - k)
        self.hidden_weights = nn.Parameter(torch.rand(3, hidden_units, hidden_units) * 2 * k - k)

        self.context_weights = nn.Parameter(torch.rand(3, context_units, hidden_units) * 2 * k - k)

        self.hidden_biases = nn.Parameter(torch.rand(2, 1, hidden_units) * 2 * k - k)

        self.hidden_attention_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.context_attention_weight = nn.Parameter(torch.rand(context_units, hidden_units) * 2 * k - k)
        self.attention_weight = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)

        self.output_weight = nn.Parameter(torch.rand(hidden_units, output_units) * 2 * k - k)
        self.output_bias = nn.Parameter(torch.rand(1, output_units) * 2 * k - k)

        self.batched_diag = functorch.vmap(torch.diag)

    def forward(self, prev_y, prev_hidden, context):
        # Compute attention between the encoder hidden states and the previous decoder hidden state

        # Swap batch and sequence
        batch_size = context.shape[1]

        # Swap axes so the first dimension of context_attn is batch
        context_attn = torch.bmm(context.swapaxes(0,1), self.context_attention_weight.unsqueeze(0).expand(batch_size,-1,-1))
        # Swap back since prev_hidden is by batch.  This makes the first dim of cross sequence
        cross = torch.tanh(context_attn.swapaxes(0,1) + prev_hidden @ self.hidden_attention_weight)
        # This will be of dimension batch, sequence_length, 1
        attention = torch.bmm(cross.swapaxes(0,1), self.attention_weight.T.unsqueeze(0).expand(batch_size, -1, -1))
        # Drop the last singleton dimension
        attention = attention.squeeze(2)
        # Softmax the predictions
        probs = torch.softmax(attention, 0)
        diagonalized_probs = self.batched_diag(probs)
        positional_contexts = torch.sum(torch.bmm(diagonalized_probs, context.swapaxes(0,1)), dim=1).reshape(batch_size, self.input_units)

        # Compute GRU update and reset gates + proposed state and final hidden state
        update_gate = torch.sigmoid(prev_y @ self.input_weights[0,] + prev_hidden @ self.hidden_weights[0,] + self.hidden_biases[0,] + positional_contexts @ self.context_weights[0,])
        reset_gate = torch.sigmoid(prev_y @ self.input_weights[1,] + prev_hidden @ self.hidden_weights[1,] + self.hidden_biases[1,] + positional_contexts @ self.context_weights[1,])

        proposed_state = torch.tanh(prev_y @ self.input_weights[2,] + (prev_hidden * reset_gate) @ self.hidden_weights[2,] + positional_contexts @ self.context_weights[2,])

        hidden_x = torch.tanh((1-update_gate) * prev_hidden + update_gate * proposed_state)
        # Compute output based on hidden state
        output_y = hidden_x @ self.output_weight + self.output_bias
        return hidden_x, output_y

In [48]:
class Network(nn.Module):
    def __init__(self, in_sequence_len, out_sequence_len, hidden_units=512, embedding_len=VOCAB_SIZE):
        super(Network, self).__init__()
        self.in_sequence_len = in_sequence_len
        self.out_sequence_len = out_sequence_len
        self.hidden_units = hidden_units
        self.embedding_len = embedding_len

        self.embedding = nn.Embedding(embedding_len, hidden_units)
        self.encoder = GRUEncoder(input_units=hidden_units, hidden_units=hidden_units, output_units=embedding_len)
        self.decoder = GRUDecoder(input_units=hidden_units, context_units=hidden_units, hidden_units=hidden_units, output_units=embedding_len)

    def forward(self, x, y):
        batch_size = x.shape[0]
        # Move batch to the second dimension, so sequence comes first
        y = y.swapaxes(0,1)
        # Embed the input sequence to reduce dimensionality
        embedded = self.embedding(x).swapaxes(0,1)

        # Encode the input sequence
        # Both tensors will have sequence then batch
        enc_hiddens = torch.zeros((1, batch_size, self.hidden_units), device=DEVICE)
        enc_outputs = torch.zeros((1, batch_size, self.embedding_len), device=DEVICE)
        for j in range(self.in_sequence_len):
            hidden, output = self.encoder(embedded[j,:,:], enc_hiddens[j,:,:])
            # Add first sequence axis
            hidden = hidden.unsqueeze(0)
            output = output.unsqueeze(0)
            enc_hiddens = torch.cat((enc_hiddens, hidden), dim=0)
            enc_outputs = torch.cat((enc_outputs, output), dim=0)

        # Decode to the output sequence
        # Pass in context
        context = enc_hiddens[1:,:,:]
        # Both tensors will have the first dimension be the sequence
        dec_hiddens = torch.zeros(1, batch_size, self.hidden_units, device=DEVICE)
        dec_outputs = torch.zeros((1, batch_size, self.embedding_len), device=DEVICE)
        for j in range(self.out_sequence_len):
            # Use either the actual previous y (from the input), or the generated y if the input sequence is shorter than the generation steps.
            prev_y = y[j,:,:] if y.shape[0] > j else torch.softmax(dec_outputs[j,:,:], dim=1)
            # Run embedding over previous y state
            prev_y = prev_y.argmax(dim=1).int()
            prev_y = self.embedding(prev_y)
            hidden, output = self.decoder(prev_y, dec_hiddens[j,:,:], context)
            # Add first sequence axis
            hidden = hidden.unsqueeze(0)
            output= output.unsqueeze(0)
            dec_hiddens = torch.cat((dec_hiddens, hidden), dim=0)
            dec_outputs = torch.cat((dec_outputs, output), dim=0)

        # Move batch back to axis 0
        out_hiddens = dec_hiddens[1:].swapaxes(0,1)
        out_output = dec_outputs[1:].swapaxes(0,1)
        return out_hiddens, out_output

def generate(sequence, target):
    _, pred = model(sequence, target[:,0,:].unsqueeze(1))
    prompts = decode_batch(sequence.cpu(), vocab_size=VOCAB_SIZE)
    texts = decode_batch(torch.argmax(pred, dim=2).cpu(), vocab_size=VOCAB_SIZE)
    correct_texts = decode_batch(torch.argmax(target, dim=2).cpu(), vocab_size=VOCAB_SIZE)

    displays = []
    for p, t, ct in zip(prompts, texts, correct_texts):
        displays.append(f"{p} | {ct} | {t}")
    return displays

In [53]:
from tqdm.auto import tqdm
model = Network(CHUNK_LENGTH, CHUNK_LENGTH, hidden_units=512).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [54]:
EPOCHS = 2000
DISPLAY_BATCHES = 8

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    for batch, (sequence, target) in tqdm(enumerate(train)):
        optimizer.zero_grad()
        forced_target = target
        # Alternate use of teacher forcing vs feeding back own inputs
        if np.random.randint(2) == 0:
            forced_target = target[:,0,:].unsqueeze(1)
        hidden, pred = model(sequence, forced_target)
        loss = loss_fn(pred, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    if epoch % 10 == 0:
        # Show text generated from training prompt as an example
        # Don't feed in all of the train y sequences, just the first token
        # The other y tokens will be predicted by the model and fed back in
        sents = generate(sequence[:DISPLAY_BATCHES], target[:DISPLAY_BATCHES])
        for sent in sents:
            print(sent)

        # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
        valid_loss = 0
        with torch.no_grad():
            for batch, (sequence, target) in enumerate(valid):
                # Only feed in the first token of the actual target
                hidden, pred = model(sequence, target[:,0,:].unsqueeze(1))
                loss = loss_fn(pred, target)
                valid_loss += loss.item()
        print(f"Epoch {epoch} train loss: {train_loss} valid loss: {valid_loss}")

125it [00:27,  4.54it/s]


NEW: Georgian president criticizes Gorbach | ev for "vindicating lies and | ------------
19 schoolgirls and two adult | s die in primary school dormitor | ssssssssssss
U.S. intelligence point | s to Pakistan agents involved in attack on Indian | ssssssssssss
Mother of murdered schoolboy Damilol | a Taylor dies of suspected heart | aaaaaaaaaaaa
NEW: Pope Benedict  |  ⁇ VI arrives in Washington for six | be be be be be be be be be be be be
Eight Florida teens to be tried as ad | ults in videotaped beating case | ulululululululululululul
Judge on Heather Mills: Le | vel of premarital wealth " | veveveveveveveveveveveve
President Harding's illegiti | mate daughter was conceived on couch | mmmmmmmmmmmm
Epoch 0 train loss: 208.55217321777343 valid loss: 237.77848327159882


125it [00:28,  4.36it/s]
125it [00:28,  4.42it/s]
125it [00:27,  4.47it/s]
125it [00:28,  4.46it/s]
125it [00:28,  4.44it/s]
125it [00:28,  4.43it/s]
125it [00:29,  4.29it/s]
125it [00:29,  4.18it/s]
125it [00:27,  4.49it/s]
125it [00:27,  4.48it/s]


NEW: Georgian president criticizes Gorbach | ev for "vindicating lies and | evevevevevevevevevevevev
19 schoolgirls and two adult | s die in primary school dormitor | ssssssssssss
U.S. intelligence point | s to Pakistan agents involved in attack on Indian | ssssssssssss
Mother of murdered schoolboy Damilol | a Taylor dies of suspected heart | aaaaaaaaaaaa
NEW: Pope Benedict  |  ⁇ VI arrives in Washington for six |  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Eight Florida teens to be tried as ad | ults in videotaped beating case | ulululululululululululul
Judge on Heather Mills: Le | vel of premarital wealth " | veveveveveveveveveveveve
President Harding's illegiti | mate daughter was conceived on couch | mmmmmmmmmmmm
Epoch 10 train loss: 148.43473364257812 valid loss: 234.2144821882248


125it [00:27,  4.48it/s]
125it [00:28,  4.35it/s]
125it [00:27,  4.58it/s]
125it [00:27,  4.55it/s]
125it [00:28,  4.45it/s]
125it [00:27,  4.48it/s]
125it [00:27,  4.53it/s]
125it [00:27,  4.55it/s]
125it [00:27,  4.55it/s]
125it [00:27,  4.55it/s]


NEW: Georgian president criticizes Gorbach | ev for "vindicating lies and | evevevevevevevevevevevev
19 schoolgirls and two adult | s die in primary school dormitor | ssssssssssss
U.S. intelligence point | s to Pakistan agents involved in attack on Indian | ssssssssssss
Mother of murdered schoolboy Damilol | a Taylor dies of suspected heart | aaaaaaaaaaaa
NEW: Pope Benedict  |  ⁇ VI arrives in Washington for six |  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Eight Florida teens to be tried as ad | ults in videotaped beating case | ulululululululululululul
Judge on Heather Mills: Le | vel of premarital wealth " | veveveveveveveveveveveve
President Harding's illegiti | mate daughter was conceived on couch | mmmmmmmmmmmm
Epoch 20 train loss: 115.63843781280518 valid loss: 229.0219978094101


125it [00:27,  4.54it/s]
125it [00:27,  4.56it/s]
125it [00:27,  4.57it/s]
125it [00:27,  4.49it/s]
125it [00:28,  4.45it/s]
125it [00:27,  4.51it/s]
125it [00:27,  4.53it/s]
125it [00:28,  4.46it/s]
125it [00:28,  4.41it/s]
125it [00:27,  4.51it/s]


NEW: Georgian president criticizes Gorbach | ev for "vindicating lies and | evevevevevevevevevevevev
19 schoolgirls and two adult | s die in primary school dormitor | ssssssssssss
U.S. intelligence point | s to Pakistan agents involved in attack on Indian | ssssssssssss
Mother of murdered schoolboy Damilol | a Taylor dies of suspected heart | aaaaaaaaaaaa
NEW: Pope Benedict  |  ⁇ VI arrives in Washington for six |  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ 
Eight Florida teens to be tried as ad | ults in videotaped beating case | ulululululululululululul
Judge on Heather Mills: Le | vel of premarital wealth " | veveveveveveveveveveveve
President Harding's illegiti | mate daughter was conceived on couch | mmmmmmmmmmmm
Epoch 30 train loss: 110.46447555923461 valid loss: 226.1087828874588


125it [00:27,  4.49it/s]
125it [00:27,  4.51it/s]
125it [00:28,  4.44it/s]
125it [00:27,  4.54it/s]
125it [00:27,  4.50it/s]
36it [00:08,  4.33it/s]


KeyboardInterrupt: 