<a href="https://www.kaggle.com/code/aisuko/residual-encoder-attention?scriptVersionId=164087786" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

The **Residual Encoder Attention block** is a key component of Transformer models, which are widely used in natural language processing tasks. This block is designed to process a sequence of input data by applying **self-sttention** and **position-wise feed-forward networks**, with **residual connections**. See more detail in notebook [Transformer From Scratch](https://www.kaggle.com/code/aisuko/transformer-from-scratch#Residual-Connection)

The **residual connections** help to migrate the problem of vanishing gradients during training, making it easier to train deep models. The **layer normalization** helps to stabilize the learning process.

This block is typically used as a building block for larger models, with several such blocks stacked togather.


# Implementation

It is a simple implementation of multi-head self-attention using PyTorch's built in functions.

In [1]:
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, n_state, n_head):
        super(MultiHeadSelfAttention, self).__init__()
        self.n_head=n_head
        self.attention=nn.MultiheadAttention(n_state, n_head)
    
    def forward(self, x):
        return self.attention(x,x,x)[0]


class MLP(nn.Module):
    """
    MLP is a simple implementaion of a feed-forward neural network(also known as a multi-layer perceptron)
    with two linear layers and a ReLU acivation function.
    """
    def __init__(self, n_state):
        super(MLP, self).__init__()
        self.fc1=nn.Linear(n_state, n_state)
        self.fc2=nn.Linear(n_state, n_state)
    
    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))


class ResidualEncoderAttentionBlock(nn.Module):
    def __init__(self, n_state, n_head):
        super(ResidualEncoderAttentionBlock, self).__init__()
        self.attn=MultiHeadSelfAttention(n_state, n_head)
        self.attn_ln=nn.LayerNorm(n_state)
        self.mlp=MLP(n_state)
        self.mlp_ln=nn.LayerNorm(n_state)

    def forward(self, x):
        x=x+self.attn(self.attn_ln(x))
        x=x+self.mlp(self.mlp_ln(x))
        return x

# Testing

This test case creates a `ResidualEncoderAttentionBlock` with **4 hidden units** and **2 attention heads**, and **a batch of 1 sequences**, **each of length 2**, **with 4 features**. It then processes the batch through the block and checks that the output has the same shape as the input.

If the block is implement correctly, this test should pass.

In [2]:
import torch

torch.manual_seed(123)

def test_res_en_atten_block():
    # Initialize a block with 4 hidden units and 2 attention heads
    block=ResidualEncoderAttentionBlock(n_state=4, n_head=2)
    print(block)
    
    # Create a batch of 1 sequences, each of length 2, with 4 features
    x=torch.randn(3, 2, 4)
    print(f'Tensor Size {x.size()}')
    
    # Process the batch through the block
    y=block(x)
    
    # Check that the output has the same shape as the input
    assert y.shape==x.shape

test_res_en_atten_block()

ResidualEncoderAttentionBlock(
  (attn): MultiHeadSelfAttention(
    (attention): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=4, out_features=4, bias=True)
    )
  )
  (attn_ln): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
  (mlp): MLP(
    (fc1): Linear(in_features=4, out_features=4, bias=True)
    (fc2): Linear(in_features=4, out_features=4, bias=True)
  )
  (mlp_ln): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
)
Tensor Size torch.Size([3, 2, 4])
