In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import math

class MultiHeadSelfAttention(nn.Module):
    """
    Multi-Head Self-Attention Module
    Args:
        d_model: Total dimension of the model.
        d_embed: Embedding dimension.
        num_head: Number of attention heads.
        dropout: Dropout rate for attention scores.
        bias: Whether to include bias in linear projections.
    """
    def __init__(self, d_model, d_embed, num_head, dropout=0.0, bias=True): # infer d_k, d_v, d_q from d_model
        super().__init__()  # Missing in the original implementation
        assert d_model % num_head == 0, "d_model must be divisible by num_head"
        self.d_model = d_model
        
        self.d_embed = d_embed
        self.num_head = num_head
        self.d_head=d_model//num_head
        self.dropout_rate = dropout  # Store dropout rate separately

        # linear transformations
        self.q_proj = nn.Linear(d_embed, d_model, bias=bias)
        self.k_proj = nn.Linear(d_embed, d_model, bias=bias)
        self.v_proj = nn.Linear(d_embed, d_model, bias=bias)
        self.output_proj = nn.Linear(d_model, d_model, bias=bias)

        # Dropout layer
        self.dropout = nn.Dropout(p=dropout)

        # Initiialize scaler
        self.scaler = float(1.0 / math.sqrt(self.d_head)) # Store as float in initialization
        

    def forward(self, sequence, att_mask=None):
        batch_size, seq_len, embed_dim = sequence.size()
        
        # Linear projections and reshape for multi-head
        Q_state = self.q_proj(sequence) #[batch_size, seq_len, d_model=num_head * d_head]
        K_state = self.k_proj(sequence)
        V_state = self.v_proj(sequence)
        
        Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head).transpose(1,2) #[batch_size, self.num_head, seq_len, self.d_head]
        K_state = K_state.view(batch_size, seq_len, self.num_head, self.d_head).transpose(1,2)
        V_state = V_state.view(batch_size, seq_len, self.num_head, self.d_head).transpose(1,2)
    
       
        # Scale Q by 1/sqrt(d_k)
        Q_state = Q_state * self.scaler
    
    
        # Compute attention matrix: QK^T
        att_matrix = torch.matmul(Q_state, K_state.transpose(-1,-2)) 

    
        # apply attention mask to attention matrix
        if att_mask is not None and not isinstance(att_mask, torch.Tensor):
            raise TypeError("att_mask must be a torch.Tensor")

        if att_mask is not None:
            
            # Expand mask for multi-head attention
            # [batch_size, seq_len] -> [batch_size, 1, 1, seq_len]
            att_mask = att_mask.unsqueeze(1).unsqueeze(2)
            att_matrix = att_matrix.masked_fill(att_mask == 0, float('-inf'))
        
        # apply softmax to the last dimension to get the attention score: softmax(QK^T)
        att_score = F.softmax(att_matrix, dim = -1)
    
        # apply drop out to attention score
        att_score = self.dropout(att_score)
    
        # get final output: softmax(QK^T)V
        att_output = torch.matmul(att_score, V_state)
    
        # concatinate all attention heads
        att_output = att_output.contiguous().view(batch_size, seq_len, self.num_head*self.d_head) 
    
        # final linear transformation to the concatenated output
        att_output = self.output_proj(att_output)

        return att_output

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class FFN(nn.Module):
    """
    Position-wise Feed-Forward Networks
    This consists of two linear transformations with a ReLU activation in between.
    
    FFN(x) = max(0, xW1 + b1 )W2 + b2
    d_model: embedding dimension (e.g., 512)
    d_ff: feed-forward dimension (e.g., 2048)
    
    """
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.d_model=d_model
        self.d_ff= d_ff
        
        # Linear transformation y = xW+b
        self.fc1 = nn.Linear(self.d_model, self.d_ff, bias = True)
        self.fc2 = nn.Linear(self.d_ff, self.d_model, bias = True)
        
        # for potential speed up
        # Pre-normalize the weights (can help with training stability)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)


    def forward(self, input):
        # check input and first FF layer dimension matching
        batch_size, seq_length, d_input = input.size()
        assert self.d_model == d_input, "d_model must be the same dimension as the input"

        # First linear transformation followed by ReLU
        # There's no need for explicit torch.max() as F.relu() already implements max(0,x)
        f1 = F.relu(self.fc1(input))

        # max(0, xW_1 + b_1)W_2 + b_2 
        f2 =  self.fc2(f1)

        return f2

        

In [5]:
net = FFN(  d_model = 512,  d_ff =2048)
print(net)

