In [1]:
import torch
import torch.nn as nn

## Simple implementation of Attention

In [6]:
class CausalSelfAttention(nn.Module):
    def __init__(self,d_in, d_out, context_length, dropout, qkb_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias= qkb_bias)
        self.W_key = nn.Linear(d_in,d_out,bias=qkb_bias)
        self.W_value = nn.Linear(d_in,d_out,bias=qkb_bias)

        self.dropout = nn.Dropout(dropout)

        self.register_buffer('mask', torch.triu(torch.ones(context_length,context_length),diagonal=1))

    def forward(self,x):
        b, n_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1,2)

        attn_scores.masked_fill_(
            self.mask.bool()[:n_tokens, :n_tokens],
            -torch.inf
        )

        attn_weights = torch.softmax(
            attn_scores/self.d_out**0.5, dim=-1
        )

        context_vec = attn_weights = attn_weights @ values
        return context_vec
    


In [7]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self,d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()

        self.heads = nn.ModuleList(
            [CausalSelfAttention(d_in, d_out, context_length,dropout, qkv_bias )
            for _ in range(num_heads)]
        )

        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)


    def forward(self,x):
        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.out_proj(context_vec)

In [None]:
d_in = 256
max_length = 1024
num_heads = 2
d_out = d_in //num_heads

mha = MultiHeadAttentionWrapper(d_in, d_out, max_length,0.0,num_heads )

#from data loader
batch = input_embeddings

context_vecs = mha(batch)

print(context_vecs.shape)

NameError: name 'batch' is not defined

## Alternate Implementation - Parallel and Optimized

In [8]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias= False):
        super().__init__()

        assert d_out% num_heads == 0, 'd_out must be divisible by num_heads'

        #this d_out is larger than the above implementation, it's actually d_out*num_heads of old implementation
        self.d_out = d_out
        self.num_heads = num_heads

        self.head_dim = d_out//  num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, n_tokens, d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        #view them to add extra dimension
        keys = keys.view(b,n_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b,n_tokens, self.num_heads, self.head_dim)
        values = values.view(b,n_tokens, self.num_heads, self.head_dim)

        #transpose to get (b, num_heads, n_tokens, head_dim)

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        #attn_score
        attn_scores = queries @ keys.transpose(2,3)

        #masked attn_score
        attn_scores.masked_fill_(
            self.mask.bool()[:n_tokens, :n_tokens],
            -torch.inf
        )

        #attn_weight
        attn_weights = torch.softmax(
            attn_scores/ keys.shape[-1]**0.5 , dim = -1
        )

        attn_weights = self.dropout(attn_weights)

        context_vecs = attn_weights @ values

        context_vecs = context_vecs.transpose(1,2)

        context_vecs = context_vecs.contiguous().view(b, n_tokens, self.d_out)

        context_vecs = self.out_proj(context_vecs)

        return context_vecs

        

