In [1]:
import math
import torch
import torch.nn as nn
from einops import rearrange

In [2]:
class Attention(nn.Module):
    def __init__(self, d_model:int, num_heads:int, dropout: float = 0.1, bias:bool = True):
        """
        Args: - d_model: size of the model (must be divisible by num_heads)
                - num_heads: number of heads
                - dropout frequency
                - whether to have a bias in the projection layers
        """
        super(Attention, self).__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model//num_heads
        self.scaling = math.sqrt(self.d_head)

        # Projection layers
        self.Q_proj = nn.Linear(d_model, d_model, bias=bias)
        self.K_proj = nn.Linear(d_model, d_model, bias=bias)
        self.V_proj = nn.Linear(d_model, d_model, bias=bias)
        self.out_proj = nn.Linear(d_model, d_model, bias=bias)

        # Extra layers
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(p=dropout)


    def forward(self, sequence:torch.Tensor, key_values:torch.Tensor = None, attn_mask: torch.Tensor = None):
        """ self-attention or cross-attention
        sequence: batch_size, seq_length, d_model
        """
        assert sequence.shape[2] == self.d_model
        assert len(sequence.shape) == 3
        cross_attention = key_values is not None
        if cross_attention:
            assert len(key_values.shape) == 3 and key_values.shape[2] == self.d_model

        # projections
        Q = self.Q_proj(sequence) # b, l, d_model
        if cross_attention:
            K = self.K_proj(key_values)
            V = self.V_proj(key_values)
        else:
            K = self.K_proj(sequence)
            V = self.V_proj(sequence)

        # print(K.shape)
        # Split heads
        Q = rearrange(Q, "b l (h d) -> b l h d", h=self.num_heads) # d = self.d_head
        K = rearrange(K, "b l (h d) -> b l h d", h=self.num_heads) # d = self.d_head
        V = rearrange(V, "b l (h d) -> b l h d", h=self.num_heads) # d = self.d_head
        # print(K.shape)
        attn_mat = Q @ K.transpose(-1, -2) / self.scaling # b l h h
        # print(attn_mat.shape)
        if attn_mask is not None:
            attn_mat += attn_mask

        attn_score = self.dropout( self.softmax(attn_mat) )
        attn = attn_score @ V # b l h d
        # print(attn.shape)
        # Merge the heads
        attn = rearrange(attn, "b l h d -> b l (h d)")
        # Final projection
        attn = self.out_proj(attn)
        return attn


d_model = 64
attn = Attention(d_model=d_model, num_heads=2)
batch_size = 4
seq_length = 100
x = torch.randn((batch_size, seq_length, d_model))

attn = attn(x)

In [3]:
attn.shape

torch.Size([4, 100, 64])

In [4]:
class FFN(nn.Module):
    def __init__(self, d_model:int, d_ff:int):
        super(FFN, self).__init__()
        self.d_model = d_model
        self.net = nn.Sequential(nn.Linear(d_model, d_ff), nn.ReLU(),
                                 nn.Linear(d_ff, d_model))
        
    def forward(self, x):
        """ x: batch, seq_len, d_model"""
        assert len(x.shape) == 3 and x.shape[2] == self.d_model
        return self.net(x)
    

In [5]:
class Encoder(nn.Module):
    """ Self-Attention, Dropout, Residual, LayerNorm,
        FFN, Dropout, Residual, LayerNorm """
    def __init__(self, d_model:int, num_heads:int, d_ff:int, dropout:float = 0.1, bias:bool = True):
        super(Encoder, self).__init__()
        self.attn = Attention(d_model, num_heads, dropout, bias)
        self.dropout = nn.Dropout(p=dropout)
        self.attn_norm = nn.LayerNorm(d_model)
        self.ffn = FFN(d_model, d_ff)
        self.ff_norm = nn.LayerNorm(d_model)

    def forward(self, encoded_inputs:torch.Tensor, padding_mask:torch.Tensor = None):
        """ encoded_inputs: batch, seq_lenght, d_model """
        x = self.attn(encoded_inputs, padding_mask)
        x = self.dropout(x) + encoded_inputs
        x = self.attn_norm(x)

        x = self.dropout( self.ffn(x) ) + x
        return self.ff_norm(x)

In [6]:
class Decoder(nn.Module):
    """Self-Attention with causal mask, Dropout, Residual, LayerNorm
       Cross-Attention with padding mask, Dropout, Residual, LayerNorm
       FFN, Dropout, Residual, LayerNorm"""
    def __init__(self, d_model:int, num_heads:int, d_ff:int, dropout:float = 0.1, bias:bool = True):
        super(Decoder, self).__init__()

        self.self_attention = Attention(d_model, num_heads, dropout, bias)
        self.self_norm = nn.LayerNorm(d_model)
        self.cross_attention = Attention(d_model, num_heads, dropout, bias)
        self.cross_norm = nn.LayerNorm(d_model)
        self.ffn = FFN(d_model, d_ff)
        self.ff_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, decoder_input:torch.Tensor, encoder_output:torch.Tensor, padding_mask:torch.Tensor = None):
        """ decoder_input: b, seq_len, d_model
            encoder_output: b, enc_seq_len, d_model
        """
        seq_len = decoder_input.shape[1]
        causal_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)*-10**9

        x = self.self_attention(decoder_input, key_values=None, attn_mask=causal_mask)
        x = self.dropout(x) + decoder_input
        x = self.self_norm(x)

        y = self.cross_attention(x, key_values=encoder_output, attn_mask=padding_mask)
        y = self.dropout(y) + x
        y = self.cross_norm(y)

        z = self.ffn(y)
        z = self.dropout(z) + y
        z = self.ff_norm(z)
        return z
    
T = torch.triu(torch.ones(5, 5), diagonal=1)*-10**9
print(T)

tensor([[-0.0000e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-0.0000e+00, -0.0000e+00, -1.0000e+09, -1.0000e+09, -1.0000e+09],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09, -1.0000e+09],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+09],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]])


In [None]:
class EncoderDecoder(nn.Module):
    def __init__(self, num_layers:int, d_model:int, num_heads:int, d_ff:int, dropout:float = 0.1, bias:bool = True):
        super(EncoderDecoder, self).__init__()

        self.encoders = nn.ModuleList([Encoder(d_model, d_ff, num_heads, dropout, bias) for _ in range(num_layers)])
        self.decoders = nn.ModuleList([Decoder(d_model, d_ff, num_heads, dropout, bias) for _ in range(num_layers)])


    def forward(self, embedded_encoder_input:torch.Tensor, embedded_decoder_input:torch.Tensor, padding_mask:torch.Tensor = None):
        """ b, seq_len, d_model """
        encoder_output = embedded_encoder_input # initialization
        decoder_output = embedded_decoder_input # initialization
        
        for encoder, decoder in zip(self.encoders, self.decoders):
            encoder_output = encoder(encoder_output, padding_mask)
            decoder_output = decoder(decoder_output, encoder_output, padding_mask)
        
        return decoder_output
