# Self-Attention Mechanism Implementation
This notebook implements the self-attention mechanism with step-by-step calculations and visualizations based on class notes.

## 1. Import Required Libraries
Import NumPy, Matplotlib, and other necessary libraries for matrix operations and visualization.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from math import sqrt
import pandas as pd

# Set random seed for reproducibility
np.random.seed(42)

# Configure matplotlib for better plots
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 12

## 2. Define Input Embeddings
Create the embedding matrix E with one-hot encodings for tokens A and B as shown in the class notes example.

In [None]:
# Define the sequence: A A B A (using one-hot encodings)
# A = [1, 0], B = [0, 1]
E = np.array([
    [1, 0],  # A
    [1, 0],  # A  
    [0, 1],  # B
    [1, 0]   # A
])

print("Input Embedding Matrix E:")
print(E)
print(f"Shape: {E.shape}")
print("\nSequence representation: A A B A")
print("A = [1, 0], B = [0, 1]")

## 3. Initialize Weight Matrices
Define the weight matrices Wq, Wk, and Wv for query, key, and value transformations using the values from class notes.

In [None]:
# Weight matrices from class notes
Wq = np.array([
    [0, 1],
    [0, 1]
])

Wk = np.array([
    [10, 0],
    [0, 10]
])

# For value matrix, we'll use the identity matrix initially
Wv = np.array([
    [1, 0],
    [0, 1]
])

print("Query Weight Matrix Wq:")
print(Wq)
print("Interpretation: No matter what the value is, we're looking for Bs")

print("\nKey Weight Matrix Wk:")
print(Wk)
print("Interpretation: 'Booster' matrix - amplifies the embeddings")

print("\nValue Weight Matrix Wv:")
print(Wv)
print("Interpretation: Identity matrix for this example")

## 4. Compute Query and Key Matrices
Calculate Q = E * Wq and K = E * Wk matrices and display the results with explanations of what each represents.

In [None]:
# Compute Query matrix
Q = E @ Wq
print("Query Matrix Q = E @ Wq:")
print(Q)
print("Interpretation: Every token is looking for Bs around")

# Compute Key matrix  
K = E @ Wk
print("\nKey Matrix K = E @ Wk:")
print(K)
print("Interpretation: Amplified embeddings for better distinction")

# Compute Value matrix
V = E @ Wv
print("\nValue Matrix V = E @ Wv:")
print(V)
print("Interpretation: Original embeddings (identity transformation)")

## 5. Calculate Attention Scores
Compute the score matrix S = Q * K^T and explain how each element represents attention between tokens.

In [None]:
# Compute attention scores
S = Q @ K.T
print("Score Matrix S = Q @ K^T:")
print(S)

print("\nInterpretation of scores:")
print("- Rows represent queries (each token asking)")
print("- Columns represent keys (each token being asked about)")
print("- High scores indicate strong attention")

# Explain specific scores
print(f"\nExample: S[0,2] = {S[0,2]} - Token 1 (A) attending to Token 3 (B)")
print(f"All tokens have high attention to Token 3 (B) because queries look for Bs")

# Display with labels
score_df = pd.DataFrame(S, 
                       index=['A₁', 'A₂', 'B₃', 'A₄'], 
                       columns=['A₁', 'A₂', 'B₃', 'A₄'])
print("\nScore Matrix with labels:")
print(score_df)

## 6. Apply Softmax to Get Attention Weights
Apply row-wise softmax to the score matrix to get the attention distribution A, including the scaling factor sqrt(dk).

In [None]:
# Define softmax function
def softmax(x, axis=-1):
    """Compute softmax values for array x along specified axis."""
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

# Apply scaling factor sqrt(dk)
dk = K.shape[1]  # dimension of keys
print(f"Key dimension dk = {dk}")
print(f"Scaling factor sqrt(dk) = {sqrt(dk)}")

# Scaled scores
S_scaled = S / sqrt(dk)
print(f"\nScaled Score Matrix S / sqrt(dk):")
print(S_scaled)

# Apply row-wise softmax to get attention weights
A = softmax(S_scaled, axis=1)
print(f"\nAttention Weight Matrix A (after softmax):")
print(A)

# Display with labels
attention_df = pd.DataFrame(A, 
                           index=['A₁', 'A₂', 'B₃', 'A₄'], 
                           columns=['A₁', 'A₂', 'B₃', 'A₄'])
print("\nAttention weights with labels:")
print(attention_df)

# Verify that rows sum to 1
print(f"\nRow sums (should be 1.0): {A.sum(axis=1)}")

