# Sample Data 

In [1]:
import torch
import tiktoken 
from torch.utils.data import DataLoader, Dataset

In [2]:
# This is a embedding of a single sentence
inputs = torch.rand(5,3)
print(inputs)

tensor([[0.7621, 0.9305, 0.4437],
        [0.9226, 0.6790, 0.3216],
        [0.8928, 0.1642, 0.7812],
        [0.1020, 0.6221, 0.3596],
        [0.7363, 0.5734, 0.9516]])


## Simple Self-Attention Mechanism

In [3]:
attention_score = inputs @ inputs.T

In [4]:
print(attention_score.shape)
print(attention_score)

torch.Size([5, 5])
tensor([[1.6435, 1.4776, 1.1798, 0.8162, 1.5169],
        [1.4776, 1.4156, 1.1864, 0.6322, 1.3747],
        [1.1798, 1.1864, 1.4343, 0.4741, 1.4949],
        [0.8162, 0.6322, 0.4741, 0.5267, 0.7740],
        [1.5169, 1.3747, 1.4949, 0.7740, 1.7765]])


In [5]:
attention_weights = torch.softmax(attention_score,dim=1)

In [6]:
attention_weights

tensor([[0.2636, 0.2233, 0.1658, 0.1152, 0.2322],
        [0.2488, 0.2339, 0.1860, 0.1068, 0.2245],
        [0.1938, 0.1951, 0.2499, 0.0957, 0.2656],
        [0.2353, 0.1958, 0.1671, 0.1762, 0.2256],
        [0.2166, 0.1878, 0.2118, 0.1030, 0.2807]])

In [7]:
torch.sum(attention_weights,dim=1,keepdim=True)

tensor([[1.0000],
        [1.0000],
        [1.0000],
        [1.0000],
        [1.0000]])

In [8]:
context_vector = attention_weights @ inputs

In [9]:
context_vector

tensor([[0.7375, 0.6289, 0.5806],
        [0.7476, 0.6161, 0.5829],
        [0.7561, 0.5656, 0.6311],
        [0.6932, 0.6183, 0.5760],
        [0.7447, 0.5889, 0.6262]])

## Simple Self-Attention with Trainable Weights

We use three matrixes Query, Keys and Values. 

We can change the input shape. 

    d_in = inputs.shape[-1] & d_out = 128 (desired value)
    
Q.shape = (d_in x d_out)

K.shape = (d_in x d_out)

V.shape = (d_in x d_out)

In [10]:
d_in = inputs.shape[-1]
d_out = 8

In [24]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in,d_out))
W_key = torch.nn.Parameter(torch.rand(d_in,d_out))
W_value = torch.nn.Parameter(torch.rand(d_in,d_out))

In [25]:
W_query.shape

torch.Size([3, 8])

In [26]:
queries = inputs @ W_query
queries

tensor([[0.4495, 1.4367, 0.6547, 1.4572, 0.5055, 1.1065, 0.6810, 0.8879],
        [0.4363, 1.2359, 0.5691, 1.3142, 0.3941, 1.1237, 0.5456, 0.6834],
        [0.3872, 1.2268, 0.5749, 1.2435, 0.7453, 1.2693, 0.6713, 0.9251],
        [0.1873, 0.8021, 0.3592, 0.7352, 0.3615, 0.4240, 0.4397, 0.6137],
        [0.4364, 1.5842, 0.7297, 1.5295, 0.9101, 1.3152, 0.8879, 1.2409]],
       grad_fn=<MmBackward0>)

In [27]:
keys = inputs @ W_key
keys

tensor([[1.0576, 1.2614, 1.2652, 1.4892, 0.9615, 0.5904, 1.0312, 1.3157],
        [0.8715, 1.1612, 1.0240, 1.3988, 1.0302, 0.4440, 1.0912, 1.2296],
        [0.6806, 0.8848, 0.7532, 1.5973, 1.1754, 0.4964, 1.1814, 1.1991],
        [0.6208, 0.5896, 0.7614, 0.7075, 0.3053, 0.4113, 0.3383, 0.6150],
        [1.0007, 1.1101, 1.1619, 1.8067, 1.1447, 0.7241, 1.1688, 1.4002]],
       grad_fn=<MmBackward0>)