FFN(
  (fc1): Linear(in_features=512, out_features=2048, bias=True)
  (fc2): Linear(in_features=2048, out_features=512, bias=True)
)


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerEncoder(nn.Module):
    """
    Encoder layer of the Transformer
    Sublayers: MultiHeadSlefAttention
              FNN
    Args:
            d_model: 512 model hidden dimension
            d_embed: 512 embedding dimension, same as d_model in transformer framework
            d_ff: 2048 hidden dimension of the feed forward network
            num_head: 8 Number of attention heads.
            dropout:  0.1 dropout rate 
            
            bias: Whether to include bias in linear projections.
              
    """

    def __init__(
        self, d_model, d_embed, d_ff,
        num_head, dropout=0.1,
        bias=True
    ):
        super().__init__()
        self.d_model = d_model
        self.d_embed = d_embed
        self.d_ff = d_ff


        # attention sublayer
        self.att = MultiHeadSelfAttention(
            d_model = d_model,
            d_embed = d_embed,
            num_head = num_head,
            dropout = dropout,
            bias = bias
        )
        
        # FFN sublayer
        self.ffn = FFN(
            d_model = d_model,
            d_ff = d_ff
        )

        
        # Dropout layer
        self.dropout = nn.Dropout(p=dropout)

        # layer-normalization layer
        self.LayerNorm_att = nn.LayerNorm(self.d_model)
        self.LayerNorm_ffn = nn.LayerNorm(self.d_model)


    def forward(self, embed_input, att_mask):
        
        ## First sublayer: self attion 
        # After embedding and positional encoding, input sequence feed into attention sublayer
        att_sublayer = self.att(sequence = embed_input, att_mask = att_mask)  # [batch_size, sequence_length, d_model]
        # apply dropout before layer normalization for each sublayer
        att_sublayer = self.dropout(att_sublayer)
        # Residual layer normalization
        att_normalized = self.LayerNorm_att(embed_input + att_sublayer)         # [batch_size, sequence_length, d_model]
        
        # Second sublayer: FFN
        ffn_sublayer = self.ffn(att_normalized)                                 # [batch_size, sequence_length, d_model]
        ffn_sublayer = self.dropout(ffn_sublayer)
        ffn_normalized = self.LayerNorm_att(att_sublayer + ffn_sublayer )       # [batch_size, sequence_length, d_model]
    

        return ffn_normalized

In [9]:
net = TransformerEncoder( d_model = 512,  d_embed = 258, d_ff =2048, num_head=8, dropout=0.1, bias=True )
print(net)

TransformerEncoder(
  (att): MultiHeadSelfAttention(
    (q_proj): Linear(in_features=258, out_features=512, bias=True)
    (k_proj): Linear(in_features=258, out_features=512, bias=True)
    (v_proj): Linear(in_features=258, out_features=512, bias=True)
    (output_proj): Linear(in_features=512, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (ffn): FFN(
    (fc1): Linear(in_features=512, out_features=2048, bias=True)
    (fc2): Linear(in_features=2048, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (LayerNorm_att): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (LayerNorm_ffn): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
)


In [11]:
def test_transformer_encoder():
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # Test parameters
    batch_size = 32
    seq_length = 20
    d_embed = 512
    d_model = 512
    d_ff = 2048
    num_heads = 8
    
    # Initialize the transformer encoder
    encoder = TransformerEncoder(
        d_model=d_model,
        d_embed=d_embed,
        d_ff=d_ff,
        num_head=num_heads,
        dropout=0.1
    )
    
    # Set to evaluation mode to disable dropout
    encoder.eval()
    
    # Create input sequence - using ones instead of random values
    # for easier interpretation of attention patterns
    input_sequence = torch.ones(batch_size, seq_length, d_embed)
    
    # Create attention mask
    attention_mask = torch.ones(batch_size, seq_length)
    attention_mask[:, 15:] = 0  # Mask last 5 positions
    
    # Store attention patterns
    attention_patterns = []
    
    # Define hook to capture attention scores
    def attention_hook(module, input, output):
        # We want to capture the attention scores before they're processed further
        # This assumes your attention module returns the attention scores
        attention_patterns.append(output)
    
    # Register the hook on the attention computation
    encoder.att.register_forward_hook(attention_hook)
    
    # Perform forward pass
    with torch.no_grad():
        output = encoder(input_sequence, attention_mask)
    
    # Basic shape tests
    expected_shape = (batch_size, seq_length, d_model)
    assert output.shape == expected_shape, f"Expected shape {expected_shape}, got {output.shape}"
    
    # Print output statistics
    print("\nOutput Statistics:")
    print(f"Mean: {output.mean():.4f}")
    print(f"Std: {output.std():.4f}")
    print(f"Min: {output.min():.4f}")
    print(f"Max: {output.max():.4f}")
    
    # Analyze attention patterns
    if attention_patterns:
        attention_output = attention_patterns[0]
        # Look at the attention patterns for unmasked vs masked positions
        unmasked_attention = output[:, :15, :].abs().mean()
        masked_attention = output[:, 15:, :].abs().mean()
        
        print("\nAttention Analysis:")
        print(f"Unmasked positions mean: {unmasked_attention:.4f}")
        print(f"Masked positions mean: {masked_attention:.4f}")
        
        # Note: We expect masked positions to still have values due to residual connections,
        # but their patterns should be different from unmasked positions
        print("\nIs the masking working?", "Yes" if unmasked_attention != masked_attention else "No")
    
    # Check for any NaN or infinite values
    assert torch.isfinite(output).all(), "Output contains NaN or infinite values"
    
    print("\nAll tests passed successfully!")
    return output, attention_patterns

# Run the test
output, attention_patterns = test_transformer_encoder()


Output Statistics:
Mean: 0.0000
Std: 1.0000
Min: -4.0672
Max: 3.6722

Attention Analysis:
Unmasked positions mean: 0.8028
Masked positions mean: 0.8009

Is the masking working? Yes

All tests passed successfully!
