In [1]:
# Based on this paper - https://arxiv.org/pdf/1706.03762.pdf
# Might want to move layer norm inside the residual block - https://arxiv.org/pdf/2002.04745.pdf
# Layer normalization - https://arxiv.org/pdf/1607.06450.pdf
# TODO: Investigate learning rate warmup - https://arxiv.org/abs/2002.04745
#!pip install torch torchtext sentencepiece datasets wandb

In [2]:
import numpy as np
import torch
from torch import nn
import sys
import os
import math
import einops
import torch.nn.functional as F

sys.path.append(os.path.abspath("../../data"))
sys.path.append(os.path.abspath("../../nnets"))
from net_utils import get_module_list

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
SP_VOCAB_SIZE = 5000
TRAIN_SIZE = 5000

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from text_data import CNNDatasetDecoderOnly

class Wrapper(CNNDatasetDecoderOnly):
    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]
    x_length = 15
    target_length = 15

wrapper = Wrapper(SP_VOCAB_SIZE)
datasets = wrapper.generate_datasets(BATCH_SIZE)
train = datasets["train"]
valid = datasets["validation"]

Found cached dataset cnn_dailymail (/Users/vik/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/1b3c71476f6d152c31c1730e83ccb08bcf23e348233f4fcc11e182248e6bf7de)
100%|██████████| 3/3 [00:00<00:00, 47.49it/s]
sentencepiece_trainer.cc(177) LOG(INFO) Running command: --input=tokens.txt --model_prefix=cnn_dailymail --vocab_size=5000 --model_type=unigram
sentencepiece_trainer.cc(77) LOG(INFO) Starts training with : 
trainer_spec {
  input: tokens.txt
  input_format: 
  model_prefix: cnn_dailymail
  model_type: UNIGRAM
  vocab_size: 5000
  self_test_sample_size: 0
  character_coverage: 0.9995
  input_sentence_size: 0
  shuffle_input_sentence: 1
  seed_sentencepiece_size: 1000000
  shrinking_factor: 0.75
  max_sentence_length: 4192
  num_threads: 16
  num_sub_iterations: 2
  max_sentencepiece_length: 16
  split_by_unicode_script: 1
  split_by_number: 1
  split_by_whitespace: 1
  split_digits: 0
  treat_whitespace_as_suffix: 0
  required_chars: 
  byte_fallback: 0
  vocabulary_output_piec

In [4]:
# Add in ROPE embedding
class ROPE(nn.Module):
    def __init__(self, embedding_dim, seq_len):
        super(ROPE, self).__init__()
        self.embedding_dim = embedding_dim
        self.seq_len = seq_len

        self.cos_embeds = torch.zeros(seq_len, embedding_dim, device=DEVICE)
        self.sin_embeds = torch.zeros(seq_len, embedding_dim, device=DEVICE)

        embed_pos = 10000 ** (-2 * torch.ceil((torch.arange(0, embedding_dim) + 1) / 2) / embedding_dim)
        for i in range(0, seq_len):
            self.cos_embeds[i] = torch.cos((i + 1) * embed_pos)
            self.sin_embeds[i] = torch.sin((i + 1) * embed_pos)

        self.indices = torch.zeros(self.embedding_dim, device=DEVICE, dtype=torch.long)
        self.mask = torch.zeros(self.embedding_dim, device=DEVICE, dtype=torch.int)
        for i in range(0, embedding_dim, 2):
            self.indices[i] = i + 1
            self.indices[i+1] = i

            self.mask[i] = -1
            self.mask[i+1] = 1


    def rotate(self, x):
        return x[...,self.indices] * self.mask

    def forward(self, x):
        current_val = x * self.cos_embeds[:x.shape[-2],:]
        next_val = self.rotate(x) * self.sin_embeds[:x.shape[-2],:]
        return current_val + next_val


class MultiHeadAttention(nn.Module):
    def __init__(self, input_units, attention_heads, mask=False):
        super(MultiHeadAttention, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.head_units = int(input_units/attention_heads)
        self.mask = mask

        k = math.sqrt(1/self.input_units)
        # Drop bias
        # Single kv head
        self.kv_proj_weight = nn.Parameter(torch.rand(2, input_units, self.head_units) * 2 * k - k)
        self.q_proj_weight = nn.Parameter(torch.rand(input_units, self.attention_heads * self.head_units) * 2 * k - k)
        self.out_proj_weight = nn.Parameter(torch.rand(self.attention_heads * self.head_units, input_units) * 2 * k - k)

        # 1024 is max sequence length
        self.rope = ROPE(self.head_units, 1024)

    def forward(self, queries, keys, values):
        # convert to 4d tensor with batch_size, attn_heads, seq_len, embedding_dim
        proj_queries = torch.einsum("...se, eo->...so", queries, self.q_proj_weight)
        proj_queries = proj_queries.view(queries.shape[0], queries.shape[1], self.attention_heads, self.head_units).swapaxes(1,2)
        proj_queries = self.rope(proj_queries)

        proj_keys = torch.einsum("...se, eo->...so", keys, self.kv_proj_weight[0])
        proj_keys = proj_keys.view(keys.shape[0], keys.shape[1], self.head_units)
        proj_keys = self.rope(proj_keys)

        proj_values = torch.einsum("...se, eo->...so", values, self.kv_proj_weight[0])
        proj_values = proj_values.view(values.shape[0], values.shape[1], self.head_units)

        attention = torch.einsum("baqh, bhk->baqk", proj_queries, torch.transpose(proj_keys, -1, -2)) / np.sqrt(proj_keys.shape[-1])
        if self.mask:
            # Prevent decoder queries from looking at tokens that come after
            # Do this by setting attention to negative infinity, so it is softmaxed to zero in the next step
            mask = torch.full((attention.shape[-2], attention.shape[-1]), -torch.inf, device=DEVICE)
            attention += torch.triu(mask, diagonal=1)

        # Softmax on last dimension
        # Sequence-wise softmax, so attention between one sequence and other sequences sums to 1
        attention = torch.softmax(attention, dim=-1)
        weighted_values = torch.einsum("baqk, bke->baqe", attention, proj_values)

        # Swap attention head and sequence axis, then reshape to batch, seq, embedding
        weighted_values = weighted_values.swapaxes(1,2).reshape(queries.shape[0], queries.shape[1], -1)
        weighted_values = torch.einsum("...se, eo->...so", weighted_values, self.out_proj_weight)
        return weighted_values

In [5]:
class SwiGLU(nn.Module):
    def __init__(self, input_units, hidden_units):
        super(SwiGLU, self).__init__()
        self.linear1 = nn.Linear(input_units, hidden_units, bias=False)
        self.linear2 = nn.Linear(input_units, hidden_units, bias=False)
        self.linear3 = nn.Linear(hidden_units, input_units, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x1 = self.linear1(x)
        swish = self.sigmoid(x1) * x1
        x2 = self.linear2(x)
        swiglu = self.linear3(swish * x2)
        return swiglu

class DecoderBlock(nn.Module):
    def __init__(self, input_units, attention_heads, hidden_units=2048, dropout_p=.1):
        super(DecoderBlock, self).__init__()
        self.in_attn = MultiHeadAttention(input_units, attention_heads, mask=True)
        self.dropouts = get_module_list(2, nn.Dropout, dropout_p)
        # Drop bias
        self.lns = get_module_list(2, nn.LayerNorm, input_units)
        # Switch to swiglu from two linear layers
        self.swiglu = SwiGLU(input_units, hidden_units)

    def forward(self, x):
        weighted_values = self.dropouts[0](self.in_attn(x, x, x))
        # Pre normalization
        x = x + self.lns[0](weighted_values)

        reprojected = self.dropouts[1](self.swiglu(x))
        # Pre normalization
        x = x + self.lns[1](reprojected)
        return x

In [14]:
class Transformer(nn.Module):
    def __init__(self, input_units, hidden_units, attention_heads, max_len=256, blocks=1):
        super(Transformer, self).__init__()
        self.blocks = blocks
        self.dropouts = get_module_list(2, nn.Dropout, .1)
        self.decoders = get_module_list(blocks, DecoderBlock, hidden_units, attention_heads)

        self.embedding = nn.Parameter(torch.empty(input_units, hidden_units))
        nn.init.xavier_uniform_(self.embedding)
        self.input_units = input_units

    # Tie input output weights
    def embed(self, x, reverse=False):
        if reverse:
            return x @ self.embedding.T
        else:
            embedded = self.embedding[x.to(torch.long).view(-1)]
            return embedded.view(x.shape[0], x.shape[1], -1)


    def forward(self, x):
        dec_outputs = self.dropouts[1](self.embed(x))
        for i in range(self.blocks):
            dec_outputs = self.decoders[i](dec_outputs)

        token_vectors = self.embed(dec_outputs, reverse=True)
        return token_vectors

In [15]:
def generate(sequence, pred, target, wrapper):
    prompts = wrapper.decode_batch(sequence[:,:wrapper.x_length].cpu())
    texts = wrapper.decode_batch(torch.argmax(pred[:,wrapper.x_length:], dim=2).cpu())
    correct_texts = wrapper.decode_batch(target[:,wrapper.x_length:].cpu())

    displays = []
    for p, t, ct in zip(prompts, texts, correct_texts):
        displays.append(f"{p} | {ct} | {t}")
    return displays

In [16]:
from tqdm.auto import tqdm
import wandb
import time

wandb.init(project="decoder-only", notes="Custom embed different init", name="new-init")

# TODO: Profile and improve perf - https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
model = Transformer(wrapper.vocab_size, 512, 8, blocks=6).to(DEVICE)
loss_fn = nn.CrossEntropyLoss(ignore_index=wrapper.pad_token)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
wandb.watch(model, log_freq=100)

0,1
epoch_time,▁▃█
loss,█▂▁
valid_loss,▁

0,1
epoch_time,43.27716
loss,1.09793
valid_loss,1.69331


[]

In [None]:
EPOCHS = 100
DISPLAY_BATCHES = 2
OUT_SEQUENCE_LEN = wrapper.y_length
PRINT_VALID = True
ACCUMULATE_STEPS = 1

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    match_pct = 0
    optimizer.zero_grad(set_to_none=True)
    start = time.time()
    for batch, (sequence, target, prev_target) in tqdm(enumerate(train)):
        pred = model(sequence.to(DEVICE))

        # If you use a batch, need to reshape pred to be batch * sequence, embedding_len to be compatible
        # Similar reshape with target to be batch * sequence vector of class indices
        cpred = pred[:, wrapper.x_length:]
        ctarget = target[:, wrapper.x_length:]
        loss = loss_fn(cpred.reshape(-1, cpred.shape[-1]), ctarget.reshape(-1).to(DEVICE))
        loss.backward()
        train_loss += loss.item()

        # Accumulate gradients
        # This seems to perform worse than no accumulation over a
        # small data set.  Test with larger set.
        if batch % ACCUMULATE_STEPS == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
    end = time.time()

    with torch.no_grad():
        mean_loss = train_loss / len(train) / BATCH_SIZE
        wandb.log({"loss": mean_loss, "epoch_time": end - start})
        print(f"Epoch {epoch} train loss: {mean_loss}")
        sents = generate(sequence, pred, target, wrapper)
        for sent in sents[:DISPLAY_BATCHES]:
            print(sent)

        if PRINT_VALID and epoch % 10 ==0:
            # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
            valid_loss = 0
            # Deactivate dropout layers
            model.eval()
            for batch, (sequence, target, prev_target) in tqdm(enumerate(valid)):
                # Inference token by tokens
                outputs = sequence[:,:(wrapper.x_length + 1)].to(DEVICE)
                # TODO: Investigate memory leak with valid generation
                for i in range(OUT_SEQUENCE_LEN):
                    pred = model(outputs)
                    last_output = torch.argmax(pred, dim=2)
                    outputs = torch.cat((outputs, last_output[:,-1:]), dim=1)

                cpred = pred[:, wrapper.x_length:]
                ctarget = target[:, wrapper.x_length:]
                loss = loss_fn(cpred.reshape(-1, cpred.shape[-1]), ctarget.reshape(-1).to(DEVICE))
                valid_loss += loss.item()
            mean_loss = valid_loss / len(valid) / BATCH_SIZE
            wandb.log({"valid_loss": mean_loss})
            print(f"Valid loss: {mean_loss}")
            sents = generate(sequence, pred, target, wrapper)
            for sent in sents[:DISPLAY_BATCHES]:
                print(sent)
            # Reactivate dropout
            model.train()

145it [00:41,  3.51it/s]


Epoch 0 train loss: 0.21963264644145966
President Bush calls India's PM to push a proposed nuclear partner | ship . Indian government won confidence vote in face of anger over U. | . .sssss .s . .sss to
17-year-old shot by unknown assailant in A | thens suburb of Peristeri . Police said no officers | . . . . . . . . .s-s the- the


2it [00:01,  1.28it/s]


Valid loss: 0.20257370918989182
Timothy Stanley: GOP senators' letter to Iranian leaders seem | s extraordinary . But undermining a president's foreign policy | .sssssss of of of the the the the
Pope has talked of retirement before, but this time he says he | think papacy will end after no more than five years . Francis | .'''ssssssssss the


145it [00:40,  3.60it/s]


Epoch 1 train loss: 0.1904661668785687
Amnesty International: Taliban first targets unpopular landlords, bureaucrat | s . Taliban spokesman in Swat Valley calls Pakistani government as "un-I | . . NEW NEW NEW NEW than than .dd thanI1-
Comedian Chris Rock to release "Kill the Messenger" DVD January | 20 . There are no Barack Obama jokes, Rock says, | . . Ones "ssssssssss


145it [00:41,  3.48it/s]


Epoch 2 train loss: 0.16845574697543836
Paula Abdul told Ladies' Home Journal she had | painkiller addiction . She went to spa to wean  | tedsted . Hetent beed beeded
Three Austin cave explorers are safe and are out of the cave, officials | said . The University of Texas students went into Airman' | . . . They thed thesss thes the thes


145it [00:40,  3.60it/s]


Epoch 3 train loss: 0.14628591640242214
Rate your favorite U.S. city by taking the America' | s Favorite Cities survey . The survey ends on July | of .g ofs . C . . The 15ss of 15
Lifeway stores put Christian magazine behind counter . Magazine featured female | pastors on its cover . Lifeway has no respect for freedom of | ss . London . . Thes: saysss thes


145it [00:40,  3.59it/s]


Epoch 4 train loss: 0.12625501659409752
Fatty particles called triglycerides are | important for heart health . One in every three Americans has high trigly | important by the event . The in Iran 2007 suspect; inleeds
Guinea President Lansana Conte dies; 40 days of mourning | declared . Army captain says government institutions dissolve | sd . Theaptainaptains dis in diss ofc of of


145it [00:40,  3.60it/s]


Epoch 5 train loss: 0.1085823541057521
CNN's Sean Callebs talks about his month living | on a food-stamp budget . He lost some weight but learn | to a to .P .d . . Thees the wife to arm
NEW: Peterson to media on handcuffs, chains: " | I got the bling. Can't complain" Drew Peterson | re thinkt spokes Cant PetersonSts,, Peterson Peterson Peterson


145it [00:40,  3.54it/s]


Epoch 6 train loss: 0.09393615660996273
Israeli offensive caused $1.9 billion in economic destruction in Gaza | , official says . It could take a year for Gaza's economy to | . survey says . U has take for foreign has U iss for has
Woman organizes dinners at restaurants for people with food allergie | s . If you have a food allergy, call ahead and tell | s . A rev had have rev have want haves haves and have


145it [00:41,  3.50it/s]


Epoch 7 train loss: 0.08209123503545235
Teacher Tracey Wygal was morbidly ob | ese when she weighed 295 pounds . A doctor presc | whe where whe weigh weigh weigh weigh pounds . A weigh poundss weigh
"Several heavy metals" found in levels abo | ve safe drinking-water standards . TVA pledges cleanup; | health Dec Laden standards standards . Form Laden pledges pledge .;


145it [00:40,  3.58it/s]


Epoch 8 train loss: 0.07151736380725071
U.S. Marshals honor federal judge whose husband and | mother were killed . Family killed by man angry at judge's decision to | were killed killed . One killed man man angry fro css decision,
A large chunk of the Wilkins ice shelf | in Antarctica broke away last month . Only a narrow strip of ice | Antarctica Antarctica Antarcticaly and month . On On narrow narrow and of narrowice


145it [00:40,  3.58it/s]


Epoch 9 train loss: 0.06157481385202243
California attorney general's comments are prejudicial, lawyer says | . Brown saying too much about Anna Nicole Smith case, K | . Brownle it too Smithst made protect Nicole Smith is is S
Partygoer used make-up to darken skin, went as | escaped prisoner . Acting Immigration chief judged costume contest | escape prisoner prisoner prisoner prisonerand cost prisoner judge judge judge costdum contest


145it [00:41,  3.48it/s]


Epoch 10 train loss: 0.05273360897754801
D-Day soldiers remember the horrors of war and fallen comrad | es . One tells how he survived despite being wound | sho . One tells how he survived despite being wound
D.C. schools chief Michelle Rhee closed 23 schools, fire | d 36 principals in first year . "We are always go | d in,rincipals . year . . We are always a


2it [00:01,  1.18it/s]


Valid loss: 0.33604849874973297
Police raided Robert Durst's Houston condo, his lawyer | says . The millionaire real estate heir was arrested over the | ekoe . Ow wasn't in his four-
Actor Ashton Kutcher complained on Facebook that men' | s rooms don't have diapering tables . | this week . Bas him sope him to considering fro


145it [00:40,  3.58it/s]


Epoch 11 train loss: 0.045656734961887886
More than 30 people have died, 50 have been injured since Friday, | reports say . Indigenous people protest government plan to sell land to | ens say . Indiging G Indig protests to to report  to
Amnesty International: Taliban first targets unpopular landlords, bureaucrat | s . Taliban spokesman in Swat Valley calls Pakistani government as "un-I | s . Taliban spokesman in Swat Valley calls as government as "I-I


145it [00:40,  3.56it/s]


Epoch 12 train loss: 0.03944810434918979
"No Country for Old Men" wins four awards, including | best picture and director . Four acting awards go to Europeans: | women picture picture director . Mc Europeann awards were to Europeans go
Police say Bruce Jeffrey Pardo had hit list after divorce proceedings were | final . Original target was Pardo's ex-wif | final . Or Orn Or Ord 'sierifwal


145it [00:41,  3.54it/s]


Epoch 13 train loss: 0.033981164570512444
Shahid Afridi claims six victims to pave the way | for Pakistan to beat Australia . Pakistan reach required target to win first one | Pakistan beat . beat Australia . Pakistan- required target to win one
Lord Taylor of Blackburn and Lord Truscott barred for corruption charges | . They allegedly agreed to take cash to influence specific legislation | . They agree agree agreed to influence to to meet specific toing


145it [00:42,  3.42it/s]


Epoch 14 train loss: 0.02959378093224147
NEW: Preacher says he found funeral for mom, four kids " | difficult" 300 mourners attended service at Redlands Community | difficult" 300 mournes attended service at Redlands Community
Elkhart, Indiana, entrepreneurs sweating out tough economic times . | Several express doubt the federal government can get commerce moving | Several express doubt in federal fed can com commerce coming


86it [00:24,  3.44it/s]

In [None]:

from torchinfo import summary

print(summary(model))

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU], record_shapes=True, ) as prof:
    model(sequence.to(DEVICE), prev_target.to(DEVICE))

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

In [None]:
from einops import rearrange

x = torch.rand(4, 3)
x[torch.tensor([0,1,0])]