In [None]:
# Based on this paper - https://arxiv.org/pdf/1706.03762.pdf
# Might want to move layer norm inside the residual block - https://arxiv.org/pdf/2002.04745.pdf
# Layer normalization - https://arxiv.org/pdf/1607.06450.pdf
#!pip install torch torchtext sentencepiece datasets

In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import sys
import os
import math
sys.path.append(os.path.abspath("../../data"))

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
SP_VOCAB_SIZE = 5000
TRAIN_SIZE = 5000

In [None]:
from text_data import Opus100DatasetWrapper

class Wrapper(Opus100DatasetWrapper):
    split_lengths = [TRAIN_SIZE, math.floor(TRAIN_SIZE * .1), 100]
    x_length = 40
    target_length = 40

wrapper = Wrapper(SP_VOCAB_SIZE)
datasets = wrapper.generate_datasets(BATCH_SIZE)
train = datasets["train"]
valid = datasets["validation"]

In [296]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_units, attention_heads, mask=False):
        super(MultiHeadAttention, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.head_units = int(input_units/attention_heads)
        self.mask = mask

        k = math.sqrt(1/self.head_units)
        self.query_weights = nn.Parameter(torch.rand(self.attention_heads, input_units, self.head_units) * 2 * k - k)
        self.key_weights = nn.Parameter(torch.rand(self.attention_heads, input_units, self.head_units) * 2 * k - k)
        self.value_weights = nn.Parameter(torch.rand(self.attention_heads, input_units, self.head_units) * 2 * k - k)
        self.final_weight = nn.Parameter(torch.rand(self.attention_heads * self.head_units, input_units) * 2 * k - k)
        self.input_bias = nn.Parameter(torch.ones(3, self.attention_heads, 1, self.head_units) * 2 * k - k)
        self.output_bias = nn.Parameter(torch.ones(1, self.input_units) * 2 * k - k)

    def forward(self, queries, keys, values):
        # convert to 4d tensor with batch_size, attn_heads, seq_len, embedding_dim
        exp_queries = queries.unsqueeze(1).expand(-1, self.attention_heads, -1,-1)
        exp_keys = keys.unsqueeze(1).expand(-1, self.attention_heads, -1,-1)
        exp_values = values.unsqueeze(1).expand(-1, self.attention_heads, -1,-1)

        proj_queries = torch.einsum("base, aeh->bash", exp_queries, self.query_weights) + self.input_bias[0]
        proj_keys = torch.einsum("base, aeh->bash", exp_keys, self.key_weights) + self.input_bias[1]
        proj_values = torch.einsum("base, aeh->bash", exp_values, self.value_weights) + self.input_bias[2]

        attention = torch.einsum("bash, bahk->bask", proj_queries, torch.transpose(proj_keys, -1, -2)) / np.sqrt(self.head_units)
        if self.mask:
            # Prevent decoder queries from looking at tokens that come after
            # Do this by setting attention to negative infinity, so it is softmaxed to zero in the next step
            mask = torch.zeros((attention.shape[-2], attention.shape[-1]), device=DEVICE)
            mask_indices = np.triu_indices(attention.shape[-2], k=1, m=attention.shape[-1])
            mask[mask_indices] = -torch.inf
            attention += mask

        # Softmax on last dimension
        # Sequence-wise softmax, so attention between one sequence and other sequences sums to 1
        attention = torch.softmax(attention, dim=-1)
        weighted_values = torch.einsum("bash, bahe->base", attention, proj_values)

        # Swap attention head and sequence axis, then reshape to batch, seq, embedding
        weighted_values = weighted_values.swapaxes(1,2).reshape(queries.shape[0], queries.shape[1], -1)
        final_values = torch.einsum("bse, eo->bso", weighted_values, self.final_weight) + self.output_bias
        return final_values

In [297]:
class EncoderBlock(nn.Module):
    def __init__(self, input_units, attention_heads, hidden_units=2048, dropout_p=.1):
        super(EncoderBlock, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.hidden_units = hidden_units

        self.mha = MultiHeadAttention(self.input_units, self.attention_heads)
        self.dropouts = nn.ModuleList(nn.Dropout(dropout_p) for _ in range(2))
        self.linear1 = nn.Linear(self.input_units, hidden_units)
        self.linear2 = nn.Linear(hidden_units, self.input_units)
        self.relu = nn.ReLU()
        self.lns = nn.ModuleList(nn.LayerNorm(self.input_units) for _ in range(2))

    def forward(self, x):
        weighted_values = self.dropouts[0](self.mha(x, x, x))
        attn_output = self.lns[0](x + weighted_values)

        reprojected = self.dropouts[1](self.linear2(self.relu(self.linear1(attn_output))))
        block_output = self.lns[1](attn_output + reprojected)
        return block_output

In [298]:
class DecoderBlock(nn.Module):
    def __init__(self, input_units, attention_heads, hidden_units=2048, dropout_p=.1):
        super(DecoderBlock, self).__init__()
        self.input_units = input_units
        self.attention_heads = attention_heads
        self.hidden_units = hidden_units

        self.in_attn = MultiHeadAttention(self.input_units, self.attention_heads, mask=True)
        self.context_attn = MultiHeadAttention(self.input_units, self.attention_heads)
        self.dropouts = nn.ModuleList([nn.Dropout(dropout_p) for _ in range(3)])
        self.linear1 = nn.Linear(self.input_units, hidden_units)
        self.linear2 = nn.Linear(hidden_units, self.input_units)
        self.relu = nn.ReLU()
        self.lns = nn.ModuleList(nn.LayerNorm(self.input_units) for _ in range(3))

    def forward(self, queries, context):
        weighted_values = self.dropouts[0](self.in_attn(queries, queries, queries))
        attn_output = self.lns[0](queries + weighted_values)

        decoder_values = self.dropouts[1](self.context_attn(attn_output, context, context))
        decoder_output = self.lns[1](attn_output + decoder_values)

        reprojected = self.dropouts[2](self.linear2(self.relu(self.linear1(decoder_output))))
        block_output = self.lns[2](decoder_output + reprojected)
        return block_output

In [299]:
class Transformer(nn.Module):
    def __init__(self, input_units, hidden_units, attention_heads, enc_sequence_len, blocks=1):
        super(Transformer, self).__init__()
        self.input_units = input_units
        self.hidden_units = hidden_units
        self.attention_heads = attention_heads
        self.blocks = blocks

        k = math.sqrt(1/self.hidden_units)
        self.output_embedding = nn.Parameter(torch.rand(hidden_units, input_units) * 2 * k - k)
        self.output_bias = nn.Parameter(torch.rand(1, input_units) * 2 * k - k)
        self.embedding = nn.Embedding(input_units, hidden_units)
        self.dropouts = nn.ModuleList(nn.Dropout(.1) for _ in range(2))
        self.encoders = nn.ModuleList(EncoderBlock(hidden_units, attention_heads) for _ in range(self.blocks))
        self.decoders = nn.ModuleList(DecoderBlock(hidden_units, attention_heads) for _ in range(self.blocks))
        self.pos_encoding = self.encoding(enc_sequence_len, self.hidden_units).to(DEVICE)


    def embed(self, indices, reverse=False):
        if reverse:
            return torch.einsum("bse, eo->bso", indices, self.embedding.T)
        else:
            one_hot = F.one_hot(indices.to(torch.long), num_classes=self.input_units).to(torch.float)
            return torch.einsum("bso, oe->bse", one_hot, self.embedding)

    def forward(self, x, y, enc_outputs=None):
        if enc_outputs is None:
            # 3D with batch, seq, embeddings
            # TODO: Tie input and output embedding weights
            enc_outputs = self.dropouts[0](self.embedding(x) + self.pos_encoding[:x.shape[1]])

            for i in range(self.blocks):
                enc_outputs = self.encoders[i](enc_outputs)

        dec_outputs = self.dropouts[1](self.embedding(y) + self.pos_encoding[:y.shape[1]])
        for i in range(self.blocks):
            dec_outputs = self.decoders[i](dec_outputs, enc_outputs)

        token_vectors = torch.einsum("bse, eo->bso", dec_outputs, self.output_embedding) + self.output_bias
        return token_vectors, enc_outputs

    def encoding(self, seq_len, embed_len):
        encodings = torch.zeros((seq_len, embed_len))
        for i in range(seq_len):
            all = torch.exp(torch.arange(0, embed_len, 2) * (-math.log(10000.0) / embed_len))
            encodings[i, 0::2] = torch.sin(i * all)
            encodings[i, 1::2] = torch.cos(i * all)
        return encodings

In [300]:
def generate(sequence, pred, target, wrapper):
    prompts = wrapper.decode_batch(sequence.cpu())
    texts = wrapper.decode_batch(torch.argmax(pred, dim=2).cpu())
    correct_texts = wrapper.decode_batch(target.cpu())

    displays = []
    for p, t, ct in zip(prompts, texts, correct_texts):
        displays.append(f"{p} | {ct} | {t}")
    return displays

In [301]:
from tqdm.auto import tqdm

# TODO: Profile and improve perf - https://pytorch.org/tutorials/recipes/recipes/profiler_recipe.html
model = Transformer(wrapper.vocab_size, 512, 8, wrapper.y_length, blocks=1).to(DEVICE)
loss_fn = nn.CrossEntropyLoss(ignore_index=wrapper.pad_token)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
EPOCHS = 100
DISPLAY_BATCHES = 2
OUT_SEQUENCE_LEN = wrapper.y_length
PRINT_VALID = True

for epoch in range(EPOCHS):
    # Run over the training examples
    train_loss = 0
    match_pct = 0
    for batch, (sequence, target, prev_target) in tqdm(enumerate(train)):
        optimizer.zero_grad(set_to_none=True)
        pred, _ = model(sequence.to(DEVICE), prev_target.to(DEVICE))

        # If you use a batch, need to reshape pred to be batch * sequence, embedding_len to be compatible
        # Similar reshape with target to be batch * sequence vector of class indices
        loss = loss_fn(pred.reshape(-1, pred.shape[-1]), target.view(-1).to(DEVICE))
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    with torch.no_grad():
        print(f"Epoch {epoch} train loss: {train_loss / len(train) / BATCH_SIZE}")
        sents = generate(sequence, pred, target, wrapper)
        for sent in sents[:DISPLAY_BATCHES]:
            print(sent)

        if PRINT_VALID and epoch % 10 ==0:
            # Compute validation loss.  Unless you have a lot of training data, the validation loss won't decrease.
            valid_loss = 0
            # Deactivate dropout layers
            model.eval()
            for batch, (sequence, target, prev_target) in tqdm(enumerate(valid)):
                # Inference token by tokens
                sequence = sequence.to(DEVICE)
                outputs = prev_target[:,0].unsqueeze(1).to(DEVICE)
                enc_outputs = None
                # TODO: Investigate memory leak with valid generation
                for i in range(OUT_SEQUENCE_LEN):
                    pred, enc_outputs = model(sequence, outputs, enc_outputs=enc_outputs)
                    last_output = torch.argmax(pred, dim=2)
                    outputs = torch.cat((outputs, last_output[:,-1:]), dim=1)
                loss = loss_fn(pred.reshape(-1, pred.shape[-1]), target.view(-1).to(DEVICE))
                valid_loss += loss.item()
            print(f"Valid loss: {valid_loss / len(valid) / BATCH_SIZE}")
            sents = generate(sequence, pred, target, wrapper)
            for sent in sents[:DISPLAY_BATCHES]:
                print(sent)
            # Reactivate dropout
            model.train()

127it [00:26,  4.81it/s]


Epoch 0 train loss: 0.17914133806397595
48:14 Porque este Dios es Dios nuestro eternalmente y para siempre: El nos capitaneará hasta la muerte. | PS 48:14 For this God is our God forever and ever. He will be our guide even to death. | The the the the..............
- Lo hiciste, eres la traga de la clase. | - You got a 4.0 GPA. | - I',......


14it [00:11,  1.18it/s]

In [None]:
from torchinfo import summary

print(summary(model))

In [None]:
from torch.profiler import profile, record_function, ProfilerActivity

with profile(activities=[ProfilerActivity.CPU], record_shapes=True, ) as prof:
    model(sequence.to(DEVICE), prev_target.to(DEVICE))

print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))

