In [None]:
"""
Transformer implementation, more or less following
https://huggingface.co/datasets/bird-of-paradise/transformer-from-scratch-tutorial/blob/main/Transformer_Implementation_Tutorial.ipynb
"""

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


In [None]:
class TransformerAttention(nn.Module):
    """
    Transformer Scaled Dot Product Attention
    Args:
        - d_model: dimensions
        - num_heads
        - dropout
        - bias

    """
    def __init__(self, d_model:int, num_heads:int = 1, dropout:float = 0.1, bias:bool = True):

        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model//num_heads

        # Linear projections
        self.q_proj = nn.Linear(d_model, d_model, bias=bias)
        self.k_proj = nn.Linear(d_model, d_model, bias=bias)
        self.v_proj = nn.Linear(d_model, d_model, bias=bias)
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)

        self.dropout = nn.Dropout(p=dropout)
        self.scaling = 1./torch.sqrt(d_head) # scaling per-head


    def self_attention(self, sequence:torch.Tensor, attn_mask:torch.Tensor = None):
        """sequence: [batch, length, d_model]
        attn_mask forces the model to only attend to previous tokens autoregressively"""
        b, l, d = sequence.shape()
        assert d == self.d_model

        # Projections
        Q = self.q_proj(sequence)
        K = self.k_proj(sequence)
        V = self.v_proj(sequence)

        # Split by head
        Q = rearrange(Q, "b l (h d) -> b l h d", h=self.num_heads) # d = self.d_head
        K = rearrange(K, "b l (h d) -> b l h d", h=self.num_heads)
        V = rearrange(V, "b l (h d) -> b l h d", h=self.num_heads)

        attn_score = Q @ K.transpose(-1,-2)/self.scaling
        if attn_mask is not None:
            assert len(attn_mask.shape) == 4
            attn_score += attn_mask
        attn_score = F.softmax(attn_score, dim=-1) 
        attn_score = self.dropout(attn_score)
        attn_output = attn_score @ V

        # Merge the heads
        attn_output = rearrange(attn_output, "b l h d -> b l (h d)")

        # Final projection
        attn_output = self.out_proj(attn_output)
        return attn_output


    def cross_attention(self, sequence:torch.Tensor, key_value:torch.Tensor, attn_mask:torch.Tensor = None):
        """sequence: batch_size, seq_len, d_model
            key_value: batch_size, kv_seq_len, d_model"""
        assert len(sequence.shape) == 3 and sequence.shape[-1] == self.d_model
        assert len(key_value.shape) == 3 and key_value.shape[-1] == self.d_model

        # Linear projections
        Q = self.q_proj(sequence)
        K = self.k_proj(key_value)
        V = self.v_proj(key_value)

        # Split the heads
        Q = rearrange(Q, "b l (h d) -> b l h d", h=self.num_heads)
        K = rearrange(K, "b l (h d) -> b l h d", h=self.num_heads)
        V = rearrange(V, "b l (h d) -> b l h d", h=self.num_heads)

        attn_score = Q @ K.transpose(-1, -2)/self.scaling
        if attn_mask is not None:
            attn_score += attn_mask
        attn_score = F.softmax(attn_score, dim=-1)
        attn_score = self.dropout(attn_score)
        attn_output = attn_score @ V

        # Merge the heads
        attn_output = rearrange(attn_output, "b l h d -> b l (h d)")

        # Final projection
        attn_output = self.out_proj(attn_output)
        return attn_output

    def forward(self, sequence):
        return self.self_attention(sequence)


In [None]:
class FFN(nn.Module):
    """
    Feedforward module
    """
    def __init__(self, d_model:int, d_ff:int):
        super().__init__()
        self.d_model = d_model
        self.net = nn.Sequential([nn.Linear(d_model, d_ff), nn.ReLU(),
                                  nn.Linear(d_ff, d_model)])

    def forward(self, x):
        """ Input: batch_size, sequence length, d_model"""
        assert len(x).shape == 3 and x.shape[-1] == self.d_model
        return self.net(x)

