implementing a compact self attention python class


In [1]:
import torch.nn as nn
import torch

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Parameter(torch.rand(d_in, d_out))
        self.W_key=nn.Parameter(torch.rand(d_in, d_out))
        self.W_value=nn.Parameter(torch.rand(d_in, d_out))


    def forward(self,x):
        keys=x@self.W_key
        queries=x@self.W_query
        values=x@self.W_value

        atten_scores=queries@keys.T # omega
        atten_weights=torch.softmax(atten_scores/keys.shape[-1]**0.5, dim=-1)
        context_vec=atten_weights@values

        return context_vec





In [2]:
inputs = torch.tensor([
    [0.43, 0.15, 0.89],  # your        (x^1)
    [0.55, 0.87, 0.66],  # journey     (x^2)
    [0.57, 0.85, 0.64],  # starts     (x^3)
    [0.22, 0.58, 0.33],  # with       (x^4)
    [0.77, 0.25, 0.10],  # one        (x^5)
    [0.05, 0.80, 0.55]  # step        (X^6)
])


In [3]:
d_in = inputs.shape[1]  # input embedding size (3)
d_out = 2  # output embedding size

In [4]:
torch.manual_seed(123)
sa_v1=SelfAttention_v1(d_in=d_in, d_out=d_out)


In [5]:
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


*we can furthur improve this by introducing linear layer*

In [6]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value=nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self,x):
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        attn_scores=queries@keys.T
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        context_vec=attn_weights@values
        
        return context_vec

        

In [7]:
torch.manual_seed(123)
sa_v2=SelfAttention_v2(d_in=d_in, d_out=d_out)
print(sa_v2(inputs))


tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)