In [149]:
queries = model.embedding(sequence.to(DEVICE))
context = model.embedding(prev_target.to(DEVICE))
attention_heads = 8
query_weights = model.decoders[0].in_attn.query_weights
key_weights = model.decoders[0].in_attn.key_weights
value_weights = model.decoders[0].in_attn.value_weights

# convert to 4d tensor with batch_size, attn_heads, seq_len, embedding_dim
exp_queries = queries.unsqueeze(1).expand(-1, attention_heads, -1,-1)
exp_context = context.unsqueeze(1).expand(-1, attention_heads, -1,-1)
proj_queries = torch.einsum("base, aeh->bash", exp_queries, query_weights)
# Transpose keys
proj_keys = torch.einsum("base, aeh->bahs", exp_context, key_weights)
proj_values = torch.einsum("base, aeh->bash", exp_context, value_weights)

attention = torch.einsum("bash, bahk->bask", proj_queries, proj_keys) / np.sqrt(64)

# Prevent decoder queries from looking at tokens that come after
# Do this by setting attention to negative infinity, so it is softmaxed to zero in the next step
mask = torch.zeros((attention.shape[-2], attention.shape[-1]), device=DEVICE)
mask_indices = np.triu_indices(attention.shape[-2], k=1, m=attention.shape[-1])
mask[mask_indices] = -torch.inf
attention += mask

