In [26]:
import torch
import torch.nn.functional as F
import math

In [27]:
def scaled_dot_product_attention(Q, K, V):
    """
    Q: Query matrix     (batch, seq_len, d_k)
    K: Key matrix       (batch, seq_len, d_k)
    V: Value matrix     (batch, seq_len, d_v)
    """

    # 1. QKᵀ → similarity scores
    scores = torch.matmul(Q, K.transpose(-2, -1))

    # 2. Scale by sqrt(d_k)
    d_k = Q.size(-1)
    scores = scores / math.sqrt(d_k)

    # 3. Softmax → probabilities
    attention_weights = F.softmax(scores, dim=-1)

    # 4. Weighted sum with V
    output = torch.matmul(attention_weights, V)

    return output, attention_weights

In [28]:
batch = 1
seq_len = 3
d_k = 4
d_v = 4

In [29]:
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_v)

In [30]:
output, weights = scaled_dot_product_attention(Q, K, V)

print("Attention Weights:\n", weights)
print("\nOutput:\n", output)

Attention Weights:
 tensor([[[0.2166, 0.2821, 0.5013],
         [0.0669, 0.1602, 0.7729],
         [0.6856, 0.2844, 0.0300]]])

Output:
 tensor([[[-0.5421, -0.8788,  0.0876,  0.5341],
         [-0.8720, -1.3406,  0.3212,  1.1088],
         [ 0.1682,  0.2629, -0.2504, -0.3174]]])
