In [1]:
# Based on this paper - https://arxiv.org/pdf/1706.03762.pdf
# Might want to move layer norm inside the residual block - https://arxiv.org/pdf/2002.04745.pdf
# Layer normalization - https://arxiv.org/pdf/1607.06450.pdf
# TODO: Investigate learning rate warmup - https://arxiv.org/abs/2002.04745
#!pip install torch torchtext sentencepiece datasets wandb

In [2]:
import numpy as np
import torch
from torch import nn
import sys
import os
import math
sys.path.append(os.path.abspath("../../data"))
sys.path.append(os.path.abspath("../../nnets"))
from net_utils import get_module_list

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
SP_VOCAB_SIZE = 5000
TRAIN_SIZE = 5000

In [3]:
from wrapper import OpusBooksDataset
from padding import PaddingSampler, pad_collate
from torch.utils.data import DataLoader

train_wrapper = OpusBooksDataset(tokenizer_vocab=SP_VOCAB_SIZE, download_split_pct="5%")
train_dataset = train_wrapper.process_dataset()
train_dataset = train_dataset.train_test_split(test_size=0.1)

train_split = train_dataset["train"]
train = DataLoader(train_split, batch_size=BATCH_SIZE, sampler=PaddingSampler(train_split["en_lens"]), collate_fn=pad_collate)

valid_split = train_dataset["test"]
valid = DataLoader(valid_split, batch_size=BATCH_SIZE, sampler=PaddingSampler(valid_split["en_lens"]), collate_fn=pad_collate)

Found cached dataset opus_books (/Users/vik/.cache/huggingface/datasets/opus_books/en-es/1.0.0/e8f950a4f32dc39b7f9088908216cd2d7e21ac35f893d04d39eb594746af2daf)


Combining:   0%|          | 0/5 [00:00<?, ?ba/s]






Tokenizing:   0%|          | 0/5 [00:00<?, ?ba/s]

Filtering:   0%|          | 0/5 [00:00<?, ?ba/s]

Filtering:   0%|          | 0/5 [00:00<?, ?ba/s]

Splitting:   0%|          | 0/5 [00:00<?, ?ba/s]

