<a href="https://colab.research.google.com/github/Rakshithbodakuntla/scaled_dot_product_attention/blob/main/scaled_dot_product_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import math
import torch
import torch.nn.functional as F


def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q, K, V: (batch, T, d_k)
    mask: (batch, T, T) where 0 means "mask out"
    """
    # Unscaled scores
    scores = Q @ K.transpose(-2, -1)  # (B, T, T)

    d_k = Q.size(-1)
    scaled_scores = scores / math.sqrt(d_k)

    if mask is not None:
        scaled_scores = scaled_scores.masked_fill(mask == 0, float("-inf"))

    attn_weights = F.softmax(scaled_scores, dim=-1)  # (B, T, T)
    output = attn_weights @ V  # (B, T, d_k)

    return output, attn_weights, scores, scaled_scores


def main():
    torch.manual_seed(0)

    B, T, d_k = 2, 4, 8
    Q = torch.randn(B, T, d_k)
    K = torch.randn(B, T, d_k)
    V = torch.randn(B, T, d_k)

    out, attn, scores, scaled_scores = scaled_dot_product_attention(Q, K, V)

    print("Raw scores (before scaling) for batch 0:\n", scores[0])
    print("\nScaled scores for batch 0:\n", scaled_scores[0])
    print("\nAttention weights for batch 0 (row sums ~1):\n", attn[0])
    print("\nOutput vectors for batch 0:\n", out[0])

    # Softmax stability check
    # Compare distribution shape before vs after scaling
    unscaled_softmax = F.softmax(scores[0], dim=-1)
    scaled_softmax = F.softmax(scaled_scores[0], dim=-1)

    print("\nUnscaled softmax row 0:", unscaled_softmax[0])
    print("Scaled softmax row 0:", scaled_softmax[0])
    print("\nUnscaled max score:", scores[0].max().item())
    print("Scaled max score:", scaled_scores[0].max().item())


if __name__ == "__main__":
    main()


Raw scores (before scaling) for batch 0:
 tensor([[-1.2955,  1.0110, -1.1253, -4.9447],
        [ 1.8294,  0.7432, -0.5131, -5.1092],
        [-3.2314,  1.6361, -0.0782, -2.5119],
        [ 0.3663, -1.4356,  0.8898, -4.7521]])

Scaled scores for batch 0:
 tensor([[-0.4580,  0.3574, -0.3979, -1.7482],
        [ 0.6468,  0.2628, -0.1814, -1.8064],
        [-1.1425,  0.5785, -0.0276, -0.8881],
        [ 0.1295, -0.5076,  0.3146, -1.6801]])

Attention weights for batch 0 (row sums ~1):
 tensor([[0.2175, 0.4916, 0.2310, 0.0599],
        [0.4537, 0.3090, 0.1982, 0.0390],
        [0.0915, 0.5115, 0.2790, 0.1180],
        [0.3453, 0.1826, 0.4155, 0.0565]])

Output vectors for batch 0:
 tensor([[-0.2786, -0.0528, -0.5245, -0.2117,  0.1071, -0.2984,  0.2336, -0.1146],
        [-0.4687, -0.1608, -0.5746,  0.1081,  0.2751, -0.0187,  0.7547, -0.4147],
        [-0.1701, -0.0369, -0.5337, -0.2441,  0.0010, -0.3409, -0.0189,  0.0410],
        [-0.4162, -0.1179, -0.3925,  0.4595,  0.2722,  0.1509,  0.5