In [22]:
import torch
from torch import nn
import sys
import os
import math
sys.path.append(os.path.abspath("../../data"))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1
SP_VOCAB_SIZE = 1000
TRAIN_SIZE = 500

In [23]:
from text_data import CNNDatasetWrapper

class Wrapper(CNNDatasetWrapper):
    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]
    x_length = 15
    target_length = 15

wrapper = Wrapper(SP_VOCAB_SIZE, DEVICE)

datasets = wrapper.generate_datasets(BATCH_SIZE)
train = datasets["train"]
valid = datasets["validation"]

Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
100%|██████████| 3/3 [00:00<00:00, 140.63it/s]
sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn_dailymail --vocab_size=1000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: cnn_dailymail
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_pie

In [24]:
class Encoder(nn.Module):
    def __init__(self, input_units, hidden_units, output_units):
        super(Encoder, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units

        k = math.sqrt(1/hidden_units)
        self.input_weight = nn.Parameter(torch.rand(input_units, hidden_units) * 2 * k - k)

        self.hidden_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.hidden_bias = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)

        self.output_weight = nn.Parameter(torch.rand(hidden_units, output_units) * 2 * k - k)
        self.output_bias = nn.Parameter(torch.rand(1, output_units) * 2 * k - k)

    def forward(self, x, prev_hidden):
        # Compute the regular RNN forward pass
        # Input times weights
        input_x = x @ self.input_weight
        # Sum input with previous hidden state, and add nonlinearity
        # Tanh prevents gradients exploding
        hidden_x = torch.tanh(input_x + prev_hidden @ self.hidden_weight + self.hidden_bias)

        # Compute output vector
        output_y = hidden_x @ self.output_weight + self.output_bias
        return hidden_x, output_y

In [25]:
class Decoder(nn.Module):
    def __init__(self, input_units, hidden_units, output_units=wrapper.vocab_size):
        super(Decoder, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.output_units = output_units

        k = math.sqrt(1/hidden_units)
        self.hidden_attention_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.context_attention_weight = nn.Parameter(torch.rand(input_units, hidden_units) * 2 * k - k)
        self.attention_weight = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)

        self.context_hidden_weight = nn.Parameter(torch.rand(hidden_units * 2, hidden_units) * 2 * k - k)
        self.hidden_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.hidden_bias = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)

        self.output_weight = nn.Parameter(torch.rand(hidden_units, output_units) * 2 * k - k)
        self.output_bias = nn.Parameter(torch.rand(1, output_units) * 2 * k - k)

    def forward(self, prev_y, prev_hidden, context):
        # Compute attention between the encoder hidden states and the previous decoder hidden state
        context_attns = context @ self.context_attention_weight
        cross = torch.tanh(context_attns + prev_hidden @ self.hidden_attention_weight)
        attention = cross @ self.attention_weight.T

        # Compute probability for each encoder hidden state, and use it to weight and sum the states
        probs = torch.softmax(attention, 1).reshape(context.shape[0])
        positional_context = torch.sum(torch.diag(probs) @ context, dim=0).reshape(1, self.input_units)

        # Compute a regular rnn.  Cat the context vector and the previous y state.
        input_x = torch.cat([positional_context, prev_y], dim=1) @ self.context_hidden_weight
        hidden_x = torch.tanh(input_x + prev_hidden @ self.hidden_weight + self.hidden_bias)

        output_y = hidden_x @ self.output_weight + self.output_bias
        return hidden_x, output_y

In [26]:
class Network(nn.Module):
    def __init__(self, in_sequence_len, out_sequence_len, hidden_units=512, embedding_len=wrapper.vocab_size):
        super(Network, self).__init__()
        self.in_sequence_len = in_sequence_len
        self.out_sequence_len = out_sequence_len
        self.hidden_units = hidden_units
        self.embedding_len = embedding_len

        self.embedding = nn.Embedding(embedding_len, hidden_units)
        self.encoder = Encoder(input_units=hidden_units, hidden_units=hidden_units, output_units=embedding_len)

        self.decoder = Decoder(input_units=hidden_units, hidden_units=hidden_units, output_units=embedding_len)

    def forward(self, x, y):
        embedded = self.embedding(x)

        # Encode the input sequence
        enc_hiddens = torch.zeros((1, self.hidden_units))
        enc_outputs = torch.zeros((1, self.embedding_len))
        for j in range(self.in_sequence_len):
            hidden, output = self.encoder(embedded[j,:], enc_hiddens[j])
            enc_hiddens = torch.cat((enc_hiddens, hidden), dim=0)
            enc_outputs = torch.cat((enc_outputs, output), dim=0)

        # Decode to the output sequence
        # Pass in context
        context = enc_hiddens[1:,:]
        dec_hiddens = torch.zeros(1, self.hidden_units)
        dec_outputs = torch.zeros((1, self.embedding_len))
        for j in range(self.out_sequence_len):
            # Use either the actual previous y (from the input), or the generated y if the input sequence is shorter than the generation steps.
            if y.shape[0] > j:
                prev_y = y[j]
            else:
                prev_y = dec_outputs[j,:]
                prev_y = prev_y.argmax(dim=1).int()

            prev_y = prev_y.unsqueeze(0)
            prev_y = self.embedding(prev_y)
            hidden, output = self.decoder(prev_y, dec_hiddens[j,:], context)
            dec_hiddens = torch.cat((dec_hiddens, hidden), dim=0)
            dec_outputs = torch.cat((dec_outputs, output), dim=0)

        return dec_hiddens[1:], dec_outputs[1:]

