In [2]:
import torch

In [4]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [7]:
# Without Weights
attn_scores = inputs @ inputs.T
print(attn_scores)
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)


tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [60]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 3


In [9]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [10]:
W_query

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])

In [14]:
# Consider the second word:
query_2 = x_2 @ W_query 
key_2 = x_2 @ W_key 
value_2 = x_2 @ W_value
print(query_2)

tensor([0.4306, 1.4551])


In [20]:
keys = inputs @ W_key 
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print("Attention Score: ", attn_score_22)
attn_scores_2 = query_2 @ keys.T # All attention scores for given query
print("Attention Scores: ", attn_scores_2)
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print("Attention Weights: ", attn_weights_2)
context_vec_2 = attn_weights_2 @ values
print("Context Vector (multiplying value matrix): ", context_vec_2)




keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
Attention Score:  tensor(1.8524)
Attention Scores:  tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
Attention Weights:  tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
Context Vector (multiplying value matrix):  tensor([0.3061, 0.8210])


In [30]:
import torch.nn as nn

class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self. W_value(x)
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attention_weights @ values
        return context_vector
    

In [40]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))


tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


In [42]:
# Causal Attention
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)
# Make all the rows sum to 1
row_sums = masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)


In [58]:
import torch.nn as nn

# Causal Attention with Dropout for Batches
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout=0.2, qkv_bias = False):
        super().__init__()
        self.d_out = d_out
        self.d_in = d_in
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length)))

    def forward(self, inputs):
        b, num_tokens, d_in = inputs.shape
        assert d_in == self.d_in, "Embedding Dimension is Incompatible with Initilized"
        keys = self.W_key(inputs) 
        queries = self.W_query(inputs)
        values = self. W_value(inputs) # q,k,v will have shape (num_tokens, d_out)
        attention_scores = queries @ keys.transpose(1,2) # (num_tokens, d_out) * (d_out, num_tokens) = (num_tokens, num_tokens)
        masked = attention_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        attention_weights = self.dropout(attention_weights)
        context_vector = attention_weights @ values
        return context_vector

In [72]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length)
print(ca(batch))

torch.Size([2, 6, 3])
tensor([[[-0.4050,  0.5388, -0.4052],
         [-0.2598,  0.3360, -0.1981],
         [-0.4339,  0.6317, -0.5422],
         [-0.3437,  0.5202, -0.5002],
         [-0.3705,  0.5540, -0.4977],
         [-0.3431,  0.5197, -0.5017]],

        [[-0.3397,  0.4587, -0.3627],
         [-0.4337,  0.6315, -0.5423],
         [-0.4094,  0.5443, -0.4061],
         [-0.4318,  0.6302, -0.5418],
         [-0.3294,  0.4880, -0.4258],
         [-0.4297,  0.6279, -0.5425]]], grad_fn=<UnsafeViewBackward0>)


In [75]:
# Multihead Attention
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads=1, qkv_bias = False):
        super().__init__()
        self.heads = nn.ModuleList([
            CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for i  in range(num_heads)
        ])
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)  

In [81]:
torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 3
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
 
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.2633,  0.4277, -0.1353,  0.0031,  0.5378,  0.2513],
         [ 0.2641,  0.4296, -0.1350,  0.0017,  0.5346,  0.2495],
         [ 0.2641,  0.4296, -0.1350,  0.0017,  0.5345,  0.2495],
         [ 0.2647,  0.4316, -0.1381,  0.0023,  0.5353,  0.2504],
         [ 0.2642,  0.4303, -0.1373,  0.0021,  0.5330,  0.2490],
         [ 0.2648,  0.4316, -0.1375,  0.0023,  0.5363,  0.2508]],

        [[ 0.2633,  0.4277, -0.1353,  0.0031,  0.5378,  0.2513],
         [ 0.2641,  0.4296, -0.1350,  0.0017,  0.5346,  0.2495],
         [ 0.2641,  0.4296, -0.1350,  0.0017,  0.5345,  0.2495],
         [ 0.2647,  0.4316, -0.1381,  0.0023,  0.5353,  0.2504],
         [ 0.2642,  0.4303, -0.1373,  0.0021,  0.5330,  0.2490],
         [ 0.2648,  0.4316, -0.1375,  0.0023,  0.5363,  0.2508]]],
       grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 6])


In [94]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False, num_heads=1):
        super().__init__()
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
                torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    def forward(self, inputs):
        b, num_tokens, d_in = inputs.shape
        query = self.W_query(inputs)
        key = self.W_key(inputs)
        value = self.W_value(inputs)
        queries = query.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = key.view(b, num_tokens, self.num_heads, self.head_dim)
        values = value.view(b, num_tokens, self.num_heads, self.head_dim)

        keys = keys.transpose(1,2)
        queries = queries.transpose(1,2)
        values = values.transpose(1,2)

        attn_scores = queries @ keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
  
        attn_scores.masked_fill_(mask_bool, -torch.inf)
 
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
 
        context_vec = (attn_weights @ values).transpose(1, 2)

        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec


In [101]:
batch_size, context_length, d_in = batch.shape
d_out = 8
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=8)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-9.2058e-05, -2.6669e-01,  1.6402e-01,  3.5730e-01, -5.3622e-02,
          -2.1625e-01, -6.7612e-02,  4.8780e-01],
         [ 2.9047e-02, -2.4400e-01,  1.9984e-01,  3.6479e-01, -2.2243e-04,
          -1.9901e-01, -9.1584e-02,  4.3615e-01],
         [ 4.0933e-02, -2.3906e-01,  2.1329e-01,  3.6909e-01,  1.5201e-02,
          -1.9030e-01, -1.0079e-01,  4.1939e-01],
         [ 4.3469e-02, -2.2697e-01,  2.3146e-01,  3.2855e-01,  5.4247e-03,
          -1.4201e-01, -9.1032e-02,  4.0699e-01],
         [ 4.1461e-02, -2.2034e-01,  2.0402e-01,  3.1036e-01, -3.8359e-02,
          -1.0335e-01, -8.3763e-02,  3.8687e-01],
         [ 4.2823e-02, -2.1431e-01,  2.2892e-01,  2.9669e-01, -1.7885e-02,
          -9.8132e-02, -8.2856e-02,  3.8911e-01]],

        [[-9.2058e-05, -2.6669e-01,  1.6402e-01,  3.5730e-01, -5.3622e-02,
          -2.1625e-01, -6.7612e-02,  4.8780e-01],
         [ 2.9047e-02, -2.4400e-01,  1.9984e-01,  3.6479e-01, -2.2243e-04,
          -1.9901e-01, -9.1584e-02,  4.3615e-01]