In [6]:
import torch
import torch.nn as nn 
from torch.nn import functional as F

In [8]:
torch.__version__

'2.4.0'

In [42]:
n_emb=6
head_size=1
block_size=8

class Head(nn.Module):
    
    '''one head in self attention'''

    def __init__(self, head_size):
        super().__init__()
        self.key=nn.Linear(n_emb,head_size)
        self.query=nn.Linear(n_emb,head_size)
        self.value=nn.Linear(n_emb,head_size)
        
        self.register_buffer('trill', torch.tril(torch.ones(block_size,block_size)))


    
    def forward(self,x):
        batch, blocks, X = x.shape
        key = self.key(x) # batch, block_size, X -- shape
        query = self.query(x) # batch, block_size, X -- shape
        weight = query @ key.transpose(-2, -1) * X ** (-0.5)
        weight=weight.masked_fill(self.trill[:blocks, :blocks] ==0 , float('-inf'))
        weight=F.softmax(weight, dim=-1)
        out = weight @ self.value(x)
        return out


        
    

In [44]:
h=Head(2)
h

Head(
  (key): Linear(in_features=6, out_features=2, bias=True)
  (query): Linear(in_features=6, out_features=2, bias=True)
  (value): Linear(in_features=6, out_features=2, bias=True)
)

In [46]:
h(torch.zeros(3,8,6))

tensor([[[-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795]],

        [[-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795]],

        [[-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795],
         [-0.1966, -0.1795]]], grad_fn=<UnsafeViewBackward0>)

In [38]:
class MultiHeadAttention(nn.Module):
    '''multihead in self attention'''
    def __init__(self, head_size, num_heads):
        super(),__init__()
        self.heads=nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.layer=nn.Linear(n_emb,n_emb)
        
    def forward(self,x):
        out=torch.cat([h(x) for h in self.head], dim=-1)
        return self.layer(out)


    
   
    