## 7. Compute Self-Attention Output
Calculate the final output using the formula Attention(Q,K,V) = softmax(Q * K^T / sqrt(dk)) * V and show how it produces contextualized encodings.

In [None]:
# Compute self-attention output
Z = A @ V
print("Self-Attention Output Z = A @ V:")
print(Z)

print(f"\nOutput shape: {Z.shape}")
print("Same as input shape - preserves sequence length!")

# Compare with original embeddings
print("\nComparison with original embeddings:")
comparison_df = pd.DataFrame({
    'Token': ['A₁', 'A₂', 'B₃', 'A₄'],
    'Original_dim1': E[:, 0],
    'Original_dim2': E[:, 1], 
    'SelfAttn_dim1': Z[:, 0],
    'SelfAttn_dim2': Z[:, 1]
})
print(comparison_df)

print("\nInterpretation:")
print("- Original A tokens: [1, 0]")
print("- Original B token: [0, 1]") 
print("- After self-attention: All tokens are influenced by the B token")
print("- Each token now has a contextualized representation")

## 8. Visualize Attention Matrix
Create heatmaps and visualizations of the attention weights to show which tokens attend to which other tokens.

In [None]:
# Create attention heatmap
plt.figure(figsize=(12, 5))

# Plot 1: Attention weights heatmap
plt.subplot(1, 2, 1)
sns.heatmap(A, annot=True, fmt='.3f', cmap='Blues',
            xticklabels=['A₁', 'A₂', 'B₃', 'A₄'],
            yticklabels=['A₁', 'A₂', 'B₃', 'A₄'],
            cbar_kws={'label': 'Attention Weight'})
plt.title('Self-Attention Weights Matrix')
plt.xlabel('Keys (being attended to)')
plt.ylabel('Queries (attending from)')

# Plot 2: Score matrix heatmap
plt.subplot(1, 2, 2)
sns.heatmap(S, annot=True, fmt='.1f', cmap='Reds',
            xticklabels=['A₁', 'A₂', 'B₃', 'A₄'],
            yticklabels=['A₁', 'A₂', 'B₃', 'A₄'],
            cbar_kws={'label': 'Raw Score'})
plt.title('Raw Attention Scores Matrix')
plt.xlabel('Keys')
plt.ylabel('Queries')

plt.tight_layout()
plt.show()

# Attention pattern analysis
print("Attention Pattern Analysis:")
print("- All tokens pay high attention to B₃ (column 3)")
print("- This is because all queries are looking for B tokens")
print("- The attention pattern shows the model focuses on relevant information")

In [None]:
# Visualize the transformation process
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Original embeddings
im1 = axes[0,0].imshow(E.T, cmap='RdYlBu', aspect='auto')
axes[0,0].set_title('Original Embeddings E')
axes[0,0].set_xlabel('Token Position')
axes[0,0].set_ylabel('Embedding Dimension')
axes[0,0].set_xticks(range(4))
axes[0,0].set_xticklabels(['A₁', 'A₂', 'B₃', 'A₄'])
plt.colorbar(im1, ax=axes[0,0])

# Attention weights
im2 = axes[0,1].imshow(A, cmap='Blues', aspect='auto')
axes[0,1].set_title('Attention Weights A')
axes[0,1].set_xlabel('Keys')
axes[0,1].set_ylabel('Queries')
axes[0,1].set_xticks(range(4))
axes[0,1].set_yticks(range(4))
axes[0,1].set_xticklabels(['A₁', 'A₂', 'B₃', 'A₄'])
axes[0,1].set_yticklabels(['A₁', 'A₂', 'B₃', 'A₄'])
plt.colorbar(im2, ax=axes[0,1])

# Self-attention output
im3 = axes[1,0].imshow(Z.T, cmap='RdYlBu', aspect='auto')
axes[1,0].set_title('Self-Attention Output Z')
axes[1,0].set_xlabel('Token Position')
axes[1,0].set_ylabel('Embedding Dimension')
axes[1,0].set_xticks(range(4))
axes[1,0].set_xticklabels(['A₁', 'A₂', 'B₃', 'A₄'])
plt.colorbar(im3, ax=axes[1,0])

# Comparison plot
x_pos = np.arange(4)
width = 0.35
axes[1,1].bar(x_pos - width/2, E[:, 0], width, label='Original dim1', alpha=0.7)
axes[1,1].bar(x_pos + width/2, Z[:, 0], width, label='Self-attn dim1', alpha=0.7)
axes[1,1].set_title('Dimension 1: Original vs Self-Attention')
axes[1,1].set_xlabel('Token Position')
axes[1,1].set_ylabel('Value')
axes[1,1].set_xticks(x_pos)
axes[1,1].set_xticklabels(['A₁', 'A₂', 'B₃', 'A₄'])
axes[1,1].legend()

