In [1]:
from importlib.metadata import version
import torch,torch.nn as nn
print(f"Torch version: {version('torch')}")

Torch version: 2.9.1


# Attending To Different Parts Of Input With Self-Attention
## Simple Self-Attention Mechanism Without Trainable Weights

In [2]:
inputs=torch.tensor([[.43,.15,.89],
                     [.55,.87,.66],
                     [.57,.85,.64],
                     [.22,.58,.33],
                     [.77,.25,.1],
                     [.05,.8,.55]])
query=inputs[1]
attn_scores_2=torch.empty(inputs.shape[0])
for i,x_i in enumerate(inputs):
    attn_scores_2[i]=torch.dot(x_i,query)
attn_scores_2

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [3]:
res=0
for idx,element in enumerate(inputs[0]):
    res+=inputs[0][idx]*query[idx]
res

tensor(0.9544)

In [4]:
attn_weights_2_tmp=attn_scores_2/attn_scores_2.sum()
print(f'Attention weights: {attn_weights_2_tmp}\nSum: {attn_weights_2_tmp.sum()}')

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: 1.0000001192092896


In [5]:
def softmax_naive(x):
    return torch.exp(x)/torch.exp(x).sum(dim=0)
attn_weights_2_naive=softmax_naive(attn_scores_2)
print(f'Attention weights: {attn_weights_2_naive}\nSum: {attn_weights_2_naive.sum()}')

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: 1.0


In [6]:
attn_weights_2=torch.softmax(attn_scores_2,dim=0)
query=inputs[1]
context_vec_2=torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2+=attn_weights_2[i]*x_i
context_vec_2

tensor([0.4419, 0.6515, 0.5683])

## Computing Attention Weights For All Input Tokens

In [7]:
attn_scores=torch.empty(6,6)
attn_scores=inputs@inputs.T
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [8]:
attn_weights=torch.softmax(attn_scores,dim=-1)
attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [9]:
row_2_sum=sum([.1385,.2379,.2333,.124,.1082,.1581])
print(f'Row 2 sum: {row_2_sum}\nAll row sums: {attn_weights.sum(dim=-1)}')

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [10]:
all_context_vecs=attn_weights@inputs
print(f'{all_context_vecs}\nPrevious 2nd context vector: {context_vec_2}')

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])
Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])


# Implementing Self-Attention With Trainable Weights
## Computing Attention Weights Step By Step

In [11]:
x_2=inputs[1]
d_in=inputs.shape[1]
d_out=2
W_query=torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
W_key=torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
W_value=torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
query_2=x_2@W_query
key_2=x_2@W_key 
value_2=x_2@W_value
query_2

tensor([1.1910, 0.9843])

In [12]:
keys=inputs@W_key 
values=inputs@W_value
print(f'keys.shape: {keys.shape}\nvalues.shape: {values.shape}')

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [13]:
keys_2=keys[1]
attn_score_22=query_2.dot(keys_2)
attn_score_22

tensor(1.5368)

In [14]:
attn_scores_2=query_2@keys.T
attn_scores_2

tensor([0.9034, 1.5368, 1.5338, 0.8304, 1.0483, 0.9155])

In [15]:
d_k=keys.shape[1]
attn_weights_2=torch.softmax(attn_scores_2/d_k**.5,dim=-1)
attn_weights_2

tensor([0.1390, 0.2176, 0.2171, 0.1320, 0.1540, 0.1402])

In [16]:
context_vec_2=attn_weights_2@values
context_vec_2

tensor([1.0064, 1.0248])

In [17]:
class SelfAttention_v1(nn.Module):
    def __init__(self,d_in,d_out):
        super().__init__()
        self.W_query=nn.Parameter(torch.rand(d_in,d_out))
        self.W_key=nn.Parameter(torch.rand(d_in,d_out))
        self.W_value=nn.Parameter(torch.rand(d_in,d_out))
    def forward(self,x):
        keys=x@self.W_key
        queries=x@self.W_query
        values=x@self.W_value
        attn_scores=queries@keys.T
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**.5,dim=-1)
        context_vec=attn_weights@values
        return context_vec
