In [2]:
import torch
import torch.nn as nn
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

batchs=torch.stack((inputs,inputs),dim=0)
batchs.shape

torch.Size([2, 6, 3])

In [5]:
class casualattention(nn.Module):
    def __init__(self,din,dout,biasbool,batchsize,dropoutsize):
        super().__init__()
        self.W_query=nn.Linear(din,dout,bias=biasbool)
        self.W_keys=nn.Linear(din,dout,bias=biasbool)
        self.W_values=nn.Linear(din,dout,bias=biasbool)
        self.dropout=nn.Dropout(dropoutsize)
        self.register_buffer('mask', torch.triu(torch.ones(batchsize, batchsize), diagonal=1))

    def forward(self,x):
        b, num_tokens, d_in = x.shape
        query= self.W_query (x)
        keys= self.W_keys (x)
        values=self.W_values (x)

        scores=query @ keys.transpose(1,2)
        scores.masked_fill_(  
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights=torch.softmax(scores/(keys.shape[-1])**0.5,dim=-1)
        attn_weights = self.dropout(attn_weights) 

        context_vec = attn_weights @ values
        return context_vec
    
class Multiheadattention(nn.Module):
    def __init__ (self,din,dout,biasbool,batchsize,dropoutsize,noofattentionhead):
        super().__init__()
        self.heads=nn.ModuleList([casualattention(din,dout,biasbool,batchsize,dropoutsize) for i in range(noofattentionhead)])
        self.out_proj=nn.Linear(dout*noofattentionhead,dout*noofattentionhead)
    
    def forward(self,x):
        context_vec=torch.cat([head(x) for head in self.heads],dim=-1)
        return self.out_proj(context_vec)


torch.manual_seed(123)

context_length = batchs.shape[1]
ca = Multiheadattention(3, 2,False, context_length, 0.0,8)

context_vecs = ca(batchs)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

        



tensor([[[ 0.1232, -0.1066, -0.0376, -0.1082, -0.3488,  0.0211,  0.3412,
           0.1926, -0.3484,  0.5276,  0.3979,  0.1163,  0.3184,  0.0403,
           0.0790, -0.4035],
         [ 0.0984, -0.0307, -0.0094, -0.1888, -0.4576,  0.0386,  0.2738,
           0.3218, -0.4608,  0.6029,  0.3603,  0.2969,  0.4167,  0.1128,
           0.1826, -0.4163],
         [ 0.0886, -0.0022, -0.0031, -0.2085, -0.4895,  0.0468,  0.2533,
           0.3586, -0.4961,  0.6282,  0.3520,  0.3553,  0.4418,  0.1322,
           0.2140, -0.4146],
         [ 0.1011, -0.0076, -0.0316, -0.1693, -0.4635,  0.0505,  0.2437,
           0.3211, -0.4247,  0.5862,  0.3282,  0.3417,  0.3826,  0.1484,
           0.1847, -0.3842],
         [ 0.0954, -0.0044, -0.0319, -0.1285, -0.4194,  0.0414,  0.2410,
           0.2809, -0.3754,  0.5705,  0.3284,  0.3530,  0.3389,  0.1275,
           0.1464, -0.3400],
         [ 0.1050, -0.0089, -0.0432, -0.1365, -0.4345,  0.0466,  0.2355,
           0.2888, -0.3688,  0.5581,  0.3132,  0.343