In [1]:
import torch
import torch.nn as nn
import math

# =============================================
# Previous components (simplified)
# =============================================
class LayerNormalization(nn.Module):
    def __init__(self, features: int, eps: float = 1e-6):
        super().__init__()
        self.alpha = nn.Parameter(torch.ones(features))
        self.bias = nn.Parameter(torch.zeros(features))
        self.eps = eps
    
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        std = x.std(dim=-1, keepdim=True)
        return self.alpha * (x - mean) / (std + self.eps) + self.bias

class ResidualConnection(nn.Module):
    def __init__(self, features: int, dropout: float):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = LayerNormalization(features)
    
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

class MultiHeadAttentionblock(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float):
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0
        self.d_k = d_model // h
        
        self.w_q = nn.Linear(d_model, d_model, bias=False)
        self.w_k = nn.Linear(d_model, d_model, bias=False)
        self.w_v = nn.Linear(d_model, d_model, bias=False)
        self.w_o = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)
    
    @staticmethod
    def attention(query, key, value, mask, dropout):
        d_k = query.shape[-1]
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        return attention_scores @ value, attention_scores
    
    def forward(self, q, k, v, mask=None):
        query = self.w_q(q).view(q.shape[0], q.shape[1], self.h, self.d_k).transpose(1, 2)
        key = self.w_k(k).view(k.shape[0], k.shape[1], self.h, self.d_k).transpose(1, 2)
        value = self.w_v(v).view(v.shape[0], v.shape[1], self.h, self.d_k).transpose(1, 2)
        
        x, self.attention_scores = MultiHeadAttentionblock.attention(query, key, value, mask, self.dropout)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        return self.w_o(x)

class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)
    
    def forward(self, x):
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

