In [7]:
import torch
from torch import nn

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [8]:
#query = inputs[1]
d_in = inputs.shape[1]
d_out = 2

## Attention Weights

In [9]:
torch.manual_seed(123)
weights_q = torch.nn.Parameter(torch.rand(d_in,d_out), requires_grad=False)
weights_k = torch.nn.Parameter(torch.rand(d_in,d_out), requires_grad=False)
weights_v = torch.nn.Parameter(torch.rand(d_in,d_out), requires_grad=False)

In [10]:
query = inputs @ weights_q
keys = inputs @ weights_k
value = inputs @ weights_v

In [11]:
attention_score = query @ keys.T
attention_weights = torch.softmax(attention_score, dim=-1)
attention_weights

tensor([[0.1484, 0.2285, 0.2217, 0.1301, 0.0883, 0.1831],
        [0.1401, 0.2507, 0.2406, 0.1157, 0.0687, 0.1842],
        [0.1406, 0.2496, 0.2397, 0.1164, 0.0696, 0.1841],
        [0.1548, 0.2130, 0.2083, 0.1394, 0.1047, 0.1799],
        [0.1577, 0.2067, 0.2028, 0.1428, 0.1122, 0.1777],
        [0.1494, 0.2267, 0.2202, 0.1310, 0.0901, 0.1825]])

In [12]:
context_vector = attention_weights @ value
context_vector

tensor([[0.3071, 0.8230],
        [0.3157, 0.8430],
        [0.3152, 0.8421],
        [0.3006, 0.8080],
        [0.2978, 0.8016],
        [0.3063, 0.8214]])

In [13]:
nn.Linear(d_in,d_out, bias=False)

Linear(in_features=3, out_features=2, bias=False)

## Self Attention

In [14]:
class SelfAttention(nn.Module):

    def __init__(self,d_in,d_out,qkv_bias=False):
        super().__init__()
        self.weights_Q = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.weights_K = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.weights_V = nn.Linear(d_in,d_out, bias=qkv_bias)

    def forward(self,x):
        query = self.weights_Q(x)
        key = self.weights_K(x)
        value = self.weights_V(x)

        attention_score = query @ key.T
        attention_weight = torch.softmax(attention_score/ (key.shape[1] ** 0.5), dim=-1)

        context_vector = attention_weight @ value

        return context_vector

torch.manual_seed(123)
sa = SelfAttention(3,2)
sa(inputs)

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

In [15]:
mask = torch.triu(torch.ones([6,6]), diagonal=1)
print(torch.rand([6,6]).masked_fill(mask.bool(), -torch.inf))
mask.bool()[:4][:4]

tensor([[0.3821,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.2745, 0.6584,   -inf,   -inf,   -inf,   -inf],
        [0.9268, 0.7388, 0.7179,   -inf,   -inf,   -inf],
        [0.0772, 0.3565, 0.1479, 0.5331,   -inf,   -inf],
        [0.4545, 0.9737, 0.4606, 0.5159, 0.4220,   -inf],
        [0.9455, 0.8057, 0.6775, 0.6087, 0.6179, 0.6932]])


tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [False, False, False, False,  True,  True]])

## Causal Attention (Masked-Self Attention)

In [16]:
class CausalAttention(nn.Module):
    def __init__(self,d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.weights_Q = nn.Linear(d_in,d_out, bias = qkv_bias)
        self.weights_K = nn.Linear(d_in,d_out, bias = qkv_bias)
        self.weights_V = nn.Linear(d_in,d_out, bias = qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length,context_length), diagonal=1))
    
    def forward(self,x):
        batch_size, context_length, emb_dim = x.shape
        query = self.weights_Q(x)
        key = self.weights_K(x)
        value = self.weights_V(x)

        attention_score = query @ key.transpose(1,2)
        attention_score = attention_score.masked_fill(mask.bool()[:context_length,:context_length], -torch.inf)
        attention_weights = torch.softmax(attention_score/ key.shape[1] ** 2, dim = -1)

        attention_weights = self.dropout(attention_weights)
        context_vector = attention_weights @ value

        return context_vector

