<a href="https://colab.research.google.com/github/Vishal-113/NLP4-/blob/main/PyTorch_Scaled_Dot_Product_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn.functional as F
import numpy as np
import sys

# Set seed for reproducibility
torch.manual_seed(42)

# --- Hyperparameters ---
BATCH_SIZE = 2
SEQ_LEN = 5         # Sequence length (L)
D_K = 8             # Dimension of Q, K, V (d_k)
# Note: For simplicity in printing, we omit the NUM_HEADS dimension here,
# treating D_K as the last dimension (d_k) and SEQ_LEN as the sequence length (L).

print("-" * 60)
print(f"Testing Scaled Dot-Product Attention with PyTorch.")
print(f"Inputs: B={BATCH_SIZE}, L={SEQ_LEN}, D_K={D_K}")
print("-" * 60)

# 1. Create Random Input Tensors
# Q, K, V shape: (batch_size, seq_len, d_k)
Q = torch.randn(BATCH_SIZE, SEQ_LEN, D_K, dtype=torch.float32)
K = torch.randn(BATCH_SIZE, SEQ_LEN, D_K, dtype=torch.float32)
V = torch.randn(BATCH_SIZE, SEQ_LEN, D_K, dtype=torch.float32)

def scaled_dot_product_attention_pytorch(Q, K, V):
    """
    Computes Scaled Dot-Product Attention: Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V

    Q, K, V shapes: (B, L, D_K)
    """
    d_k = Q.size(-1)

    # 1. Calculate raw scores (Dot product: Q * K^T)
    # K.transpose(-2, -1) changes (B, L, D_K) to (B, D_K, L)
    raw_scores = torch.matmul(Q, K.transpose(-2, -1)) # Shape: (B, L, L)

    # --- Softmax Stability Check 1: Before Scaling ---
    # Print max raw score to show potential instability before scaling
    max_raw_score = raw_scores.max().item()
    print(f"\n[Softmax Stability Check 1]")
    print(f"Max Raw Score (Q*K^T): {max_raw_score:.4f} (Can be large, leading to vanishing gradients)")

    # 2. Scale the scores
    scale_factor = torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    scaled_scores = raw_scores / scale_factor

    # --- Softmax Stability Check 2: After Scaling ---
    # Print max scaled score to show how division by sqrt(d_k) stabilizes the input
    max_scaled_score = scaled_scores.max().item()
    print(f"Max Scaled Score:      {max_scaled_score:.4f} (Stabilized input for softmax)")

    # 3. Apply Softmax to get Attention Weights
    # The softmax is applied along the last dimension (the key dimension)
    attention_weights = F.softmax(scaled_scores, dim=-1) # Shape: (B, L, L)

    # 4. Multiply with V to get the output vectors
    output = torch.matmul(attention_weights, V) # Shape: (B, L, D_K)

    return output, attention_weights

# Run the attention function
output_vectors, attn_weights = scaled_dot_product_attention_pytorch(Q, K, V)


# --- Print Results (Step 3) ---

# 3a. Attention Weight Matrix
print("\n" + "=" * 60)
print("3a. Attention Weight Matrix (Batch 0)")
print("Weights shape: (L_query, L_key) = (5, 5). Rows sum to 1.")
print("=" * 60)

# Print the attention matrix for the first element in the batch
attn_matrix_b0 = attn_weights[0].detach().numpy()
# Format and print header/rows
header = "{:<10}".format("Query\\Key") + "".join(f"{j:<10}" for j in range(SEQ_LEN))
print(header)
print("-" * len(header))

for i in range(SEQ_LEN):
    row = f"Q[{i}]:<4" + "".join(f"{attn_matrix_b0[i, j]:<10.4f}" for j in range(SEQ_LEN))
    print(f"Q[{i}]:     " + "".join(f"{attn_matrix_b0[i, j]:<10.4f}" for j in range(SEQ_LEN)))

# 3b. Output Vectors
print("\n" + "=" * 60)
print("3b. Output Vectors (Batch 1)")
print("Output shape: (L, D_K) = (5, 8). These are the contextualized embeddings.")
print("=" * 60)

# Print the output vectors for the second element in the batch
output_b1 = output_vectors[1].detach().numpy()

for i in range(SEQ_LEN):
    # Print first 4 dimensions of the vector
    vector_snippet = ', '.join(f'{x:.4f}' for x in output_b1[i, :4])
    print(f"Output for Pos {i}: [{vector_snippet} ...]")

------------------------------------------------------------
Testing Scaled Dot-Product Attention with PyTorch.
Inputs: B=2, L=5, D_K=8
------------------------------------------------------------

[Softmax Stability Check 1]
Max Raw Score (Q*K^T): 4.8066 (Can be large, leading to vanishing gradients)
Max Scaled Score:      1.6994 (Stabilized input for softmax)

3a. Attention Weight Matrix (Batch 0)
Weights shape: (L_query, L_key) = (5, 5). Rows sum to 1.
Query\Key 0         1         2         3         4         
------------------------------------------------------------
Q[0]:     0.2815    0.1595    0.1294    0.3631    0.0665    
Q[1]:     0.1579    0.1584    0.2353    0.1935    0.2549    
Q[2]:     0.2155    0.5677    0.0260    0.1256    0.0651    
Q[3]:     0.1026    0.3303    0.0869    0.3324    0.1478    
Q[4]:     0.1031    0.1305    0.3847    0.0949    0.2868    

3b. Output Vectors (Batch 1)
Output shape: (L, D_K) = (5, 8). These are the contextualized embeddings.
Output fo