In [1]:
# https://fkodom.substack.com/p/transformers-from-scratch-in-pytorch
# scaled dot product attention
# Q, K, V batches of matrices with shape (batch_size, seq_length, num_features)
# Multiplying the query (Q) and key (K) array results in (batch_size, seq_length, seq_length)
    # tells us roughly how important each element in the sequence is
    # attention layer - what we pay attention to
    # normalized using softmax
# attention applied to value V array using matmul

In [4]:
from torch import Tensor
import torch.nn.functional as F

# omit Mask operation for simplicity
def scaled_dot_product_attention(query: Tensor, key: Tensor, value: Tensor) -> Tensor:
    # batch matmul only performed over last two dimensions
    # bmm = batch matmul, transpose key tensor to batch_size, num_features, seq_len)
    # temp represents raw attention scores
    temp = query.bmm(key.transpose(1, 2))
    # sqrt of dimension of query vectors 
    # done to prevent dot product values from growing too large as number of dimensions in query and key increase
    # numerical instability if numbers becomes too small
    scale = query.size(-1) ** 0.5
    # compute softmax
    # raw attention scores into probability distribution
    softmax = F.softmax(temp / scale, dim=-1)
    # matmul this softmax with the value tensor and return
    # weighted sum of value vectors where weights are determined by attention scores from softmax
    return softmax.bmm(value)
    

In [5]:
# multi head attenion composed of several identical attention heads
# each attention head contains 3 linear layers followed by a scaled dot product attention

import torch
from torch import nn

class AttentionHead(nn.Module):
    def __init__(self, dim_in: int, dim_q: int, dim_k: int):
        super().__init__()
        self.q = nn.Linear(dim_in, dim_q)
        self.k = nn.Linear(dim_in, dim_k)
        self.v = nn.Linear(dim_in, dim_k)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return scaled_dot_product_attention(self.q(query), self.k(key), self.v(value))

In [6]:
# multi head attention layer
# combine num_heads different attention heads and a linear layer for output
# Each attention head computes its own query, key and value arrays and applies scaled dot product attention
# means each head can attend to a different part of the input sequence independent of others
# Increasing attention heads allows us to "pay attention" to more parts of the sequence at once
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_in: int, dim_q: int, dim_k: int):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(dim_in, dim_q, dim_k) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * dim_k, dim_in)

    def forward(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor:
        return self.linear(
            torch.cat([h(query, key, value) for h in self.heads], dim=-1)
        )

In [7]:
# position encoding
# MulitHeadAttention has no trainable components that operate over the sequence dimension
# Everything operates over the feature dimension so it is independent of sequence length
# Provide positional information so that it knows about the relative position of data points in the input sequences

def position_encoding(seq_len: int, dim_model: int, device: torch.device = torch.device("cpu"), dtype: torch.dtype = torch.float32) -> Tensor:
    # position tensor which represents positions of tokens in the sequence
    # torch.arange generates a 1D tensor of integers from 0 to seq_len - 1
    # reshape reshapes tensor to have shape (1, seq_len, 1) so it can be broadcast acorss all dimensions of the encoding in the next steps
    pos = torch.arange(seq_len, dtype=dtype, device=device).reshape(1, -1, 1)
    # creates dimension tensor which corresponds to each of the dim_odel dimensions in the encoding
    # reshapes to shape (1, 1, dim_model)
    dim = torch.arange(dim_model, dtype=dtype, device=device).reshape(1, 1, -1)
    # computes phase for each position and dimension
    # dim / model scales the dimension values based on total number of dimensions in the model
    # 1e4 ** creates the scaling factor
    # pos / divides the position by the scaling factor yielding a "phase"
    phase = pos / (1e4 ** (dim / dim_model))
    # alternate sine and cosine to help psotitional encoding capture odd and even frequency patterns for each dimension
    return torch.where(dim.long() % 2 == 0, torch.sin(phase), torch.cos(phase))

In [12]:
# transformer uses encoder-decoder architecture
# encoder (left) processes input sequence and returns a feature vector (or memory vector)
# decoder processes the target sequence and incorporates information from the encoder memory
# output from decoder is model prediction

In [None]:
# each layer in encoder and decoder contains fully connected feed forward network 
    # consisting of two linear transformations and ReLU activation
    # dimension of input and output is 512, inner layer has dimension of 2048

In [15]:
def feed_forward(dim_input: int = 512, dim_feedforward: int = 2048) -> nn.Module:
    return nn.Sequential(
        nn.Linear(dim_input, dim_feedforward),
        nn.ReLU(),
        nn.Linear(dim_feedforward, dim_input)
    )

In [16]:
# what kind of normalization used? 
# output of each sub layer is LayerNorm(x + Sublayer(x)) where Sublayer(x) is the function implemented by the sub layer itself
# Apply dropout to the output of each sub layer, before it is added to the sub layer input and normalized

In [25]:
class Residual(nn.Module):
    def __init__(self, sublayer: nn.Module, dimension: int, dropout: float = 0.1):
        super().__init__()
        self.sublayer = sublayer
        self.norm = nn.LayerNorm(dimension)
        self.dropout = nn.Dropout(dropout)

    def forward(self, *tensors: Tensor) -> Tensor:
        # Assume query tensor is given first, so we can compute residual
        # Matches signature on MultiHeadAttention
        return self.norm(tensors[0] + self.dropout(self.sublayer(*tensors)))

In [19]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim_model: int = 512, num_heads: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1,):
        super().__init__()
        dim_q = dim_k = max(dim_model // num_heads, 1)
        self.attention = Residual(MultiHeadAttention(num_heads, dim_model, dim_q, dim_k), dimension=dim_model, dropout=dropout,)
        self.feed_forward = Residual(feed_forward(dim_model, dim_feedforward), dimension=dim_model, dropout=dropout,)
        
    def forward(self, src: Tensor) -> Tensor:
        src = self.attention(src, src, src)
        return self.feed_forward(src)

class TransformerEncoder(nn.Module):
    def __init__(self, num_layers: int = 6, dim_model: int = 512, num_heads: int = 8, dim_feedforward: int = 2048, dropout: float = 0.1,):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerEncoderLayer(dim_model, num_heads, dim_feedforward, dropout)
                for _ in range(num_layers)
            ]
        )

    def forward(self, src: Tensor) -> Tensor:
        seq_len, dimension = src.size(1), src.size(2)
        src += position_encoding(seq_len, dimension)
        for layer in self.layers:
            src = layer(src)

        return src

