# Multi-headed attnetion implementations

In [1]:
import torch
torch.manual_seed(123)
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')
print(f'PyTorch version: {torch.__version__}')


batch_size= 8
context_len = 1024
embed_dim = 768
embeddings = torch.randn(batch_size,context_len,embed_dim,device=device)

Using device: mps
PyTorch version: 2.9.0


## 1. `CausalAttention` class from Ch 3.

In [2]:
import torch.nn as nn
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
         #define the weights
        self.W_q = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))  # New
    
    #now x has shape (batch, seq, feature)
    def forward(self, x):
        query = self.W_q(x)
        key = self.W_k(x)
        value = self.W_v(x)

        # want (batch,nseq, nseq) <- (batch,nseq,d_out) x (batch,d_out,nseq)
        #need transpose to swap last two rows 
        omega = query @ key.transpose(-2,-1)
        masked = omega.masked_fill(self.mask.bool(),-torch.inf)
        alpha = self.dropout(torch.softmax(masked/key.shape[-1]**0.5,dim=-1))
        
        #in value (batch,nseqj,d_out), alpha (batch,nseqi,nseqj)
        # out (batch,nseq,d_out) 
        # keeping the order value.trasponse(-2,-1) @ alpha.transpose(-2,-1)
        #does the right multiplication, but needs a final transpose(-2,-1) on the
        #output to move d_out to the back.
        #distributing this transpose over the matrix product just gives
        context_vector = alpha @ value
        return context_vector
    

In [3]:
class Ch03_MHA_Wrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
             for _ in range(num_heads)]
        )
        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)

    def forward(self, x):
        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.out_proj(context_vec)
    
mha_ch03_wrapper = Ch03_MHA_Wrapper(
    d_in=embed_dim,
    d_out=embed_dim//12,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_ch03_wrapper(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


## 2. `MultiHeadAttention` class from Ch 3 as `Ch03_MHA`.

In [4]:
class Ch03_MHA(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, \
            'd_out must be divisible by num_heads'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_q = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out,bias=qkv_bias)
        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length),diagonal=1))
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out)


    def forward(self,x):
        b, nseq, d_in = x.shape
        query = self.W_q(x)
        key = self.W_k(x)
        value = self.W_v(x)

        #now split along n_heads
        query = query.view(b,nseq,self.num_heads,self.head_dim)
        key = key.view(b,nseq,self.num_heads,self.head_dim)
        value = value.view(b,nseq,self.num_heads,self.head_dim)

        #move n_head dimension in front of n_seq dim
        query = query.transpose(1,2)
        key = key.transpose(1,2)
        value = value.transpose(1,2)

        #next is dot product attention scores, matmul over the head_dim dimension
        # (b,n_head,n_seq,d_head) x (b,n_head,d_head,n_seq)
        # output is (b,n_head,n_seq,n_seq')
        omega = query @ key.transpose(-2,-1)

        mask = self.mask.bool()[:nseq, :nseq]
        omega = omega.masked_fill(mask, -torch.inf)
        alpha = torch.softmax(omega/self.head_dim**0.5,dim = -1)
        alpha = self.dropout(alpha)
        # shapes are 
        # alpha ~ (b,n_head,n_seq,n_seq')
        # value ~ (b,n_head,n_seq,d_head) 
        #context_vec ~ (b,n_head,d_seq,d_head)
        context_vec = alpha @ value
        #now put n_head next to d_head to roll back up
        context_vec = context_vec.transpose(1,2)
        context_vec = context_vec.contiguous().view(b,nseq,self.d_out)

        #now it is (b,n_seq,d_out), same shape as x except d_in -> d_out
        context_vec = self.out_proj(context_vec)
        return context_vec