In [28]:
values = inputs @ W_value
values

tensor([[1.0316, 1.7072, 1.1795, 0.7497, 1.0151, 1.7037, 0.7740, 0.5364],
        [1.0986, 1.5399, 1.1081, 0.7692, 0.9671, 1.4956, 0.7223, 0.4026],
        [1.1105, 1.2778, 0.9913, 0.9067, 0.6565, 1.3423, 1.1267, 0.6877],
        [0.3233, 0.8458, 0.5366, 0.2812, 0.4277, 0.9028, 0.3991, 0.3949],
        [1.0919, 1.6196, 1.1675, 0.9261, 0.7988, 1.7326, 1.2298, 0.8863]],
       grad_fn=<MmBackward0>)

In [29]:
attention_score = queries @ keys.T
attention_score

tensor([[ 8.2958,  7.6157,  7.4103,  4.0416,  8.8571],
        [ 7.2019,  6.5772,  6.4034,  3.5502,  7.7037],
        [ 7.9116,  7.2915,  7.1765,  3.8269,  8.5160],
        [ 4.6181,  4.2860,  4.1728,  2.1939,  4.9176],
        [ 9.8606,  9.1228,  8.9507,  4.7251, 10.5758]], grad_fn=<MmBackward0>)

In [30]:
d_k = keys.shape[-1]

attention_weight = torch.softmax((attention_score/d_k**0.5),dim=1)
attention_weight

tensor([[0.2526, 0.1986, 0.1847, 0.0561, 0.3080],
        [0.2484, 0.1992, 0.1873, 0.0683, 0.2967],
        [0.2470, 0.1984, 0.1905, 0.0583, 0.3059],
        [0.2337, 0.2078, 0.1996, 0.0992, 0.2598],
        [0.2534, 0.1952, 0.1837, 0.0412, 0.3263]], grad_fn=<SoftmaxBackward0>)

In [31]:
context_vector = attention_weight @ values
context_vector

tensor([[1.0383, 1.5193, 1.0908, 0.8106, 0.8398, 1.5596, 0.9482, 0.6376],
        [1.0292, 1.5086, 1.0825, 0.8033, 0.8341, 1.5484, 0.9394, 0.6322],
        [1.0371, 1.5152, 1.0884, 0.8101, 0.8369, 1.5557, 0.9485, 0.6374],
        [1.0067, 1.4785, 1.0602, 0.7844, 0.8191, 1.5164, 0.9149, 0.6157],
        [1.0496, 1.5315, 1.1006, 0.8206, 0.8450, 1.5731, 0.9620, 0.6464]],
       grad_fn=<MmBackward0>)

## Self-Attention class

