In [54]:
from datasets import load_dataset

# Load from Huggingface datasets module
data = load_dataset("cnn_dailymail", "3.0.0")
train = data["train"]["highlights"]
valid = data["validation"]["highlights"]

Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
100%|██████████| 3/3 [00:00<00:00, 188.23it/s]


In [55]:
STOP_TOKEN = 2
START_TOKEN = 1
UNK_TOKEN = 0
CHUNK_LENGTH = 12
VOCAB_SIZE = 1000
TRAIN_LENGTH = 10000
VALID_LENGTH = 500

In [3]:
from torchtext.data.functional import generate_sp_model
with open("tokens.txt", "w+") as f:
    f.write("\n".join(train) + "\n".join(valid))

generate_sp_model("tokens.txt", vocab_size=VOCAB_SIZE, model_prefix="cnn")

sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn --vocab_size=1000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: cnn
  model_type: UNIGRAM
  vocab_size: 1000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extremely_large_corpus: 0
  hard_vocab_limit: 1
  use_all_vocab: 0
  unk_id: 0
  bos_id: 1
  eos_id: 2
  pad_id: -1
  unk_piece: <unk>
  bos_piece: <s>
  eos_piece: </s>
  pad_piece: <pad>
  unk_surface:  ⁇ 
}
norm

In [56]:
import sentencepiece as spm
import numpy as np
from torchtext.data.functional import load_sp_model, sentencepiece_numericalizer
import torch

sp_base = spm.SentencePieceProcessor(model_file="cnn.model")
sp_model = load_sp_model("cnn.model")
encoding_generator = sentencepiece_numericalizer(sp_model)

def chunk_text(texts):
    return [(t[:CHUNK_LENGTH], t[CHUNK_LENGTH:(2 * CHUNK_LENGTH)]) for t in texts if (CHUNK_LENGTH * 2) < len(t)]

def decode_ids(ids):
    if isinstance(ids, torch.Tensor):
        ids = [int(i) for i in list(ids.numpy())]
    return sp_base.decode(ids)

def encode(tokens):
    mat = np.zeros((len(tokens), VOCAB_SIZE))
    for i in range(len(tokens)):
        mat[i,tokens[i]] = 1
    return mat

train_ids = list(encoding_generator(train[:TRAIN_LENGTH]))
valid_ids = list(encoding_generator(valid[:VALID_LENGTH]))
train_ids = chunk_text(train_ids)
valid_ids = chunk_text(valid_ids)

In [57]:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch import nn
import math

class CNNDataset(Dataset):
    def __init__(self, data):
        self.dataset = data

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x = torch.tensor(encode(self.dataset[idx][0])).float()
        y = torch.tensor(encode(self.dataset[idx][1])).float()
        return x, y

In [58]:
BATCH_SIZE = 1

train_dataset = CNNDataset(train_ids)
valid_dataset = CNNDataset(valid_ids)

train = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)
valid = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [59]:
class Encoder(nn.Module):
    def __init__(self, input_units=VOCAB_SIZE, hidden_units=512, sequence_len=CHUNK_LENGTH):
        super(Encoder, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.sequence_len = sequence_len

        k = 1/math.sqrt(hidden_units)
        self.input_weight = nn.Parameter(torch.rand(input_units, hidden_units) * 2 * k - k)

        self.hidden_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.hidden_bias = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)

    def forward(self, x):
        hiddens = [torch.zeros((1, self.hidden_units))]
        for j in range(self.sequence_len):
            prev_hidden_index = max(0, j-1)

            input_x = x[j,:] @ self.input_weight
            hiddens.append(torch.tanh(input_x + hiddens[prev_hidden_index] @ self.hidden_weight + self.hidden_bias))
        return torch.cat(hiddens[1:], dim=0)

In [60]:
class Decoder(nn.Module):
    def __init__(self, input_units=512, hidden_units=512, output_units=VOCAB_SIZE, in_sequence_len=CHUNK_LENGTH, out_sequence_len=CHUNK_LENGTH):
        super(Decoder, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.output_units = output_units
        self.in_sequence_len = in_sequence_len
        self.out_sequence_len = out_sequence_len

        k = 1/math.sqrt(hidden_units)
        self.hidden_attention_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.input_attention_weight = nn.Parameter(torch.rand(input_units, hidden_units) * 2 * k - k)
        self.attention_weight = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)


        self.context_hidden_weight = nn.Parameter(torch.rand(input_units, hidden_units) * 2 * k - k)
        self.hidden_weight = nn.Parameter(torch.rand(hidden_units, hidden_units) * 2 * k - k)
        self.hidden_bias = nn.Parameter(torch.rand(1, hidden_units) * 2 * k - k)

        self.output_weight = nn.Parameter(torch.rand(hidden_units, output_units) * 2 * k - k)
        self.output_bias = nn.Parameter(torch.rand(1, output_units) * 2 * k - k)

    def forward(self, context):
        hiddens = [torch.zeros(1, self.hidden_units)]
        outputs = [torch.zeros(1, self.output_units)]
        context_attns = context @ self.input_attention_weight
        for i in range(self.out_sequence_len):
            prev_hidden_index = max(0, i-1)
            cross = torch.tanh(context_attns + hiddens[prev_hidden_index] @ self.hidden_attention_weight)
            attention = cross @ self.attention_weight.T

            probs = torch.softmax(attention, 0).reshape(self.in_sequence_len)
            positional_context = torch.sum(torch.diag(probs) @ context, dim=0).reshape(1, self.input_units)

            input_x = positional_context @ self.input_attention_weight

            hiddens.append(torch.tanh(input_x + hiddens[prev_hidden_index] @ self.hidden_weight + self.hidden_bias))
            outputs.append(hiddens[i] @ self.output_weight + self.output_bias)
        return torch.cat(hiddens[1:], dim=0), torch.cat(outputs[1:], dim=0)

In [61]:
class Network(nn.Module):
    def __init__(self, in_sequence_len, out_sequence_len):
        super(Network, self).__init__()
        self.encoder = Encoder(sequence_len=in_sequence_len, input_units=VOCAB_SIZE)
        self.decoder = Decoder(in_sequence_len=in_sequence_len, out_sequence_len=out_sequence_len, output_units=VOCAB_SIZE)

    def forward(self, x):
        hiddens = self.encoder(x)
        hiddens, outputs = self.decoder(hiddens)
        return hiddens, outputs

In [None]:
from statistics import mean
from tqdm.auto import tqdm

device = torch.device("cpu")
model = Network(CHUNK_LENGTH, CHUNK_LENGTH).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

EPOCHS = 5
for epoch in range(EPOCHS):
    for batch, (sequence, target) in tqdm(enumerate(train)):
        optimizer.zero_grad()

        sequence = sequence.to(device)
        hidden, pred = model(sequence[0,:,:])

        pred = pred.reshape(1, CHUNK_LENGTH, VOCAB_SIZE)
        loss = loss_fn(pred, target)
        loss.backward()
        optimizer.step()

    losses = []
    with torch.no_grad():
        for batch, (sequence, target) in enumerate(valid):
            sequence = sequence.to(device)
            hidden, pred = model(sequence[0,:,:])
            pred = pred.reshape(1, CHUNK_LENGTH, VOCAB_SIZE)
            loss = loss_fn(pred, target)
            losses.append(loss.item())

    prompt = decode_ids(torch.argmax(sequence[0,:,:], dim=1))
    text = decode_ids(torch.argmax(pred[0,:,:], dim=1))
    print(f"Epoch {epoch} valid loss: {mean(losses)}")
    print(f"Example: {prompt} | {text}")

6422it [03:02, 29.06it/s]