In [1]:
import torch
import torch.nn as nn

In [15]:
class MultiheadCausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout,n_heads,qkv_bias=False):
        super().__init__()
            
        assert (d_out % n_heads == 0), \
    "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = n_heads
        # self.head_dim = d_out // num_heads 
        self.head_dim = d_out // n_heads
        self.dropout = nn.Dropout(dropout)
        self.query_w = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.key_w = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.value_w = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
        
        
    def forward(self,x):
       b ,num_tokens, d_in= x.shape
       
       query = self.query_w(x).view(b, num_tokens, self.num_heads, self.head_dim)
       key = self.key_w(x).view(b, num_tokens, self.num_heads, self.head_dim)
       value = self.value_w(x).view(b, num_tokens, self.num_heads, self.head_dim)
        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
       key = key.transpose(1, 2)
       query = query.transpose(1, 2)
       value = value.transpose(1, 2)
       
       
       attn_scores = query @ key.transpose(2, 3)  # Dot product for each head
       
       
       mask_bool = self.mask[:num_tokens, :num_tokens].bool()
       attn_scores.masked_fill_(mask_bool, -torch.inf)
       attent_weights = torch.softmax(attn_scores/key.shape[-1] ** 0.5 , dim=-1)
       
       attent_weights = self.dropout(attent_weights)
       
       context_vec = (attent_weights @ value).transpose(1, 2) 
       
       context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
       
       context_vec = self.out_proj(context_vec)
       return context_vec
       
       

In [30]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89, 0.55, 0.87, 0.66,0.57, 0.85, 0.64, 0.22, 0.58, 0.33],  # Row 1
     [0.57, 0.85, 0.64, 0.22, 0.58, 0.33,0.77, 0.25, 0.10, 0.05, 0.80, 0.55],  # Row 2
     [0.77, 0.25, 0.10, 0.05, 0.80, 0.55,0.57, 0.85, 0.64, 0.22, 0.58, 0.33]]  # Row 3
)

batch = torch.stack((inputs,inputs), dim=0)
print(batch.shape)
batch_size, context_length ,din= batch.shape
dout = din
mul_cls = MultiheadCausalAttention(din,dout,context_length,0.0,12)


torch.Size([2, 3, 12])


In [31]:
batch

tensor([[[0.4300, 0.1500, 0.8900, 0.5500, 0.8700, 0.6600, 0.5700, 0.8500,
          0.6400, 0.2200, 0.5800, 0.3300],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300, 0.7700, 0.2500,
          0.1000, 0.0500, 0.8000, 0.5500],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500, 0.5700, 0.8500,
          0.6400, 0.2200, 0.5800, 0.3300]],

        [[0.4300, 0.1500, 0.8900, 0.5500, 0.8700, 0.6600, 0.5700, 0.8500,
          0.6400, 0.2200, 0.5800, 0.3300],
         [0.5700, 0.8500, 0.6400, 0.2200, 0.5800, 0.3300, 0.7700, 0.2500,
          0.1000, 0.0500, 0.8000, 0.5500],
         [0.7700, 0.2500, 0.1000, 0.0500, 0.8000, 0.5500, 0.5700, 0.8500,
          0.6400, 0.2200, 0.5800, 0.3300]]])

In [32]:
mul_cls(batch)

tensor([[[ 2.7085e-01, -7.0683e-02,  1.9019e-02,  2.1527e-01,  1.4085e-01,
          -2.3308e-01,  5.1715e-01, -1.5791e-01, -4.4830e-02,  3.2062e-01,
          -5.2556e-01,  1.9533e-01],
         [ 2.6795e-01, -8.3251e-02,  3.5241e-05,  2.1007e-01,  1.0036e-01,
          -3.0398e-01,  4.4630e-01, -3.0354e-01, -1.0211e-01,  3.5549e-01,
          -5.4674e-01,  2.8070e-01],
         [ 2.8022e-01, -7.2653e-02, -2.0956e-02,  1.8526e-01,  9.9690e-02,
          -2.8955e-01,  4.2378e-01, -3.2395e-01, -5.1590e-02,  3.7418e-01,
          -4.8803e-01,  2.4294e-01]],

        [[ 2.7085e-01, -7.0683e-02,  1.9019e-02,  2.1527e-01,  1.4085e-01,
          -2.3308e-01,  5.1715e-01, -1.5791e-01, -4.4830e-02,  3.2062e-01,
          -5.2556e-01,  1.9533e-01],
         [ 2.6795e-01, -8.3251e-02,  3.5241e-05,  2.1007e-01,  1.0036e-01,
          -3.0398e-01,  4.4630e-01, -3.0354e-01, -1.0211e-01,  3.5549e-01,
          -5.4674e-01,  2.8070e-01],
         [ 2.8022e-01, -7.2653e-02, -2.0956e-02,  1.8526e-01,  

In [28]:
ok = 0
for p in mul_cls.parameters():
    if p.requires_grad == False:
        ok = ok+ p.numel()
        print(p.numel())
print('feff',ok)

36
36
36
36
6
feff 150
