In [2]:

import torch.nn as nn
import torch

In [14]:


class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout_prob, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias = qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) 
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal = 1))
        self.dropout = nn.Dropout(dropout_prob)

    
    def forward(self, x):
        batch_size, num_tokens, featers = x.shape
        Queries = self.W_query(x)
        Keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = Queries @ Keys.transpose(1,2)
        # print("mask: \n", self.msk.bool()[:num_tokens, :num_tokens])
        # print("mask Shape : ", self.mask.shape)
        # print("attention scores : \n", attn_scores)

        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / (Keys.shape[-1]**0.5), dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec





In [15]:

x = torch.tensor([[0.43, 0.15, 0.89],
                 [0.55, 0.87, 0.66],
                 [0.57, 0.85, 0.66],
                 [0.22, 0.58, 0.66],
                 [0.77, 0.25, 0.10],
                 [0.05, 0.80, 0.55]])

inputs = torch.stack((x,x), dim=0)

print("input : \n", inputs)
print("input shape : ", inputs.shape)


d_in = 3
d_out = 2

context_length = inputs.shape[1]
print("context length : ", context_length)
ca = CausalAttention(d_in=3, d_out =2, context_length=context_length, dropout_prob=0.2)
print("causal attention for input pairs : \n", ca)
print("context vector for input pairs : \n", ca(inputs))

input : 
 tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6600],
         [0.2200, 0.5800, 0.6600],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6600],
         [0.2200, 0.5800, 0.6600],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])
input shape :  torch.Size([2, 6, 3])
context length :  6
causal attention for input pairs : 
 CausalAttention(
  (W_query): Linear(in_features=3, out_features=2, bias=False)
  (W_key): Linear(in_features=3, out_features=2, bias=False)
  (W_value): Linear(in_features=3, out_features=2, bias=False)
  (dropout): Dropout(p=0.2, inplace=False)
)
context vector for input pairs : 
 tensor([[[ 3.3758e-01, -6.0836e-01],
         [ 0.0000e+00,  0.0000e+00],
         [ 1.1897e-01, -5.3597e-01],
         [ 1.0012e-02, -4.9149e-01],
         [-1.1757e-05, -6.6738e-01],
   