In [2]:
import torch.nn as nn

In [3]:
class Causal_Attention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,q_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_keys=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_values=nn.Linear(d_in,d_out,bias=q_bias)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
    #register buffers are used when those things are not to be included in training process :here we have to keep uppertriangular matrix fixed and need not be trained
    def forward(self,x):
        b,num_tokens,d_in=x.shape    #b=batches , num_tokens, d_in=dimension
        keys=self.W_keys(x)
        queries=self.W_query(x)
        values=self.W_values(x)
        
        attn_scores=queries @ keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)  #masked_fill_=in_place change
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)    #[:num_tokens] is used for cases where no of tokens are less than supported context size]
        attn_weights=self.dropout(attn_weights)
        context_vectors= attn_weights @ values
        return context_vectors

In [4]:
class Multihead_Attention(nn.Module):
    def __init__(self,d_in,d_out,context_length,num_heads,dropout,q_bias=False):
        super().__init__()
        self.heads=nn.ModuleList([Causal_Attention(d_in,d_out,context_length,dropout,q_bias)for _ in range(num_heads)])
        
    def forward(self,x):
        return torch.cat([head(x) for head in self.heads],dim=-1)
    

In [5]:
inputs=torch.tensor([[0.43,0.15,0.89], #your
                             [0.55,0.89,0.66], #journey
                             [0.57,0.85,0.64],  #starts
                             [0.22,0.58,0.33],  #with
                             [0.77,0.25,0.10],   #one
                             [0.05,0.80,0.55]])  #step
torch.manual_seed(123)
batch=torch.stack((inputs,inputs),0)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [6]:
torch.manual_seed(123)
context_length=batch.shape[1]
d_in=3
d_out=2
multi=Multihead_Attention(d_in,d_out,context_length,num_heads=2,dropout=0)
context_vectors=multi(batch)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [7]:
context_vectors.size()

torch.Size([2, 6, 4])

In [8]:
context_vectors

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5913,  0.0009,  0.5925,  0.3316],
         [-0.6325, -0.0663,  0.6223,  0.3896],
         [-0.5693, -0.0865,  0.5493,  0.3615],
         [-0.5540, -0.0999,  0.5332,  0.3448],
         [-0.5311, -0.1096,  0.5087,  0.3511]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5913,  0.0009,  0.5925,  0.3316],
         [-0.6325, -0.0663,  0.6223,  0.3896],
         [-0.5693, -0.0865,  0.5493,  0.3615],
         [-0.5540, -0.0999,  0.5332,  0.3448],
         [-0.5311, -0.1096,  0.5087,  0.3511]]], grad_fn=<CatBackward0>)

# with weight splits

In [49]:
class Multihead_Attention_V2(nn.Module):
    def __init__(self,d_in,d_out,context_length,num_heads,dropout,q_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"
        self.d_out=d_out
        self.num_heads=num_heads
        self.W_query=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_keys=nn.Linear(d_in,d_out,bias=q_bias)
        self.W_values=nn.Linear(d_in,d_out,bias=q_bias)
        self.head_dim=d_out//num_heads
        self.out_proj=nn.Linear(d_out,d_out)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
        
    
    def forward(self,x):
        b,num_tokens,d_in=x.shape
        keys=self.W_keys(x)
        queries=self.W_query(x)
        values=self.W_values(x)
        #now change dimensions
        
        keys=keys.view(b,num_tokens,self.num_heads,self.head_dim)
        queries=queries.view(b,num_tokens,self.num_heads,self.head_dim)
        values=values.view(b,num_tokens,self.num_heads,self.head_dim)
        #group as num_heads
        keys=keys.transpose(1,2)
        queries=queries.transpose(1,2)
        values=values.transpose(1,2)
        
        attn_scores=queries @ keys.transpose(2,3)
        
        mask_bool=self.mask.bool()[:num_tokens, :num_tokens]  #if num_tokens is less than specified context length
        
        attn_scores.masked_fill_(mask_bool,-torch.inf)
        
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
        attn_weights=self.dropout(attn_weights)
        
        context_vectors=(attn_weights @ values).transpose(1,2)
        context_vectors=context_vectors.contiguous().view(b,num_tokens,self.d_out)
        context_vectors=self.out_proj(context_vectors)
        
        return context_vectors
        
        
        

# creating an instance of the class

In [50]:
torch.manual_seed(123)
inputs=torch.tensor([[0.43,0.18,0.89,0.55,0.87,0.66],
                    [0.57,0.85,0.64,0.22,0.58,0.33],
                    [0.77,0.25,0.10,0.05,0.80,0.55]])
batch=torch.stack((inputs,inputs),dim=0)
batch_size,context_length,d_in=batch.shape
d_out=6
mha_=Multihead_Attention_V2(d_in,d_out,context_length,num_heads=2,dropout=0.2)
context_vectors=mha_(batch)
context_vectors

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

tensor([[[ 0.0719, -0.2033, -0.1114,  0.0696, -0.4202, -0.3544],
         [ 0.0661, -0.0972,  0.0067, -0.0077, -0.3788, -0.3595],
         [ 0.1325, -0.0153,  0.1036, -0.0873, -0.2743, -0.2474]],

        [[ 0.1765, -0.0618,  0.0931,  0.0191, -0.3341, -0.2435],
         [ 0.0661, -0.0972,  0.0067, -0.0077, -0.3788, -0.3595],
         [ 0.1350, -0.0790, -0.0160, -0.0484, -0.2638, -0.2286]]],
       grad_fn=<ViewBackward0>)

In [36]:
batch.shape

torch.Size([2, 3, 6])

In [37]:
batch

tensor([[[0.4300, 0.1800, 0.8900, 0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1800, 0.8900, 0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500]]])