In [1]:
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [3]:
print(torch.cuda.is_available())

False


In [5]:
print("Torch version:", torch.__version__)
print("MPS available:", torch.backends.mps.is_available())
print("MPS built:", torch.backends.mps.is_built())
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


Torch version: 2.7.1
MPS available: True
MPS built: True


In [None]:
x1 = torch.tensor([0.43,0.19,0.89])
x2 = torch.tensor([0.55,0.87,0.66]) 

Goal is to write 
1. Simple
2. Self (With trainable weights)
    2.1. Scaled dot proeduct attention
    2.2. Create a class
    Exercise 3.1
3.5 Causal 
    3.5.1 Causal attention mask
    3.5.2 Add dropout
    3.5.3. Causal attention class
4. Multi-head

In [12]:
inputs = torch.tensor(
    [   [0.43, 0.15, 0.89],     # your
        [0.55, 0.87, 0.66],     # journey
        [0.57, 0.85, 0.64],     # starts
        [0.22, 0.58, 0.33],     # with
        [0.77, 0.25, 0.10],     # one
        [0.05, 0.80, 0.55]]     # step
)

In [27]:
x2 = inputs[1]
x2.shape

torch.Size([3])

In [28]:
attention_scores_1 = torch.zeros(inputs.shape[0])
for i, x_i in enumerate(inputs): 
    attention_scores_1[i] = x2.dot(x_i)

attention_scores_1

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [31]:
attn_scores = inputs @  inputs.T # 6 x 3 3 x 6 = 6 x 6
attn_scores[1]

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [34]:
# Normalize using sum 
attn_weights = attn_scores / attn_scores.sum(dim=1, keepdim=True) 
attn_weights

tensor([[0.2241, 0.2140, 0.2113, 0.1066, 0.1026, 0.1415],
        [0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656],
        [0.1454, 0.2277, 0.2248, 0.1280, 0.1104, 0.1637],
        [0.1304, 0.2313, 0.2275, 0.1354, 0.0953, 0.1801],
        [0.1436, 0.2219, 0.2245, 0.1090, 0.2088, 0.0921],
        [0.1350, 0.2325, 0.2269, 0.1405, 0.0628, 0.2022]])

In [35]:
# Normalize using softmax
attn_weights = torch.softmax(attn_scores, dim=1)
attn_weights[1]

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [65]:
# Calculate context vectors 
context_vecs = attn_weights @ inputs # 6 x 6, 6 x 3
context_vecs

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

In [70]:
# Implementing it with trainable weights
torch.manual_seed(789)
dim_in = 3
dim_out = 2
Wq = torch.randn(dim_in, dim_out, requires_grad=False)
Wk = torch.randn(dim_in, dim_out, requires_grad=False)
Wv = torch.randn(dim_in, dim_out, requires_grad=False)

q = inputs @ Wq # 6 x 3 @ 3 x 2 6 x 2 
k = inputs @ Wk
v = inputs @ Wv

attn_scores = q @ k.T # 6 x 2 @ 6 x 2 
attn_wts = torch.softmax(attn_scores / k.shape[-1] ** 0.5, dim=1) # 6 x 6

context_vecs = attn_wts @ v  #  6 x 6, 6 x 2 = 
k.shape[-1]

2

In [67]:
context_vecs

tensor([[-0.8073, -0.0946],
        [-0.7652, -0.0935],
        [-0.7697, -0.0930],
        [-0.7786, -0.0966],
        [-0.8164, -0.1016],
        [-0.7326, -0.1010]])

In [None]:
import torch.nn as nn
# ⚠️ What's the deal with not having bias in the initial layers
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.Wq = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wk = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.Wv = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x): 
        q = self.Wq(x) # 6 x 3 @ 3 x 2 6 x 2 
        k = self.Wk(x)
        v = self.Wv(x)

        attn_scores = q @ k.T # 6 x 2 @ 6 x 2 
        attn_wts = torch.softmax(attn_scores / k.shape[-1] ** 0.5, dim=1) # 6 x 6

        context_vecs = attn_wts @ v  #  6 x 6, 6 x 2 = 
        return context_vecs

In [69]:
torch.manual_seed(789)
sa = SelfAttention(dim_in, dim_out)
sa.forward(inputs)

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)

In [78]:
# 3.5 Causal attention
context_length = 6
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [79]:
attn_scores


tensor([[-0.3296, -0.1134, -0.0996, -0.0581,  0.1766, -0.1964],
        [-0.9665, -0.2984, -0.2345, -0.2161,  0.9836, -0.8459],
        [-0.9228, -0.2862, -0.2259, -0.2047,  0.9224, -0.7980],
        [-0.6798, -0.2045, -0.1558, -0.1592,  0.7650, -0.6374],
        [ 0.1251,  0.0156, -0.0085,  0.0589, -0.4413,  0.2915],
        [-1.1313, -0.3323, -0.2457, -0.2758,  1.3832, -1.1245]])

In [80]:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[-0.3296,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.9665, -0.2984,    -inf,    -inf,    -inf,    -inf],
        [-0.9228, -0.2862, -0.2259,    -inf,    -inf,    -inf],
        [-0.6798, -0.2045, -0.1558, -0.1592,    -inf,    -inf],
        [ 0.1251,  0.0156, -0.0085,  0.0589, -0.4413,    -inf],
        [-1.1313, -0.3323, -0.2457, -0.2758,  1.3832, -1.1245]])

