## Attention With Trainable Weights

In [1]:
import torch

In [2]:
inputs = torch.nn.Embedding(4,8)

In [3]:
inputs = inputs.weight.data
inputs

tensor([[ 0.3476,  1.0008,  1.4503,  0.3622,  0.7486, -0.0103, -0.4972, -2.9958],
        [-0.4068, -0.5580, -0.1595, -0.7773, -0.5482, -0.2031,  0.2543,  1.1040],
        [-0.8071,  0.7434,  0.9972, -0.2365,  0.6343,  0.2737, -0.2480,  2.2735],
        [ 0.8415, -1.1536,  1.2855,  0.5882,  0.3148, -0.2748, -1.1443,  2.0168]])

In [4]:
# Set dimensions
d_in = 8
d_out = 6
# Create Weight Matricies query, key, value with random entries
W_q = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_k = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_v = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [5]:
#choose an input vector (we're using 2) and transform it into our query vector using using W_q
query = inputs[2] @ W_q
query


tensor([2.1511, 1.3342, 1.9950, 2.2145, 3.2097, 1.1408])

In [6]:
#Calculate keys and values matrices
keys = inputs @ W_k
values = inputs @ W_v
print(keys)
print(values)

tensor([[-0.3760,  1.0337, -1.1583, -0.2894, -0.4050,  1.6662],
        [-0.4264, -0.8371, -0.4320, -0.0674, -0.6597, -1.5198],
        [ 1.3435,  0.7114,  3.0836,  2.1317,  1.5462,  1.1950],
        [-0.8361, -0.5026,  1.9398,  1.9284,  1.8886, -0.0835]])
tensor([[ 0.3589,  1.2143, -0.6569,  0.8759,  1.0961, -0.6289],
        [-0.6003, -1.1869, -0.3515, -0.9751, -1.3323, -0.0853],
        [ 2.9386,  0.3470,  2.3401,  1.7831,  0.8412,  2.7903],
        [ 2.5762, -1.2917,  1.0480,  1.4931,  1.0715,  1.5268]])


In [7]:
attention_scores = query @ keys.T
attention_scores

tensor([-1.7802, -6.8963, 21.0379, 11.6379])

In [8]:
#normalize the scores to get the weights
# denominator is just because that's what the researchers figured out worked best for training 
attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim= -1)
attention_weights

tensor([8.8114e-05, 1.0913e-05, 9.7881e-01, 2.1090e-02])

In [9]:
#Make sure you did it right
attention_weights.sum()

tensor(1.0000)

In [10]:
# make the context vector
context_vector = attention_weights @ values
context_vector

tensor([2.9307, 0.3125, 2.3125, 1.7769, 0.8460, 2.7633])

## Making the Simple Attention Class Using Parameter 

In [11]:
import torch.nn as nn

In [12]:
#Create a simple attention class v1
class SimpleAttention(nn.Module):
    #Constructor, initialize those matrix dimensions
    def __init__(self, d_in, d_out):
        super().__init__()
        #Create weight matrices
        self.W_q = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_k = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_v = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

    #   x = embedding vectors (inputs)
    def forward(self,x):
        #Make the queries, keys and values
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        #Compute scores, then normalize into weights
        scores = queries @ keys.T
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        #Compute context vector using weights and values
        context = weights @ values
        return context


In [13]:
#Using the class example
#instatiate an instance
simple = SimpleAttention(d_in = 8, d_out = 6)


In [14]:
#It has attributes!
simple.W_k

Parameter containing:
tensor([[0.9237, 0.6093, 0.1456, 0.4082, 0.4097, 0.6360],
        [0.7457, 0.0569, 0.3666, 0.2114, 0.4545, 0.4032],
        [0.9186, 0.6230, 0.0799, 0.5688, 0.5827, 0.2930],
        [0.9995, 0.6606, 0.4155, 0.4434, 0.5952, 0.9104],
        [0.2912, 0.7817, 0.8116, 0.6722, 0.2601, 0.3147],
        [0.2672, 0.2461, 0.2986, 0.3389, 0.9953, 0.3255],
        [0.0884, 0.8030, 0.1776, 0.0880, 0.7437, 0.5777],
        [0.1080, 0.9056, 0.7430, 0.6177, 0.7764, 0.7794]])

