In [1]:
import torch
import torch.nn.functional as F
import math

In [2]:
def scaled_dot_product_attention(Q, K, V):
    """
    Q, K, V shapes: (batch, seq_len, d_k)
    """

    d_k = Q.size(-1)

    # Step 1: raw attention scores (QK^T)
    scores = torch.matmul(Q, K.transpose(-2, -1))

    # Step 2: scale scores
    scaled_scores = scores / math.sqrt(d_k)

    # Step 3: softmax to get weights
    attn_weights = F.softmax(scaled_scores, dim=-1)

    # Step 4: multiply by V
    output = torch.matmul(attn_weights, V)

    return output, attn_weights, scores, scaled_scores


In [5]:
batch = 1
seq_len = 5
d_k = 4

# random Q, K, V
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_k)

output, attn_weights, scores, scaled_scores = scaled_dot_product_attention(Q, K, V)


A. Attention Weight Matrix

In [8]:
print("Attention Weights (Softmax):\n", attn_weights)


Attention Weights (Softmax):
 tensor([[[0.3996, 0.2274, 0.0969, 0.1289, 0.1472],
         [0.2049, 0.0494, 0.2094, 0.3574, 0.1789],
         [0.0943, 0.5423, 0.1192, 0.1276, 0.1166],
         [0.0767, 0.2918, 0.1180, 0.0897, 0.4238],
         [0.0494, 0.6280, 0.0460, 0.0420, 0.2346]]])


B. Output Vectors

In [11]:
print("\nOutput Vectors:\n", output)



Output Vectors:
 tensor([[[-0.5907, -0.4586,  0.1636,  0.9050],
         [-0.2145, -0.6599,  0.7463,  0.2710],
         [-0.8828,  0.0079,  0.4739,  1.0310],
         [-0.5885, -0.5099,  0.5524,  0.8580],
         [-1.0159, -0.0569,  0.4155,  1.3062]]])


C. Softmax Stability Check

In [14]:
print("\nSoftmax Stability Check:")
print("Max |raw scores| before scaling :", scores.abs().max().item())
print("Max |scaled scores| after scaling:", scaled_scores.abs().max().item())



Softmax Stability Check:
Max |raw scores| before scaling : 3.958254814147949
Max |scaled scores| after scaling: 1.9791274070739746