In [36]:
class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_k, qkv_bias=False):
        super().__init__()
        self.W_query = torch.nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in,d_out,bias=qkv_bias)
        
    def forward(self,x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        attention_score = queries @ keys.T
        attention_weight = torch.softmax((attention_score/keys.shape[-1]**0.5),dim=-1)
        context_vec = attention_weight @ values
        return context_vec;

In [37]:
inputs

tensor([[0.7621, 0.9305, 0.4437],
        [0.9226, 0.6790, 0.3216],
        [0.8928, 0.1642, 0.7812],
        [0.1020, 0.6221, 0.3596],
        [0.7363, 0.5734, 0.9516]])

In [38]:
d_in = inputs.shape[-1]
d_out = 8

torch.manual_seed(123)
ob = SelfAttention(d_in,d_out)
ob(inputs)

tensor([[ 0.6793,  0.2955, -0.5741,  0.0871,  0.0271, -0.3884, -0.4559,  0.6883],
        [ 0.6807,  0.2957, -0.5745,  0.0888,  0.0285, -0.3896, -0.4556,  0.6891],
        [ 0.6843,  0.2954, -0.5765,  0.0954,  0.0381, -0.3936, -0.4518,  0.6917],
        [ 0.6773,  0.2943, -0.5735,  0.0872,  0.0302, -0.3876, -0.4531,  0.6869],
        [ 0.6825,  0.2953, -0.5759,  0.0927,  0.0353, -0.3919, -0.4527,  0.6906]],
       grad_fn=<MmBackward0>)

## Casual Self-Attention

In [39]:
batch_input = torch.rand(4,5,3)
batch_input

tensor([[[0.1459, 0.0969, 0.7076],
         [0.5112, 0.7050, 0.0114],
         [0.4702, 0.8526, 0.7320],
         [0.5183, 0.5983, 0.4527],
         [0.2251, 0.3111, 0.1955]],

        [[0.9153, 0.7751, 0.6749],
         [0.1166, 0.8858, 0.6568],
         [0.8459, 0.3033, 0.6060],
         [0.9882, 0.8363, 0.9010],
         [0.3950, 0.8809, 0.1084]],

        [[0.5432, 0.2185, 0.3834],
         [0.3720, 0.5374, 0.9551],
         [0.7475, 0.4979, 0.8549],
         [0.2438, 0.7577, 0.4536],
         [0.4130, 0.5585, 0.1170]],

        [[0.5578, 0.6681, 0.9275],
         [0.3443, 0.6800, 0.9998],
         [0.2855, 0.9753, 0.2518],
         [0.7204, 0.6959, 0.6397],
         [0.8954, 0.2979, 0.6314]]])

In [53]:
class CasualSelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, dropout, context_length, qkv_bias=False):
        super().__init__()
        self.W_query = torch.nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key = torch.nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value = torch.nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length),diagonal=1))
        
    def forward(self,x):
        batch_sz ,num_tokens, d_in = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        
        attention_score = queries @ keys.transpose(1,2)
        attention_score.masked_fill_(
            self.mask.bool()[:num_tokens,:num_tokens],-torch.inf)
        
        attention_weight = torch.softmax((attention_score/keys.shape[-1]**0.5),dim=-1)
        
        attention_weight = self.dropout(attention_weight)
        context_vec = attention_weight @ values
        return context_vec

In [54]:
context_length = batch_input.shape[1]
dropout = 0.0
d_in = batch_input.shape[-1]
d_out = 8

torch.manual_seed(123)
ca = CasualSelfAttention(d_in,d_out,dropout,context_length)

In [55]:
output = ca(batch_input)

In [56]:
output