# =============================================
# EncoderBlock Implementation
# =============================================
class EncoderBlock(nn.Module):
    def __init__(self, features: int, self_attention_block: MultiHeadAttentionblock,
                 feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        # Two residual connections: one for attention, one for feedforward
        self.residual_connections = nn.ModuleList([
            ResidualConnection(features, dropout) for _ in range(2)
        ])
    
    def forward(self, x, src_mask):
        # Step 1: Self-Attention with residual
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        # Step 2: FeedForward with residual
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x


In [4]:
# =============================================
# DEMO: Step-by-step execution
# =============================================
print("=== ENCODER BLOCK DEMO ===\n")

# Hyperparameters
batch_size = 1 ## number of sentences 
seq_len = 4 ## number of tokens
d_model = 8 ## feature dim
h = 2 ##num of heads
d_ff = 16 ##hidden dim
dropout = 0.1 ##dropout p

# Create components
attention = MultiHeadAttentionblock(d_model, h, dropout)
feedforward = FeedForwardBlock(d_model, d_ff, dropout)
encoder_block = EncoderBlock(d_model, attention, feedforward, dropout)

=== ENCODER BLOCK DEMO ===



In [5]:
# Input: (batch=2, seq=4, features=8)
# Sentence example: "The cat sat here" (4 tokens)
x = torch.randn(batch_size, seq_len, d_model)
print("Step 1 - Input x:")
print("  Shape:", x.shape)  # (1, 4, 8)
print("  Sample x[0,0]:", x[0, 0])
print('X: ',x)
print()

Step 1 - Input x:
  Shape: torch.Size([1, 4, 8])
  Sample x[0,0]: tensor([-1.2552, -1.0697, -0.2918,  0.9712,  1.0733, -0.7998,  1.5604, -0.9575])
X:  tensor([[[-1.2552, -1.0697, -0.2918,  0.9712,  1.0733, -0.7998,  1.5604,
          -0.9575],
         [-0.4063, -1.2687, -0.6252, -0.3680, -0.3114, -0.7477, -0.1026,
          -1.1438],
         [-0.5951,  0.6209, -0.7472,  1.8960, -0.8180, -0.8053,  0.0325,
           0.6191],
         [ 0.3887, -0.2316,  0.4306, -0.3481, -0.1165,  0.7390,  0.6994,
          -0.2416]]])



In [6]:
# Source mask: shape (batch, 1, 1, seq_len) for broadcasting
# Assume last token is padding (mask it out)
src_mask = torch.ones(batch_size, 1, 1, seq_len)
src_mask[:, :, :, -1] = 0  # Mask last token
print("Step 2 - Source mask:")
print("  Shape:", src_mask.shape)  # (2, 1, 1, 4)
print("  Mask[0]:", src_mask[0].squeeze())
print("  → Last token (padding) will be ignored")
print()

Step 2 - Source mask:
  Shape: torch.Size([2, 1, 1, 4])
  Mask[0]: tensor([1., 1., 1., 0.])
  → Last token (padding) will be ignored



In [None]:






# ============================================
# MANUAL STEP-BY-STEP
# ============================================
print("Step 3 - First Residual Connection (Self-Attention)")
print("="*60)

# Save original for comparison
x_original = x.clone()

# Apply first residual (attention)
print("  3a. LayerNorm(x)")
norm1 = encoder_block.residual_connections[0].norm(x)
print("    Normalized x[0,0]:", norm1[0, 0])
print()

print("  3b. MultiHeadAttention(Q=x, K=x, V=x, mask)")
attn_out = encoder_block.self_attention_block(norm1, norm1, norm1, src_mask)
print("    Attention output shape:", attn_out.shape)  # (2, 4, 8)
print("    Attention out[0,0]:", attn_out[0, 0])
print()

print("  3c. Dropout + Add residual")
x_after_attn = x + encoder_block.residual_connections[0].dropout(attn_out)
print("    After residual shape:", x_after_attn.shape)
print("    x_after_attn[0,0]:", x_after_attn[0, 0])
print()

print("Step 4 - Second Residual Connection (FeedForward)")
print("="*60)

print("  4a. LayerNorm(x_after_attn)")
norm2 = encoder_block.residual_connections[1].norm(x_after_attn)
print("    Normalized[0,0]:", norm2[0, 0])
print()

print("  4b. FeedForward")
ff_out = encoder_block.feed_forward_block(norm2)
print("    FF output shape:", ff_out.shape)  # (2, 4, 8)
print("    FF out[0,0]:", ff_out[0, 0])
print()

print("  4c. Dropout + Add residual")
x_final_manual = x_after_attn + encoder_block.residual_connections[1].dropout(ff_out)
print("    Final output shape:", x_final_manual.shape)
print("    x_final[0,0]:", x_final_manual[0, 0])
print()

# ============================================
# END-TO-END
# ============================================
print("Step 5 - END-TO-END: encoder_block(x, src_mask)")
print("="*60)

x_direct = x_original.clone()
output = encoder_block(x_direct, src_mask)
print("  Output shape:", output.shape)  # (2, 4, 8)
print("  Output[0,0]:", output[0, 0])
print()

# ============================================
# Visualize information flow
# ============================================
print("Step 6 - INFORMATION FLOW")
print("="*60)
print("Token 0 (The)  before encoder:", x_original[0, 0, :3])
print("Token 0 (The)  after attention:", x_after_attn[0, 0, :3])
print("Token 0 (The)  after FF:", output[0, 0, :3])
print()
print("✅ Each token now contains:")
print("   - Context from ALL other tokens (via self-attention)")
print("   - Position-specific transformations (via feedforward)")
print()

# ============================================
# Check mask effect
# ============================================
print("Step 7 - MASK EFFECT (last token should receive little attention)")
print("="*60)
attn_scores = encoder_block.self_attention_block.attention_scores
print("Attention scores shape:", attn_scores.shape)  # (2, h, 4, 4)
print("Token 0 attention to all tokens (Head 0):")
print("  To token 0:", attn_scores[0, 0, 0, 0].item())
print("  To token 1:", attn_scores[0, 0, 0, 1].item())
print("  To token 2:", attn_scores[0, 0, 0, 2].item())
print("  To token 3 (masked):", attn_scores[0, 0, 0, 3].item())
print("  → Token 3 (padding) receives near-zero attention ✅")