In [2]:
import sys
sys.path.append("..")

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [6]:
from llm_foundry.model.layers import RMSNorm, apply_rope, compute_rope_params

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
batch_size = 2
seq_len = 4
d_in = 16
num_heads = 2
head_dim = d_in // num_heads
d_out = d_in

In [17]:
##Dummy

torch.manual_seed(0)
x = torch.randn(batch_size, seq_len, d_in, dtype=torch.float32, device=device)
print(x.shape)
print(x[0, 0, :])

torch.Size([2, 4, 16])
tensor([-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160, -2.1152,
         0.3223, -1.2633,  0.3500,  0.3081,  0.1198,  1.2377,  1.1168, -0.2473])


In [64]:
## Weights for getting q, k, v

W_query = nn.Linear(d_in, d_out, bias=False).to(device)
W_key = nn.Linear(d_in, d_out, bias=False).to(device)
W_value = nn.Linear(d_in, d_out, bias=False).to(device) 

out_proj = nn.Linear(d_out, d_in, bias=False).to(device)

print(W_query.weight.shape)

torch.Size([16, 16])


In [39]:
## Projecting to Q, K and V

Q = W_query(x)
K = W_key(x)
V = W_value(x)

print(Q.shape)
print(K.shape)
print(V.shape)

torch.Size([2, 4, 16])
torch.Size([2, 4, 16])
torch.Size([2, 4, 16])


In [40]:
#Change shape to add num_heads
Q = Q.view(batch_size, seq_len, num_heads, head_dim)  # (batch_size, seq_len, num_heads, head_dim)
K = K.view(batch_size, seq_len, num_heads, head_dim)  # (batch_size, seq_len, num_heads, head_dim)
V = V.view(batch_size, seq_len, num_heads, head_dim)  # (batch_size, seq_len, num_heads, head_dim)

##Transposing
Q = Q.transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
K = K.transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)
V = V.transpose(1, 2)  # (batch_size, num_heads, seq_len, head_dim)

print("after reshape -> queries, keys, values:", Q.shape, K.shape, V.shape)

after reshape -> queries, keys, values: torch.Size([2, 2, 4, 8]) torch.Size([2, 2, 4, 8]) torch.Size([2, 2, 4, 8])


In [41]:
## Applying RMS Norm for Q and K
##Given apply_qk_norm = True
apply_qk_norm = True

if apply_qk_norm:
    rms_norm = RMSNorm(head_dim, eps=1e-6).to(device)
    Q = rms_norm(Q)
    K = rms_norm(K)
    print("Applied RMS Norm")
    print("after rms norm -> queries, keys:", Q.shape, K.shape)
else:
    print("skipping rms norm")

Applied RMS Norm
after rms norm -> queries, keys: torch.Size([2, 2, 4, 8]) torch.Size([2, 2, 4, 8])


In [43]:
### Applying RoPE to Q and K
##Given rope_theta = 100000

cos, sin = compute_rope_params(head_dim=head_dim, context_length=seq_len)

# Send to device
cos = cos.to(device)
sin = sin.to(device)

print("cos, sin:", cos.shape, sin.shape)

cos, sin: torch.Size([4, 8]) torch.Size([4, 8])


In [44]:
## Apply RoPE
Q = apply_rope(Q, cos, sin)
K = apply_rope(K, cos, sin)
print("Applied RoPE")
print("after rope -> queries, keys:", Q.shape, K.shape)

Applied RoPE
after rope -> queries, keys: torch.Size([2, 2, 4, 8]) torch.Size([2, 2, 4, 8])


In [45]:
## Scaling vectors
scaling = 1.0 / (head_dim ** 0.5)
Q = Q * scaling

# Dont need to apply on K as it is a dot product (QK^T)/sqrt(d_k), so needed only once on either Q or K
print("after scaling -> queries:", Q.shape)

after scaling -> queries: torch.Size([2, 2, 4, 8])


In [51]:
##Applying mask using triu

causal = torch.triu(torch.ones((seq_len, seq_len), device=device), diagonal=1).bool()  # (seq_len, seq_len)
print("causal mask:", causal)


##Broadcast to (batch_size, num_heads, seq_len, seq_len)
causal = causal.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, seq_len)
print("after unsqueeze -> causal:", causal.shape)

causal mask: tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])
after unsqueeze -> causal: torch.Size([1, 1, 4, 4])


In [54]:
## Computing attention scores
attn_scores = Q @ K.transpose(-2, -1)  # (batch_size, num_heads, seq_len, seq_len)
print("attention scores shape:", attn_scores.shape)


##Apply Mask
attn_scores = attn_scores.masked_fill(causal, float('-inf'))
print("after mask -> attention scores:", attn_scores[0, 0, :, :])

attention scores shape: torch.Size([2, 2, 4, 4])
after mask -> attention scores: tensor([[ 1.5351e-01,        -inf,        -inf,        -inf],
        [-1.2653e+00, -7.5895e-01,        -inf,        -inf],
        [-1.1759e+00, -5.6546e-02,  6.2223e-01,        -inf],
        [ 1.4348e+00,  7.9986e-04, -9.4357e-01,  1.4716e+00]],
       grad_fn=<SliceBackward0>)


In [56]:
## Applying Softmax
attn_weights = torch.softmax(attn_scores.to(torch.float32), dim=-1).to(Q.dtype) ## For numerical stability, convert to float32 before softmax 
                                                                                ## then convert back to original dtype

print("after softmax -> attention weights:", attn_weights[0, 0, :, :])

after softmax -> attention weights: tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.3761, 0.6239, 0.0000, 0.0000],
        [0.0990, 0.3032, 0.5978, 0.0000],
        [0.4222, 0.1006, 0.0391, 0.4380]], grad_fn=<SliceBackward0>)


In [60]:
## Verify each row sums to 1
print("Row sums (should be 1):", attn_weights[0, 0, :, :].sum(dim=-1))

Row sums (should be 1): tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)


In [None]:
##Final context vectors
context = attn_weights @ V  # (batch_size, num_heads, seq_len, head_dim)

##Bring back the dimension to (batch_size, seq_len, num_heads, head_dim)
context = context.transpose(1, 2).contiguous()

# Bring back to (batch_size, seq_len, num_heads * head_dim = d_out)
context = context.reshape(batch_size, seq_len, d_out)  # (batch_size, seq_len, d_out)

In [62]:
## Final output projection
out = out_proj(context)  # (batch_size, seq_len, d_in)
print("final output:", out.shape)

final output: torch.Size([2, 4, 16])


In [63]:
print("Input x.shape:", x.shape)
print("Output out.shape:", out.shape)
# inspect a small slice
print("out[0,0,:5]:", out[0,0,:5])

Input x.shape: torch.Size([2, 4, 16])
Output out.shape: torch.Size([2, 4, 16])
out[0,0,:5]: tensor([-0.3983, -0.1958,  0.2629,  0.5424, -0.0734], grad_fn=<SliceBackward0>)
