### Transformer Decoder

<div align="center">
  <img src="https://res.cloudinary.com/edlitera/image/upload/c_fill,f_auto/v1680629118/blog/gz5ccspg3yvq4eo6xhrr" alt="Transformer Decoder" width="300">
</div>

In [7]:
import torch
import torch.nn as nn

In [None]:
# Importing classes from the respective notebooks
%run 4_Multihead_Attention.ipynb
%run 5_FeedForward.ipynb

**Attention**:
- `Self-Attention Layer`: Allows each position in the decoder to attend to all previous positions (using causal masking)
- `Cross-Attention Layer`: Allows the decoder to attend to the encoder's output

In [9]:
class DecoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        """
        Single transformer decoder layer with masked self-attention, 
        cross-attention, and feed-forward network.
        
        Args:
            embed_dim: Dimension of embeddings
            num_heads: Number of attention heads
            ff_hidden_dim: Hidden dimension of feed-forward network
            dropout: Dropout probability
        """
        super().__init__()
        self.self_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        
        self.cross_attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.ff = FeedForward(embed_dim, ff_hidden_dim, dropout)
        self.norm3 = nn.LayerNorm(embed_dim)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, self_attn_mask=None, cross_attn_mask=None):
        """
        Forward pass of decoder layer.
        
        Args:
            x: Input tensor [batch_size, seq_len, embed_dim]
            enc_output: Encoder output [batch_size, enc_seq_len, embed_dim]
            self_attn_mask: Mask for self-attention (usually causal mask)
            cross_attn_mask: Mask for cross-attention
            
        Returns:
            x: Output tensor [batch_size, seq_len, embed_dim]
        """
        # Self-attention block with residual connection and layer normalization
        residual = x
        x = self.norm1(x)
        x = residual + self.dropout(self.self_attn(x, x, x, self_attn_mask))
        
        # Cross-attention block with residual connection and layer normalization
        residual = x
        x = self.norm2(x)
        x = residual + self.dropout(self.cross_attn(x, enc_output, enc_output, cross_attn_mask))
        
        # Feed-forward block with residual connection and layer normalization
        residual = x
        x = self.norm3(x)
        x = residual + self.dropout(self.ff(x))
        
        return x

In [10]:
# Example

embed_dim = 64
num_heads = 8
ff_hidden_dim = 256
dropout = 0.1

decoder_layer = DecoderLayer(embed_dim, num_heads, ff_hidden_dim, dropout)

In [11]:
# input tensors for the forward pass
batch_size = 2
seq_len = 10
enc_seq_len = 12

# Random input tensors for the example
x = torch.randn(batch_size, seq_len, embed_dim)  # Decoder input (e.g., previous tokens embeddings)
enc_output = torch.randn(batch_size, enc_seq_len, embed_dim)  # Encoder output (e.g., encoder's final hidden states)

In [19]:
# Creating a simple causal mask for self-attention
self_attn_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()  # Upper triangular matrix (future tokens should be masked)

self_attn_mask

tensor([[False,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False, False,  True,  True,  True],
        [False, False, False, False, False, False, False, False,  True,  True],
        [False, False, False, False, False, False, False, False, False,  True],
        [False, False, False, False, False, False, False, False, False, False]])

In [20]:
# Make the mask compatible with the batch size by unsqueezing it
self_attn_mask = self_attn_mask.unsqueeze(0).expand(batch_size, -1, -1)  # Shape [batch_size, seq_len, seq_len]
self_attn_mask.shape

torch.Size([2, 10, 10])

In [21]:
cross_attn_mask = None  # Could be used to mask out certain encoder tokens, for now it's None

In [22]:
# Forward pass through the decoder layer
output = decoder_layer(x, enc_output, self_attn_mask=self_attn_mask, cross_attn_mask=cross_attn_mask)

print(output.shape)  # Should be [batch_size, seq_len, embed_dim]

torch.Size([2, 10, 64])


### Attention:
- **Self-Attention (in the Decoder)**: This is where each token in the target sequence (tgt) can look at earlier tokens in the sequence (but not future tokens). This is done using a causal mask. The model can only look at **previous tokens** in the target sequence, and the **causal mask** prevents it from looking at future tokens.
- **Cross-Attention (in the Decoder)**: This is where the target sequence (tgt) attends to the source sequence (src). The decoder can attend to any token in the source, but it should not attend to padding tokens in the source. No causal mask is needed here because we are not concerned with future tokens in the source. The model attends to the **source sequence** but **ignores padding tokens** in the source. There is no causal mask in cross-attention because we don’t need to worry about future tokens in the source.