## Attention With Trainable Weights

In [2]:
import torch

In [3]:
inputs = torch.nn.Embedding(4,8)

In [4]:
inputs = inputs.weight.data
inputs

tensor([[ 0.0239, -1.7226,  0.7044,  0.4385,  0.4857,  0.1602,  1.4788,  0.3629],
        [ 1.2462,  0.1250,  0.3095,  0.2417,  0.2231,  0.7651,  0.7462, -1.3258],
        [-1.0885,  0.6456,  0.3154, -0.0432,  0.0079, -0.0806,  0.5046,  1.6984],
        [-2.6347,  0.2324,  0.0873, -1.0514, -0.3753,  1.2293, -0.2159,  0.1539]])

In [5]:
# Set dimensions
d_in = 8
d_out = 6
# Create Weight Matricies query, key, value with random entries
W_q = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_k = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_v = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [6]:
#choose an input vector (we're using 2) and transform it into our query vector using using W_q
query = inputs[2] @ W_q
query


tensor([ 0.6138,  1.2582,  0.9669,  1.5001,  0.7241, -0.1785])

In [7]:
#Calculate keys and values matrices
keys = inputs @ W_k
values = inputs @ W_v
print(keys)
print(values)

tensor([[ 1.4578,  1.0081,  1.3433,  1.8586,  1.5021,  1.6440],
        [ 2.1286,  1.1845,  1.6787,  1.3754,  0.4569,  0.6616],
        [ 0.6965,  2.1634,  0.3031,  1.0812,  1.4640,  0.9794],
        [-3.0087, -1.3744, -2.3723, -0.4954, -0.6695, -2.3872]])
tensor([[ 2.1312,  0.9543,  0.4384,  1.1125,  0.6249,  0.6792],
        [ 1.3873,  0.9746,  1.7408,  1.3421,  1.8992, -0.1012],
        [ 1.0119,  1.7650,  0.8197,  0.4706,  0.0826,  1.2183],
        [-0.4323, -0.5147, -0.0686, -1.1670, -2.1212, -1.1722]])


In [8]:
attention_scores = query @ keys.T
attention_scores

tensor([ 7.0442,  6.6958,  5.9496, -6.6715])

In [9]:
#normalize the scores to get the weights
# denominator is just because that's what the researchers figured out worked best for training 
attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim= -1)
attention_weights

tensor([0.3983, 0.3455, 0.2548, 0.0015])

In [10]:
#Make sure you did it right
attention_weights.sum()

tensor(1.0000)

In [11]:
# make the context vector
context_vector = attention_weights @ values
context_vector

tensor([1.5853, 1.1657, 0.9847, 1.0249, 0.9229, 0.5442])

## Making the Simple Attention Class Using Parameter 

In [12]:
import torch.nn as nn

In [13]:
#Create a simple attention class v1
class SimpleAttention(nn.Module):
    #Constructor, initialize those matrix dimensions
    def __init__(self, d_in, d_out):
        super().__init__()
        #Create weight matrices
        self.W_q = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_k = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_v = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

    #   x = embedding vectors (inputs)
    def forward(self,x):
        #Make the queries, keys and values
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        #Compute scores, then normalize into weights
        scores = queries @ keys.T
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        #Compute context vector using weights and values
        context = weights @ values
        return context


In [14]:
#Using the class example
#instatiate an instance
simple = SimpleAttention(d_in = 8, d_out = 6)


In [15]:
#It has attributes!
simple.W_k

Parameter containing:
tensor([[0.5559, 0.0741, 0.6156, 0.2810, 0.7015, 0.6712],
        [0.7562, 0.4573, 0.4618, 0.2484, 0.4273, 0.4339],
        [0.7648, 0.4864, 0.6730, 0.1632, 0.9166, 0.0370],
        [0.7251, 0.4688, 0.3398, 0.2526, 0.0536, 0.2267],
        [0.0152, 0.5302, 0.5565, 0.7878, 0.8187, 0.4568],
        [0.3125, 0.1330, 0.2019, 0.5791, 0.4915, 0.9064],
        [0.3664, 0.3168, 0.8228, 0.7732, 0.4664, 0.4628],
        [0.8123, 0.8187, 0.6959, 0.5389, 0.5679, 0.6873]])

In [16]:
#Class returns the context vector
context_vectors = simple(inputs)
context_vectors

tensor([[ 0.8045,  1.3429,  0.3836,  1.1769,  1.4100,  1.0680],
        [ 0.7073,  1.4433,  0.2528,  1.2199,  1.4084,  1.0402],
        [ 0.6566,  1.4735,  0.1903,  1.2191,  1.3733,  0.9962],
        [-0.8812, -2.4792, -0.0976, -1.3198,  0.2601, -1.2102]])

## Making the Simple Attention Class Again using nn.Linear

In [30]:
#Create a simple attention class v2
#Using nn.Linear to be more efficient

class SimpleAttention(nn.Module):
    #Constructor, initialize those matrix dimensions
    def __init__(self, d_in, d_out):
        super().__init__()
        #Create weight matrices
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)

    #   x = embedding vectors (inputs)
    def forward(self,x):
        #Make the queries, keys and values
        #Still just muktiplying the matrices
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v (x)
        #Compute scores, then normalize into weights
        scores = queries @ keys.T
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        #Compute context vector using weights and values
        context = weights @ values
        return context


In [31]:
simple = SimpleAttention(d_in = 8, d_out = 6)

In [32]:
context_vectors = simple(inputs)
context_vectors