sa_v1=SelfAttention_v1(d_in,d_out)
sa_v1(inputs)

tensor([[1.0648, 0.9735],
        [1.0831, 0.9858],
        [1.0840, 0.9864],
        [1.0507, 0.9568],
        [1.0812, 0.9812],
        [1.0449, 0.9533]], grad_fn=<MmBackward0>)

In [18]:
class SelfAttention_v2(nn.Module):
    def __init__(self,d_in,d_out,qkv_bias=False):
        super().__init__()
        self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
    def forward(self,x):
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        attn_scores=queries@keys.T
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**.5,dim=-1)
        context_vec=attn_weights@values
        return context_vec
sa_v2=SelfAttention_v2(d_in,d_out)
sa_v2(inputs)

tensor([[0.0962, 0.1514],
        [0.0965, 0.1543],
        [0.0965, 0.1542],
        [0.0962, 0.1532],
        [0.0961, 0.1510],
        [0.0964, 0.1545]], grad_fn=<MmBackward0>)

# Hiding Future Words With Causal Attention
## Applying Causal Attention Mask

In [19]:
queries=sa_v2.W_query(inputs)
keys=sa_v2.W_key(inputs) 
attn_scores=queries@keys.T
attn_weights=torch.softmax(attn_scores/keys.shape[-1]**.5,dim=-1)
attn_weights

tensor([[0.1658, 0.1717, 0.1716, 0.1624, 0.1645, 0.1640],
        [0.1740, 0.1693, 0.1693, 0.1612, 0.1637, 0.1626],
        [0.1736, 0.1695, 0.1694, 0.1612, 0.1637, 0.1626],
        [0.1724, 0.1673, 0.1672, 0.1637, 0.1650, 0.1644],
        [0.1654, 0.1706, 0.1705, 0.1636, 0.1651, 0.1648],
        [0.1760, 0.1664, 0.1663, 0.1630, 0.1646, 0.1636]],
       grad_fn=<SoftmaxBackward0>)

In [20]:
context_length=attn_scores.shape[0]
mask_simple=torch.tril(torch.ones(context_length,context_length))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [21]:
masked_simple=attn_weights*mask_simple
masked_simple

tensor([[0.1658, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1740, 0.1693, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1736, 0.1695, 0.1694, 0.0000, 0.0000, 0.0000],
        [0.1724, 0.1673, 0.1672, 0.1637, 0.0000, 0.0000],
        [0.1654, 0.1706, 0.1705, 0.1636, 0.1651, 0.0000],
        [0.1760, 0.1664, 0.1663, 0.1630, 0.1646, 0.1636]],
       grad_fn=<MulBackward0>)

In [22]:
row_sums=masked_simple.sum(dim=-1,
                           keepdim=True)
masked_simple_norm=masked_simple/row_sums
masked_simple_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5068, 0.4932, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3388, 0.3307, 0.3305, 0.0000, 0.0000, 0.0000],
        [0.2571, 0.2494, 0.2494, 0.2442, 0.0000, 0.0000],
        [0.1980, 0.2043, 0.2042, 0.1959, 0.1977, 0.0000],
        [0.1760, 0.1664, 0.1663, 0.1630, 0.1646, 0.1636]],
       grad_fn=<DivBackward0>)

In [23]:
mask=torch.triu(torch.ones(context_length,context_length),diagonal=1)
masked=attn_scores.masked_fill(mask.bool(),-torch.inf)
masked

tensor([[0.1170,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1624, 0.1241,   -inf,   -inf,   -inf,   -inf],
        [0.1606, 0.1262, 0.1255,   -inf,   -inf,   -inf],
        [0.0885, 0.0457, 0.0454, 0.0156,   -inf,   -inf],
        [0.0834, 0.1275, 0.1269, 0.0684, 0.0812,   -inf],
        [0.1116, 0.0316, 0.0313, 0.0030, 0.0166, 0.0082]],
       grad_fn=<MaskedFillBackward0>)