In [27]:
class TransformerDecoderLayer(nn.Module):
    def __init__(self, dim_model: int = 512, num_heads: int = 6, dim_feedforward: int = 2048, dropout: float = 0.1):
        super().__init__()
        dim_q = dim_k = max(dim_model // num_heads, 1)
        self.attention_1 = Residual(MultiHeadAttention(num_heads, dim_model, dim_q, dim_k), dimension=dim_model, dropout=dropout,)
        self.attention_2 = Residual(MultiHeadAttention(num_heads, dim_model, dim_q, dim_k), dimension=dim_model, dropout=dropout,)
        self.feed_forward = Residual(feed_forward(dim_model, dim_feedforward), dimension=dim_model, dropout=dropout)

    def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:
        tgt = self.attention_1(tgt, tgt, tgt)
        tgt = self.attention_2(tgt, memory, memory)
        return self.feed_forward(tgt)

class TransformerDecoder(nn.Module):
    def __init__(self, num_layers: int = 6, dim_model: int = 512, num_heads: int = 8, dim_feedforward: int = 2048, dropout: float = 0.1,):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerDecoderLayer(dim_model, num_heads, dim_feedforward, dropout)
                for _ in range(num_layers)
            ]
        )
        self.linear = nn.Linear(dim_model, dim_model)

    def forward(self, tgt: Tensor, memory: Tensor) -> Tensor:
        seq_len, dimension = tgt.size(1), tgt.size(2)
        tgt += position_encoding(seq_len, dimension)
        for layer in self.layers:
            tgt = layer(tgt, memory)

        return torch.softmax(self.linear(tgt), dim=-1)
        
        

In [21]:
class Transformer(nn.Module):
    def __init__(
        self, 
        num_encoder_layers: int = 6, 
        num_decoder_layers: int = 6, 
        dim_model: int = 512, 
        num_heads: int = 6, 
        dim_feedforward: int = 2048, 
        dropout: float = 0.1, 
        activation: nn.Module = nn.ReLU(),
    ):
        super().__init__()
        self.encoder = TransformerEncoder(
            num_layers=num_encoder_layers,
            dim_model=dim_model,
            num_heads=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout
        )
        self.decoder = TransformerDecoder(
            num_layers=num_decoder_layers,
            dim_model=dim_model,
            num_heads=num_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
        )

    def forward(self, src: Tensor, tgt: Tensor) -> Tensor:
        return self.decoder(tgt, self.encoder(src))
        

In [28]:
src = torch.randn(64, 32, 512)
tgt = torch.randn(64, 16, 512)
out = Transformer()(src, tgt)
print(out.shape)

torch.Size([64, 16, 512])