In [None]:
class Encoder(nn.Module):
    """
    Encoder consists of:
        - Attention
        - Residual
        - LayerNorm
        - Feedforward
        - Residual
        - LayerNorm
    """
    def __init__(self, d_model:int, d_ff:int, num_heads:int = 1, dropout:float = 0.1, bias:bool = True):
        super().__init__()
        self.d_model = d_model
        self.attn = TransformerAttention(d_model, num_heads, dropout, bias)
        self.attn_norm = nn.LayerNorm(d_model)
        self.ff = FFN(d_model, d_ff)
        self.ff_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p=dropout)


    def forward(self, embedded_input):
        """ Dimensions: [batch_size, sequence_length, d_model]  """
        assert len(embedded_inputs.shape) == 3 and embedded_inputs.shape[-1] == self.d_model
        x = self.attn(embedded_inputs) # Attention
        x = self.dropout(x) + embedded_inputs # dropout and residuals
        x = self.attn_norm(x)
        x = self.dropout( self.ff(x) ) + x # Feedforward, dropout and residuals
        return self.ff_norm(x)

In [None]:
class Decoder(nn.Module):
    """Layers:
    self-attention
    residual LayerNorm
    cross-attention
    residual LayerNorm
    FFN
    residual LayerNorm"""
    def __init__(self, d_model:int, d_ff:int, num_heads:int, dropout:float = 0.1, bias:bool = True):
        self.d_model = d_model
        self.self_attn = TransformerAttention(d_model, num_heads, dropout, bias)
        self.self_norm = nn.LayerNorm(d_model)
        self.cross_attn = TransformerAttention(d_model, num_heads, dropout, bias)
        self.cross_norm = nn.LayerNorm(d_model)
        self.ffn = FFN(d_model, d_ff)
        self.ffn_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p=dropout)


    def forward(self, embed_input:torch.Tensor, cross_input:torch.Tensor, padding_mask:torch.Tensor = None):
        """
        embed_input: Decoder input sequence [batch_size, seq_len, d_model]
        cross_input: Encoder output sequence [batch_size, encoder_seq_len, d_model]
        padding_mask: Padding mask for cross-attention [batch_size, seq_len, encoder_seq_len]

        causal_attention_mask: Causal mask for self-attention [batch_size, seq_len, seq_len]
        """
        assert embed_input.shape[2] == self.d_model
        assert cross_input.shape[2] == self.d_model
        seq_len = embed_input.shape[1]

        # Causal mask to attend only to past tokens
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).to(embed_input.device)
        causal_mask.masked_fill_(-math.inf) # seq_len, seq_len
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # 1, 1, seq_len, seq_len

        x = self.self_attn.self_attention(embed_input, attn_mask=causal_mask) # batch_size, seq_len, d_model
        x = self.dropout(x) + embed_input # dropout and residuals
        x = self.self_norm(x)

        y = self.cross_attn.cross_attention(x, key_value=cross_input, att_mask=padding_mask)
        y = self.dropout(y) + x # dropout and residuals
        y = self.cross_norm(y)

        z = self.dropout( self.ffn(y) ) + y # dropout, ffn, and residuals
        z = self.ffn_norm(z)
        return z

In [None]:
class TransformerEncoderDecoder(nn.Module):
    """Stacking Encoders and Decoders"""
    def __init__(self, num_layers:int, d_model:int, d_ff:int,
                 num_heads:int = 1, dropout:float = 0.1, bias:bool = True, device:str = "cpu"):
        super().__init__()
        self.device = device
        self.encoders = nn.ModuleList([Encoder(d_model, d_ff, num_heads, dropout, bias) for _ in range(num_layers)]).to(device)
        self.decoders = nn.ModuleList([Decoder(d_model, d_ff, num_heads, dropout, bias) for _ in range(num_layers)]).to(device)

    def forward(self, encoder_input, decoder_input, padding_mask=None):
        """ inputs: batch_size, seq_length, d_model"""
        encoder_output = encoder_input
        decoder_output = decoder_input

        for (encoder, decoder) in zip(self.encoders, self.decoders):
            encoder_output = encoder(encoder_output, padding_mask)
            decoder_output = decoder(decoder_output, encoder_output, padding_mask)

        return decoder_output

