In [14]:
#!pip install torch torchtext datasets sentencepiece numpy
from datasets import load_dataset

TRAIN_LENGTH = 10000
VALID_LENGTH = 250
SP_VOCAB_SIZE = 1000
STOP_TOKEN = SP_VOCAB_SIZE + 1
START_TOKEN = SP_VOCAB_SIZE + 2
VOCAB_SIZE = SP_VOCAB_SIZE + 3

CHUNK_LENGTH = 12

# Load from Huggingface datasets module
data = load_dataset("cnn_dailymail", "3.0.0")
train = data["train"]["highlights"][:TRAIN_LENGTH]
valid = data["validation"]["highlights"][:VALID_LENGTH]

Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
100%|██████████| 3/3 [00:00<00:00, 171.17it/s]


In [94]:
from torchtext.data.functional import generate_sp_model
with open("tokens.txt", "w+") as f:
    f.write("\n".join(train) + "\n".join(valid))

generate_sp_model("tokens.txt", vocab_size=SP_VOCAB_SIZE, model_prefix="cnn")

sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn --vocab_size=1000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: cnn
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
}
norm

In [15]:
import sentencepiece as spm
import numpy as np
from torchtext.data.functional import load_sp_model, sentencepiece_numericalizer
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sp_base = spm.SentencePieceProcessor(model_file="cnn.model")
sp_model = load_sp_model("cnn.model")
encoding_generator = sentencepiece_numericalizer(sp_model)

def chunk_text(texts):
    return [(t[:CHUNK_LENGTH], t[CHUNK_LENGTH:(2 * CHUNK_LENGTH)]) for t in texts if (CHUNK_LENGTH * 2) < len(t)]

def decode_ids(ids):
    if isinstance(ids, torch.Tensor):
        ids = [int(i) for i in list(ids.numpy()) if int(i) not in [START_TOKEN, STOP_TOKEN]]
    return sp_base.decode(ids)

def decode_batch(id_tensor):
    decoded = []
    for i in range(id_tensor.shape[0]):
        decoded.append(decode_ids(id_tensor[i,:]))
    return decoded

def encode(tokens):
    mat = np.zeros((len(tokens), VOCAB_SIZE))
    for i in range(len(tokens)):
        mat[i,tokens[i]] = 1
    return mat

train_ids = list(encoding_generator(train))
valid_ids = list(encoding_generator(valid))
train_ids = chunk_text(train_ids)
valid_ids = chunk_text(valid_ids)

In [16]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn
import math

class CNNDataset(Dataset):
    def __init__(self, data):
        self.dataset = data

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x = torch.tensor(self.dataset[idx][0]).int()
        y_list = [START_TOKEN] + self.dataset[idx][1] + [STOP_TOKEN]
        y = torch.tensor(encode(y_list)).float()
        return x.to(DEVICE), y.to(DEVICE)

In [17]:
BATCH_SIZE = 32

train_dataset = CNNDataset(train_ids)
valid_dataset = CNNDataset(valid_ids)

train = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
valid = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [18]:
class Encoder(nn.Module):
    def __init__(self, input_units, hidden_units, output_units):
        super(Encoder, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units

        k = 1/math.sqrt(hidden_units)
        self.input_weight = nn.Parameter(torch.rand(input_units, hidden_units) * 2 * k - k)

        self.hidden_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.hidden_bias = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)

        self.output_weight = nn.Parameter(torch.rand(hidden_units, output_units) * 2 * k - k)
        self.output_bias = nn.Parameter(torch.rand(1, output_units) * 2 * k - k)

    def forward(self, x, prev_hidden):
        # Compute the regular RNN forward pass
        # Input times weights
        input_x = x @ self.input_weight
        # Sum input with previous hidden state, and add nonlinearity
        # Tanh prevents gradients exploding
        hidden_x = torch.tanh(input_x + prev_hidden @ self.hidden_weight + self.hidden_bias)
        # Compute output vector
        output_y = hidden_x @ self.output_weight + self.output_bias
        return hidden_x, output_y

