### Transformer Encoder

<div align="center">
  <img src="https://media.datacamp.com/legacy/v1704797298/image_3aa5aef3db.png" alt="Transformer Encoder" width="300">
</div>

In [4]:
import torch
import torch.nn as nn

In [None]:
# Importing classes from the respective notebooks
%run 4_Multihead_Attention.ipynb
%run 5_FeedForward.ipynb

In [5]:
class EncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        """
        Single transformer encoder layer with self-attention and feed-forward network.
        
        Args:
            embed_dim: Dimension of embeddings
            num_heads: Number of attention heads
            ff_hidden_dim: Hidden dimension of feed-forward network
            dropout: Dropout probability
        """
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.ff = FeedForward(embed_dim, ff_hidden_dim, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor [batch_size, seq_len, embed_dim]
            mask: Optional attention mask
            
        Returns:
            x: Output tensor [batch_size, seq_len, embed_dim]
        """
        # Self-attention block with residual connection and layer normalization
        residual = x
        x = self.norm1(x)
        x = residual + self.dropout(self.self_attn(x, x, x, mask))
        
        # Feed-forward block with residual connection and layer normalization
        residual = x
        x = self.norm2(x)
        x = residual + self.dropout(self.ff(x))
        
        return x

In [6]:
# Parameters
batch_size = 2
seq_len = 10
embed_dim = 512
num_heads = 8
ff_hidden_dim = 2048
dropout = 0.1

# random input tensor with shape [batch_size, seq_len, embed_dim]
x = torch.randn(batch_size, seq_len, embed_dim)

encoder_layer = EncoderLayer(embed_dim=embed_dim, num_heads=num_heads, ff_hidden_dim=ff_hidden_dim, dropout=dropout)
output = encoder_layer(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")

Input shape: torch.Size([2, 10, 512])
Output shape: torch.Size([2, 10, 512])
