In [1]:
import torch
import math

inputs = torch.tensor(
[[0.43, 0.15, 0.89], 
[0.55, 0.87, 0.66], 
[0.57, 0.85, 0.64], 
[0.22, 0.58, 0.33],
[0.77, 0.25, 0.10], 
[0.05, 0.80, 0.55]] 
)

query = inputs[1]

In [2]:
attention_scores = torch.empty(inputs.shape[0])
for i in range(inputs.shape[0]):
    attention_scores[i] = torch.dot(query, inputs[i])

In [3]:
attention_scores

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [4]:
attention_weights = inputs @ inputs.T

attention_weights = torch.softmax(attention_weights, dim=1)

context_scores = attention_weights @ inputs


In [5]:
context_scores

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

In [6]:
attention_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [7]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [8]:
import torch.nn as nn
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        #self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_keys = nn.Parameter(torch.rand(d_in, d_out))
        self.W_values = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_keys 
        querys = x @ self.W_query
        values = x @ self.W_values

        attention_scores = querys @ keys.T
        attention_weights = torch.softmax(
            attention_scores / keys.shape[-1]**0.5, dim=-1)
        
        context_vec = attention_weights @ values
        return context_vec
    
torch.manual_seed(123)

d_in = inputs.shape[1]
d_out = 2

sa_v1 = SelfAttention(3, 2)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/18.webp" width="1000px">

In [9]:
batch = torch.stack((inputs, inputs), dim=0)

In [10]:
batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads, dropout=0):
        super().__init__()
        self.d_out = d_out
        self.num_heads = num_heads
        self.W_keys = nn.Linear(d_in, d_out)
        self.W_query = nn.Linear(d_in, d_out)
        self.W_values = nn.Linear(d_in, d_out)
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out)

        self.head_dim = d_out // num_heads

        self.register_buffer('mask', torch.triu(torch.ones(context_lengt, context_lengt), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        
        keys = self.W_keys(x)
        queries = self.W_query(x)
        values = self.W_values(x)

        keys = keys.view(b, self.num_heads, num_tokens, self.head_dim)
        queries = queries.view(b, self.num_heads, num_tokens, self.head_dim)
        values = values.view(b, self.num_heads, num_tokens, self.head_dim)

        attention_scores = queries @ keys.transpose(2, 3)
        
        mask = self.mask.bool()[:num_tokens, :num_tokens]
        attention_scores.masked_fill_(mask, float('-inf'))

        attention_weights = torch.softmax(attention_scores / math.sqrt(self.head_dim), dim=-1)
        attention_weights = self.dropout(attention_weights)

        context_vec = attention_weights @ values
        context_vec = context_vec.view(b, num_tokens, self.d_out)

        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

    
b, context_lengt, d_in = batch.shape

model = MultiHeadAttention(d_in = d_in, d_out=2, num_heads=2)
model(batch)

tensor([[[-0.3432,  0.1220],
         [-0.3927,  0.1201],
         [-0.3999,  0.1241],
         [-0.3731,  0.1135],
         [-0.2825,  0.1169],
         [-0.3665,  0.1220]],

        [[-0.3432,  0.1220],
         [-0.3927,  0.1201],
         [-0.3999,  0.1241],
         [-0.3731,  0.1135],
         [-0.2825,  0.1169],
         [-0.3665,  0.1220]]], grad_fn=<ViewBackward0>)