In [49]:
import torch
import torch.nn as nn

In [50]:
torch.manual_seed(13)

<torch._C.Generator at 0x1e871a77470>

In [51]:
class SimpleSelfAttention(nn.Module):
    def __init__(self):
        super().__init__()
        pass

    def forward(self, x):
        attention_scores = x @ x.T
        attention_weights = torch.softmax(attention_scores, dim = -1)
        context_vectors = attention_weights @ x

        return context_vectors

inputs = torch.rand((6, 3))

print("input vectors:", inputs)

simple_attn = SimpleSelfAttention()

print("context_vectors:", simple_attn(inputs))

input vectors: tensor([[0.0918, 0.4794, 0.8106],
        [0.0151, 0.0153, 0.6036],
        [0.2318, 0.8633, 0.9859],
        [0.1975, 0.0830, 0.4253],
        [0.9149, 0.4799, 0.5348],
        [0.2695, 0.2530, 0.3390]])
context_vectors: tensor([[0.2910, 0.4493, 0.6830],
        [0.2795, 0.3944, 0.6475],
        [0.3042, 0.4965, 0.7123],
        [0.2998, 0.3949, 0.6389],
        [0.3800, 0.4470, 0.6501],
        [0.3116, 0.4064, 0.6421]])


In [52]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, bias = False):
        super().__init__()

        self.W_q = nn.Linear(d_in, d_out, bias)
        self.W_k = nn.Linear(d_in, d_out, bias)
        self.W_v = nn.Linear(d_in, d_out, bias)

    def forward(self, x):
        x_q = self.W_q(x)
        x_k = self.W_k(x)
        x_v = self.W_v(x)

        attn_scores = x_q @ x_k.T
        attn_weights = torch.softmax(attn_scores / x_k.shape[-1]**0.5, dim = -1)

        context_vectors = attn_weights @ x_v

        return context_vectors
    
inputs = torch.rand((6, 3))

print("input vectors:", inputs)

self_attn = SelfAttention(3, 2)

print("context_vectors:", self_attn(inputs))

input vectors: tensor([[0.8367, 0.1289, 0.9693],
        [0.4495, 0.4031, 0.8202],
        [0.9792, 0.3278, 0.4076],
        [0.7276, 0.4506, 0.2334],
        [0.0411, 0.2441, 0.6077],
        [0.8273, 0.9808, 0.2711]])
context_vectors: tensor([[0.0317, 0.2223],
        [0.0301, 0.2131],
        [0.0315, 0.2122],
        [0.0305, 0.2057],
        [0.0289, 0.2092],
        [0.0304, 0.1988]], grad_fn=<MmBackward0>)


In [53]:
class CasualAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, bias = False):
        super().__init__()
        self.W_q = nn.Linear(d_in, d_out, bias)
        self.W_k = nn.Linear(d_in, d_out, bias)
        self.W_v = nn.Linear(d_in, d_out, bias)

        self.dropout = nn.Dropout(dropout)

        self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        x_q = self.W_q(x)
        x_k = self.W_k(x)
        x_v = self.W_v(x)

        attn_scores = x_q @ x_k.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)

        attn_weights = torch.softmax(attn_scores / x_k.shape[-1] ** 0.5, dim = -1)

        attn_weights = self.dropout(attn_weights)

        context_vectors = attn_weights @ x_v

        return context_vectors
    
inputs = torch.rand((2, 6, 3))

print("input vectors:", inputs)

cas_attn = CasualAttention(3, 3, 6, 0)

print("context_vectors:", cas_attn(inputs))

input vectors: tensor([[[0.6962, 0.5262, 0.2312],
         [0.6436, 0.1238, 0.6160],
         [0.0873, 0.3889, 0.3837],
         [0.3722, 0.0526, 0.0294],
         [0.3711, 0.5818, 0.8501],
         [0.9642, 0.6270, 0.0568]],

        [[0.0031, 0.2991, 0.2265],
         [0.5508, 0.3316, 0.1820],
         [0.0930, 0.6206, 0.9111],
         [0.2370, 0.0340, 0.9527],
         [0.1492, 0.7638, 0.5858],
         [0.8560, 0.2700, 0.7211]]])
context_vectors: tensor([[[ 0.5911, -0.5495,  0.0423],
         [ 0.5815, -0.4355,  0.0276],
         [ 0.5029, -0.3696,  0.1214],
         [ 0.4228, -0.3201,  0.0585],
         [ 0.4895, -0.3488,  0.1428],
         [ 0.5084, -0.4006,  0.0977]],

        [[ 0.2106, -0.1561,  0.2397],
         [ 0.3211, -0.2721,  0.1165],
         [ 0.4388, -0.3043,  0.2942],
         [ 0.4588, -0.2569,  0.2949],
         [ 0.4818, -0.2955,  0.3405],
         [ 0.5293, -0.3267,  0.2915]]], grad_fn=<UnsafeViewBackward0>)


In [54]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, bias = False):
        super().__init__()

        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads

        self.W_q = nn.Linear(d_in, d_out, bias)
        self.W_k = nn.Linear(d_in, d_out, bias)
        self.W_v = nn.Linear(d_in, d_out, bias)

        self.head_dim = d_out // num_heads

        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)

        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        x_q = self.W_q(x)
        x_k = self.W_k(x)
        x_v = self.W_v(x)

        queries = x_q.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = x_k.view(b, num_tokens, self.num_heads, self.head_dim)
        values = x_v.view(b, num_tokens, self.num_heads, self.head_dim)

        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens],
            -torch.inf
        )

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        
        context_vec = self.out_proj(context_vec)

        return context_vec
    
inputs = torch.rand((2, 6, 3))

print("input vectors:", inputs)

mul_attn = MultiHeadAttention(3, 4, 6, 0.3, 2)

context_vecs = mul_attn(inputs)
print("context_vectors:", context_vecs.shape, "\n", context_vecs)

input vectors: tensor([[[0.6069, 0.6346, 0.2570],
         [0.4978, 0.5292, 0.6207],
         [0.6042, 0.9654, 0.2792],
         [0.7618, 0.5583, 0.0870],
         [0.7451, 0.8755, 0.9874],
         [0.6039, 0.3986, 0.2782]],

        [[0.8604, 0.1442, 0.1881],
         [0.9453, 0.4741, 0.7547],
         [0.5609, 0.3357, 0.5604],
         [0.7825, 0.4710, 0.4185],
         [0.3601, 0.0065, 0.6250],
         [0.3426, 0.0959, 0.9077]]])
context_vectors: torch.Size([2, 6, 4]) 
 tensor([[[ 0.3551,  0.2561,  0.5592, -0.1051],
         [ 0.3093, -0.0781,  0.6471, -0.0024],
         [ 0.2425, -0.0514,  0.5994,  0.0654],
         [ 0.3411, -0.0645,  0.6742, -0.0191],
         [ 0.3074, -0.1097,  0.6693,  0.0169],
         [ 0.2898,  0.0148,  0.5932, -0.0062]],

        [[ 0.3018,  0.1400,  0.4481, -0.1227],
         [ 0.2839,  0.1380,  0.4653, -0.0608],
         [ 0.3076, -0.1412,  0.5609, -0.0545],
         [ 0.2869, -0.0890,  0.5627, -0.0163],
         [ 0.2167,  0.0412,  0.5111,  0.0495],
 