In [1]:
#!pip install torch torchtext sentencepiece datasets
# Try opus books dataset for translation - https://huggingface.co/datasets/opus_books

In [2]:
import numpy as np
import torch
from torch import nn
import functorch
import sys
import os
import math
sys.path.append(os.path.abspath("../../data"))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
SP_VOCAB_SIZE = 1000
TRAIN_SIZE = 500

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from text_data import CNNDatasetWrapper

class Wrapper(CNNDatasetWrapper):
    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]
    x_length = 15
    target_length = 15

wrapper = Wrapper(SP_VOCAB_SIZE, DEVICE)

datasets = wrapper.generate_datasets(BATCH_SIZE)
train = datasets["train"]
valid = datasets["validation"]

Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
100%|██████████| 3/3 [00:00<00:00, 152.11it/s]
sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn_dailymail --vocab_size=1000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: cnn_dailymail
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_pie

In [4]:
class GRUCell(nn.Module):
    def __init__(self, input_units, hidden_units, output_units):
        super(GRUCell, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.output_units = output_units

        k = math.sqrt(1/hidden_units)
        self.input_weights = nn.Parameter(torch.rand(3, input_units, hidden_units) * 2 * k - k)
        self.input_biases = nn.Parameter(torch.rand(3, 1, hidden_units) * 2 * k - k)

        self.hidden_weights = nn.Parameter(torch.rand(3, hidden_units, hidden_units) * 2 * k - k)
        self.hidden_biases = nn.Parameter(torch.rand(3, 1, hidden_units) * 2 * k - k)

    def forward(self, x, prev_hidden):
        # Compute the regular RNN forward pass
        # Compute update and reset gates for GRU
        reset_gate = torch.sigmoid(x @ self.input_weights[0,] + self.input_biases[0,] + prev_hidden @ self.hidden_weights[0,] + self.hidden_biases[0,])
        update_gate = torch.sigmoid(x @ self.input_weights[1,] + self.input_biases[1,] + prev_hidden @ self.hidden_weights[1,] + self.hidden_biases[1,])
        new_gate = torch.tanh(x @ self.input_weights[2,] + self.input_biases[2,] + torch.mul(reset_gate, prev_hidden @ self.hidden_weights[2,] + self.hidden_biases[2,]))

        hidden_x = torch.mul((1 - update_gate), new_gate) + torch.mul(update_gate, new_gate)
        return hidden_x

In [5]:
class EncoderDecoder(nn.Module):
    def __init__(self, in_sequence_len, out_sequence_len, embedding_len, hidden_units=512, layers=2):
        super(EncoderDecoder, self).__init__()
        self.in_sequence_len = in_sequence_len
        self.out_sequence_len = out_sequence_len
        self.hidden_units = hidden_units
        self.embedding_len = embedding_len
        self.layers = layers

        self.embedding = nn.Embedding(embedding_len, hidden_units)
        self.encoders = nn.ModuleList([GRUCell(input_units=hidden_units, hidden_units=hidden_units, output_units=hidden_units) for _ in range(layers)])
        self.decoders = nn.ModuleList([GRUCell(input_units=hidden_units * 2, hidden_units=hidden_units, output_units=hidden_units) for _ in range(layers)])

        self.linear = nn.Linear(in_features=hidden_units, out_features=embedding_len)

        k = math.sqrt(1/hidden_units)
        self.hidden_attention_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.context_attention_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.attention_weight = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)
        self.batched_diag = functorch.vmap(torch.diag)

    def attention(self, context, prev_hidden, batch_size):
        # Swap axes so the first dimension of context_attn is batch
        context_attn = torch.bmm(context.swapaxes(0,1), self.context_attention_weight.unsqueeze(0).expand(batch_size,-1,-1))
        # Swap back since prev_hidden is by batch.  This makes the first dim of cross sequence
        cross = torch.tanh(context_attn.swapaxes(0,1) + prev_hidden @ self.hidden_attention_weight)
        # This will be of dimension batch, sequence_length, 1
        attention = torch.bmm(cross.swapaxes(0,1), self.attention_weight.T.unsqueeze(0).expand(batch_size, -1, -1))
        # Drop the last singleton dimension
        attention = attention.squeeze(2)
        # Softmax the predictions across each batch
        probs = torch.softmax(attention, 1)
        diagonalized_probs = self.batched_diag(probs)
        positional_contexts = torch.sum(torch.bmm(diagonalized_probs, context.swapaxes(0,1)), dim=1).reshape(batch_size, self.hidden_units)
        return positional_contexts

    def forward(self, x, y):
        batch_size = x.shape[0]
        # Move batch to the second dimension, so sequence comes first
        y = y.swapaxes(0,1)
        # Embed the input sequence to reduce dimensionality
        embedded = self.embedding(x).swapaxes(0,1)

        # Encode the input sequence
        # Both tensors will have sequence then batch
        enc_hiddens = torch.zeros((1, self.layers, batch_size, self.hidden_units), device=DEVICE)
        for j in range(self.in_sequence_len):
            seq_enc_hiddens = embedded[j,:].unsqueeze(0)
            for i in range(self.layers):
                hidden = self.encoders[i](seq_enc_hiddens[i,], enc_hiddens[j,i])
                # Add first sequence axis
                hidden = hidden.unsqueeze(0)
                seq_enc_hiddens = torch.cat((seq_enc_hiddens, hidden), dim=0)

            enc_hiddens = torch.cat((enc_hiddens, seq_enc_hiddens[1:].unsqueeze(0)), dim=0)

        # Decode to the output sequence
        # Pass in context
        context = enc_hiddens[1:,-1,:,:]
        # Both tensors will have the first dimension be the sequence
        dec_hiddens = torch.zeros(1, self.layers, batch_size, self.hidden_units, device=DEVICE)
        outputs = torch.zeros(1, batch_size, self.embedding_len, device=DEVICE)
        for j in range(self.out_sequence_len):
            # Use either the actual previous y (from the input), or the generated y if the input sequence is shorter than the generation steps.
            if y.shape[0] > j:
                prev_y = y[j,:]
            else:
                prev_y = outputs[j,:,:]
                prev_y = prev_y.argmax(dim=1).int()

            # Run embedding over previous y state
            prev_y = self.embedding(prev_y)
            seq_dec_hiddens = prev_y.unsqueeze(0)
            for i in range(self.layers):
                positional_context = self.attention(context, dec_hiddens[j,i,], batch_size)
                hidden = self.decoders[i](torch.cat((seq_dec_hiddens[i,], positional_context), dim=1), dec_hiddens[j,i,],)
                # Add first sequence axis
                hidden = hidden.unsqueeze(0)
                seq_dec_hiddens = torch.cat((seq_dec_hiddens, hidden), dim=0)

            # Swap sequence and batch axes to apply linear transform, then swap back
            prev_output = self.linear(seq_dec_hiddens[-1].unsqueeze(0).swapaxes(0,1)).swapaxes(0,1)
            outputs = torch.cat((outputs, prev_output), dim=0)
            dec_hiddens = torch.cat((dec_hiddens, seq_dec_hiddens[1:].unsqueeze(0)), dim=0)

        # Move batch back to axis 0
        out_hiddens = dec_hiddens[1:,-1,:,:].swapaxes(0,1)
        out_output = outputs[1:,].swapaxes(0,1)
        return out_output, out_hiddens