plt.tight_layout()
plt.show()

## 9. Multi-Head Self-Attention Implementation
Extend the single-head implementation to multiple heads using different randomly initialized parameter matrices.

In [None]:
def multi_head_attention(E, num_heads=2, d_model=2):
    """
    Implement multi-head self-attention
    
    Args:
        E: Input embeddings (seq_len, d_model)
        num_heads: Number of attention heads
        d_model: Model dimension
    
    Returns:
        output: Multi-head attention output
        attention_weights: List of attention weight matrices for each head
    """
    seq_len = E.shape[0]
    d_k = d_model // num_heads  # dimension per head
    
    # Initialize random weight matrices for each head
    heads_output = []
    attention_weights = []
    
    print(f"Multi-Head Self-Attention with {num_heads} heads")
    print(f"d_model = {d_model}, d_k = {d_k}")
    
    for head in range(num_heads):
        print(f"\n--- Head {head + 1} ---")
        
        # Random weight matrices for this head
        Wq_h = np.random.randn(d_model, d_k) * 0.5
        Wk_h = np.random.randn(d_model, d_k) * 0.5  
        Wv_h = np.random.randn(d_model, d_k) * 0.5
        
        # Compute Q, K, V for this head
        Q_h = E @ Wq_h
        K_h = E @ Wk_h
        V_h = E @ Wv_h
        
        # Compute attention scores and weights
        S_h = Q_h @ K_h.T
        S_h_scaled = S_h / sqrt(d_k)
        A_h = softmax(S_h_scaled, axis=1)
        
        # Compute output for this head
        Z_h = A_h @ V_h
        
        heads_output.append(Z_h)
        attention_weights.append(A_h)
        
        print(f"Head {head + 1} attention weights:")
        print(A_h.round(3))
    
    # Concatenate all heads and apply output projection
    concatenated = np.concatenate(heads_output, axis=1)
    
    # Output projection (random matrix for demo)
    W_o = np.random.randn(num_heads * d_k, d_model) * 0.5
    output = concatenated @ W_o
    
    return output, attention_weights

# Apply multi-head attention
multi_head_output, multi_attention_weights = multi_head_attention(E, num_heads=2)

print(f"\nMulti-Head Attention Output:")
print(multi_head_output)
print(f"Shape: {multi_head_output.shape}")

In [None]:
# Visualize multi-head attention patterns
fig, axes = plt.subplots(1, len(multi_attention_weights), figsize=(15, 5))

for i, attn_weights in enumerate(multi_attention_weights):
    sns.heatmap(attn_weights, annot=True, fmt='.3f', cmap='viridis',
                xticklabels=['A₁', 'A₂', 'B₃', 'A₄'],
                yticklabels=['A₁', 'A₂', 'B₃', 'A₄'],
                ax=axes[i], cbar_kws={'label': 'Attention Weight'})
    axes[i].set_title(f'Head {i+1} Attention Weights')
    axes[i].set_xlabel('Keys')
    axes[i].set_ylabel('Queries')

plt.tight_layout()
plt.show()

# Compare single-head vs multi-head outputs
print("Comparison: Single-Head vs Multi-Head Self-Attention")
comparison_df = pd.DataFrame({
    'Token': ['A₁', 'A₂', 'B₃', 'A₄'],
    'Original_1': E[:, 0],
    'Original_2': E[:, 1],
    'SingleHead_1': Z[:, 0], 
    'SingleHead_2': Z[:, 1],
    'MultiHead_1': multi_head_output[:, 0],
    'MultiHead_2': multi_head_output[:, 1]
})
print(comparison_df.round(4))

print("\nKey Insights:")
print("1. Each head learns different attention patterns")
print("2. Multi-head attention captures diverse relationships")
print("3. Different heads may focus on different aspects of the input")
print("4. The final output combines information from all heads")

## Summary

This notebook demonstrated the complete self-attention mechanism:

1. **Basic Attention**: Showed how queries attend to keys to find relevant information
2. **Self-Attention**: Extended to where each token is both query and key
3. **Mathematical Framework**: Implemented the full formula: Attention(Q,K,V) = softmax(QK^T/√dk)V
4. **Visualization**: Created heatmaps to understand attention patterns
5. **Multi-Head Extension**: Showed how multiple attention heads capture different relationships

**Key Takeaways**:
- Self-attention produces contextualized representations of the same length as input
- The attention matrix shows which tokens influence each other
- Scaling by √dk prevents extreme softmax values
- Multiple heads allow the model to attend to different types of relationships simultaneously
- This mechanism is fundamental to Transformer architectures