In [15]:
#Class returns the context vector
context_vectors = simple(inputs)
context_vectors

tensor([[ 0.5131, -0.9164, -0.1836,  1.5753, -0.6551,  0.4312],
        [-0.2766,  0.0189, -0.2676, -0.2599, -0.0839, -0.4699],
        [ 0.1757,  1.4299,  0.8559,  2.2200,  1.6469,  2.4665],
        [ 0.0137,  1.4395,  0.8817,  2.1822,  1.5771,  2.3379]])

## Making the Simple Attention Class Again using nn.Linear

In [24]:
#Create a simple attention class v2
#Using nn.Linear to be more efficient

class SimpleAttention(nn.Module):
    #Constructor, initialize those matrix dimensions
    def __init__(self, d_in, d_out):
        super().__init__()
        #Create weight matrices
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)

    #   x = embedding vectors (inputs)
    def forward(self,x):
        #Make the queries, keys and values
        #Still just muktiplying the matrices
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v (x)
        #Compute scores, then normalize into weights
        scores = queries @ keys.T
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        #Compute context vector using weights and values
        context = weights @ values
        return context


In [25]:
simple = SimpleAttention(d_in = 8, d_out = 4)

In [26]:
context_vectors = simple(inputs)
context_vectors

tensor([[ 0.0883,  0.5871,  0.1271,  0.0613],
        [-0.1756,  0.3361,  0.1367, -0.2108],
        [-0.1097,  0.3380,  0.1416, -0.1644],
        [-0.1760,  0.3377,  0.1179, -0.1991]], grad_fn=<MmBackward0>)

In [19]:
# the problem here is that each context vector uses information from ALL of the embedding vectors
#In practice, we should only use information about preceeding vectors
# To do so, implement causal attention AKA masked attention

### Simple Masked Attention (1.0)

In [None]:
#Hack to get some weights
#weights = simple(inputs)
weights

tensor([[-0.2313, -0.4622, -0.3702,  0.2098],
        [-0.0081, -0.1016, -0.1492, -0.1211],
        [-0.1679, -0.3332, -0.2500,  0.0806],
        [-0.1755, -0.3499, -0.2763,  0.1009]], grad_fn=<MmBackward0>)

In [27]:
#Note that these are already normalized
weights.sum(dim=-1)

tensor([-0.8540, -0.3801, -0.6706, -0.7008], grad_fn=<SumBackward1>)

In [28]:
#tril returns the lower triangular parts of the matrix (and the diagonal), turns the rest to zero
simple_mask = torch.tril(torch.ones(weights.shape[0], weights.shape[0]))
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [29]:
# * is a cell by cell product
masked_weights = weights*simple_mask
masked_weights

tensor([[-0.2313, -0.0000, -0.0000,  0.0000],
        [-0.0081, -0.1016, -0.0000, -0.0000],
        [-0.1679, -0.3332, -0.2500,  0.0000],
        [-0.1755, -0.3499, -0.2763,  0.1009]], grad_fn=<MulBackward0>)

In [30]:
#no longer normalized...
masked_weights.sum(dim=-1)

tensor([-0.2313, -0.1098, -0.7512, -0.7008], grad_fn=<SumBackward1>)

In [31]:
#Need to normalize masked weights
row_sums = masked_weights.sum(dim=-1, keepdim=True)
row_sums

tensor([[-0.2313],
        [-0.1098],
        [-0.7512],
        [-0.7008]], grad_fn=<SumBackward1>)