def generate(sequence, prev_target, target, wrapper):
    pred, _ = model(sequence, prev_target[:,0].unsqueeze(1))
    prompts = wrapper.decode_batch(sequence.cpu())
    texts = wrapper.decode_batch(torch.argmax(pred, dim=2).cpu())
    correct_texts = wrapper.decode_batch(target.cpu())

    displays = []
    for p, t, ct in zip(prompts, texts, correct_texts):
        displays.append(f"{p} | {ct} | {t}")
    return displays

In [6]:
from tqdm.auto import tqdm
model = EncoderDecoder(wrapper.x_length, wrapper.y_length, hidden_units=512, layers=1, embedding_len=wrapper.vocab_size).to(DEVICE)
loss_fn = nn.CrossEntropyLoss(ignore_index=wrapper.pad_token)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
EPOCHS = 1000
DISPLAY_BATCHES = 8

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    match_pct = 0
    for batch, (sequence, target, prev_target) in tqdm(enumerate(train)):
        optimizer.zero_grad()
        forced_target = prev_target
        # Alternate use of teacher forcing vs feeding back own inputs
        if np.random.randint(2) == 0:
            forced_target = prev_target[:,0].unsqueeze(1)
        pred, hidden = model(sequence, forced_target)

        # Need to reshape pred to be batch * sequence, embedding_len to be compatible
        # Similar reshape with target to be batch * sequence vector of class indices
        loss = loss_fn(pred.reshape(-1, pred.shape[-1]), target.view(-1))
        loss.backward()
        optimizer.step()
        train_loss += loss.item() / BATCH_SIZE
        match_pct += torch.sum(target == torch.argmax(pred, 2)) / (wrapper.y_length * BATCH_SIZE)

    with torch.no_grad():
        if epoch % 10 == 0:
            # Show text generated from training prompt as an example
            # Don't feed in all of the train y sequences, just the first token
            # The other y tokens will be predicted by the model and fed back in
            sents = generate(sequence[:DISPLAY_BATCHES], prev_target[:DISPLAY_BATCHES], target[:DISPLAY_BATCHES], wrapper)
            for sent in sents:
                print(sent)

            # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
            valid_loss = 0
            for batch, (sequence, target, prev_target) in enumerate(valid):
                # Only feed in the first token of the actual target
                pred, hidden = model(sequence, prev_target[:,0].unsqueeze(1))
                loss = loss_fn(pred.reshape(-1, pred.shape[-1]), target.view(-1))
                valid_loss += loss.item() / BATCH_SIZE

            print(f"Epoch {epoch} train loss: {train_loss / len(train)} match_pct: {match_pct / len(train)} valid loss: {valid_loss / len(valid)}")

