# A test bench testing the functional part off a single attention head

This gets divided into a few sections

1. projection
2. QK^T
3. softmax(QK^T/sqty(d_h)) = S
4. SV

In [None]:
import torch
import torch.nn as nn
import math


#Entire attention functional test
class TinySelfAttn(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == embed_dim

        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)

    def _shape(self, x, bsz, seq_len):
        # (B, S, D) -> (B, num_heads, S, head_dim)
        return x.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    def forward(self, x, attn_mask=None):
        # x: (B, S, 768)
        bsz, seq_len, _ = x.size()
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = self._shape(q, bsz, seq_len)
        k = self._shape(k, bsz, seq_len)
        v = self._shape(v, bsz, seq_len)

        # scaled dot-product
        attn_scores = (q @ k.transpose(-1, -2)) / math.sqrt(self.head_dim)
        if attn_mask is not None:
            attn_scores += attn_mask
        attn_probs = attn_scores.softmax(dim=-1)

        ctx = attn_probs @ v  # (B, heads, S, head_dim)
        ctx = ctx.transpose(1, 2).contiguous().view(bsz, seq_len, self.embed_dim)

        out = self.out_proj(ctx)
        return out

# 1) QKV input projection
class AttnInputProj(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)

    def _shape(self, x, bsz, seq_len):
        # (B, S, D) -> (B, H, S, head_dim)
        return x.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

    def forward(self, x):
        # x: (B, S, D)
        bsz, seq_len, _ = x.size()
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = self._shape(q, bsz, seq_len)
        k = self._shape(k, bsz, seq_len)
        v = self._shape(v, bsz, seq_len)
        # all: (B, H, S, head_dim)
        return q, k, v


# 2) QK^T + scale (+ optional mask)
class AttnScores(nn.Module):
    def __init__(self, head_dim):
        super().__init__()
        self.scale = 1.0 / math.sqrt(head_dim)

    def forward(self, q, k, attn_mask=None):
        # q, k: (B, H, S, head_dim)
        # k^T over last two dims
        scores = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # (B, H, S, S)
        if attn_mask is not None:
            scores = scores + attn_mask  # mask should be broadcastable
        return scores


# 3) softmax over last dim
class AttnWeights(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, scores):
        # scores: (B, H, S, S)
        return scores.softmax(dim=-1)  # (B, H, S, S)


# 4) context = weights @ V
class AttnContext(nn.Module):
    def __init__(self, embed_dim=768, num_heads=12):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

    def forward(self, attn_probs, v):
        # attn_probs: (B, H, S, S)
        # v:          (B, H, S, head_dim)
        ctx = torch.matmul(attn_probs, v)  # (B, H, S, head_dim)
        # merge heads back: (B, S, D)
        ctx = ctx.transpose(1, 2).contiguous().view(
            ctx.size(0),  # B
            ctx.size(2),  # S
            self.num_heads * self.head_dim  # D
        )
        return ctx  # (B, S, D)


# 5) final output projection
class AttnOutputProj(nn.Module):
    def __init__(self, embed_dim=768):
        super().__init__()
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)

    def forward(self, x):
        # x: (B, S, D)
        return self.out_proj(x)


# Load the specific weights for the particular attentio layer we care about

In [2]:
from safetensors.torch import load_file

ALL_WEIGHTS_PATH = "../weights/downloads/model.safetensors"
state = load_file(ALL_WEIGHTS_PATH)

prefix = "model.vlm_with_expert.vlm.model.vision_model.encoder.layers.11.self_attn"

# grab only that layerâ€™s params
layer_weights = {k[len(prefix)+1:]: v for k, v in state.items() if k.startswith(prefix + ".")}
print(layer_weights.keys())

for key in layer_weights:
    print(f"{key}: {layer_weights[key].shape}, dtype={layer_weights[key].dtype}")

dict_keys(['k_proj.bias', 'k_proj.weight', 'out_proj.bias', 'out_proj.weight', 'q_proj.bias', 'q_proj.weight', 'v_proj.bias', 'v_proj.weight'])
k_proj.bias: torch.Size([768]), dtype=torch.bfloat16
k_proj.weight: torch.Size([768, 768]), dtype=torch.bfloat16
out_proj.bias: torch.Size([768]), dtype=torch.bfloat16
out_proj.weight: torch.Size([768, 768]), dtype=torch.bfloat16
q_proj.bias: torch.Size([768]), dtype=torch.bfloat16
q_proj.weight: torch.Size([768, 768]), dtype=torch.bfloat16
v_proj.bias: torch.Size([768]), dtype=torch.bfloat16
v_proj.weight: torch.Size([768, 768]), dtype=torch.bfloat16


# Below gets the weights for a full attention test (not unit tests)

In [3]:
attn = TinySelfAttn(embed_dim=768, num_heads=12)
# make sure dtypes match
attn = attn.to(torch.bfloat16)

