In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import json

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cpu')
print(f"Using device: {device}")

Using device: cpu


In [4]:
class MultiQueryAttention(nn.Module):
    """
    Multi-query attention: multiple queries, but shared key and value.

    """

    def __init__(self, dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5 #scales down attention scores so gradient remains stables 

        self.q_proj = nn.Linear(dim, dim) 

        self.k_proj = nn.Linear(dim, self.head_dim)
        self.v_proj = nn.Linear(dim, self.head_dim)

        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape

        q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim)
        q = q.permute(0, 2, 1, 3)

        k = self.k_proj(x) #shared K & V for all heads
        v = self.v_proj(x)

        k = k.unsqueeze(1)
        v = v.unsqueeze(1)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v)
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)

        return x 

print("Testing Multi-query Attention...")
mqa = MultiQueryAttention(dim=64, num_heads=4)
test_input = torch.randn(2, 16, 64)
test_output = mqa(test_input)

print(f"✔️ Input shape: {test_input.shape}")
print(f"✔️ Output shape: {test_output.shape}")

mqa_params = sum(p.numel() for p in mqa.parameters()) #counts every weight in the model
print(f"✔️ Multi - Query Attention parameters: {mqa_params:,}")

Testing Multi-query Attention...
✔️ Input shape: torch.Size([2, 16, 64])
✔️ Output shape: torch.Size([2, 16, 64])
✔️ Multi - Query Attention parameters: 10,400


In [6]:
from torchvision import datasets, transforms

class StandardAttention(nn.Module):
    def __init__(self, dim, num_heads=4, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        return x

std_attn = StandardAttention(dim=64, num_heads=4)
mq_attn = MultiQueryAttention(dim=64, num_heads=4)

std_params = sum(p.numel() for p in std_attn.parameters())
mq_params = sum(p.numel() for p in mq_attn.parameters())

print("\n" + "="*60)
print("PARAMETER COMPARISON")
print("="*60)
print(f"Standard Attention: {std_params:,} parameters")
print(f"Multi-Query Attention: {mq_params:,} paramters")
print(f"Reduction:              {std_params - mq_params:,} parameters ({(1 - mq_params/std_params)*100:.1f}% fewer)")
print("="*60)


PARAMETER COMPARISON
Standard Attention: 16,640 parameters
Multi-Query Attention: 10,400 paramters
Reduction:              6,240 parameters (37.5% fewer)


In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=2, embed_dim=64):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
        
    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class TransformerBlock(nn.Module):
    """Transformer block - now using Multi-Query Attention!"""
    def __init__(self, dim, num_heads=4, mlp_ratio=4, dropout=0.1, use_mqa=True):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        
        # Choose attention type
        if use_mqa:
            self.attn = MultiQueryAttention(dim, num_heads, dropout)
        else:
            self.attn = StandardAttention(dim, num_heads, dropout)
            
        self.norm2 = nn.LayerNorm(dim)
        
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden_dim, dim),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

print("✓ Building blocks defined!")

In [None]:
class VisionTransformerTracker(nn.Module):
    """ Vision transformer for tracking - with Multi-Query Attention """
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_channels=2,
        embed_dim=64,
        depth=4,
        num_heads=4,
        mlp_ratio=4,
        dropout=0.1,
        use_mqa=True
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.dropout = nn.Dropout(dropout)

        #Transformer blocks with Multi-Query attention
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout, use_mqa=use_mqa)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)

        self.head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.GELU(),
            nn.Linear(embed_dim // 2, 2)
        )

    def forward(self, x):
        x = self.patch_embed(x)
        x = x + self.pos_embed
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        x = x.mean(dim=1)
        pos = self.head(x)

        return pos

print("Creating Multi-Query Attention model...")
mq_model = VisionTransformerTracker(
    img_size=32,
    patch_size=4,
    embed_dim=64,
    depth=4,
    num_heads=4,
    use_mqa=True
).to(device)

mq_params = sum(p.numel() for p in mq_model.parameters())
print(f"✔️ Multi-Query model: {mq_params:,} parameters")

test_input = torch.randn(4, 2, 32, 32).to(device)
test_output = mq_model(test_input)
print(f"✔️ Output shape: {test_output.shape}")