In [83]:
test = torch.ones((context_length, context_length))
test += 1
print(test)
dropout = nn.Dropout(0.5)
print(dropout(test))

tensor([[2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2.],
        [2., 2., 2., 2., 2., 2.]])
tensor([[0., 0., 0., 4., 4., 0.],
        [4., 0., 0., 4., 0., 4.],
        [0., 0., 4., 0., 4., 0.],
        [4., 4., 0., 4., 4., 0.],
        [4., 0., 0., 0., 4., 4.],
        [4., 4., 4., 0., 0., 4.]])


In [94]:
class CausalAttention(nn.Module): 
    def __init__(self, d_in, d_out, context_length, dropout, qkv_biases=False):
        super().__init__()
        self.wq = nn.Linear(d_in, d_out, bias=qkv_biases)
        self.wk = nn.Linear(d_in, d_out, bias=qkv_biases)
        self.wv = nn.Linear(d_in, d_out, bias=qkv_biases)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask', 
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x): 
        b, num_tokens, d_in = x.shape
        q = self.wq(x) # b, T, d_out
        v = self.wv(x)
        k = self.wk(x)

        attn_scores = q @ k.transpose(1, 2)  # b t dout - b t dout = b t t 
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / k.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ v # b t t, b t c
        return context_vec



In [96]:
torch.manual_seed(123)
ca = CausalAttention(d_in=dim_in, d_out=dim_out, context_length=6, dropout=0.5)
inp = torch.stack((inputs, inputs), dim=0)
out = ca(inp)
out

tensor([[[-0.9038,  0.4432],
         [-0.4368,  0.2142],
         [-0.4849, -0.1341],
         [-0.5834,  0.0081],
         [-0.6219, -0.0526],
         [-0.1417, -0.0505]],

        [[ 0.0000,  0.0000],
         [-1.1749,  0.0116],
         [-0.7733,  0.0073],
         [-0.9140, -0.2769],
         [-0.7679, -0.0735],
         [-0.6749, -0.0984]]], grad_fn=<UnsafeViewBackward0>)

In [97]:
class MultiHeadAttentionWrapper(nn.Module): 
    def __init__(self, d_in, d_out, context_length, dropout, n_heads, qkv_biases=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_biases) for head in range(n_heads)]
        )
        
    def forward(self, x): 
        return torch.cat([head(x) for head in self.heads], dim=-1)


In [101]:
ma = MultiHeadAttentionWrapper(dim_in, dim_out, context_length, 0.5, n_heads=2)
out = ma.forward(inp)
out.shape

torch.Size([2, 6, 4])

In [142]:
class MultiHeadAttention(nn.Module): 
    def __init__(self, d_in, d_out, context_length, dropout, n_heads, qkv_biases=False):
        super().__init__()
        if d_out % n_heads != 0: 
            raise ValueError("Please make it divisible")
        
        self.n_heads = n_heads
        self.d_out = d_out
        self.head_dim = d_out // n_heads

        self.Wq = nn.Linear(d_in, d_out, bias=qkv_biases)
        self.Wk = nn.Linear(d_in, d_out, bias=qkv_biases)
        self.Wv = nn.Linear(d_in, d_out, bias=qkv_biases)

        self.dropout = nn.Dropout(dropout)

        self.register_buffer(
            'mask', 
            torch.triu(torch.ones((context_length, context_length)), diagonal=1)
        )

        self.out_proj = nn.Linear(d_out, d_out)

    def forward(self, x): 
        b, t, d_in = x.shape
                
        q = self.Wq(x) # b, t, d_out
        k = self.Wk(x)
        v = self.Wv(x) # b, t, c

        q = q.view(b, t, self.n_heads, self.head_dim) # b, t, n_head, head_dim
        k = k.view(b, t, self.n_heads, self.head_dim)
        v = v.view(b, t, self.n_heads, self.head_dim)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2) # b, n_head, t, head_dim
        print(q.shape, k.shape, v.shape)

        attn_scores = q @ k.transpose(2, 3) # b, n_head, t, head_dim @ b, n_head, head_dim, t
        attn_scores.masked_fill_(
            mask.bool()[:t, :t], -torch.inf
        )
        attn_weights = torch.softmax(attn_scores, dim=-1)

        context_vec = (attn_weights @ v).transpose(1,2) # b, n_head, t, t @ b, n_head, t, head_dim = b, n_head, t, head_dim
        print(context_vec.shape, self.d_out)
        context_vec = context_vec.contiguous().view(b, t, self.d_out)

        context_vec = self.out_proj(context_vec)
        return context_vec

In [145]:
dim_out = 2

In [146]:
torch.manual_seed(123)
mha = MultiHeadAttention(dim_in, dim_out, context_length, 0.5, 2, qkv_biases=False)
out = mha(inp)
out.shape

torch.Size([2, 2, 6, 1]) torch.Size([2, 2, 6, 1]) torch.Size([2, 2, 6, 1])
torch.Size([2, 6, 2, 1]) 2


torch.Size([2, 6, 2])

In [147]:
out

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)