14it [00:05,  2.58it/s]


Arrests made after ATF agents set | up a sting operation . Affidavit: AT | 
NEW: New video proves Bhutto shot, widow | er says . Doctors claim Pakistani police prevented | 
This month Art of Life looks at motor | bike, planes, DJs and Rock id | 
NEW: Bridge reopens Friday morn | ing after highway engineers give OK . Four killed | 
Moore criticized a report Gupta did on | CNN Monday on "Sicko" Gupta's report question | i upmergencys underlso protest Argentinlesgain helpScause policyra women
NEW: British counter-terrorism expert | s re-inspect Bhutto's vehicle . NEW: Bhutto | 
NEW: 17-year-old alleged shooter appears in | Miami courtroom . NEW: Eric Rive | 
Rinko Kikuchi was Oscar- | nominated for her performance in the | 
Epoch 0 train loss: 0.21559224277734756 match_pct: 0.005859375 valid loss: 0.21120067685842514


14it [00:05,  2.72it/s]
14it [00:05,  2.69it/s]
14it [00:05,  2.56it/s]
14it [00:05,  2.70it/s]
14it [00:05,  2.76it/s]
14it [00:05,  2.72it/s]
14it [00:05,  2.67it/s]
14it [00:05,  2.69it/s]
14it [00:05,  2.64it/s]
14it [00:05,  2.56it/s]


Fire chief: "I think it's a miracle | that we haven't seen some serious injuries" Officials | ss
Stolen art can be lost for decades . S | oft targets like museums  | ssssssssssssssss
At least 10 people have died in torrential rain | s in Ecuador, officials say . Authorities say the | ss
Savers at leading UK mortgage bank | lined up to empty their accounts . N | ss
"Inspiring Impressionism" | looks at Old Masters, other influences on | ss
NEW: Woman says husband didn't show up | at a party on suspected date of killing . N | ss
Sheriff: Possible tornado caused | heavy damage in Prosperity, South Carolin | ss
President Bush says Tony Snow "will battle c | ancer and win" Job of press secre | sss
Epoch 10 train loss: 0.17134868140731538 match_pct: 0.0736607164144516 valid loss: 0.17988649755716324


