<h1 style="text-align: center; text-decoration: underline ;">
    Self-Attention from Scratch
</h1>


##  Imports and data initialization

In [185]:
import torch
import torch.nn.functional as F
from torch import nn
from profiler_utils import record_time_function
torch.set_printoptions(sci_mode=False)

In [186]:
torch.manual_seed(74)
B,T,C = 32,256,384 
x = torch.randn(B,T,C)
x.shape

torch.Size([32, 256, 384])

## V1 : naive implementation 

> ### <u> Intuition </u>
> - We are simply **averaging across all the previously generated tokens for each seperate batch**  
> and then **predicting the next one**.  
>
> - While this approach is quite **lossy**, it’s still a **good starting point** —  
> because the information that seems lost can still be recovered later.
> 
> <br>

In [187]:
@record_time_function(runs=5)
def naive_implementation():
    xbow = torch.zeros((B,T,C))
    for b in range(B):
        for t in range(T):
            xprev = x[b,:t+1]
            xbow[b,t] = torch.mean(xprev,0)
    return xbow


In [188]:
naive_implementation()

[CUDA] naive_implementation (defined in 2484760147.py): 465.6422 ms (avg over 5 runs)


tensor([[[    -0.4402,     -1.2658,     -1.8452,  ...,     -0.8260,
              -0.9302,     -0.8619],
         [     0.3373,     -0.9448,     -0.3450,  ...,      0.5698,
              -0.9247,     -0.2553],
         [     0.3851,     -0.0229,     -0.2522,  ...,     -0.2147,
              -0.8270,     -0.0406],
         ...,
         [     0.0849,     -0.0396,     -0.0593,  ...,     -0.0612,
               0.0360,     -0.0222],
         [     0.0793,     -0.0407,     -0.0584,  ...,     -0.0613,
               0.0311,     -0.0216],
         [     0.0758,     -0.0367,     -0.0510,  ...,     -0.0634,
               0.0361,     -0.0229]],

        [[     0.0895,      0.4426,     -0.2971,  ...,      0.0378,
              -1.9672,     -0.9947],
         [     0.0705,      0.3697,     -0.2142,  ...,     -0.6332,
              -0.5950,     -0.0885],
         [    -0.2398,      0.4682,     -0.0545,  ...,      0.0834,
              -0.4633,     -0.7432],
         ...,
         [     0.0444,   

## V2 : Efficient Averaging using tril and matmul 

> ### <u>Intuition</u>
> - `tril` gives us the **lower triangular** part of a matrix.  
> - When we perform a row-wise normalization `tril` matrix full of ones:
> 
> <div style="text-align: center;">
> <pre>
>  [1,    0,    0   ]
>  [0.5,  0.5,  0   ]
>  [0.33, 0.33, 0.33]
> </pre>
> </div>
> 
> - The **bottom-left triangle** now contains **weights** that sum to 1 in each row,  
> while the top-right triangle is zeros.  
> 
> - Now if we perform **matrix multiplication** using `@` <u>(matrix multiplication operator)</u>, we get the same results as if we were performing an average.
> 
> </br>


In [189]:
@record_time_function(runs=100)
def efficient_averaging():
    wei = torch.tril(torch.ones(T,T))
    wei = wei / wei.sum(1, keepdim=True)
    xbow2 = wei @ x # wei brodcasts itself 4 times as Batch size is 4 so we multiply each of the 
    xbow2.shape     # examples in the batch with the normalized matrix
    return xbow2

In [190]:
efficient_averaging()

[CUDA] efficient_averaging (defined in 576775067.py): 4.4733 ms (avg over 100 runs)


tensor([[[    -0.4402,     -1.2658,     -1.8452,  ...,     -0.8260,
              -0.9302,     -0.8619],
         [     0.3373,     -0.9448,     -0.3450,  ...,      0.5698,
              -0.9247,     -0.2553],
         [     0.3851,     -0.0229,     -0.2522,  ...,     -0.2147,
              -0.8270,     -0.0406],
         ...,
         [     0.0849,     -0.0396,     -0.0593,  ...,     -0.0612,
               0.0360,     -0.0222],
         [     0.0793,     -0.0407,     -0.0584,  ...,     -0.0613,
               0.0311,     -0.0216],
         [     0.0758,     -0.0367,     -0.0510,  ...,     -0.0634,
               0.0361,     -0.0229]],

        [[     0.0895,      0.4426,     -0.2971,  ...,      0.0378,
              -1.9672,     -0.9947],
         [     0.0705,      0.3697,     -0.2142,  ...,     -0.6332,
              -0.5950,     -0.0885],
         [    -0.2398,      0.4682,     -0.0545,  ...,      0.0834,
              -0.4633,     -0.7432],
         ...,
         [     0.0444,   

## V3 : Adding Softmax to the implementation 

In [None]:
@record_time_function(runs=100)
def softmax_averaging():
    tril = torch.tril(torch.ones(T,T))
    wei = torch.zeros((T,T))
    wei = wei.masked_fill(tril == 0 , float('-inf'))
    wei = F.softmax(wei,dim=-1)
    xbow3 = wei @ x 
    return xbow3

In [175]:
softmax_averaging()

time taken by softmax_normalizing: 4.7276 ms


tensor([[[    -0.4402,     -1.2658,     -1.8452,  ...,     -0.8260,
              -0.9302,     -0.8619],
         [     0.3373,     -0.9448,     -0.3450,  ...,      0.5698,
              -0.9247,     -0.2553],
         [     0.3851,     -0.0229,     -0.2522,  ...,     -0.2147,
              -0.8270,     -0.0406],
         ...,
         [     0.0849,     -0.0396,     -0.0593,  ...,     -0.0612,
               0.0360,     -0.0222],
         [     0.0793,     -0.0407,     -0.0584,  ...,     -0.0613,
               0.0311,     -0.0216],
         [     0.0758,     -0.0367,     -0.0510,  ...,     -0.0634,
               0.0361,     -0.0229]],

        [[     0.0895,      0.4426,     -0.2971,  ...,      0.0378,
              -1.9672,     -0.9947],
         [     0.0705,      0.3697,     -0.2142,  ...,     -0.6332,
              -0.5950,     -0.0885],
         [    -0.2398,      0.4682,     -0.0545,  ...,      0.0834,
              -0.4633,     -0.7432],
         ...,
         [     0.0444,   

## V4 : Complete self-attention

> as we can see above wei seems to give equal weightage to all tokens, however in theory this is not really true as certain tokens might find other tokens more or less interesting , therefore we make each vector or token emit a key and query vector , this key gives certain key informations that answers the queries that each word might be asking all other other words, the reason this is called self-attention is because the input to all these vectors is the same 'x'
> 
><br>

In [178]:
@record_time_function(runs=100)
def self_attn():
    head_size = 16 
    key  = nn.Linear(C,head_size, bias = False)
    query  = nn.Linear(C,head_size, bias = False)
    value  = nn.Linear(C,head_size, bias = False)

    k = key(x)
    q = query(x)
    v = value(x) # exists because directly using the tokens to aggregate we get a value each token holds 
                # and aggregate that instead and also to get the output in the head size dimension
                
    wei = q @ k.transpose(-2,-1) # dot product essentially

    tril = torch.tril(torch.ones(T,T))
    wei = wei.masked_fill(tril == 0 , float('-inf'))
    wei = F.softmax(wei,dim=-1) # normalizing

    v = value(x)
    xbow4 = wei @ v
    return xbow4


In [180]:
self_attn()

time taken by self_attn: 4.9463 ms


tensor([[[-1.1047, -0.0934,  0.4096,  ..., -0.9040, -0.0155,  0.9516],
         [ 0.0608, -0.3139,  0.4058,  ...,  0.2095,  0.0601, -0.1607],
         [ 0.7137, -0.0043, -0.2764,  ..., -0.7193, -0.3178, -0.0263],
         ...,
         [ 0.0494,  0.0107,  0.0033,  ..., -0.0117, -0.1115, -0.1293],
         [ 0.1249, -0.1412,  0.1612,  ..., -0.0303, -0.1349, -0.3617],
         [-0.0850, -0.1028, -0.2304,  ..., -0.0188,  0.0623, -0.0946]],

        [[ 0.1180,  1.0841,  0.4757,  ...,  0.6120,  0.0767, -0.5927],
         [-1.1127,  0.6219,  0.3231,  ..., -0.0248, -0.0586,  0.7324],
         [-0.2012,  0.4842,  0.4768,  ...,  0.0085,  0.0180,  0.0816],
         ...,
         [ 0.0130,  0.1304,  0.0526,  ...,  0.0698, -0.0454, -0.1068],
         [-0.0121,  0.1925,  0.3428,  ..., -0.1335, -0.0814, -0.3686],
         [-0.0101,  0.1342,  0.1216,  ..., -0.0432,  0.0320, -0.0691]],

        [[-0.0757,  0.7129,  1.2230,  ...,  0.1801,  0.0882, -1.0942],
         [ 0.0562,  0.4837,  0.1702,  ...,  0

> ## NOTES
> - attention is just a communication mechanism , it can be applied to any arbitrary directed graph
> - attention doesnt have a notion of space it operates over sets of vectors therefore it needs positional encoding
> - there is no communication across batches, each example in the batch has its own attention
> - encoders unlike decoders are not causal, the future can talk to the past, gpt is a decoder only architecture, attention works with both encoders and decoders
> - attention can be of many types , cross attention uses the queries from another node and only passes the key and value from current nodes or vice versa
> - attention needs to be normalized before softmaxing or it turns into an almost one-hot encoded vector where the major attention is only paid to one token instead of spread across tokens like it should be. high variance before softmaxing leads to almost one-hot like predictions
>
><br>

## Why normalizing before softmax matters

In [None]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2,-1)
k2 = torch.randn(B,T,head_size)
q2 = torch.randn(B,T,head_size)
wei2 = q2 @ k2.transpose(-2,-1) * head_size**-0.5

In [10]:
k.var() , k2.var()

(tensor(1.0069), tensor(0.9967))

In [11]:
q.var() ,  q2.var()

(tensor(0.9302), tensor(0.8952))

In [12]:
wei.var() , wei2.var()

(tensor(15.6977), tensor(1.0194))

In [13]:
print("Softmax(wei)  [0][0]:")
print(torch.softmax(wei, dim=-1)[0][0])   # almost one-hot like results

print("\nSoftmax(wei2) [0][0]:")
print(torch.softmax(wei2, dim=-1)[0][0])  # better distribution

Softmax(wei)  [0][0]:
tensor([    0.0000,     0.8555,     0.1444,     0.0000,     0.0000,     0.0000,
            0.0000,     0.0001])

Softmax(wei2) [0][0]:
tensor([0.2348, 0.0406, 0.1533, 0.0552, 0.0287, 0.3193, 0.0635, 0.1047])