In [27]:
from tqdm.auto import tqdm

device = torch.device("cpu")
model = Network(wrapper.x_length, wrapper.y_length, hidden_units=512).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=wrapper.pad_token)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

In [28]:
EPOCHS = 1000
for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    for batch, (sequence, target, prev_target) in tqdm(enumerate(train)):
        optimizer.zero_grad()
        hidden, pred = model(sequence[0,:], prev_target[0,:])

        pred = pred.reshape(wrapper.y_length, wrapper.vocab_size)
        target = target.reshape(wrapper.y_length)
        loss = loss_fn(pred, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    if epoch % 10 == 0:
        # Compute validation loss.  Unless you have a lot of training data, it won't be able to generalize.
        valid_loss = 0
        with torch.no_grad():
            for batch, (sequence, target, prev_target) in enumerate(valid):
                # Only feed in the first token of the actual target
                hidden, pred = model(sequence[0,:], prev_target[0,:])
                pred = pred.reshape(wrapper.y_length, wrapper.vocab_size)
                target = target.reshape(wrapper.y_length)
                loss = loss_fn(pred, target)
                valid_loss += loss.item()

        print(f"Epoch {epoch} train loss: {train_loss / len(train)} valid loss: {valid_loss / len(valid)}")

448it [00:24, 18.45it/s]


Epoch 0 train loss: 5.999296477862766 valid loss: 5.748046959147734


448it [00:24, 18.51it/s]
448it [00:24, 18.50it/s]
448it [00:24, 18.47it/s]
448it [00:24, 18.28it/s]
448it [00:24, 18.25it/s]
448it [00:24, 18.42it/s]
448it [00:24, 18.51it/s]
448it [00:24, 18.62it/s]
448it [00:24, 18.57it/s]
448it [00:24, 18.28it/s]


Epoch 10 train loss: 3.1196700052491257 valid loss: 5.839749308193431


448it [00:24, 18.42it/s]
448it [00:23, 18.88it/s]
448it [00:22, 19.82it/s]
448it [00:22, 19.56it/s]
448it [00:22, 19.90it/s]
448it [00:22, 20.09it/s]
448it [00:22, 19.84it/s]
448it [00:23, 19.03it/s]
448it [00:27, 16.31it/s]
448it [00:27, 16.18it/s]


Epoch 20 train loss: 1.3020785356472646 valid loss: 6.145435006010766


448it [00:27, 16.51it/s]
448it [00:27, 16.55it/s]
448it [00:27, 16.25it/s]
448it [00:26, 16.72it/s]
448it [00:27, 16.35it/s]
448it [00:27, 16.28it/s]
448it [00:27, 16.28it/s]
448it [00:27, 16.25it/s]
448it [00:27, 16.27it/s]
448it [00:27, 16.30it/s]


Epoch 30 train loss: 0.5017049135806572 valid loss: 6.383245907577813


448it [00:27, 16.41it/s]
448it [00:27, 16.35it/s]
448it [00:27, 16.34it/s]
448it [00:27, 16.51it/s]
448it [00:27, 16.52it/s]
448it [00:27, 16.58it/s]
448it [00:27, 16.36it/s]
448it [00:27, 16.15it/s]
448it [00:27, 16.34it/s]
448it [00:26, 16.65it/s]


Epoch 40 train loss: 0.24164675797302543 valid loss: 6.541052986593807


448it [00:26, 16.65it/s]
448it [00:27, 16.53it/s]
448it [00:27, 16.49it/s]
448it [00:27, 16.30it/s]
448it [00:27, 16.39it/s]
448it [00:27, 16.34it/s]
448it [00:27, 16.30it/s]
448it [00:27, 16.19it/s]
448it [00:27, 16.23it/s]
448it [00:27, 16.18it/s]


Epoch 50 train loss: 0.14702079623072808 valid loss: 6.659904620226691


448it [00:27, 16.32it/s]
448it [00:27, 16.31it/s]
448it [00:27, 16.09it/s]
448it [00:27, 16.49it/s]
448it [00:27, 16.45it/s]
448it [00:27, 16.35it/s]
448it [00:27, 16.01it/s]
448it [00:28, 15.72it/s]
448it [00:28, 15.88it/s]
448it [00:28, 15.58it/s]


Epoch 60 train loss: 0.10320704718469642 valid loss: 6.745182869481106


448it [00:25, 17.71it/s]
448it [00:22, 19.52it/s]
448it [00:23, 19.48it/s]
448it [00:23, 19.46it/s]
448it [00:23, 19.32it/s]
448it [00:23, 19.27it/s]
448it [00:23, 19.39it/s]
448it [00:23, 19.28it/s]
448it [00:23, 19.35it/s]
448it [00:23, 19.36it/s]


Epoch 70 train loss: 0.07870655508512366 valid loss: 6.816830148883894


448it [00:23, 19.32it/s]
448it [00:23, 19.20it/s]
448it [00:23, 19.17it/s]
448it [00:23, 18.93it/s]
448it [00:23, 18.89it/s]
448it [00:23, 18.97it/s]
448it [00:23, 18.81it/s]
448it [00:24, 18.62it/s]
448it [00:23, 18.75it/s]
260it [00:13, 18.95it/s]


KeyboardInterrupt: 