In [None]:
#!pip install torch torchtext sentencepiece datasets
# Try opus books dataset for translation - https://huggingface.co/datasets/opus_books

In [8]:
import numpy as np
import torch
from torch import nn
import functorch
import sys
import os
import math
sys.path.append(os.path.abspath("../../data"))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
SP_VOCAB_SIZE = 1000
CHUNK_LENGTH = 10
Y_CHUNK_LENGTH = 5
TRAIN_SIZE = 500

In [9]:
from text_data import CNNDatasetWrapper

class Wrapper(CNNDatasetWrapper):
    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]
    x_length = CHUNK_LENGTH
    target_length = Y_CHUNK_LENGTH

wrapper = Wrapper(SP_VOCAB_SIZE, DEVICE)

datasets = wrapper.generate_datasets(BATCH_SIZE)
train = datasets["train"]
valid = datasets["validation"]
test = datasets["test"]

Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
100%|██████████| 3/3 [00:00<00:00, 158.95it/s]
sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn_dailymail --vocab_size=1000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: cnn_dailymail
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_pie

In [60]:
class EncoderDecoder(nn.Module):
    def __init__(self, in_sequence_len, out_sequence_len, embedding_len, hidden_units=512, layers=2):
        super(EncoderDecoder, self).__init__()
        self.in_sequence_len = in_sequence_len
        self.out_sequence_len = out_sequence_len
        self.hidden_units = hidden_units
        self.embedding_len = embedding_len
        self.layers = layers

        self.embedding = nn.Embedding(embedding_len, hidden_units)
        self.encoder = nn.GRU(input_size=hidden_units, hidden_size=hidden_units, num_layers=layers)
        self.decoder = nn.GRU(input_size=hidden_units * 2, hidden_size=hidden_units, num_layers=layers)
        self.linear = nn.Linear(in_features=hidden_units, out_features=embedding_len)

    def forward(self, x, y):
        batch_size = x.shape[0]
        # Move batch to the second dimension, so sequence comes first
        y = y.swapaxes(0,1)
        # Embed the input sequence to reduce dimensionality
        embedded = self.embedding(x).swapaxes(0,1)

        # Encode the input sequence
        # Both tensors will have sequence then batch
        initial_hidden = torch.zeros((1 * self.layers, batch_size, self.hidden_units), device=DEVICE)
        enc_output, enc_hidden = self.encoder(embedded, initial_hidden)

        # Decode to the output sequence
        # Pass in context
        context = enc_output
        # Both tensors will have the first dimension be the sequence
        dec_hiddens = torch.zeros(1, batch_size, self.hidden_units, device=DEVICE)
        dec_outputs = torch.zeros((1, batch_size, self.hidden_units), device=DEVICE)
        for j in range(self.out_sequence_len):
            # Use either the actual previous y (from the input), or the generated y if the input sequence is shorter than the generation steps.
            prev_y = y[j,:,:] if y.shape[0] > j else torch.softmax(dec_outputs[j,:,:], dim=1)
            # Run embedding over previous y state
            prev_y = prev_y.argmax(dim=1).int()
            prev_y = self.embedding(prev_y)

            output, hidden = self.decoder(torch.cat((prev_y, context[-1]), dim=1).unsqueeze(0), dec_hiddens[j,].unsqueeze(0),)
            dec_hiddens = torch.cat((dec_hiddens, hidden), dim=0)
            dec_outputs = torch.cat((dec_outputs, output), dim=0)

        # Move batch back to axis 0
        out_hiddens = dec_hiddens[1:,:,:].swapaxes(0,1)
        out_output = self.linear(dec_outputs[1:,:,:].swapaxes(0,1))
        return out_output, out_hiddens

def generate(sequence, target, wrapper):
    pred, _ = model(sequence, target[:,0,:].unsqueeze(1))
    prompts = wrapper.decode_batch(sequence.cpu())
    texts = wrapper.decode_batch(torch.argmax(pred, dim=2).cpu())
    correct_texts = wrapper.decode_batch(torch.argmax(target, dim=2).cpu())

    displays = []
    for p, t, ct in zip(prompts, texts, correct_texts):
        displays.append(f"{p} | {ct} | {t}")
    return displays