14it [00:05,  2.66it/s]
14it [00:05,  2.68it/s]
14it [00:05,  2.69it/s]
14it [00:05,  2.64it/s]
14it [00:05,  2.65it/s]
14it [00:05,  2.67it/s]
14it [00:05,  2.60it/s]
14it [00:05,  2.45it/s]
14it [00:05,  2.54it/s]
14it [00:05,  2.61it/s]


NEW: Tennessee man describes diving to the flo | or as his house blows away . The tornad | ssssss
Atlanta surpasses LA, Philadel | phia as city with most bank heists . FBI says it | sassss
Erik Prince: "There was definite | ly incoming small arms fire from insurg | sssssss
Red Cross says it became aware of the re | lationship 10 days ago . Relationship allegedly | ssssssssssssssss
Woman sentenced to 200 lashes and six months | in jail under Islamic law . Judge more than doubled 19- | sa . N
BUPA was founded in 1947 in response | to plans to establish the NHS . | ssssss
Mom thinks girl was abused while in the care of a | baby sitter, attorney says . Mother had no | sa . N
NEW: 17-year-old alleged shooter appears in | Miami courtroom . NEW: Eric Rive | sa . N
Epoch 20 train loss: 0.16391602903604507 match_pct: 0.0890066996216774 valid loss: 0.17916938662528992


14it [00:05,  2.61it/s]
14it [00:05,  2.58it/s]
14it [00:05,  2.66it/s]
14it [00:05,  2.63it/s]
14it [00:05,  2.65it/s]
14it [00:05,  2.51it/s]
14it [00:05,  2.58it/s]
14it [00:05,  2.66it/s]
14it [00:05,  2.62it/s]
14it [00:05,  2.68it/s]


National Hurricane Center director B | ill Proenza has left his position . Nearly | saged to eeee . The
The car was driving on a runway at 3:4 | 5 a.m. when the driver hit the brak | the red to r
Lance corporal due to give birth at | any time, sheriff says . Marine's car found Monday at | ssssssss . The
Iraqi forces detain the suspected leader of a terrorist ce | ll network . Cell is believed to be fund | to red to rs . NEW:
Woman, boyfriend arrested after a tip led | to search . Police believe child found dead in box is Riley | sing to ru . He
Remittances to Mexico fell $1 | 00 million in January, according to Bank of | s,,,,,,,
Phrase in Obama speech similar | to that of Massachusetts Gov. | ssssssssss
NEW: Police say they have more than one confes | sion in the case . NEW: Investigation | ssssssss .
Epoch 30 train loss: 0.15292869401829584 match_pct: 0.1294642835855484 valid loss: 0.1812349185347557


14it [00:05,  2.63it/s]
14it [00:05,  2.60it/s]
14it [00:05,  2.62it/s]
14it [00:05,  2.66it/s]
14it [00:05,  2.65it/s]
14it [00:05,  2.67it/s]
14it [00:05,  2.61it/s]
14it [00:05,  2.65it/s]
14it [00:05,  2.61it/s]
14it [00:05,  2.65it/s]


University of Memphis athlete Taylor | Bradford, 21, was shot Sept | ,,,,,,,,,,,,
Experts say Lewis Hamilton is set to earn more than | Michael Schumacher . David Beck | to  to  to  to  . . . . . . . .
NEW: NFL chief, Atlanta Falcons | owner critical of Michael Vick's conduct . N | s,,,,,,,,,,,,, . .
President Bush will have a routine colonos | copy Saturday . While he's anesthetized | s a rs a rs a r
President Bush to address the Veterans of | Foreign Wars on Wednesday . Bush to say that withdraw | ssssssssss . . .
"Desperate Housewives" actress E | va Longoria Parker gives pee | seareeeeeeeeee
U.N. agency appeals for medical reset | tlement of Palestinians in Iraq camps . About | sssssssssssss .
NEW: 17-year-old alleged shooter appears in | Miami courtroom . NEW: Eric Rive | toingrrrrrrr
Epoch 40 train loss: 0.13671733226094926 match_pct: 0.189453125 valid loss: 0.182972714304924


14it [00:05,  2.53it/s]
9it [00:03,  2.72it/s]

In [None]:
pred.shape

In [None]:
target.shape