In [1]:
import torch
import torch.nn.functional as F
import math

In [2]:
def scaled_dot_product_attention(Q, K, V):
    """
    Q, K, V shapes: (batch, seq_len, d_k)
    """

    d_k = Q.size(-1)

    # Step 1: raw attention scores (QK^T)
    scores = torch.matmul(Q, K.transpose(-2, -1))

    # Step 2: scale scores
    scaled_scores = scores / math.sqrt(d_k)

    # Step 3: softmax to get weights
    attn_weights = F.softmax(scaled_scores, dim=-1)

    # Step 4: multiply by V
    output = torch.matmul(attn_weights, V)

    return output, attn_weights, scores, scaled_scores


In [3]:
batch = 1
seq_len = 5
d_k = 4

# random Q, K, V
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_k)

output, attn_weights, scores, scaled_scores = scaled_dot_product_attention(Q, K, V)


A. Attention Weight Matrix

In [4]:
print("Attention Weights (Softmax):\n", attn_weights)


Attention Weights (Softmax):
 tensor([[[0.0953, 0.0583, 0.1108, 0.1139, 0.6217],
         [0.7757, 0.0969, 0.0894, 0.0135, 0.0246],
         [0.3436, 0.0518, 0.1023, 0.2252, 0.2771],
         [0.1064, 0.0805, 0.0728, 0.5460, 0.1942],
         [0.0990, 0.1313, 0.1625, 0.2648, 0.3423]]])


B. Output Vectors

In [5]:
print("\nOutput Vectors:\n", output)



Output Vectors:
 tensor([[[-0.5819,  0.1432, -0.1883, -0.1889],
         [-0.1706, -0.4788, -1.4440, -0.5544],
         [-0.3814, -0.2530, -0.5871, -0.1538],
         [-0.2771, -0.2069,  0.0550,  0.1046],
         [-0.2146,  0.2131, -0.2903, -0.1948]]])


C. Softmax Stability Check

In [6]:
print("\nSoftmax Stability Check:")
print("Max |raw scores| before scaling :", scores.abs().max().item())
print("Max |scaled scores| after scaling:", scaled_scores.abs().max().item())



Softmax Stability Check:
Max |raw scores| before scaling : 6.129274368286133
Max |scaled scores| after scaling: 3.0646371841430664