tensor([[[ 0.2543,  0.1893, -0.3156, -0.2557, -0.3265, -0.0626, -0.4241,
           0.3383],
         [ 0.3759,  0.1769, -0.3647, -0.0093,  0.0024, -0.2020, -0.2818,
           0.4137],
         [ 0.4877,  0.2316, -0.4990, -0.0266,  0.0284, -0.2606, -0.3621,
           0.5537],
         [ 0.5084,  0.2354, -0.5043, -0.0032,  0.0435, -0.2778, -0.3616,
           0.5659],
         [ 0.4597,  0.2109, -0.4537,  0.0041,  0.0478, -0.2533, -0.3208,
           0.5099]],

        [[ 0.8828,  0.3709, -0.7252,  0.1622,  0.0849, -0.5184, -0.5523,
           0.8784],
         [ 0.6921,  0.3256, -0.7218, -0.0315,  0.0798, -0.3738, -0.4960,
           0.7940],
         [ 0.6895,  0.3127, -0.6290,  0.0328,  0.0173, -0.3819, -0.4906,
           0.7309],
         [ 0.7661,  0.3452, -0.6851,  0.0477,  0.0123, -0.4262, -0.5429,
           0.8028],
         [ 0.7174,  0.3149, -0.6558,  0.0682,  0.0845, -0.4089, -0.4702,
           0.7592]],

        [[ 0.4352,  0.1834, -0.2946,  0.0959, -0.0555, -0.2525, -0

# Multihead Casual Self Attention 

### Stacking multiple casual self attention layers paralally

In [91]:
class MultiHeadAttentionLayer(torch.nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False,num_head=2):
        super().__init__()
        self.heads = torch.nn.ModuleList([CasualSelfAttention(d_in=d_in,d_out=d_out,context_length=context_length,dropout=dropout,qkv_bias=qkv_bias) for _ in range(num_head)])
    
    def forward(self,x):
        return torch.cat([head(x) for head in self.heads],dim=-1)

In [92]:
context_length = batch_input.shape[1]
dropout = 0.0
d_in = batch_input.shape[-1]
d_out = 8

torch.manual_seed(123)

mha = MultiHeadAttentionLayer(d_in=d_in,d_out=d_out,context_length=context_length,dropout=dropout,num_head=12)

In [93]:
output = mha(batch_input)

In [94]:
output.shape

torch.Size([4, 5, 96])

In [95]:
output

tensor([[[ 0.2543,  0.1893, -0.3156,  ...,  0.3447,  0.4177,  0.2244],
         [ 0.3759,  0.1769, -0.3647,  ...,  0.4790,  0.4910,  0.0391],
         [ 0.4877,  0.2316, -0.4990,  ...,  0.6047,  0.6287,  0.0344],
         [ 0.5084,  0.2354, -0.5043,  ...,  0.6281,  0.6488,  0.0401],
         [ 0.4597,  0.2109, -0.4537,  ...,  0.5671,  0.5853,  0.0342]],

        [[ 0.8828,  0.3709, -0.7252,  ...,  1.0484,  1.0458,  0.1661],
         [ 0.6921,  0.3256, -0.7218,  ...,  0.8800,  0.9161,  0.0352],
         [ 0.6895,  0.3127, -0.6290,  ...,  0.8424,  0.8612,  0.1372],
         [ 0.7661,  0.3452, -0.6851,  ...,  0.9309,  0.9485,  0.1665],
         [ 0.7174,  0.3149, -0.6558,  ...,  0.8694,  0.8813,  0.0835]],

        [[ 0.4352,  0.1834, -0.2946,  ...,  0.5007,  0.4941,  0.1883],
         [ 0.5066,  0.2588, -0.4774,  ...,  0.6167,  0.6530,  0.1904],
         [ 0.5847,  0.2907, -0.5289,  ...,  0.7078,  0.7429,  0.2220],
         [ 0.5597,  0.2770, -0.5438,  ...,  0.6882,  0.7246,  0.1459],
  

### Multi head attention with combined matrixes

In [96]:
from torch import nn
class MultiHeadAttentionLayer(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads 

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)  

        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        
        return context_vec


In [97]:
context_length = batch_input.shape[1]
dropout = 0.0
d_in = batch_input.shape[-1]
num_heads = 12
d_out = 8*num_heads


torch.manual_seed(123)

mha = MultiHeadAttentionLayer(d_in=d_in,d_out=d_out,context_length=context_length,dropout=dropout,num_heads=num_heads)

In [98]:
output = mha(batch_input)

In [99]:
output.shape

torch.Size([4, 5, 96])

In [100]:
output

tensor([[[ 0.1579,  0.2051, -0.2036,  ...,  0.2524,  0.1098,  0.0859],
         [ 0.0522,  0.1414, -0.2023,  ...,  0.0997,  0.1077,  0.1170],
         [ 0.0709,  0.1912, -0.2327,  ...,  0.1151,  0.1214,  0.1673],
         [ 0.0653,  0.2020, -0.2400,  ...,  0.1119,  0.1164,  0.1710],
         [ 0.0529,  0.1757, -0.2272,  ...,  0.1049,  0.1070,  0.1511]],

        [[ 0.0944,  0.4150, -0.3520,  ...,  0.1285,  0.1355,  0.2875],
         [ 0.0947,  0.2966, -0.3042,  ...,  0.1338,  0.1250,  0.2499],
         [ 0.1032,  0.3218, -0.3034,  ...,  0.1455,  0.1265,  0.2263],
         [ 0.1185,  0.3665, -0.3292,  ...,  0.1536,  0.1321,  0.2522],
         [ 0.0855,  0.3096, -0.3079,  ...,  0.1134,  0.1246,  0.2384]],

        [[ 0.0641,  0.2157, -0.2313,  ...,  0.1242,  0.0978,  0.1131],
         [ 0.1239,  0.2810, -0.2597,  ...,  0.1950,  0.1219,  0.1643],
         [ 0.1373,  0.3280, -0.2832,  ...,  0.2037,  0.1272,  0.1897],
         [ 0.1173,  0.2820, -0.2719,  ...,  0.1741,  0.1247,  0.1872],
  