In [1]:
# A compact self-attention class
!pip install torch



In [2]:
import torch
import torch.nn as nn

In [3]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in,d_out))
        self.W_key = nn.Parameter(torch.rand(d_in,d_out))
        self.W_value = nn.Parameter(torch.rand(d_in,d_out))
    
    def forward(self,x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = attn_weights @ values
        return context_vec

In [4]:
# Set dimensions
d_in = 4    # Input dimension
d_out = 6   # Output dimension

In [6]:
# Create sample input - shape [batch_size, d_in]
inputs = torch.randn(3, d_in)  # 3 samples, each with d_in features

In [7]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[ 0.1996,  0.3756,  0.5111, -1.0476, -0.9225, -0.2788],
        [ 0.1823,  0.4863,  0.4892, -1.2474, -1.0161, -0.2384],
        [ 2.0155,  1.0874,  0.9103,  3.2482,  2.5879,  1.5314]],
       grad_fn=<MmBackward0>)


## Applying a causal attention mask

In [9]:
# Correct way to use the parameters
queries = inputs @ sa_v1.W_query  # Matrix multiplication with @
keys = inputs @ sa_v1.W_key
values = inputs @ sa_v1.W_value  # You'll need this for the full calculation

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[1.7554e-01, 8.1036e-01, 1.4092e-02],
        [1.0222e-01, 8.9574e-01, 2.0375e-03],
        [3.7166e-05, 2.5175e-07, 9.9996e-01]], grad_fn=<SoftmaxBackward0>)


In [10]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0.],
        [1., 1., 0.],
        [1., 1., 1.]])


In [11]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[1.7554e-01, 0.0000e+00, 0.0000e+00],
        [1.0222e-01, 8.9574e-01, 0.0000e+00],
        [3.7166e-05, 2.5175e-07, 9.9996e-01]], grad_fn=<MulBackward0>)


In [12]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0243e-01, 8.9757e-01, 0.0000e+00],
        [3.7166e-05, 2.5175e-07, 9.9996e-01]], grad_fn=<DivBackward0>)


In [13]:
# Masking

mask = torch.triu(torch.ones(context_length,context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[ -0.4524,     -inf,     -inf],
        [ -1.2347,   4.0821,     -inf],
        [  1.4937, -10.7408,  26.4787]], grad_fn=<MaskedFillBackward0>)


In [14]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.0243e-01, 8.9757e-01, 0.0000e+00],
        [3.7166e-05, 2.5175e-07, 9.9996e-01]], grad_fn=<SoftmaxBackward0>)


In [15]:
# Introducing drop-out

torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [16]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.0485e-01, 1.7951e+00, 0.0000e+00],
        [0.0000e+00, 5.0351e-07, 0.0000e+00]], grad_fn=<MulBackward0>)