In [24]:
attn_weights=torch.softmax(masked/keys.shape[-1]**.5,dim=-1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5068, 0.4932, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3388, 0.3307, 0.3305, 0.0000, 0.0000, 0.0000],
        [0.2571, 0.2494, 0.2494, 0.2442, 0.0000, 0.0000],
        [0.1980, 0.2043, 0.2042, 0.1959, 0.1977, 0.0000],
        [0.1760, 0.1664, 0.1663, 0.1630, 0.1646, 0.1636]],
       grad_fn=<SoftmaxBackward0>)

## Masking Additional Attention Weights With Dropout

In [25]:
dropout=torch.nn.Dropout(.5)
example=torch.ones(6,6)
dropout(example)

tensor([[0., 2., 2., 2., 0., 0.],
        [2., 2., 0., 0., 2., 2.],
        [2., 0., 2., 0., 2., 0.],
        [0., 2., 0., 0., 0., 0.],
        [2., 2., 2., 2., 2., 2.],
        [0., 0., 2., 0., 2., 0.]])

In [26]:
dropout(attn_weights)

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0135, 0.9865, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6610, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4988, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3960, 0.0000, 0.4084, 0.3918, 0.3954, 0.0000],
        [0.3521, 0.3327, 0.3326, 0.0000, 0.3292, 0.3273]],
       grad_fn=<MulBackward0>)

## Implementing Compact Causal Self-Attention Class

In [27]:
batch=torch.stack((inputs,inputs),dim=0)
batch.shape

torch.Size([2, 6, 3])

In [28]:
class CausalAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
    def forward(self,x):
        b,num_tokens,d_in=x.shape
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        attn_scores=queries@keys.transpose(1,2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens,
                                                  :num_tokens],-torch.inf)
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**.5,dim=-1)
        attn_weights=self.dropout(attn_weights)
        context_vec=attn_weights@values
        return context_vec
context_length=batch.shape[1]
ca=CausalAttention(d_in,d_out,context_length,0)
context_vecs=ca(batch)
print(f'{context_vecs}\ncontext_vecs.shape: {context_vecs.shape}')

