<a href="https://colab.research.google.com/github/Tanish-Sarkar/Elite-Transformers/blob/main/Module2%20-%20Attention%20Deep%20Dive/Single_head_attention_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# 1. Single Scaled Dot-Product Attention

In [6]:
class ScaledDotProductAttention(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, Q, K, V, mask=None):
    """
      Q, K, V: shape (batch, seq_len, d_k)   ← sometimes called (B, N, d)
      mask:    shape (batch, seq_len, seq_len) or broadcastable
    """
    d_k = Q.size(-1)

    # Step-1: Compute raw dot-product attention scores
    scores = torch.matmul(Q, K.transpose(-2,-1))     # (B, T, T)

    # Step 2: Scale by sqrt(d_k) — prevents large values → softmax saturation
    scores = scores / math.sqrt(d_k)

    # Step 3: Apply mask (if any) — very important for decoder
    if mask is not None:
      scores = scores.masked_fill(mask == 0, float('-inf'))

    # Step 4: Softmax to get the probabilities
    attn_weights = F.softmax(scores, dim=1)            # (B, T, T)

    # step 5: Weighted sum of values
    output = torch.matmul(attn_weights, V)             # (B, T, d_k)

    return output, attn_weights

# 2. Test single head

In [8]:
B, T, d_k = 2, 8, 64

Q = torch.randn(B, T, d_k)
K = torch.randn(B, T, d_k)
V = torch.randn(B, T, d_k)

attn_single = ScaledDotProductAttention()
out_single, weight_single = attn_single(Q, K, V)

print("Single Head Output Shape: ", out_single.shape)
print("Attention Weight Shape: ", weight_single.shape)

Single Head Output Shape:  torch.Size([2, 8, 64])
Attention Weight Shape:  torch.Size([2, 8, 8])


# 3. Causal Mask (for decoder – prevents seeing future tokens)

In [9]:
def generate_causal_mask(seq_len):
  """
   Returns upper-triangular mask (1 = attend, 0 = mask)
   Shape: (seq_len, seq_len)
  """
  mask = torch.tril(torch.ones(seq_len, seq_len))
  return mask

mask = generate_causal_mask(T)
print("\nCausal mask example (for seq_len=8):\n", mask)

# Apply mask to single head
out_masked, weights_mask = attn_single(Q, K, V, mask=mask)
print("\nMasked Output Shape: ", out_masked.shape)
print("Attention Weight Shape: ", weights_mask.shape)


Causal mask example (for seq_len=8):
 tensor([[1., 0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1.]])

Masked Output Shape:  torch.Size([2, 8, 64])
Attention Weight Shape:  torch.Size([2, 8, 8])