tensor([[ 0.3095,  0.2706, -0.0924, -0.0816,  0.5705,  0.0940],
        [ 0.2962,  0.2670, -0.1222, -0.1218,  0.5529,  0.1088],
        [ 0.3377,  0.2674, -0.0573, -0.0208,  0.5803,  0.0515],
        [ 0.3687,  0.3480, -0.0360, -0.0423,  0.6836,  0.0205]],
       grad_fn=<MmBackward0>)

In [28]:
# the problem here is that each context vector uses information from ALL of the embedding vectors
#In practice, we should only use information about preceeding vectors
# To do so, implement causal attention AKA masked attention

### Simple Masked Attention (1.0)

In [None]:
#Hack to get some weights
#weights = simple(inputs)
weights

tensor([[0.2237, 0.2290, 0.2012, 0.3461],
        [0.2439, 0.3024, 0.1850, 0.2687],
        [0.1952, 0.1441, 0.3208, 0.3399],
        [0.2289, 0.2066, 0.3235, 0.2410]], grad_fn=<SoftmaxBackward0>)

In [51]:
#Note that these are already normalized
weights.sum(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [None]:
#tril returns the lower triangular parts of the matrix (and the diagonal), turns the rest to zero
simple_mask = torch.tril(torch.ones(weights.shape[0], weights.shape[0]))
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [53]:
# * is a cell by cell product
masked_weights = weights*simple_mask
masked_weights

tensor([[0.2237, 0.0000, 0.0000, 0.0000],
        [0.2439, 0.3024, 0.0000, 0.0000],
        [0.1952, 0.1441, 0.3208, 0.0000],
        [0.2289, 0.2066, 0.3235, 0.2410]], grad_fn=<MulBackward0>)

In [54]:
#no longer normalized...
masked_weights.sum(dim=-1)

tensor([0.2237, 0.5463, 0.6601, 1.0000], grad_fn=<SumBackward1>)

In [55]:
#Need to normalize masked weights
row_sums = masked_weights.sum(dim=-1, keepdim=True)
row_sums

tensor([[0.2237],
        [0.5463],
        [0.6601],
        [1.0000]], grad_fn=<SumBackward1>)

In [None]:
#Dividing by row sums normalizes
masked_weights = masked_weights / row_sums
masked_weights.sum(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

### More Efficient Masking Method (1.1)

In [61]:
# Masking method 2
# triu() returns the upper triangular part of the matrix, puts zeros on the diagonal
mask = torch.triu(torch.ones(weights.shape[0], weights.shape[0]), diagonal=1)
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [62]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [63]:
weights

tensor([[0.2237, 0.2290, 0.2012, 0.3461],
        [0.2439, 0.3024, 0.1850, 0.2687],
        [0.1952, 0.1441, 0.3208, 0.3399],
        [0.2289, 0.2066, 0.3235, 0.2410]], grad_fn=<SoftmaxBackward0>)

In [None]:
#Need a matrix of true/false to fill in, fills the true spots with the given value (-inf)
weights = weights.masked_fill(mask.bool(), -torch.inf, )
weights

tensor([[0.2237,   -inf,   -inf,   -inf],
        [0.2439, 0.3024,   -inf,   -inf],
        [0.1952, 0.1441, 0.3208,   -inf],
        [0.2289, 0.2066, 0.3235, 0.2410]], grad_fn=<MaskedFillBackward0>)

In [67]:
masked_weights = torch.softmax(weights, dim=-1)
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4854, 0.5146, 0.0000, 0.0000],
        [0.3242, 0.3081, 0.3677, 0.0000],
        [0.2445, 0.2391, 0.2688, 0.2475]], grad_fn=<SoftmaxBackward0>)

### Dropout Explanation

In [None]:
## Dropout helps us avoid overfitting during training
# Randomly and uniformly ignoring certain data points
# we must set a dropout rate
dropout = nn.Dropout(0.5)

In [None]:
# It removes around half of the data points ...
dropout(masked_weights)

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.0292, 0.0000, 0.0000],
        [0.6485, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4783, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

## Causal Attenion Class

In [69]:
# we need to be able to give our LLM batches of input
# For example:
# Concatenates two tensors
batches = torch.stack((inputs, inputs), dim=0)

In [70]:
batches

tensor([[[ 0.0239, -1.7226,  0.7044,  0.4385,  0.4857,  0.1602,  1.4788,
           0.3629],
         [ 1.2462,  0.1250,  0.3095,  0.2417,  0.2231,  0.7651,  0.7462,
          -1.3258],
         [-1.0885,  0.6456,  0.3154, -0.0432,  0.0079, -0.0806,  0.5046,
           1.6984],
         [-2.6347,  0.2324,  0.0873, -1.0514, -0.3753,  1.2293, -0.2159,
           0.1539]],

        [[ 0.0239, -1.7226,  0.7044,  0.4385,  0.4857,  0.1602,  1.4788,
           0.3629],
         [ 1.2462,  0.1250,  0.3095,  0.2417,  0.2231,  0.7651,  0.7462,
          -1.3258],
         [-1.0885,  0.6456,  0.3154, -0.0432,  0.0079, -0.0806,  0.5046,
           1.6984],
         [-2.6347,  0.2324,  0.0873, -1.0514, -0.3753,  1.2293, -0.2159,
           0.1539]]])

In [None]:
#this class needs to handle batches of input
class CausalAttention(nn.module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        #Create weight matrices
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)
        self.dropout = Dropout(dropout)

    #   x = embedding vectors (inputs)
    def forward(self,x):
        #Make the queries, keys and values
        #Still just muktiplying the matrices
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v (x)
        #Compute scores, then normalize into weights
        scores = queries @ keys.T
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        #Compute context vector using weights and values
        context = weights @ values
        return context