In [61]:
from tqdm.auto import tqdm
model = EncoderDecoder(wrapper.x_length, wrapper.target_length, hidden_units=512, layers=1, embedding_len=wrapper.vocab_size).to(DEVICE)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
EPOCHS = 1000
DISPLAY_BATCHES = 8

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    match_pct = 0
    for batch, (sequence, target, prev_target) in tqdm(enumerate(train)):
        optimizer.zero_grad()
        forced_target = prev_target
        # Alternate use of teacher forcing vs feeding back own inputs
        #if np.random.randint(2) == 0:
        #    forced_target = prev_target[:,0,:].unsqueeze(1)
        pred, hidden = model(sequence, forced_target)
        loss = loss_fn(pred, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        match_pct += torch.sum(torch.argmax(target, 2) == torch.argmax(pred, 2)) / (Y_CHUNK_LENGTH * BATCH_SIZE)

    with torch.no_grad():
        if epoch % 10 == 0:
            # Show text generated from training prompt as an example
            # Don't feed in all of the train y sequences, just the first token
            # The other y tokens will be predicted by the model and fed back in
            sents = generate(sequence[:DISPLAY_BATCHES], prev_target[:DISPLAY_BATCHES], wrapper)
            for sent in sents:
                print(sent)
            """
            # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
            valid_loss = 0
            for batch, (sequence, target, prev_target) in enumerate(valid):
                # Only feed in the first token of the actual target
                pred, hidden = model(sequence, prev_target[:,0,:].unsqueeze(1))
                loss = loss_fn(pred, target)
                valid_loss += loss.item()
            print(f"Epoch {epoch} train loss: {train_loss} valid loss: {valid_loss}")
            """
            print(f"Epoch {epoch} train loss: {train_loss} match_pct: {match_pct / len(train)}")

16it [00:01, 12.82it/s]


Top seed Andy Roddick | reaches last | ver becomeci summ summ
The central part of Iowa | is the state' | verver refugee wantR
Humanitarian group says Shi | ite, S | ver cancer Brit schoolac
Sharon Long is forens | ic artist who | verver refugeedridless
NEW: Barack Obama makes f | un of Hillary | verver refugee Netanyahuless
Francesco Totti thre | atens to | verndan refugee Netanyahuoli
Pakistan rejects fear | s its nu | verver children seekd
NEW: British counter-terrorism | experts | ververndanarges Du
Epoch 0 train loss: 0.12685480900108814 match_pct: 0.0007812500116415322


16it [00:01, 13.01it/s]
16it [00:01, 12.95it/s]
16it [00:01, 12.76it/s]
16it [00:01, 12.92it/s]
16it [00:01, 12.59it/s]
16it [00:01, 12.70it/s]
16it [00:01, 12.75it/s]
16it [00:01, 12.90it/s]
16it [00:01, 12.90it/s]
16it [00:01, 12.75it/s]


Police said man lost his balance | on an es | li presidentbb to
I-Reporters share | tales of | Lockhartiadss
Atlanta surpasses LA | , Phil | liutebybyrik
NEW: British counter-terrorism | experts | li weapon J Ans
You can enjoy the same | posh place | Lockhartbdss
Glasgow derby betwee | n Celtic | li turny Put A
Orange jailed in Alabama | in 196 | record It disds
World No. 3 Novak Djo | kovic | record record dis twoed
Epoch 10 train loss: 0.06087294267490506 match_pct: 0.2660156190395355


16it [00:01, 12.89it/s]
16it [00:01, 12.84it/s]
16it [00:01, 12.58it/s]
16it [00:01, 12.80it/s]
16it [00:01, 12.76it/s]
16it [00:01, 12.90it/s]
6it [00:00, 13.16it/s]