In [5]:
mha_ch03 = Ch03_MHA(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_ch03(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


## 3) `MultiHeadAttentionCombined` : an alternative multi-head attention with combined weights

In [6]:
class MultiHeadAttentionCombinedQKV(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, \
            'd_out must be divisible by num_heads'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length),diagonal=1))
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out)

    #(b, n_seq, 3 x d)
    def forward(self,x):
        b, nseq, emb_dim = x.shape
        qkv = self.W_qkv(x)
        #now split along qkv and n_heads
        #(b,nseq, 3, m_heads,d_out/n_heads )
        qkv = qkv.view(b,nseq,3,self.num_heads,self.head_dim)
        #put 3 at front and swap n_seq <-> n_heads, (3, b,n_heads,n_seq)
        qkv = qkv.permute(2, 0 ,3 ,1, 4)

        query, key, value = qkv.unbind(0)

        #next is dot product attention scores, matmul over the head_dim dimension
        # (b,n_head,n_seq,d_head) x (b,n_head,d_head,n_seq)
        # output is (b,n_head,n_seq,n_seq')
        omega = query @ key.transpose(-2,-1)

        mask = self.mask.bool()[:nseq, :nseq]
        omega = omega.masked_fill(mask, -torch.inf)
        alpha = torch.softmax(omega/self.head_dim**0.5,dim = -1)
        alpha = self.dropout(alpha)
        # shapes are 
        # alpha ~ (b,n_head,n_seq,n_seq')
        # value ~ (b,n_head,n_seq,d_head) 
        #context_vec ~ (b,n_head,d_seq,d_head)
        context_vec = alpha @ value
        #now put n_head next to d_head to roll back up
        context_vec = context_vec.transpose(1,2)
        context_vec = context_vec.contiguous().view(b,nseq,self.d_out)

        #now it is (b,n_seq,d_out), same shape as x except d_in -> d_out
        context_vec = self.out_proj(context_vec)
        return context_vec

mha_combined_qkv = MultiHeadAttentionCombinedQKV(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_combined_qkv(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


## 4) Einsum!!

In [7]:
import math

class MHAEinsum(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, \
            'd_out must be divisible by num_heads'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_q = nn.Parameter(torch.randn(d_in, d_out))
        self.W_k = nn.Parameter(torch.randn(d_in, d_out))
        self.W_v = nn.Parameter(torch.randn(d_in, d_out))

        if qkv_bias:
            self.bias_q = nn.Parameter(torch.zeros(d_out))
            self.bias_k = nn.Parameter(torch.zeros(d_out))
            self.bias_v = nn.Parameter(torch.zeros(d_out))

        else:
            self.register_parameter('bias_q', None)
            self.register_parameter('bias_k', None)
            self.register_parameter('bias_v', None)


        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length),diagonal=1))
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.W_q, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_k, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.W_v, a=math.sqrt(5))
        if self.bias_q is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_q)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias_q, -bound, bound)
            nn.init.uniform_(self.bias_k, -bound, bound)
            nn.init.uniform_(self.bias_v, -bound, bound)


    def forward(self,x):
        b, nseq, _ = x.shape

        Q = torch.einsum('bnd,di->bni',x,self.W_q)
        K = torch.einsum('bnd,di->bni',x,self.W_k)
        V = torch.einsum('bnd,di->bni',x,self.W_v)

        if self.bias_q is not None:
            Q += self.bias_q
            K += self.bias_k
            V += self.bias_v

        Q = Q.view(b,nseq,self.num_heads,self.head_dim)
        K = K.view(b,nseq,self.num_heads,self.head_dim)
        V = V.view(b,nseq,self.num_heads,self.head_dim)
        omega = torch.einsum('bnhd,bmhd->bhnm',Q,K)
        mask = self.mask.bool()[:nseq, :nseq]
        omega = omega.masked_fill(mask, -torch.inf)
        alpha = torch.softmax(omega/self.head_dim**0.5,dim = 1)
        alpha = self.dropout(alpha)

        context_vec = torch.einsum('bmhd,bhnm->bnhd',V,alpha)
        context_vec = context_vec.contiguous().view(b,nseq,self.d_out)

        #now it is (b,n_seq,d_out), same shape as x except d_in -> d_out
        context_vec = self.out_proj(context_vec)
        return context_vec
    
