In [1]:
# A compact self-attention class
!pip install torch



In [2]:
import torch
import torch.nn as nn

In [3]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in,d_out))
        self.W_key = nn.Parameter(torch.rand(d_in,d_out))
        self.W_value = nn.Parameter(torch.rand(d_in,d_out))
    
    def forward(self,x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [4]:
# Set dimensions
d_in = 4    # Input dimension
d_out = 6   # Output dimension

In [6]:
# Create sample input - shape [batch_size, d_in]
inputs = torch.randn(3, d_in)  # 3 samples, each with d_in features

In [7]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[ 0.1996,  0.3756,  0.5111, -1.0476, -0.9225, -0.2788],
        [ 0.1823,  0.4863,  0.4892, -1.2474, -1.0161, -0.2384],
        [ 2.0155,  1.0874,  0.9103,  3.2482,  2.5879,  1.5314]],
       grad_fn=<MmBackward0>)