In [37]:
class Decoder(nn.Module):
    def __init__(self, input_units, hidden_units, output_units=VOCAB_SIZE):
        super(Decoder, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.output_units = output_units

        k = 1/math.sqrt(hidden_units)
        self.hidden_attention_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.context_attention_weight = nn.Parameter(torch.rand(input_units, hidden_units) * 2 * k - k)
        self.attention_weight = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)

        self.context_hidden_weight = nn.Parameter(torch.rand(input_units + output_units, hidden_units) * 2 * k - k)
        self.hidden_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.hidden_bias = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)

        self.output_weight = nn.Parameter(torch.rand(hidden_units, output_units) * 2 * k - k)
        self.output_bias = nn.Parameter(torch.rand(1, output_units) * 2 * k - k)

    def forward(self, prev_y, prev_hidden, context):
        # Compute attention between the encoder hidden states and the previous decoder hidden state
        # Loop over batches
        # Swap batch and sequence
        """
        context_attn = torch.bmm(context.swapaxes(0,1), self.context_attention_weight.unsqueeze(0).repeat(context.shape[1],1,1))
        cross = torch.tanh(context_attn.swapaxes(0,1) + prev_hidden @ self.hidden_attention_weight)
        aw_unrolled = self.attention_weight.repeat(context.shape[1], 1,1).swapaxes(2,1)
        attention = torch.bmm(cross, aw_unrolled)
        positional_contexts = torch.zeros(1, context.shape[2], device=DEVICE)
        for i in range(context.shape[1]):
            # Compute probability for each encoder hidden state, and use it to weight and sum the states
            probs = torch.softmax(attention[i,:,:], 0).reshape(context.shape[0])
            positional_context = torch.sum(torch.diag(probs) @ context[:,i,:], dim=0).reshape(1, self.input_units)
            positional_contexts = torch.cat((positional_contexts, positional_context), dim=0)
        """
        positional_contexts = torch.zeros(1, context.shape[2]).to(DEVICE)
        for i in range(context.shape[1]):
            context_attn = context[:,i,:] @ self.context_attention_weight
            cross = torch.tanh(context_attn + prev_hidden[i,:].unsqueeze(0) @ self.hidden_attention_weight)
            attention = cross @ self.attention_weight.T

            # Compute probability for each encoder hidden state, and use it to weight and sum the states
            probs = torch.softmax(attention, 0).reshape(context.shape[0])
            positional_context = torch.sum(torch.diag(probs) @ context[:,i,:], dim=0).reshape(1, self.input_units)
            positional_contexts = torch.cat((positional_contexts, positional_context), dim=0)

        # Compute a regular rnn.  Cat the context vector and the previous y state.
        input_x = torch.cat([positional_contexts[1:], prev_y], dim=1) @ self.context_hidden_weight
        hidden_x = torch.tanh(input_x + prev_hidden @ self.hidden_weight + self.hidden_bias)

        output_y = hidden_x @ self.output_weight + self.output_bias
        return hidden_x, output_y

