In [1]:
import torch
from torch import tensor
from math import sqrt
from torch.nn.functional import softmax

def par_attention(queries: tensor, keys: tensor, values: tensor, dim: int) -> tensor:
    raw_weights = queries @ keys.T

    mask = torch.tril(torch.ones_like(raw_weights), diagonal=0)
    raw_weights = raw_weights.masked_fill(mask == 0, float('-inf'))

    scale_factor = sqrt(dim)
    scaled_weights = softmax(raw_weights / scale_factor, dim=1)

    # now scaled weights is a matrix where each row represents the scaled weights produced based on a given query.
    # meanwhile values just has a value vector on each row.

    reshaped_scaled_weights = scaled_weights.view(scaled_weights.shape[0], scaled_weights.shape[1], 1)
    reshaped_values = values.view(1, values.shape[0], values.shape[1])

    scaled_values = reshaped_scaled_weights * reshaped_values

    contextualized_values = torch.sum(scaled_values, 1)
    return contextualized_values

Goal for this notebook is to implement an attention block, and then to build out the rest of the transformer architecture. I'll start with a single-headed version, then build out multi-headedness (which shouldn't be a lot of additional work, I think). The original paper uses a model dimensionality of 512 and 8 heads, which each work on 512 / 8 = 64 dimensions. In my case my embedding function produces 300 dimension vectors, so I think I'll experiment with using 6 heads with 50 dimensions each.

An attention head includes:
1. Separate, learned linear projections for Q, K, V vectors. I think this is basically just a feed forward layer without a nonlinearity?
In any case, this linear projection reduces the dimensionality of the input 
2. The scaled dot product attention function.

Then, a multi-head attention block contains:
1. Some number of attention heads.
2. A concatenation step. This just takes the model_dim / h length vectors that are output from the attention head and concatenates them.
3. A learned linear projection. My intuition is that this projection "blends" the h concatenated vectors into a more meaningful and cohesive whole.

The whole transformer layer / block varies from encode, to decode, to decode-only. In the original paper's decode block, a transformer layer includes two separate attention blocks, one of which allows for queries to be drawn from the previous decoder layer while values and keys are drawn from the encoder. In a decode-only architecture I don't think there's any meaningful or useful analogy for this, so instead I'll be using only a single attention block per transformer layer.

With all that said, my transformer layers will include:
1. Masked multi-head attention - i.e. the attention block outlined above. The input is just some number of model_dim length embedding vectors, which comes from either the previous transformer layer or, for the first transformer layer, the positional embedding function. These are differentiated into Q, K, V vectors by the linear projections in the attention heads.
2. Residual connection defined as: LayerNorm(x + Sublayer(x))
3. Feed forward block. This is two feed forward layers, with a single ReLU in between. In the paper, the layers share a hidden dimension which is four times larger than the model dimension. I'll experiment with something similar.
4. Another residual connection.

The complete architecture:
1. Positional encoding function applied to input tokens/words. (update: in this case I'm not training any aspect of this as part of my model, so I think it makes sense to keep it outside of the model itself.)
2. Sequential transformer layers. In the paper, there are 6.
3. A linear layer that takes in all the vectors output by the final transformer layer, and has outputs for each possible next word.
4. Softmax function over outputs gives us probabilities for next word, the final output of the network.

Note to self: It might be good to build an API that lets the user specify a custom/different embedding function, but that's probably not a priority.

In [2]:
import torch.nn as nn

class AttentionHead(nn.Module):
    # For simplicity, I assume query, key, and value vectors have the same dimensionality
    def __init__(self, model_dim, vectors_dim):
        super().__init__()
        self.model_dim = model_dim
        self.vectors_dim = vectors_dim
        self.Q_proj = nn.Linear(model_dim, vectors_dim)
        self.K_proj = nn.Linear(model_dim, vectors_dim)
        self.V_proj = nn.Linear(model_dim, vectors_dim)

    def forward(self, x):
        # each row of x is a vector representing the meaning of the token at the corresponding position with whatever context we've attained so far.
        Q = self.Q_proj(x)
        K = self.K_proj(x)
        V = self.V_proj(x)

        output = par_attention(Q, K, V, self.vectors_dim)
        return output

class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim, num_heads):
        super().__init__()
        self.att_heads = nn.ModuleList([AttentionHead(model_dim, model_dim // num_heads) for _ in range(num_heads)])
        self.proj = nn.Linear(model_dim, model_dim)

    def forward(self, x):
        head_outputs = [head(x) for head in self.att_heads]
        x = torch.concat(head_outputs, dim=1)
        x = self.proj(x)
        return x
        
class TransformerLayer(nn.Module):
    def __init__(self, model_dim, num_heads, ff_hidden_dim):
        super().__init__()
        self.attention_block = MultiHeadAttention(model_dim, num_heads)
        self.norm1 = nn.LayerNorm()
        self.ff1 = nn.Linear(model_dim, ff_hidden_dim)
        self.ff_relu = nn.ReLU()
        self.ff2 = nn.Linear(ff_hidden_dim, model_dim)
        self.norm2 = nn.LayerNorm()

    def forward(self, x):
        x_res = x
        x = self.attention_block(x)
        x += x_res
        x = self.norm1(x)

        x_res = x
        x = self.ff1(x)
        x = self.ff_relu(x)
        x = self.ff2(x)
        x += x_res
        x = self.norm2(x)

        return x


In [3]:

class TransformerNetwork(nn.Module):
    def __init__(self, num_layers, model_dim, att_heads, ff_hidden_dim, context_len, output_dict_size):
        super().__init__()
        self.trans_layers = nn.ModuleList([TransformerLayer(model_dim, att_heads, ff_hidden_dim) for _ in range(num_layers)])
        self.word_predictor = nn.Linear(model_dim * context_len, output_dict_size)

    def forward(self, x):
        for layer in self.trans_layers:
            x = layer.forward(x)



In [4]:
model = TransformerNetwork(num_layers=6, model_dims=512, att_heads=8, att_key_dims=512//8, att_value_dims=512//8, ff_dims=2048)


TypeError: TransformerNetwork.__init__() got an unexpected keyword argument 'model_dims'