In [1]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)



In [2]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [3]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)

In [48]:
queries = inputs @ W_query
keys = inputs @ W_key
values = inputs @ W_value


query_2 = queries[1]

In [49]:
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [51]:
attn_scores = queries @ keys.T
print(attn_scores)

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])


### WHY Divide by SQRT (DIMENSION)?


The softmax function is sensitive to the magnitudes of its inputs. When the inputs are large, the differences between the exponential values of each input become much more pronounced. This causes the softmax output to become "peaky," where the highest value receives almost all the probability mass, and the rest receive very little.

In attention mechanisms, particularly in transformers, if the dot products between query and key vectors become too large (like multiplying by 8 in this example), the attention scores can become very large. This results in a very sharp softmax distribution, making the model overly confident in one particular "key." Such sharp distributions can make learning unstable,

### But Why SQRT?

The dot product of  Q and K increases the variance because multiplying two random numbers increases the variance.

The increase in variance grows with the dimension. 

Dividing by sqrt (dimension) keeps the variance close to 1

In [58]:
import numpy as np

# Function to compute variance before and after scaling
def compute_variance(dim, num_trials=1000):
    dot_products = []
    scaled_dot_products = []

    # Generate multiple random vectors and compute dot products
    for _ in range(num_trials):
        q = np.random.randn(dim)
        k = np.random.randn(dim)
        
        # Compute dot product
        dot_product = np.dot(q, k)
        dot_products.append(dot_product)
        
        # Scale the dot product by sqrt(dim)
        scaled_dot_product = dot_product / np.sqrt(dim)
        scaled_dot_products.append(scaled_dot_product)
    
    # Calculate variance of the dot products
    variance_before_scaling = np.var(dot_products)
    variance_after_scaling = np.var(scaled_dot_products)

    return variance_before_scaling, variance_after_scaling

# For dimension 5
variance_before_5, variance_after_5 = compute_variance(5)
print(f"Variance before scaling (dim=5): {variance_before_5}")
print(f"Variance after scaling (dim=5): {variance_after_5}")

# For dimension 20
variance_before_100, variance_after_100 = compute_variance(100)
print(f"Variance before scaling (dim=100): {variance_before_100}")
print(f"Variance after scaling (dim=100): {variance_after_100}")



Variance before scaling (dim=5): 5.106730729143345
Variance after scaling (dim=5): 1.0213461458286688
Variance before scaling (dim=100): 103.50966809563538
Variance after scaling (dim=100): 1.0350966809563538


In [52]:
d_k = keys.shape[-1]

In [56]:
attn_weights_2 =torch.softmax(attn_scores_2/d_k**0.5 , dim=-1)

attn_weights_2

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])

In [62]:
attn_weights_2 @ values

tensor([0.3061, 0.8210])

## Implementing a Compact Self Attention Python Class

In [87]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    
    def __init__(self,inputs):
        super().__init__()
        self.inputs = inputs
        self.d_in = inputs.shape[1]
        self.d_out = 2
        self.W_query = nn.Parameter(torch.rand(self.d_in,self.d_out))
        self.W_key = nn.Parameter(torch.rand(self.d_in,self.d_out))        
        self.W_value = nn.Parameter(torch.rand(self.d_in,self.d_out))        

    def forward(self):
        keys = self.inputs @ self.W_key        
        queries = self.inputs @ self.W_query
        values = self.inputs @ self.W_value
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,dim=-1
        )
        
        context_vect = attn_weights @ values
        return context_vect

In [88]:
torch.manual_seed(123)

sa_v1 = SelfAttention_v1(inputs)

In [89]:
print(sa_v1.forward())

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [80]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [93]:
import torch.nn as nn

class SelfAttention_v2(nn.Module):
    
    def __init__(self,inputs,qkv_bias=False):
        super().__init__()
        self.inputs = inputs
        self.d_in = inputs.shape[1]
        self.d_out = 2
        self.W_query = nn.Linear(self.d_in,self.d_out,bias=qkv_bias)
        self.W_key = nn.Linear(self.d_in,self.d_out,bias=qkv_bias)        
        self.W_value = nn.Linear(self.d_in,self.d_out,bias=qkv_bias)

    def forward(self):
        keys = self.W_key(self.inputs)        
        queries = self.W_query(self.inputs)
        values = self.W_value(self.inputs)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5,dim=-1
        )
        
        context_vect = attn_weights @ values
        return context_vect

In [94]:
torch.manual_seed(123)

sa_v1 = SelfAttention_v2(inputs)

In [95]:
sa_v1.forward()

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)