In [2]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import sys
import os
import math
sys.path.append(os.path.abspath("../../data"))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
SP_VOCAB_SIZE = 5000
TRAIN_SIZE = 5000

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from text_data import Opus100DatasetWrapper

class Wrapper(Opus100DatasetWrapper):
    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]
    x_length = 40
    target_length = 40

wrapper = Wrapper(SP_VOCAB_SIZE)
datasets = wrapper.generate_datasets(BATCH_SIZE)
train = datasets["train"]
valid = datasets["validation"]

Found cached dataset opus100 (/Users/vik/.cache/huggingface/datasets/opus100/en-es/0.0.0/256f3196b69901fb0c79810ef468e2c4ed84fbd563719920b1ff1fdc750f7704)
100%|██████████| 3/3 [00:00<00:00, 223.38it/s]
sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=opus100 --vocab_size=5000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: opus100
  model_type: UNIGRAM
  vocab_size: 5000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piece_score: 1
  train_extr

In [4]:
def generate(sequence, pred, target, wrapper):
    prompts = wrapper.decode_batch(sequence.cpu())
    texts = wrapper.decode_batch(torch.argmax(pred, dim=2).cpu())
    correct_texts = wrapper.decode_batch(target.cpu())

    displays = []
    for p, t, ct in zip(prompts, texts, correct_texts):
        displays.append(f"{p} | {ct} | {t}")
    return displays

In [47]:
class TransformerNetwork(nn.Module):
    def __init__(self, input_units, hidden_units, attention_heads, blocks=1):
        super(TransformerNetwork, self).__init__()
        self.transformer = nn.Transformer(hidden_units, attention_heads, num_encoder_layers=blocks, num_decoder_layers=blocks, batch_first=True)
        self.embedding = nn.Embedding(input_units, hidden_units)
        k = math.sqrt(1/hidden_units)
        self.output_embedding = nn.Linear(hidden_units, input_units)

    def forward(self, x, y):
        embed_x = self.embedding(x)
        embed_y = self.embedding(y)
        output = self.transformer(embed_x, embed_y)
        token_vectors = self.output_embedding(output)
        return token_vectors

In [48]:
from tqdm.auto import tqdm
import wandb

#wandb.init(project="transformer", notes="Pytorch 2 blocks", name="pt-2-blocks")

# TODO: Profile and improve perf - https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
model = TransformerNetwork(wrapper.vocab_size, 512, 8, blocks=1).to(DEVICE)
loss_fn = nn.CrossEntropyLoss(ignore_index=wrapper.pad_token)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
wandb.watch(model, log_freq=100)

[]

In [34]:
EPOCHS = 25
DISPLAY_BATCHES = 2
OUT_SEQUENCE_LEN = wrapper.y_length
PRINT_VALID = True

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    match_pct = 0
    for batch, (sequence, target, prev_target) in tqdm(enumerate(train)):
        optimizer.zero_grad(set_to_none=True)
        pred = model(sequence.to(DEVICE), prev_target.to(DEVICE))

        # If you use a batch, need to reshape pred to be batch * sequence, embedding_len to be compatible
        # Similar reshape with target to be batch * sequence vector of class indices
        loss = loss_fn(pred.reshape(-1, pred.shape[-1]), target.view(-1).to(DEVICE))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    with torch.no_grad():
        mean_loss = train_loss / len(train) / BATCH_SIZE
        wandb.log({"loss": mean_loss})
        print(f"Epoch {epoch} train loss: {mean_loss}")
        sents = generate(sequence, pred, target, wrapper)
        for sent in sents[:DISPLAY_BATCHES]:
            print(sent)

        if PRINT_VALID and epoch % 10 ==0:
            # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
            valid_loss = 0
            # Deactivate dropout layers
            model.eval()
            for batch, (sequence, target, prev_target) in tqdm(enumerate(valid)):
                # Inference token by tokens
                sequence = sequence.to(DEVICE)
                outputs = prev_target[:,0].unsqueeze(1).to(DEVICE)
                enc_outputs = None
                # TODO: Investigate memory leak with valid generation
                for i in range(OUT_SEQUENCE_LEN):
                    pred = model(sequence, outputs)
                    last_output = torch.argmax(pred, dim=2)
                    outputs = torch.cat((outputs, last_output[:,-1:]), dim=1)
                loss = loss_fn(pred.reshape(-1, pred.shape[-1]), target.view(-1).to(DEVICE))
                valid_loss += loss.item()
            mean_loss = valid_loss / len(valid) / BATCH_SIZE
            wandb.log({"valid_loss": mean_loss})
            print(f"Valid loss: {mean_loss}")
            sents = generate(sequence, pred, target, wrapper)
            for sent in sents[:DISPLAY_BATCHES]:
                print(sent)
            # Reactivate dropout
            model.train()

10it [00:03,  3.25it/s]


KeyboardInterrupt: 

In [42]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU], record_shapes=True, ) as prof:
    model(sequence.to(DEVICE), prev_target.to(DEVICE))

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

STAGE:2023-01-23 07:59:04 7619:35806559 ActivityProfilerController.cpp:294] Completed Stage: Warm Up
STAGE:2023-01-23 07:59:04 7619:35806559 ActivityProfilerController.cpp:300] Completed Stage: Collection


---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
---------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
              aten::dropout         0.22%     229.000us        46.05%      46.972ms       4.697ms            10  
           aten::bernoulli_        39.33%      40.113ms        39.33%      40.113ms       4.011ms            10  
               aten::linear         0.27%     277.000us        21.67%      22.100ms       2.009ms            11  
                  aten::bmm        19.11%      19.491ms        19.89%      20.290ms       2.899ms             7  
                aten::addmm        12.03%      12.271ms        13.80%      14.080ms       2.011ms             7  
               aten::matmul         0.06%      62.000us         7.00%       7.144ms     

In [49]:
from torchinfo import summary

print(summary(model))

Layer (type:depth-idx)                                                 Param #
TransformerNetwork                                                     --
├─Transformer: 1-1                                                     --
│    └─TransformerEncoder: 2-1                                         --
│    │    └─ModuleList: 3-1                                            3,152,384
│    │    └─LayerNorm: 3-2                                             1,024
│    └─TransformerDecoder: 2-2                                         --
│    │    └─ModuleList: 3-3                                            4,204,032
│    │    └─LayerNorm: 3-4                                             1,024
├─Embedding: 1-2                                                       2,560,512
├─Linear: 1-3                                                          2,565,513
Total params: 12,484,489
Trainable params: 12,484,489
Non-trainable params: 0
