In [3]:
# Based on this paper - https://arxiv.org/pdf/1706.03762.pdf
# Might want to move layer norm inside the residual block - https://arxiv.org/pdf/2002.04745.pdf
# Layer normalization - https://arxiv.org/pdf/1607.06450.pdf
#!pip install torch torchtext sentencepiece datasets wandb

In [4]:
import torch
from torch import nn
import sys
import os
import math
sys.path.append(os.path.abspath("../../data"))
sys.path.append(os.path.abspath("../../nnets"))
from net_utils import get_module_list

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
SP_VOCAB_SIZE = 5000
TRAIN_SIZE = 5000

In [5]:
from text_data import Opus100DatasetWrapper

class Wrapper(Opus100DatasetWrapper):
    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]
    x_length = 40
    target_length = 40

wrapper = Wrapper(SP_VOCAB_SIZE)
datasets = wrapper.generate_datasets(BATCH_SIZE)
train = datasets["train"]
valid = datasets["validation"]

Found cached dataset opus100 (/Users/vik/.cache/huggingface/datasets/opus100/en-es/0.0.0/256f3196b69901fb0c79810ef468e2c4ed84fbd563719920b1ff1fdc750f7704)
100%|██████████| 3/3 [00:00<00:00, 18.89it/s]
sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=opus100 --vocab_size=5000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: opus100
  model_type: UNIGRAM
  vocab_size: 5000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extre

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_units, attention_heads, mask=False):
        super(MultiHeadAttention, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.head_units = int(input_units/attention_heads)
        self.mask = mask

        self.output_proj = nn.Linear(self.attention_heads * self.head_units, input_units)
        self.query_proj = nn.Linear(input_units, self.attention_heads * self.head_units)
        self.key_proj = nn.Linear(input_units, self.attention_heads * self.head_units)
        self.value_proj = nn.Linear(input_units, self.attention_heads * self.head_units)

    def forward(self, queries, keys, values):
        # convert to 4d tensor with batch_size, attn_heads, seq_len, embedding_dim
        proj_queries = self.query_proj(queries).view(queries.shape[0], queries.shape[1], self.attention_heads, self.head_units)
        proj_queries = proj_queries.swapaxes(1,2)

        proj_keys = self.key_proj(keys).view(keys.shape[0], keys.shape[1], self.attention_heads, self.head_units)
        proj_keys = proj_keys.swapaxes(1,2)

        proj_values = self.value_proj(values).view(values.shape[0], values.shape[1], self.attention_heads, self.head_units)
        proj_values = proj_values.swapaxes(1,2)

        attention = proj_queries @ torch.transpose(proj_keys, -1, -2)
        if self.mask:
            # Prevent decoder queries from looking at tokens that come after
            # Do this by setting attention to negative infinity, so it is softmaxed to zero in the next step
            mask = torch.full((attention.shape[-2], attention.shape[-1]), -torch.inf, device=DEVICE)
            attention += torch.triu(mask, diagonal=1)

        # Softmax on last dimension
        # Sequence-wise softmax, so attention between one sequence and other sequences sums to 1
        attention = torch.softmax(attention, dim=-1)
        weighted_values = attention @ proj_values

        # Swap attention head and sequence axis, then reshape to batch, seq, embedding
        weighted_values = weighted_values.swapaxes(1,2).reshape(queries.shape[0], queries.shape[1], -1)
        weighted_values = self.output_proj(weighted_values)
        return weighted_values

In [7]:
class EncoderBlock(nn.Module):
    def __init__(self, input_units, attention_heads, hidden_units=2048, dropout_p=.1):
        super(EncoderBlock, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.hidden_units = hidden_units

        self.mha = MultiHeadAttention(self.input_units, self.attention_heads)
        self.dropouts = get_module_list(2, nn.Dropout, dropout_p)
        self.linear1 = nn.Linear(self.input_units, hidden_units)
        self.linear2 = nn.Linear(hidden_units, self.input_units)
        self.relu = nn.ReLU()
        self.lns = get_module_list(2, nn.LayerNorm, self.input_units)

    def forward(self, x):
        weighted_values = self.dropouts[0](self.mha(x, x, x))
        x = self.lns[0](x + weighted_values)

        reprojected = self.dropouts[1](self.linear2(self.relu(self.linear1(x))))
        x = self.lns[1](x + reprojected)
        return x

In [8]:
class DecoderBlock(nn.Module):
    def __init__(self, input_units, attention_heads, hidden_units=2048, dropout_p=.1):
        super(DecoderBlock, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.hidden_units = hidden_units

        self.in_attn = MultiHeadAttention(self.input_units, self.attention_heads, mask=True)
        self.context_attn = MultiHeadAttention(self.input_units, self.attention_heads)
        self.dropouts = get_module_list(3, nn.Dropout, dropout_p)
        self.linear1 = nn.Linear(self.input_units, hidden_units)
        self.linear2 = nn.Linear(hidden_units, self.input_units)
        self.relu = nn.ReLU()
        self.lns = get_module_list(3, nn.LayerNorm, self.input_units)

    def forward(self, x, context):
        weighted_values = self.dropouts[0](self.in_attn(x, x, x))
        x = self.lns[0](x + weighted_values)

        decoder_values = self.dropouts[1](self.context_attn(x, context, context))
        x = self.lns[1](x + decoder_values)

        reprojected = self.dropouts[2](self.linear2(self.relu(self.linear1(x))))
        x = self.lns[2](x + reprojected)
        return x

In [9]:
class Transformer(nn.Module):
    def __init__(self, input_units, hidden_units, attention_heads, max_len=256, blocks=1):
        super(Transformer, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.attention_heads = attention_heads
        self.blocks = blocks

        k = math.sqrt(1/self.hidden_units)
        self.output_embedding = nn.Linear(hidden_units, input_units)
        self.embedding = nn.Embedding(input_units, hidden_units)
        self.dropouts = get_module_list(2, nn.Dropout, .1)
        self.encoders = get_module_list(self.blocks, EncoderBlock, hidden_units, attention_heads)
        self.decoders = get_module_list(self.blocks, DecoderBlock, hidden_units, attention_heads)
        self.pos_encoding = self.encoding(max_len, self.hidden_units).to(DEVICE)
        # TODO: init params with xavier uniform distro


    def forward(self, x, y, enc_outputs=None):
        if enc_outputs is None:
            # 3D with batch, seq, embeddings
            # TODO: Tie input and output embedding weights
            enc_outputs = self.dropouts[0](self.embedding(x) + self.pos_encoding[:x.shape[1]])

            for i in range(self.blocks):
                enc_outputs = self.encoders[i](enc_outputs)

        dec_outputs = self.dropouts[1](self.embedding(y) + self.pos_encoding[:y.shape[1]])
        for i in range(self.blocks):
            dec_outputs = self.decoders[i](dec_outputs, enc_outputs)

        token_vectors = self.output_embedding(dec_outputs)
        return token_vectors, enc_outputs

    def encoding(self, seq_len, embed_len):
        encodings = torch.zeros((seq_len, embed_len))
        for i in range(seq_len):
            all = torch.exp(torch.arange(0, embed_len, 2) * (-math.log(10000.0) / embed_len))
            encodings[i, 0::2] = torch.sin(i * all)
            encodings[i, 1::2] = torch.cos(i * all)
        return encodings

In [10]:
def generate(sequence, pred, target, wrapper):
    prompts = wrapper.decode_batch(sequence.cpu())
    texts = wrapper.decode_batch(torch.argmax(pred, dim=2).cpu())
    correct_texts = wrapper.decode_batch(target.cpu())

    displays = []
    for p, t, ct in zip(prompts, texts, correct_texts):
        displays.append(f"{p} | {ct} | {t}")
    return displays

def trim_padding(batch, pad_token, other_seq=None):
    least_padding = (batch == pad_token).sum(axis=1).min()
    if other_seq is not None:
        least_padding = min(least_padding, (other_seq == pad_token).sum(axis=1).min())
    if least_padding == 0:
        return batch
    return batch[:,:-least_padding]

In [15]:
from tqdm.auto import tqdm
import wandb

wandb.init(project="transformer", notes="fix module lists two blocks", name="fixed-module-lists-two")

# TODO: Profile and improve perf - https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
model = Transformer(wrapper.vocab_size, 512, 8, blocks=2).to(DEVICE)
loss_fn = nn.CrossEntropyLoss(ignore_index=wrapper.pad_token)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
wandb.watch(model, log_freq=100)

0,1
loss,█▆▄▃▂▁
valid_loss,▁

0,1
loss,0.08905
valid_loss,0.21476


[]

In [16]:
EPOCHS = 25
DISPLAY_BATCHES = 2
OUT_SEQUENCE_LEN = wrapper.y_length
PRINT_VALID = True

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    match_pct = 0
    for batch, (sequence, target, prev_target) in tqdm(enumerate(train)):
        optimizer.zero_grad(set_to_none=True)
        pred, _ = model(trim_padding(sequence, wrapper.pad_token).to(DEVICE), trim_padding(prev_target, wrapper.pad_token, other_seq=target).to(DEVICE))

        # If you use a batch, need to reshape pred to be batch * sequence, embedding_len to be compatible
        # Similar reshape with target to be batch * sequence vector of class indices
        loss = loss_fn(pred.reshape(-1, pred.shape[-1]), trim_padding(target, wrapper.pad_token, other_seq=prev_target).reshape(-1).to(DEVICE))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    with torch.no_grad():
        mean_loss = train_loss / len(train) / BATCH_SIZE
        wandb.log({"loss": mean_loss})
        print(f"Epoch {epoch} train loss: {mean_loss}")
        sents = generate(sequence, pred, target, wrapper)
        for sent in sents[:DISPLAY_BATCHES]:
            print(sent)

        if PRINT_VALID and epoch % 10 ==0:
            # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
            valid_loss = 0
            # Deactivate dropout layers
            model.eval()
            for batch, (sequence, target, prev_target) in tqdm(enumerate(valid)):
                # Inference token by tokens
                sequence = sequence.to(DEVICE)
                outputs = prev_target[:,0].unsqueeze(1).to(DEVICE)
                enc_outputs = None
                # TODO: Investigate memory leak with valid generation
                for i in range(OUT_SEQUENCE_LEN):
                    pred, enc_outputs = model(sequence, outputs, enc_outputs=enc_outputs)
                    last_output = torch.argmax(pred, dim=2)
                    outputs = torch.cat((outputs, last_output[:,-1:]), dim=1)
                loss = loss_fn(pred.reshape(-1, pred.shape[-1]), target.view(-1).to(DEVICE))
                valid_loss += loss.item()
            mean_loss = valid_loss / len(valid) / BATCH_SIZE
            wandb.log({"valid_loss": mean_loss})
            print(f"Valid loss: {mean_loss}")
            sents = generate(sequence, pred, target, wrapper)
            for sent in sents[:DISPLAY_BATCHES]:
                print(sent)
            # Reactivate dropout
            model.train()

127it [00:21,  5.94it/s]


Epoch 0 train loss: 0.16896558793510977
Por este motivo nuestros nano-recubrimientos finales son 100% impermeables. | This is why our final nano-coatings are 100% waterproof. | A is the the thes asssss.sssss
No, no es posible. | No, it's not possible. | I, no.t,.,


15it [00:12,  1.23it/s]


Valid loss: 0.2157400498787562
Lo juro. | I swear. | I'm a s.
Una trucha arco iris. | What is it? Rainbow trout. | I'm a nice.


127it [00:21,  6.02it/s]


Epoch 1 train loss: 0.1417849435229001
¿Lo ves, estúpida? | See, you fucking cow? | What you you's you??
Pero informar sobre genómica implica más que tan solo escribir sobre ciencia y medicina. | But reporting on genomics involves more than just writing about science and medicine. | The that is the thes ofp..., and than that have. theeach.ing.. non.


127it [00:21,  6.05it/s]


Epoch 2 train loss: 0.12492617747680408
Yo no quiero jugar. | I don't wanna play the game. | I''t wantna.. one.
Negativo por negativo es igual a positivo. | A negative times a negative is a positive. | Ida,.. s  little.


127it [00:21,  6.01it/s]


Epoch 3 train loss: 0.1110113746537937
Corte. | Cut. | I..
- Ya lo veo claro. Ud. está muerta. | It's a cinch. | -'s thelwayslock'.


127it [00:20,  6.11it/s]


Epoch 4 train loss: 0.09955680264732031
Detener el rompimiento. | Stop the breakup. | This it d of
- Pellízcame, pellízcame. | Pinch me. Pinch me. | -y,


127it [00:20,  6.25it/s]


Epoch 5 train loss: 0.08915272818540963
Por lo que se dice en el instituto... Chris Valley y Chris Hoffman... son uña y carne. | Word around school is Chris Valley and Chris Hoffman are attached at the hip. | Theas to theinging part Valley the peace Valleyonpthth at at letha to the t.
Me tiraron del pelo. | They, like, were pulling my hair out, you know? | You' no him you waitingl.. s. you..


127it [00:21,  5.85it/s]


Epoch 6 train loss: 0.08002431148032504
¿Quién se benefició? | Who benefited? | What???
Mantenlos cerrados. | Keep 'em closed. | Iy mights r.


87it [00:14,  5.94it/s]


KeyboardInterrupt: 

In [None]:
from torchinfo import summary

print(summary(model))

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU], record_shapes=True, ) as prof:
    model(sequence.to(DEVICE), prev_target.to(DEVICE))

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))