tensor([[[-0.3282,  0.0936],
         [-0.4517,  0.1618],
         [-0.4805,  0.1857],
         [-0.4516,  0.1627],
         [-0.3364,  0.1961],
         [-0.3885,  0.1684]],

        [[-0.3282,  0.0936],
         [-0.4517,  0.1618],
         [-0.4805,  0.1857],
         [-0.4516,  0.1627],
         [-0.3364,  0.1961],
         [-0.3885,  0.1684]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


# Extending Single To MHA
## Stacking Multiple Single-Head Attention Layers

In [29]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
        super().__init__()
        self.heads=nn.ModuleList([CausalAttention(d_in,d_out,context_length,dropout,qkv_bias) for _ in range(num_heads)])
    def forward(self,x):
        return torch.cat([head(x) for head in self.heads],dim=-1)
context_length=batch.shape[1]
d_in,d_out=3,2
mha=MultiHeadAttentionWrapper(d_in,d_out,context_length,0,num_heads=2)
context_vecs=mha(batch)
print(f'{context_vecs}\ncontext_vecs.shape: {context_vecs.shape}')

tensor([[[ 0.0101,  0.1371,  0.3000,  0.0332],
         [-0.1396, -0.0091,  0.1687,  0.1241],
         [-0.1905, -0.0499,  0.1168,  0.1547],
         [-0.1790, -0.0758,  0.0790,  0.1624],
         [-0.2179, -0.0168,  0.0341,  0.1156],
         [-0.1955, -0.0632,  0.0372,  0.1399]],

        [[ 0.0101,  0.1371,  0.3000,  0.0332],
         [-0.1396, -0.0091,  0.1687,  0.1241],
         [-0.1905, -0.0499,  0.1168,  0.1547],
         [-0.1790, -0.0758,  0.0790,  0.1624],
         [-0.2179, -0.0168,  0.0341,  0.1156],
         [-0.1955, -0.0632,  0.0372,  0.1399]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


## Implementing MHA With Weight Splits

In [30]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
        super().__init__()
        assert (d_out%num_heads==0),'d_out must be divisible by num_heads.'
        self.d_out=d_out
        self.num_heads=num_heads
        self.head_dim=d_out//num_heads
        self.W_query=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_key=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.W_value=nn.Linear(d_in,d_out,bias=qkv_bias)
        self.out_proj=nn.Linear(d_out,d_out)
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))
    def forward(self,x):
        b,num_tokens,d_in=x.shape
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        keys=keys.view(b,num_tokens,self.num_heads,self.head_dim) 
        values=values.view(b,num_tokens,self.num_heads,self.head_dim)
        queries=queries.view(b,num_tokens,self.num_heads,self.head_dim)
        keys=keys.transpose(1,2)
        queries=queries.transpose(1,2)
        values=values.transpose(1,2)
        attn_scores=queries@keys.transpose(2,3)
        mask_bool=self.mask.bool()[:num_tokens,
                                   :num_tokens]
        attn_scores.masked_fill_(mask_bool,-torch.inf)
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**.5,dim=-1)
        attn_weights=self.dropout(attn_weights)
        context_vec=(attn_weights@values).transpose(1,2) 
        context_vec=context_vec.contiguous().view(b,num_tokens,self.d_out)
        context_vec=self.out_proj(context_vec)
        return context_vec
batch_size,context_length,d_in=batch.shape
d_out=2
mha=MultiHeadAttention(d_in,d_out,context_length,0,num_heads=2)
context_vecs=mha(batch)
print(f'{context_vecs}\ncontext_vecs.shape: {context_vecs.shape}')

tensor([[[ 0.1650,  0.3430],
         [-0.0802,  0.4443],
         [-0.1536,  0.4775],
         [-0.1298,  0.4528],
         [-0.0799,  0.4977],
         [-0.1025,  0.4644]],

        [[ 0.1650,  0.3430],
         [-0.0802,  0.4443],
         [-0.1536,  0.4775],
         [-0.1298,  0.4528],
         [-0.0799,  0.4977],
         [-0.1025,  0.4644]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


In [31]:
a=torch.tensor([[[[.2745,.6584,.2775,.8573],
                  [.8993,.039,.9268,.7388],
                  [.7179,.7058,.9156,.434]],
                 [[.0772,.3565,.1479,.5331],
                  [.4066,.2318,.4545,.9737],
                  [.4606,.5159,.422,.5786]]]])
a@a.transpose(2,3)

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])

In [32]:
first_head=a[0,0,:,:]
first_res=first_head@first_head.T
print(f'1st head:\n{first_res}')

1st head:
tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])


In [33]:
second_head=a[0,1,:,:]
second_res=second_head@second_head.T
print(f'2nd head:\n{second_res}')

2nd head:
tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


In [34]:
class SelfAttention_v1(nn.Module):
    def __init__(self,d_in,d_out):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Parameter(torch.rand(d_in,d_out))
        self.W_key=nn.Parameter(torch.rand(d_in,d_out))
        self.W_value=nn.Parameter(torch.rand(d_in,d_out))
    def forward(self,x):
        keys=x@self.W_key
        queries=x@self.W_query
        values=x@self.W_value
        attn_scores=queries@keys.T
        attn_weights=torch.softmax(attn_scores/keys.shape[-1]**.5,dim=-1)
        context_vec=attn_weights@values
        return context_vec
sa_v1=SelfAttention_v1(d_in,d_out)
sa_v1.W_query=torch.nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_key=torch.nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value=torch.nn.Parameter(sa_v2.W_value.weight.T)
sa_v1(inputs)

tensor([[0.0962, 0.1514],
        [0.0965, 0.1543],
        [0.0965, 0.1542],
        [0.0962, 0.1532],
        [0.0961, 0.1510],
        [0.0964, 0.1545]], grad_fn=<MmBackward0>)