# Softmax on last dimension
# Sequence-wise softmax, so attention between one sequence and other sequences sums to 1
attention = torch.softmax(attention, dim=-1)

weighted_values = torch.einsum("bash, bahe->base", attention, proj_values)

# Swap attention head and sequence axis, then reshape to batch, seq, embedding
weighted_values = weighted_values.swapaxes(1,2).reshape(queries.shape[0], queries.shape[1], -1)

In [150]:
manual = torch.softmax((exp_queries[0,1] @ query_weights[1] @ (exp_context[0,1] @ key_weights[1]).T / np.sqrt(64)) + mask, -1) @ (exp_context[0,1] @ value_weights[1])

torch.allclose(weighted_values[0,1,64:128], manual[1])

True

In [153]:
weighted_values.shape

torch.Size([32, 40, 512])

In [160]:
torch.softmax((exp_queries[0,1] @ query_weights[1] @ (exp_context[0,1] @ key_weights[1]).T / np.sqrt(64)) + mask, -1)[30]

tensor([1.1553e-03, 5.5856e-04, 7.2043e-02, 5.1467e-04, 7.6951e-03, 9.0066e-03,
        9.1353e-02, 1.6192e-01, 1.0041e-05, 1.4945e-03, 2.5468e-04, 4.5007e-02,
        1.9007e-03, 8.3639e-04, 1.1511e-02, 2.1658e-03, 1.6677e-03, 3.5177e-01,
        1.6545e-05, 2.3807e-02, 1.9574e-02, 1.9574e-02, 1.9574e-02, 1.9574e-02,
        1.9574e-02, 1.9574e-02, 1.9574e-02, 1.9574e-02, 1.9574e-02, 1.9574e-02,
        1.9574e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
       grad_fn=<SelectBackward0>)

In [175]:
((exp_queries[0,1] @ query_weights[1] @ (exp_context[0,1] @ key_weights[1]).T / np.sqrt(64)) + mask)[8]

tensor([-1.5150, -2.2417,  2.6179, -2.3235,  0.3813,  0.5386,  2.8554,  3.4278,
        -6.2604,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf], grad_fn=<SelectBackward0>)

In [162]:
sequence[0]

tensor([  32, 1206, 2243,    6, 2425,   10, 5000, 5000, 5000, 5000, 5000, 5000,
        5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000,
        5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000, 5000,
        5000, 5000, 5000, 5000], dtype=torch.int32)