In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_units, attention_heads, mask=False):
        super(MultiHeadAttention, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.head_units = int(input_units/attention_heads)
        self.mask = mask

        k = math.sqrt(1/self.input_units)
        self.in_proj_weight = nn.Parameter(torch.rand(3, input_units, self.attention_heads * self.head_units) * 2 * k - k)
        self.in_proj_bias = nn.Parameter(torch.rand(3, input_units) * 2 * k - k)

        self.out_proj_weight = nn.Parameter(torch.rand(self.attention_heads * self.head_units, input_units) * 2 * k - k)
        self.out_proj_bias = nn.Parameter(torch.rand(input_units) * 2 * k - k)

    def forward(self, queries, keys, values):
        # convert to 4d tensor with batch_size, attn_heads, seq_len, embedding_dim
        proj_queries = torch.einsum("...se, eo->...so", queries, self.in_proj_weight[0]) + self.in_proj_bias[0]
        proj_queries = proj_queries.view(queries.shape[0], queries.shape[1], self.attention_heads, self.head_units).swapaxes(1,2)

        proj_keys = torch.einsum("...se, eo->...so", keys, self.in_proj_weight[1]) + self.in_proj_bias[1]
        proj_keys = proj_keys.view(keys.shape[0], keys.shape[1], self.attention_heads, self.head_units).swapaxes(1,2)

        proj_values = torch.einsum("...se, eo->...so", values, self.in_proj_weight[2]) + self.in_proj_bias[2]
        proj_values = proj_values.view(values.shape[0], values.shape[1], self.attention_heads, self.head_units).swapaxes(1,2)

        attention = torch.einsum("baqh, bahk->baqk", proj_queries, torch.transpose(proj_keys, -1, -2)) / np.sqrt(proj_keys.shape[-1])
        if self.mask:
            # Prevent decoder queries from looking at tokens that come after
            # Do this by setting attention to negative infinity, so it is softmaxed to zero in the next step
            mask = torch.full((attention.shape[-2], attention.shape[-1]), -torch.inf, device=DEVICE)
            attention += torch.triu(mask, diagonal=1)

        # Softmax on last dimension
        # Sequence-wise softmax, so attention between one sequence and other sequences sums to 1
        attention = torch.softmax(attention, dim=-1)
        weighted_values = torch.einsum("baqk, bake->baqe", attention, proj_values)

        # Swap attention head and sequence axis, then reshape to batch, seq, embedding
        weighted_values = weighted_values.swapaxes(1,2).reshape(queries.shape[0], queries.shape[1], -1)
        weighted_values = torch.einsum("...se, eo->...so", weighted_values, self.out_proj_weight) + self.out_proj_bias
        return weighted_values

In [5]:
class EncoderBlock(nn.Module):
    def __init__(self, input_units, attention_heads, hidden_units=2048, dropout_p=.1):
        super(EncoderBlock, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.hidden_units = hidden_units

        self.mha = MultiHeadAttention(self.input_units, self.attention_heads)
        self.dropouts = get_module_list(2, nn.Dropout, dropout_p)
        self.linear1 = nn.Linear(self.input_units, hidden_units)
        self.linear2 = nn.Linear(hidden_units, self.input_units)
        self.relu = nn.ReLU()
        self.lns = get_module_list(2, nn.LayerNorm, self.input_units)

    def forward(self, x):
        weighted_values = self.dropouts[0](self.mha(x, x, x))
        x = self.lns[0](x + weighted_values)

        reprojected = self.dropouts[1](self.linear2(self.relu(self.linear1(x))))
        x = self.lns[1](x + reprojected)
        return x

In [6]:
class DecoderBlock(nn.Module):
    def __init__(self, input_units, attention_heads, hidden_units=2048, dropout_p=.1):
        super(DecoderBlock, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.hidden_units = hidden_units

        self.in_attn = MultiHeadAttention(self.input_units, self.attention_heads, mask=True)
        self.context_attn = MultiHeadAttention(self.input_units, self.attention_heads)
        self.dropouts = get_module_list(3, nn.Dropout, dropout_p)
        self.linear1 = nn.Linear(self.input_units, hidden_units)
        self.linear2 = nn.Linear(hidden_units, self.input_units)
        self.relu = nn.ReLU()
        self.lns = get_module_list(3, nn.LayerNorm, self.input_units)

    def forward(self, x, context):
        weighted_values = self.dropouts[0](self.in_attn(x, x, x))
        x = self.lns[0](x + weighted_values)

        decoder_values = self.dropouts[1](self.context_attn(x, context, context))
        x = self.lns[1](x + decoder_values)

        reprojected = self.dropouts[2](self.linear2(self.relu(self.linear1(x))))
        x = self.lns[2](x + reprojected)
        return x

In [7]:
class Transformer(nn.Module):
    def __init__(self, input_units, hidden_units, attention_heads, padding_idx, max_len=256, blocks=1):
        super(Transformer, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.attention_heads = attention_heads
        self.blocks = blocks

        self.output_embedding = nn.Linear(hidden_units, input_units)
        self.embedding = nn.Embedding(input_units, hidden_units, padding_idx=padding_idx)
        self.dropouts = get_module_list(2, nn.Dropout, .1)
        self.encoders = get_module_list(self.blocks, EncoderBlock, hidden_units, attention_heads)
        self.decoders = get_module_list(self.blocks, DecoderBlock, hidden_units, attention_heads)
        self.pos_encoding = self.encoding(max_len, self.hidden_units).to(DEVICE)


    def forward(self, x, y, enc_outputs=None):
        if enc_outputs is None:
            # 3D with batch, seq, embeddings
            # TODO: Tie input and output embedding weights
            enc_outputs = self.dropouts[0](self.embedding(x) + self.pos_encoding[:x.shape[1]])

            for i in range(self.blocks):
                enc_outputs = self.encoders[i](enc_outputs)

        dec_outputs = self.dropouts[1](self.embedding(y) + self.pos_encoding[:y.shape[1]])
        for i in range(self.blocks):
            dec_outputs = self.decoders[i](dec_outputs, enc_outputs)

        token_vectors = self.output_embedding(dec_outputs)
        return token_vectors, enc_outputs

    def encoding(self, seq_len, embed_len):
        encodings = torch.zeros((seq_len, embed_len))
        for i in range(seq_len):
            all = torch.exp(torch.arange(0, embed_len, 2) * (-math.log(10000.0) / embed_len))
            encodings[i, 0::2] = torch.sin(i * all)
            encodings[i, 1::2] = torch.cos(i * all)
        return encodings

In [8]:
def generate(sequence, pred, target, wrapper):
    prompts = wrapper.decode_ids(sequence.cpu())
    texts = wrapper.decode_ids(torch.argmax(pred, dim=2).cpu())
    correct_texts = wrapper.decode_ids(target.cpu())

    displays = []
    for p, t, ct in zip(prompts, texts, correct_texts):
        displays.append(f"{p} | {ct} | {t}")
    return displays

In [9]:
from tqdm.auto import tqdm
import wandb

wandb.init(project="transformer", notes="Baseline performance one block updated dataset", name="baseline-one-new-data")

# TODO: Profile and improve perf - https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
model = Transformer(SP_VOCAB_SIZE, 512, 8, blocks=1, padding_idx=train_wrapper.tokenizer.pad_token_id).to(DEVICE)
loss_fn = nn.CrossEntropyLoss(ignore_index=train_wrapper.tokenizer.pad_token_id)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
wandb.watch(model, log_freq=100)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

[34m[1mwandb[0m: Currently logged in as: [33mvikp[0m. Use [1m`wandb login --relogin`[0m to force relogin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[]

In [None]:
EPOCHS = 100
DISPLAY_BATCHES = 2
PRINT_VALID = True

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    match_pct = 0
    for batch in tqdm(train):
        optimizer.zero_grad(set_to_none=True)
        sequence = batch["en_ids"].to(torch.long)
        target = batch["es_ids"].to(torch.long)
        # Setup target properly
        prev_target = torch.roll(target, 1, -1)
        prev_target[:,0] = train_wrapper.tokenizer.bos_token_id

        pred, _ = model(sequence.to(DEVICE), prev_target.to(DEVICE))

        # If you use a batch, need to reshape pred to be batch * sequence, embedding_len to be compatible
        # Similar reshape with target to be batch * sequence vector of class indices
        loss = loss_fn(pred.reshape(-1, pred.shape[-1]), target.reshape(-1).to(DEVICE))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    with torch.no_grad():
        mean_loss = train_loss / len(train) / BATCH_SIZE
        wandb.log({"loss": mean_loss})
        print(f"Epoch {epoch} train loss: {mean_loss}")
        sents = generate(sequence, pred, target, train_wrapper)
        for sent in sents[:DISPLAY_BATCHES]:
            print(sent)

        if PRINT_VALID and epoch % 10 ==0:
            # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
            valid_loss = 0
            # Deactivate dropout layers
            model.eval()
            for batch in tqdm(valid):
                # Inference token by tokens
                sequence = batch["en_ids"].to(torch.long)
                target = batch["es_ids"].to(torch.long)
                # Setup target properly
                prev_target = torch.roll(target, 1, -1)
                prev_target[:,0] = train_wrapper.tokenizer.bos_token_id
                outputs = prev_target[:,0].unsqueeze(1).to(DEVICE)
                enc_outputs = None
                # TODO: Investigate memory leak with valid generation
                for i in range(target.shape[1]):
                    pred, enc_outputs = model(sequence, outputs, enc_outputs=enc_outputs)
                    last_output = torch.argmax(pred, dim=2)
                    outputs = torch.cat((outputs, last_output[:,-1:]), dim=1)
                loss = loss_fn(pred.reshape(-1, pred.shape[-1]), target.view(-1).to(DEVICE))
                valid_loss += loss.item()
            mean_loss = valid_loss / len(valid) / BATCH_SIZE
            wandb.log({"valid_loss": mean_loss})
            print(f"Valid loss: {mean_loss}")
            sents = generate(sequence, pred, target, train_wrapper)
            for sent in sents[:DISPLAY_BATCHES]:
                print(sent)
            # Reactivate dropout
            model.train()

  0%|          | 0/77 [00:00<?, ?it/s]

Epoch 0 train loss: 0.22905117582965207
And with this pleasing anticipation, she sat down to reconsider the past, recall the words and endeavour to comprehend all the feelings of Edward; and, of course, to reflect on her own with discontent. | Y con este agradable vaticinio se sentó a reconsiderar el pasado, recordar las palabras e intentar comprender los sentimientos de Edward; y, por supuesto, a reflexionar sobre su propio descontento. | - de de de de de de que de de de de de de de de que de de que de de de de que de de de de que de de que que de que de que la que de de de de que de de de de de
She vowed at first she would never trim me up a new bonnet, nor do any thing else for me again, so long as she lived; but now she is quite come to, and we are as good friends as ever. | Primero juró que nunca más volvería a arreglarme ninguna toca nueva ni jamás haría ninguna otra cosa por mí; pero ahora ya se ha aplacado y estamos tan amigas como siempre. | - de de que que de de de de de de q

  0%|          | 0/9 [00:00<?, ?it/s]

Valid loss: 0.21411480340692732
ONE person I was sure would represent me as capable of any thing-- What I felt was dreadful!--My resolution was soon made, and at eight o'clock this morning I was in my carriage. | ¡Lo que sentí fue atroz! Rápidamente tomé una decisión, y hoy a las ocho de la mañana ya me encontraba en mi carruaje. | - de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de
The son, a steady respectable young man, was amply provided for by the fortune of his mother, which had been large, and half of which devolved on him on his coming of age. | El hijo, un joven serio y respetable, tenía el futuro asegurado por la fortuna de su madre, que era cuantiosa, y de cuya mitad había entrado en posesión al cumplir su mayoría de edad. | - de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de que de qu

  0%|          | 0/77 [00:00<?, ?it/s]

In [None]:
from torchinfo import summary

print(summary(model))

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU], record_shapes=True, ) as prof:
    model(sequence.to(DEVICE), prev_target.to(DEVICE))

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))