In [38]:
class Network(nn.Module):
    def __init__(self, in_sequence_len, out_sequence_len, hidden_units=512, embedding_len=VOCAB_SIZE):
        super(Network, self).__init__()
        self.in_sequence_len = in_sequence_len
        self.out_sequence_len = out_sequence_len
        self.hidden_units = hidden_units
        self.embedding_len = embedding_len

        self.enc_embedding = nn.Embedding(embedding_len, hidden_units)
        self.encoder = Encoder(input_units=hidden_units, hidden_units=hidden_units, output_units=embedding_len)

        self.dec_embedding= nn.Embedding(embedding_len, hidden_units)
        self.decoder = Decoder(input_units=hidden_units, hidden_units=hidden_units, output_units=embedding_len)

    def forward(self, x, y):
        batch_size = x.shape[0]
        # Move batch to the second dimension, so sequence comes first
        y = y.swapaxes(0,1)
        # Embed the input sequence to reduce dimensionality
        embedded = self.enc_embedding(x).swapaxes(0,1)

        # Encode the input sequence
        # Both tensors will have sequence then batch
        enc_hiddens = torch.zeros((1, batch_size, self.hidden_units), device=DEVICE)
        enc_outputs = torch.zeros((1, batch_size, self.embedding_len), device=DEVICE)
        for j in range(self.in_sequence_len):
            hidden, output = self.encoder(embedded[j,:,:], enc_hiddens[j,:,:])
            # Add first sequence axis
            hidden = hidden.unsqueeze(0)
            output = output.unsqueeze(0)
            enc_hiddens = torch.cat((enc_hiddens, hidden), dim=0)
            enc_outputs = torch.cat((enc_outputs, output), dim=0)

        # Decode to the output sequence
        # Pass in context
        context = enc_hiddens[1:,:,:]
        # Both tensors will have the first dimension be the sequence
        dec_hiddens = torch.zeros(1, batch_size, self.hidden_units, device=DEVICE)
        dec_outputs = torch.zeros((1, batch_size, self.embedding_len), device=DEVICE)
        for j in range(self.out_sequence_len):
            # Use either the actual previous y (from the input), or the generated y if the input sequence is shorter than the generation steps.
            prev_y = y[j,:,:] if y.shape[0] > j else torch.softmax(dec_outputs[j,:,:], dim=1)
            hidden, output = self.decoder(prev_y, dec_hiddens[j,:,:], context)
            # Add first sequence axis
            hidden = hidden.unsqueeze(0)
            output= output.unsqueeze(0)
            dec_hiddens = torch.cat((dec_hiddens, hidden), dim=0)
            dec_outputs = torch.cat((dec_outputs, output), dim=0)

        # Move batch back to axis 0
        out_hiddens = dec_hiddens[1:].swapaxes(0,1)
        out_output = dec_outputs[1:].swapaxes(0,1)
        return out_hiddens, out_output

def generate(sequence, target):
    _, pred = model(sequence, target[:,0,:].unsqueeze(1))
    prompts = decode_batch(sequence.cpu())
    texts = decode_batch(torch.argmax(pred, dim=2).cpu())
    correct_texts = decode_batch(torch.argmax(target, dim=2).cpu())

    displays = []
    for p, t, ct in zip(prompts, texts, correct_texts):
        displays.append(f"{p} | {ct} | {t}")
    return displays

In [39]:
from tqdm.auto import tqdm

model = Network(CHUNK_LENGTH, CHUNK_LENGTH + 2, hidden_units=512).to(DEVICE)
loss_fn = nn.CrossEntropyLoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

In [40]:
EPOCHS = 100
DISPLAY_BATCHES = 8

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    for batch, (sequence, target) in tqdm(enumerate(train)):
        optimizer.zero_grad()
        hidden, pred = model(sequence, target)
        loss = loss_fn(pred, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Show text generated from training prompt as an example
    # Don't feed in all of the train y sequences, just the first token
    # The other y tokens will be predicted by the model and fed back in
    sents = generate(sequence[:DISPLAY_BATCHES], target[:DISPLAY_BATCHES])
    for sent in sents:
        print(sent)

    # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
    valid_loss = 0
    with torch.no_grad():
        for batch, (sequence, target) in enumerate(valid):
            # Only feed in the first token of the actual target
            hidden, pred = model(sequence, target[:,0,:].unsqueeze(1))
            loss = loss_fn(pred, target)
            valid_loss += loss.item()

    print(f"Epoch {epoch} train loss: {train_loss / len(train_dataset)} valid loss: {valid_loss / len(valid_dataset)}")

0it [00:00, ?it/s]


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [12, 512] but got: [32, 512].

In [10]:
import torch
attention_weight = torch.rand(1, 512)

nw = attention_weight.repeat(2,5,1).swapaxes(1,2)

In [11]:
nw.shape

torch.Size([2, 512, 5])