In [32]:
#Dividing by row sums normalizes
masked_weights = masked_weights / row_sums
masked_weights.sum(dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

### More Efficient Masking Method (1.1)

In [33]:
# Masking method 2
# triu() returns the upper triangular part of the matrix, puts zeros on the diagonal
mask = torch.triu(torch.ones(weights.shape[0], weights.shape[0]), diagonal=1)
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [34]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [35]:
weights

tensor([[-0.2313, -0.4622, -0.3702,  0.2098],
        [-0.0081, -0.1016, -0.1492, -0.1211],
        [-0.1679, -0.3332, -0.2500,  0.0806],
        [-0.1755, -0.3499, -0.2763,  0.1009]], grad_fn=<MmBackward0>)

In [36]:
#Need a matrix of true/false to fill in, fills the true spots with the given value (-inf)
weights = weights.masked_fill(mask.bool(), -torch.inf, )
weights

tensor([[-0.2313,    -inf,    -inf,    -inf],
        [-0.0081, -0.1016,    -inf,    -inf],
        [-0.1679, -0.3332, -0.2500,    -inf],
        [-0.1755, -0.3499, -0.2763,  0.1009]], grad_fn=<MaskedFillBackward0>)

In [37]:
masked_weights = torch.softmax(weights, dim=-1)
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5234, 0.4766, 0.0000, 0.0000],
        [0.3612, 0.3061, 0.3327, 0.0000],
        [0.2462, 0.2068, 0.2226, 0.3245]], grad_fn=<SoftmaxBackward0>)

### Dropout Explanation

In [38]:
## Dropout helps us avoid overfitting during training 
# (don't want to overfit to training set bc then model cannot generalize)
# Randomly and uniformly ignoring certain data points
# we must set a dropout rate
dropout = nn.Dropout(0.5)

In [39]:
# It removes around half of the data points ...
dropout(masked_weights)

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [1.0467, 0.9533, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4135, 0.0000, 0.6490]], grad_fn=<MulBackward0>)

## Causal Attenion Class

In [40]:
# we need to be able to give our LLM batches of input
# For example:
# Concatenates two tensors
batches = torch.stack((inputs, inputs), dim=0)

In [41]:
batches

tensor([[[ 0.3476,  1.0008,  1.4503,  0.3622,  0.7486, -0.0103, -0.4972,
          -2.9958],
         [-0.4068, -0.5580, -0.1595, -0.7773, -0.5482, -0.2031,  0.2543,
           1.1040],
         [-0.8071,  0.7434,  0.9972, -0.2365,  0.6343,  0.2737, -0.2480,
           2.2735],
         [ 0.8415, -1.1536,  1.2855,  0.5882,  0.3148, -0.2748, -1.1443,
           2.0168]],

        [[ 0.3476,  1.0008,  1.4503,  0.3622,  0.7486, -0.0103, -0.4972,
          -2.9958],
         [-0.4068, -0.5580, -0.1595, -0.7773, -0.5482, -0.2031,  0.2543,
           1.1040],
         [-0.8071,  0.7434,  0.9972, -0.2365,  0.6343,  0.2737, -0.2480,
           2.2735],
         [ 0.8415, -1.1536,  1.2855,  0.5882,  0.3148, -0.2748, -1.1443,
           2.0168]]])

In [42]:
#this class needs to handle batches of input
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        #Create weight matrices
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)
        # include dropout
        self.dropout = nn.Dropout(dropout)
        #Use the following to manage memory efficiently:
        self.register_buffer(
            'mask', 
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
            )

    #   x = embedding vectors (inputs)
    def forward(self,x):
        b, num_tokens, d_in = x.shape
        #Make the queries, keys and values
        #Still just muktiplying the matrices
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v (x)
        #Compute scores, then normalize into weights (doing dot products of queries and keys)
        scores = queries @ keys.transpose(1,2)
        #masking process: turns into trues/falses then converts some into -inf which is normalized to 0 in next step
        scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        #Normalize, -inf turns into zeros
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        weights = self.dropout(weights)
        #Compute context vector using weights and values
        context = weights @ values
        return context

In [43]:
#Instantiate a causal attention mechanism
causal = CausalAttention(d_in=8, d_out=4, context_length= 4, dropout=0)

In [44]:
causal(batches)

