In [1]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5) 
   [0.05, 0.80, 0.55]] # step     (x^6)
)

### We are taking dimendions of q,k,v vectors to be 3X2 in GPT models d_in and d_out are generally same. 

In [2]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

## Initializing q,k,v matrix

In [3]:
torch.manual_seed(123)

W_query=torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
W_key=torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
W_value=torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)

<div class="alert alert-block alert-info">
    
Note that we are setting requires_grad=False to reduce clutter in the outputs for
illustration purposes. 

If we were to use the weight matrices for model training, we
would set requires_grad=True to update these matrices during model training.

</div>

In [4]:
W_query

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

## Finding q,k,v vectors

In [5]:
query=inputs @ W_query
value=inputs @ W_value
key=inputs @ W_key

In [6]:
query

tensor([[0.2309, 1.0966],
        [0.4306, 1.4551],
        [0.4300, 1.4343],
        [0.2355, 0.7990],
        [0.2983, 0.6565],
        [0.2568, 1.0533]])

## Calculationg attention scores

In [7]:
attention_scores=query @ key.T

In [8]:
attention_scores

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])

## applying scaling

In [9]:
sqrt_d_key=torch.sqrt(torch.tensor(key.shape[-1]))

In [10]:
sqrt_d_key

tensor(1.4142)

## Calculating attention matrix

In [11]:
attention_scores=attention_scores/sqrt_d_key

In [12]:
attention_weights=torch.softmax(attention_scores,dim=-1)

In [13]:
attention_weights

tensor([[0.1551, 0.2104, 0.2059, 0.1413, 0.1074, 0.1799],
        [0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
        [0.1503, 0.2256, 0.2192, 0.1315, 0.0914, 0.1819],
        [0.1591, 0.1994, 0.1962, 0.1477, 0.1206, 0.1769],
        [0.1610, 0.1949, 0.1923, 0.1501, 0.1265, 0.1752],
        [0.1557, 0.2092, 0.2048, 0.1419, 0.1089, 0.1794]])

## calculating context vector

In [14]:
context_vector=attention_weights @ value

In [15]:
context_vector

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]])

# Making a simple class for Self attention

In [16]:
import torch.nn as nn
class Self_attention(nn.Module):
    def __init__(self,d_in,d_out):
        super().__init__()


        self.W_query=torch.nn.Parameter(torch.rand(d_in,d_out))
        self.W_key=torch.nn.Parameter(torch.rand(d_in,d_out))
        self.W_value=torch.nn.Parameter(torch.rand(d_in,d_out))

    def forward(self,x):
        query=x @ self.W_query
        value=x @ self.W_value
        key=x @ self.W_key

        attention_scores=query @ key.T
        attention_weights=torch.softmax(attention_scores/key.shape[-1]**0.5,dim=-1)

        context_vector=attention_weights @ value
        return context_vector


In [17]:
torch.manual_seed(123)
sa_v1 = Self_attention(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


## we are using linear instead of parameter because linear does all the weight initializng and adding biases on its own

In [18]:
import torch.nn as nn
class Self_attentionV2(nn.Module):
    def __init__(self,d_in,d_out,qkv_bias=False):
        super().__init__()


        self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)

    def forward(self,x):
        query=self.W_query(x)
        value=self.W_value(x)
        key=self.W_key(x)

        attention_scores=query @ key.T
        attention_weights=torch.softmax(attention_scores/key.shape[-1]**0.5,dim=-1)

        context_vector=attention_weights @ value
        return context_vector


In [19]:
torch.manual_seed(789)
sa_v2 = Self_attentionV2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


# Applying Causal Attention Mechanism

In [20]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [21]:
query=sa_v2.W_query(inputs)
key=sa_v2.W_key(inputs)

attn_score= query @ key.T

attn_score

tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)

## Data Leakage due to softmax of even masked weight

In [23]:
attn_weight=torch.softmax(attn_score/key.shape[-1]**0.5,dim=1)

In [24]:
attn_weight

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

In [25]:
context_length=attn_score.shape[0]

In [26]:
torch.ones(context_length,context_length)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

In [28]:
simple_mask=torch.tril(torch.ones(context_length,context_length))

In [29]:
masked_simple=simple_mask * attn_weight

In [30]:
masked_simple

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

### we can not use softmax here because it's giving weightage to even zero values .

In [31]:
torch.softmax(masked_simple,dim=1)

tensor([[0.1951, 0.1610, 0.1610, 0.1610, 0.1610, 0.1610],
        [0.1914, 0.1842, 0.1561, 0.1561, 0.1561, 0.1561],
        [0.1861, 0.1792, 0.1793, 0.1518, 0.1518, 0.1518],
        [0.1789, 0.1753, 0.1753, 0.1736, 0.1484, 0.1484],
        [0.1736, 0.1708, 0.1708, 0.1695, 0.1707, 0.1446],
        [0.1712, 0.1666, 0.1666, 0.1646, 0.1666, 0.1644]],
       grad_fn=<SoftmaxBackward0>)

In [34]:
row_sum=torch.sum(masked_simple,dim=1,keepdim=True)
row_sum

tensor([[0.1921],
        [0.3700],
        [0.5357],
        [0.6775],
        [0.8415],
        [1.0000]], grad_fn=<SumBackward1>)

In [36]:
norm_masked_simple=masked_simple/row_sum

In [38]:
norm_masked_simple.sum(dim=1)

tensor([1., 1., 1., 1., 1., 1.], grad_fn=<SumBackward1>)