Implementing a compact casual attention class

In [2]:
import torch
inputs = torch.tensor([
    [0.43, 0.15, 0.89],  # your        (x^1)
    [0.55, 0.87, 0.66],  # journey     (x^2)
    [0.57, 0.85, 0.64],  # starts     (x^3)
    [0.22, 0.58, 0.33],  # with       (x^4)
    [0.77, 0.25, 0.10],  # one        (x^5)
    [0.05, 0.80, 0.55]  # step        (X^6)
])



In [1]:
import torch
import torch.nn as nn

class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )


    def forward(self, x):
        # batch_size, num_tokens, embedding_size
        # batch, seq_length, embedding_dim
        b, num_tokens, d_in=x.shape
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)

        attn_score=queries@keys.transpose(1,2)
        attn_score.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        # dim=0 column wise (downward)
        # dim=1 row wise (side wise)
        # dim=-1 the last dimention
        attn_weights=torch.softmax(attn_score/keys.shape[-1]**0.5, dim=-1)
        attn_weights=self.dropout(attn_weights)
        context_vec=attn_weights@values
        return context_vec




In [4]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
print(batch)

torch.Size([2, 6, 3])
tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])


In [5]:
torch.manual_seed(123)
context_length=batch.shape[1]
d_in=batch.shape[2]
d_out=2
ca=CasualAttention(d_in=d_in, d_out=d_out, context_length=context_length, dropout=0.0)
context_vec=ca(batch)



In [6]:
print(context_vec.shape)

torch.Size([2, 6, 2])