mha_einsum = MHAEinsum(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_einsum(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


## 5) PyTorch's scaled dot product attention with FlashAttention

In [8]:
class MHAPyTorchScaledDotProduct(nn.Module):
    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
        super().__init__()

        assert d_out % num_heads == 0, "d_out is indivisible by num_heads"

        self.num_heads = num_heads
        self.context_length = context_length
        self.head_dim = d_out // num_heads
        self.d_out = d_out

        self.W_qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
        self.proj = nn.Linear(d_out, d_out)
        self.dropout = dropout

    def forward(self, x):
        batch_size, num_tokens, embed_dim = x.shape

        # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
        qkv = self.W_qkv(x)
        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        query, key, value = qkv.unbind(0)
        use_dropout = 0. if not self.training else self.dropout

        context_vec = nn.functional.scaled_dot_product_attention(
            query, key, value, attn_mask=None, dropout_p=use_dropout, is_causal=True)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.transpose(1,2).contiguous().view(batch_size, num_tokens, self.d_out)

        context_vec = self.proj(context_vec)

        return context_vec
    
mha_pytorch_scaled = MHAPyTorchScaledDotProduct(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_pytorch_scaled(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


## 6) PyTorch's scaled dot product attention w/o FlashAttention

In [9]:
class MHAPyTorchSDPAWithoutFlash(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, \
            'd_out must be divisible by num_heads'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.context_length = context_length

        self.W_qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length),diagonal=1))
        self.dropout = dropout
        self.out_proj = nn.Linear(d_out, d_out)

    #(b, n_seq, 3 x d)
    def forward(self,x):
        b, num_tokens, emb_dim = x.shape
        qkv = self.W_qkv(x)
        qkv = qkv.view(b,num_tokens,3,self.num_heads,self.head_dim)
        qkv = qkv.permute(2, 0 ,3 ,1, 4)
        query, key, value = qkv.unbind(0)


        use_dropout = 0. if not self.training else self.dropout

        if self.context_length >= num_tokens:
            attn_mask = self.mask[:num_tokens, :num_tokens]
        else:
            attn_mask = self.mask[:self.context_length, :self.context_length]

        context_vec = nn.functional.scaled_dot_product_attention(query, key, value,attn_mask=attn_mask, dropout_p=use_dropout, is_causal=False)
        context_vec = context_vec.transpose(1,2)
        context_vec = context_vec.contiguous().view(b,num_tokens,self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec
mha_pytorch_sdpa_no_flash = MHAPyTorchSDPAWithoutFlash(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_pytorch_sdpa_no_flash(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


## 7) PyTorch's torch.nn.MultiheadedAttention

In [10]:
import torch.nn as nn

class MHAPyTorchClass(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, need_weights=True):
        super().__init__()
        self.context_length = context_length
        self.multiheaded_attn = nn.MultiheadAttention(
            embed_dim=d_out,
            num_heads=num_heads,
            dropout=dropout,
            bias=qkv_bias,
            add_bias_kv=qkv_bias,
            batch_first=True
        )

        self.need_weights = need_weights
        self.proj = nn.Linear(d_out,d_out)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())


    def forward(self,x):
        batch_size, num_tokens, _ = x.shape

        if self.context_length >= num_tokens:
            attn_mask = self.mask[:num_tokens,:num_tokens]
        else:
            attn_mask = self.mask[:self.context_length, :self.context_length]

        attn_output, _ = self.multiheaded_attn(x, x, x, attn_mask=attn_mask, need_weights=self.need_weights)
        output = self.proj(attn_output)
        return output


mha_pytorch_class_default = MHAPyTorchClass(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False
).to(device)

out = mha_pytorch_class_default(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


## 8) PyTorch's MultiheadAttention with scaled_dot_product_attention

In [11]:
mha_pytorch_class_noweights = MHAPyTorchClass(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=context_len,
    dropout=0.0,
    num_heads=12,
    qkv_bias=False,
    need_weights=False # NEW!
).to(device)

out = mha_pytorch_class_noweights(embeddings)
print(out.shape)

torch.Size([8, 1024, 768])


## 9) PyTorch's FlexAttention

In [12]:
from packaging.version import parse as parse_version

def normalize_version(version):
    parsed_version = parse_version(version)
    return parse_version(f"{parsed_version.major}.{parsed_version.minor}.{parsed_version.micro}")

current_version = normalize_version(torch.__version__)
MIN_TORCH_VERSION = "2.5.0"
required_version = parse_version(MIN_TORCH_VERSION)

In [13]:
if current_version >= required_version and torch.cuda.is_available():
    from torch.nn.attention.flex_attention import flex_attention, create_block_mask


def causal(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx



In [14]:
class MHAPyTorchFlexAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads == 0, \
            'd_out must be divisible by num_heads'
        
        self.context_length = context_length
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
        self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)
        self.dropout = dropout
        self.proj = nn.Linear(d_out, d_out)

    #(b, n_seq, 3 x d)
    def forward(self,x):
        b, nseq, emb_dim = x.shape
        qkv = self.W_qkv(x)
        qkv = qkv.view(b,nseq,3,self.num_heads,self.head_dim)
        qkv = qkv.permute(2, 0 ,3 ,1, 4)
        query, key, value = qkv.unbind(0)

        if self.context_length >= num_tokens:
            attn_mask = self.block_mask[:num_tokens,:num_tokens]
        else:
            attn_mask = self.block_mask[:self.context_length, :self.context_length]

        context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)
        context_vec = context_vec.transpose(1,2)
        context_vec = context_vec.contiguous().view(b,num_tokens,self.d_out)
        context_vec = self.proj(context_vec)
        
        return context_vec

In [15]:
if current_version >= required_version and torch.cuda.is_available():

    mha_pytorch_flex = MHAPyTorchFlexAttention(
        d_in=embed_dim,
        d_out=embed_dim,
        context_length=context_len,
        dropout=0.0,
        num_heads=12,
        qkv_bias=False
    ).to(device)

    out = mha_pytorch_flex(embeddings)
    print(out.shape)

## Speed comparison

In [16]:
torch.manual_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch version: {torch.__version__}")
print(f"Running on {device}")

PyTorch version: 2.9.0
Running on cpu


In [17]:
## 1) CausalAttention MHA wrapper class from chapter 3
%timeit mha_ch03_wrapper(embeddings)

62.7 ms ± 392 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
## 2) The multi-head attention class from chapter 3
%timeit mha_ch03(embeddings)

63.9 ms ± 650 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
## 3) An alternative multi-head attention with combined weights
%timeit mha_combined_qkv(embeddings)

64.1 ms ± 181 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [19]:
## 4) Multi-head attention using Einstein summation
%timeit mha_einsum(embeddings)

73.2 ms ± 5.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [18]:
## 5) Multi-head attention with PyTorch's scaled dot product attention
%timeit mha_pytorch_scaled(embeddings)

85 ms ± 276 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [None]:
## 6) PyTorch's scaled dot product attention without FlashAttention
%timeit mha_pytorch_sdpa_no_flash(embeddings)

99.5 ms ± 790 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
## 7) Using PyTorch's torch.nn.MultiheadAttention
%timeit mha_pytorch_class_default(embeddings)

198 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
## 8) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`
%timeit mha_pytorch_class_noweights(embeddings)

168 ms ± 2.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [None]:
## 9) Using PyTorch's FlexAttention

# Requires PyTorch 2.5.0 or newer and currently only supports CUDA PyTorch
%timeit mha_pytorch_flex(embeddings)