# now load
with torch.no_grad():
    attn.q_proj.weight.copy_(layer_weights["q_proj.weight"].to(torch.bfloat16))
    attn.q_proj.bias.copy_(layer_weights["q_proj.bias"].to(torch.bfloat16))
    attn.k_proj.weight.copy_(layer_weights["k_proj.weight"].to(torch.bfloat16))
    attn.k_proj.bias.copy_(layer_weights["k_proj.bias"].to(torch.bfloat16))
    attn.v_proj.weight.copy_(layer_weights["v_proj.weight"].to(torch.bfloat16))
    attn.v_proj.bias.copy_(layer_weights["v_proj.bias"].to(torch.bfloat16))
    attn.out_proj.weight.copy_(layer_weights["out_proj.weight"].to(torch.bfloat16))
    attn.out_proj.bias.copy_(layer_weights["out_proj.bias"].to(torch.bfloat16))


#

In [4]:
B, S, D = 1, 50, 768  #Batch size, Sequence length, Embedding dimension
#Define number of heads
num_heads = 12
x = torch.randn(B, S, D, dtype=torch.bfloat16)
out = attn(x)
print(out.shape)  # (1, 50, 768)



torch.Size([1, 50, 768])


# Below gets just the linear projects of Q, K, and V

In [8]:
x_new = torch.randn(B, S, D, dtype=torch.bfloat16)
print(f"type of x_new: {x_new.dtype}")
attnInputProj = AttnInputProj(embed_dim=D, num_heads=num_heads).to(torch.bfloat16)
#Get the projection weights and biases
with torch.no_grad():
    attnInputProj.q_proj.weight.copy_(layer_weights["q_proj.weight"].to(torch.bfloat16))
    attnInputProj.q_proj.bias.copy_(layer_weights["q_proj.bias"].to(torch.bfloat16))
    attnInputProj.k_proj.weight.copy_(layer_weights["k_proj.weight"].to(torch.bfloat16))
    attnInputProj.k_proj.bias.copy_(layer_weights["k_proj.bias"].to(torch.bfloat16))
    attnInputProj.v_proj.weight.copy_(layer_weights["v_proj.weight"].to(torch.bfloat16))
    attnInputProj.v_proj.bias.copy_(layer_weights["v_proj.bias"].to(torch.bfloat16))

q, k, v = attnInputProj.forward(x_new)

print(f"Q shape: {q.shape}, and type: {q.dtype}")
print(f"K shape: {k.shape}, and type: {k.dtype}")
print(f"V shape: {v.shape}, and type: {v.dtype}")




type of x_new: torch.bfloat16
Q shape: torch.Size([1, 12, 50, 64]), and type: torch.bfloat16
K shape: torch.Size([1, 12, 50, 64]), and type: torch.bfloat16
V shape: torch.Size([1, 12, 50, 64]), and type: torch.bfloat16


# Now we will use this linear projection to calculate QK^T/sqrt(d_h)

In [16]:
attnScores = AttnScores(head_dim=D//num_heads).to(torch.bfloat16)
#create an optional triangular mask filled with negative infinity values
mask = torch.tril(torch.ones((S, S), dtype=torch.bfloat16))  # (S, S)
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
output_scores = attnScores.forward(q, k)  #No mask provided
print(f"Output scores shape: {output_scores.shape}, and type: {output_scores.dtype}")

Output scores shape: torch.Size([1, 12, 50, 50]), and type: torch.bfloat16


# Test Softmax

In [None]:
AttnWeights = AttnWeights()
output_weights = AttnWeights.forward(output_scores)
print(f"Output weights shape: {output_weights.shape}, and type: {output_weights.dtype}")

Output weights shape: torch.Size([1, 12, 50, 50]), and type: torch.bfloat16


# Test SV

In [9]:
AttnContext = AttnContext(embed_dim=D, num_heads=num_heads).to(torch.bfloat16)
output_context = AttnContext.forward(output_weights, v)
print(f"Output context shape: {output_context.shape}, and type: {output_context.dtype}")

Output context shape: torch.Size([1, 50, 768]), and type: torch.bfloat16


In [10]:
AttnOutputProj = AttnOutputProj(embed_dim=D).to(torch.bfloat16)
with torch.no_grad():
    AttnOutputProj.out_proj.weight.copy_(layer_weights["out_proj.weight"].to(torch.bfloat16))
    AttnOutputProj.out_proj.bias.copy_(layer_weights["out_proj.bias"].to(torch.bfloat16))
output_final = AttnOutputProj.forward(output_context)

#This is the final output of the attention mechanism
print(f"Output final shape: {output_final.shape}, and type: {output_final.dtype}")


Output final shape: torch.Size([1, 50, 768]), and type: torch.bfloat16