tensor([[[ 0.1315, -0.0068, -0.5639, -1.1360],
         [ 0.0835, -0.0535, -0.2836, -0.3171],
         [-0.0576, -0.1803, -0.1704, -0.1978],
         [-0.1245, -0.3059,  0.0398,  0.0690]],

        [[ 0.1315, -0.0068, -0.5639, -1.1360],
         [ 0.0835, -0.0535, -0.2836, -0.3171],
         [-0.0576, -0.1803, -0.1704, -0.1978],
         [-0.1245, -0.3059,  0.0398,  0.0690]]], grad_fn=<UnsafeViewBackward0>)

### an example of how to transpose batches correctly

In [45]:
W_q = nn.Linear(d_in, d_out, bias=False)
W_k = nn.Linear(d_in, d_out, bias=False)
W_v = nn.Linear(d_in, d_out, bias=False)

In [46]:
queries = W_q(batches)
queries

tensor([[[ 0.0670,  0.2508, -0.4446, -0.4527, -0.3071,  1.5719],
         [ 0.0319, -0.0357,  0.3191,  0.1176,  0.1593, -0.5529],
         [-0.7918, -0.4235,  0.8298,  0.1500,  0.4593,  0.1080],
         [-0.9608, -0.2707,  0.7450, -0.3383,  0.9127, -0.4438]],

        [[ 0.0670,  0.2508, -0.4446, -0.4527, -0.3071,  1.5719],
         [ 0.0319, -0.0357,  0.3191,  0.1176,  0.1593, -0.5529],
         [-0.7918, -0.4235,  0.8298,  0.1500,  0.4593,  0.1080],
         [-0.9608, -0.2707,  0.7450, -0.3383,  0.9127, -0.4438]]],
       grad_fn=<UnsafeViewBackward0>)

In [47]:
keys = W_k(batches)
keys

tensor([[[-0.4463, -0.8833, -0.9793,  0.5838, -0.1625,  1.1621],
         [ 0.0057,  0.2712,  0.3750, -0.0250,  0.2763, -0.4203],
         [ 0.2456, -0.4015,  0.7905, -0.4507,  0.7702, -0.4926],
         [ 0.3099, -0.4570,  0.4964, -0.6015,  0.0530, -0.3242]],

        [[-0.4463, -0.8833, -0.9793,  0.5838, -0.1625,  1.1621],
         [ 0.0057,  0.2712,  0.3750, -0.0250,  0.2763, -0.4203],
         [ 0.2456, -0.4015,  0.7905, -0.4507,  0.7702, -0.4926],
         [ 0.3099, -0.4570,  0.4964, -0.6015,  0.0530, -0.3242]]],
       grad_fn=<UnsafeViewBackward0>)

In [48]:
#This transposes without mixing up the batches
keys.transpose(1,2)

tensor([[[-0.4463,  0.0057,  0.2456,  0.3099],
         [-0.8833,  0.2712, -0.4015, -0.4570],
         [-0.9793,  0.3750,  0.7905,  0.4964],
         [ 0.5838, -0.0250, -0.4507, -0.6015],
         [-0.1625,  0.2763,  0.7702,  0.0530],
         [ 1.1621, -0.4203, -0.4926, -0.3242]],

        [[-0.4463,  0.0057,  0.2456,  0.3099],
         [-0.8833,  0.2712, -0.4015, -0.4570],
         [-0.9793,  0.3750,  0.7905,  0.4964],
         [ 0.5838, -0.0250, -0.4507, -0.6015],
         [-0.1625,  0.2763,  0.7702,  0.0530],
         [ 1.1621, -0.4203, -0.4926, -0.3242]]], grad_fn=<TransposeBackward0>)

##  Multi-Head Attention 1.0 (not efficient)

In [68]:
# Different attention heads (above is one attention head) pay attention to different things
# A bunch of causal attention instances together

class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


In [69]:
mha = MultiHeadAttention(d_in=8, d_out= 4, context_length= 4, dropout= 0, num_heads=3)

## Real Multi-Head Attention Mechanism

In [None]:
# Much more efficient, only one matrix multiplicaiton step (all multiplication for all attention heads once)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec
