## Manual Self-Attention Implementation in PyTorch – Focusing on a Single Token's Perspective

In [None]:
import torch

We create a random input token matrix inputToken with shape [5, 4]. Each row represents a token in 4D space.

In [None]:
d_in = 4
no_of_tokens = 5
# hyperparameter
d_out = 3

inputToken = torch.rand(no_of_tokens,d_in)
print(inputToken)
print("Shape:", inputToken.shape)

tensor([[0.5159, 0.4220, 0.5786, 0.9455],
        [0.8057, 0.6775, 0.6087, 0.6179],
        [0.6932, 0.4354, 0.0353, 0.1908],
        [0.9268, 0.5299, 0.0950, 0.5789],
        [0.9131, 0.0275, 0.1634, 0.3009]])
Shape: torch.Size([5, 4])


**Example**: Picking One Token

In [None]:
# pick out a single token
stoken = inputToken[2]
stoken.shape

torch.Size([4])

**Learnable Parameters:** Query, Key, and Value Matrices
In self-attention, every token is transformed into three vectors:

**Query**: What the token is asking about others

**Key**: What others contain (like tags or metadata)

**Value**: The actual content to be aggregated

In [None]:
# random seed initialization
torch.manual_seed(123)
# defining query, key and value matrix
QueryM = nn.Linear(d_in, d_out)
KeyM = nn.Linear(d_in, d_out)
ValueM = nn.Linear(d_in, d_out)

In [None]:
print("Query Matrix: ",QueryM)
print("------------")
print("Key Matrix: ",KeyM)
print("------------")
print("Value Matrix: ",ValueM)
print("------------")

Query Matrix:  Linear(in_features=4, out_features=3, bias=True)
------------
Key Matrix:  Linear(in_features=4, out_features=3, bias=True)
------------
Value Matrix:  Linear(in_features=4, out_features=3, bias=True)
------------


### Single Token Attention

We transform one token into its **query**, **key**, and **value**.

In [None]:
query3 = QueryM(stoken)
key3 = KeyM(stoken)
value3 = ValueM(stoken)

In [None]:
query3

tensor([-0.5313, -0.5278, -0.2748], grad_fn=<ViewBackward0>)

In [None]:
# finding keys & values for all inputs
keys = KeyM(inputToken)
values = ValueM(inputToken)

In [None]:
keys

tensor([[ 0.2236,  0.8145, -0.4259],
        [ 0.1462,  0.9094, -0.2659],
        [ 0.1122,  0.6939, -0.2615],
        [ 0.0270,  0.8234, -0.2469],
        [ 0.2751,  0.6120, -0.0675]], grad_fn=<AddmmBackward0>)

In [None]:
values

tensor([[ 0.5962,  0.2799, -0.5147],
        [ 0.7200,  0.4179, -0.5233],
        [ 0.4013,  0.6908, -0.3236],
        [ 0.6165,  0.7007, -0.5124],
        [ 0.4484,  0.7392, -0.4560]], grad_fn=<AddmmBackward0>)

### Attention Weights Calculation


This gives raw **attention scores** for each token with respect to the query.

In [None]:
attention_weights = query3 @ keys.T
attention_weights

tensor([-0.4316, -0.4845, -0.3540, -0.3810, -0.4506],
       grad_fn=<SqueezeBackward4>)

We scale the scores (divide by** √d_in**) to prevent overly large dot products and apply softmax to normalize them into probabilities.

In [None]:
attention_weights = torch.softmax(attention_weights / d_in**(1/2), dim =-1)
attention_weights

tensor([0.1988, 0.1936, 0.2067, 0.2039, 0.1969], grad_fn=<SoftmaxBackward0>)

In [None]:
torch.sum(attention_weights)

tensor(1.0000, grad_fn=<SumBackward0>)

This is the weighted sum of value vectors — the output of attention for token 3.

In [None]:
context_vector_3 = attention_weights @ values
context_vector_3

tensor([ 0.5549,  0.5678, -0.4649], grad_fn=<SqueezeBackward4>)

## Self-Attention as a PyTorch Class – Computing Attention Across All Tokens

This class encapsulates the **entire** self-attention mechanism for a given input sequence.

In [None]:
import torch.nn as nn
torch.manual_seed(123)

<torch._C.Generator at 0x7d5fe87d6870>

In [None]:
import torch.nn as nn
torch.manual_seed(123)

class SelfAttention(nn.Module):
    def __init__(self, no_of_tokens, d_in, d_out):
        super(SelfAttention, self).__init__()
        self.QueryM = nn.Linear(d_in, d_out)
        self.KeyM   = nn.Linear(d_in, d_out)
        self.ValueM = nn.Linear(d_in, d_out)

    def forward(self, inputToken):
        queries = self.QueryM(inputToken)
        keys = self.KeyM(inputToken)
        values = self.ValueM(inputToken)
        attention_weights = queries @ keys.T
        attention_weights = torch.softmax(attention_weights / d_in**0.5, dim=-1)
        context_vector = attention_weights @ values
        return context_vector, queries, keys, values, attention_weights

You can now instantiate and run it:

In [None]:
att = SelfAttention(5, 4, 3)
context_vector, queries, keys, values, attention_weights = att(inputToken)
print("Context Vector:\n", context_vector)

Context Vector:
 tensor([[ 0.5522,  0.5712, -0.4637],
        [ 0.5531,  0.5700, -0.4640],
        [ 0.5549,  0.5678, -0.4649],
        [ 0.5535,  0.5694, -0.4642],
        [ 0.5536,  0.5687, -0.4642]], grad_fn=<MmBackward0>)