torch.manual_seed(123)
inputs = torch.rand([2,4,3])
context_length = inputs.shape[1]
ca = CausalAttention(d_in = 3,d_out = 2, context_length=context_length, dropout = 0.0)
ca(inputs)

tensor([[[-0.3325, -0.1223],
         [-0.5163, -0.1861],
         [-0.3971, -0.1450],
         [-0.4687, -0.1633]],

        [[-0.1982, -0.1163],
         [-0.3748, -0.1848],
         [-0.4641, -0.2301],
         [-0.5462, -0.2600]]], grad_fn=<UnsafeViewBackward0>)

## MultiHead Attention Wrapper

In [17]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in,d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias=qkv_bias) for _ in range(num_heads)]
        )

    def forward(self,x):
        return torch.cat([head(x) for head in self.heads], dim = -1)

torch.manual_seed(123)
inputs = torch.rand([2,6,3])
context_length = inputs.shape[1]
mha = MultiHeadAttentionWrapper(d_in = 3,d_out = 2, context_length=context_length, dropout = 0.0, num_heads = 2)
context_vector = mha(inputs)
print(context_vector)
print(context_vector.shape)

tensor([[[ 0.3863,  0.1636, -0.2143,  0.4306],
         [ 0.4723,  0.2268, -0.4003,  0.4802],
         [ 0.3628,  0.1752, -0.3094,  0.3714],
         [ 0.4285,  0.2032, -0.3585,  0.4283],
         [ 0.3775,  0.1839, -0.3268,  0.3890],
         [ 0.3842,  0.1951, -0.3605,  0.4010]],

        [[ 0.6047,  0.3282, -0.5738,  0.7392],
         [ 0.6680,  0.3543, -0.6236,  0.7846],
         [ 0.5888,  0.2982, -0.4944,  0.6919],
         [ 0.6332,  0.2833, -0.4200,  0.6942],
         [ 0.6835,  0.3021, -0.4513,  0.7311],
         [ 0.6988,  0.3034, -0.4423,  0.7439]]], grad_fn=<CatBackward0>)
torch.Size([2, 6, 4])


## MultiHead Atttention

In [18]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.d_out = d_out

        self.weights_Q = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.weights_K = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.weights_V = nn.Linear(d_in,d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out,d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length),diagonal=1))

    def forward(self,x):
        batch_size, context_length, emb_dim = x.shape

        query = self.weights_Q(x)
        key = self.weights_K(x)
        value = self.weights_V(x)

        query = query.view(batch_size, context_length, self.num_heads, self.head_dim)
        key = key.view(batch_size, context_length, self.num_heads, self.head_dim)
        value = value.view(batch_size, context_length, self.num_heads, self.head_dim)

        query = query.transpose(1,2)
        key = key.transpose(1,2)
        value = value.transpose(1,2)

        attention_score = query @ key.transpose(2,3)
        
        attention_score.masked_fill(self.mask.bool()[:context_length, :context_length], -torch.inf)

        attention_weight = torch.softmax(attention_score/key.shape[-1]**0.5, dim=-1)
        self.dropout(attention_weight)

        context_vector = (attention_weight @ value).transpose(1,2)

        context_vector = context_vector.contiguous().view(batch_size, context_length, self.d_out)
        context_vector = self.out_proj(context_vector)

        return context_vector


torch.manual_seed(123)

inputs = torch.rand([2,4,3])
context_length = inputs.shape[1]
mha = MultiHeadAttention(d_in=3, d_out=4, num_heads=2, dropout=0.0, context_length=context_length)
mha(inputs)

tensor([[[ 0.4856, -0.5293, -0.3424,  0.0471],
         [ 0.4893, -0.5318, -0.3472,  0.0396],
         [ 0.4861, -0.5299, -0.3431,  0.0458],
         [ 0.4882, -0.5310, -0.3457,  0.0420]],

        [[ 0.6483, -0.5114, -0.4756, -0.0350],
         [ 0.6530, -0.5127, -0.4810, -0.0411],
         [ 0.6525, -0.5113, -0.4799, -0.0383],
         [ 0.6543, -0.5116, -0.4819, -0.0403]]], grad_fn=<ViewBackward0>)