In [4]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
import torch
from merge.modules.config import TransformerConfig
from merge.modules.transformer import Transformer


torch.Size([1, 2, 8])

In [21]:
import torch
from dataclasses import dataclass
from merge.modules.attention import GQA

@dataclass
class TransformerConfig:
    num_heads: int = 8
    num_kv_heads: int = 2
    d_model: int = 512
    attention_bias: bool = False
    attention_dropout: float = 0.1
    pos_encoding_type: str = None

# Create test inputs
batch_size = 2
seq_length = 16
config = TransformerConfig()

# Initialize GQA layer
gqa = GQA(config)

# Create dummy input tensor
x = torch.randn(batch_size, seq_length, config.d_model)

# Create causal mask
mask = torch.tril(torch.ones((1, 1, seq_length, seq_length)))
mask = torch.where(mask == 1.0, 0.0, -10000.0)

# Forward pass
output = gqa(x, mask, pos_info=None)

# Verification
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
assert output.shape == x.shape, "Output shape doesn't match input shape"

# Test that k/v heads are properly grouped
# Get intermediate tensors for verification
q = gqa.W_Q(x)
kv = gqa.W_KV(x)
k, v = kv.chunk(2, dim=-1)

print(f"\nProjection shapes:")
print(f"Q projection: {q.shape}")  # Should be [batch, seq, d_model]
print(f"KV projection: {kv.shape}")  # Should be [batch, seq, 2 * (d_model//num_heads * num_kv_heads)]

# Verify the grouping ratio
assert gqa.num_heads % gqa.num_kv_heads == 0, "Number of heads must be divisible by number of KV heads"
print(f"\nGrouping ratio (queries per k/v): {gqa.num_heads // gqa.num_kv_heads}")

print("\nTest passed!")

Input shape: torch.Size([2, 16, 512])
Output shape: torch.Size([2, 16, 512])

Projection shapes:
Q projection: torch.Size([2, 16, 512])
KV projection: torch.Size([2, 16, 256])

Grouping ratio (queries per k/v): 4

Test passed!
