## Attention With Trainable Weights

In [3]:
import torch

In [4]:
inputs = torch.nn.Embedding(4,8)

In [5]:
inputs = inputs.weight.data
inputs

tensor([[ 1.1235,  0.3848,  0.2176, -0.1915,  0.3218, -1.1199,  0.4628,  1.0018],
        [-0.2260,  1.5758,  0.1162, -1.9782, -0.3457, -1.1353,  1.3096,  0.4596],
        [-0.3055,  0.7803, -2.1933,  0.9615,  0.2858,  1.4231,  0.7304,  0.7601],
        [ 0.1555, -2.5241,  1.0640,  0.1207,  0.7441, -1.2528,  1.6269, -1.0157]])

In [6]:
# Set dimensions
d_in = 8
d_out = 6
# Create Weight Matricies query, key, value with random entries
W_q = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_k = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_v = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [7]:
#choose an input vector (we're using 2) and transform it into our query vector using using W_q
query = inputs[2] @ W_q
query


tensor([1.6538, 1.5932, 1.6082, 0.8148, 0.4994, 1.3328])

In [8]:
#Calculate keys and values matrices
keys = inputs @ W_k
values = inputs @ W_v
print(keys)
print(values)

tensor([[ 1.1261,  0.3428,  0.6490,  1.6041,  0.8039,  0.6276],
        [-0.3990, -1.6304, -0.3453, -0.6979,  0.7171,  0.0516],
        [ 2.2449,  2.9626,  0.7028,  0.0294,  2.8479,  1.7154],
        [-0.2247, -2.1196,  0.9183, -0.3535, -1.9207,  0.3464]])
tensor([[ 1.6006,  0.6530,  1.9231,  0.7862,  1.2701,  1.1857],
        [-0.2345, -0.8941, -0.1930, -1.8626,  0.3781, -0.9824],
        [ 1.4429,  0.8960,  2.4966,  0.7620,  0.5672,  0.9624],
        [-0.0681, -0.5490, -2.6490,  1.0066, -0.5846,  0.9395]])


In [9]:
attention_scores = query @ keys.T
attention_scores

tensor([ 5.9973, -3.9544, 13.2953, -3.0572])

In [10]:
#normalize the scores to get the weights
# denominator is just because that's what the researchers figured out worked best for training 
attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim= -1)
attention_weights

tensor([4.8267e-02, 8.3025e-04, 9.4971e-01, 1.1975e-03])

In [11]:
#Make sure you did it right
attention_weights.sum()

tensor(1.)

In [12]:
# make the context vector
context_vector = attention_weights @ values
context_vector

tensor([1.4473, 0.8811, 2.4605, 0.7613, 0.5996, 0.9715])

## Making the Simple Attention Class Using Parameter 

In [13]:
import torch.nn as nn

In [14]:
#Create a simple attention class v1
class SimpleAttention(nn.Module):
    #Constructor, initialize those matrix dimensions
    def __init__(self, d_in, d_out):
        super().__init__()
        #Create weight matrices
        self.W_q = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_k = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_v = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

    #   x = embedding vectors (inputs)
    def forward(self,x):
        #Make the queries, keys and values
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        #Compute scores, then normalize into weights
        scores = queries @ keys.T
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        #Compute context vector using weights and values
        context = weights @ values
        return context


In [15]:
#Using the class example
#instatiate an instance
simple = SimpleAttention(d_in = 8, d_out = 6)


In [16]:
#It has attributes!
simple.W_k

Parameter containing:
tensor([[0.1909, 0.7732, 0.1293, 0.9050, 0.7625, 0.0677],
        [0.5439, 0.8092, 0.6802, 0.0667, 0.9650, 0.8042],
        [0.5610, 0.0248, 0.1272, 0.1991, 0.9177, 0.0663],
        [0.9340, 0.2450, 0.0773, 0.9778, 0.0867, 0.3319],
        [0.8942, 0.1181, 0.7229, 0.6745, 0.9474, 0.2715],
        [0.3403, 0.5913, 0.8111, 0.3381, 0.6692, 0.8602],
        [0.7867, 0.0910, 0.5758, 0.8477, 0.7767, 0.7543],
        [0.2919, 0.8024, 0.7835, 0.4280, 0.0323, 0.3201]])

In [17]:
#Class returns the context vector
context_vectors = simple(inputs)
context_vectors

tensor([[ 1.4453,  0.9395,  1.9048,  1.3181,  0.5674,  1.5561],
        [-0.6997,  0.5977, -0.4032, -0.6267,  0.0862, -1.3821],
        [ 1.4512,  0.6566,  2.5537,  1.9153,  0.5605,  1.8776],
        [-0.7949,  0.6097, -0.6143, -0.9525,  0.0121, -1.5324]])

## Making the Simple Attention Class Again using nn.Linear

In [41]:
#Create a simple attention class v2
#Using nn.Linear to be more efficient

class SimpleAttention(nn.Module):
    #Constructor, initialize those matrix dimensions
    def __init__(self, d_in, d_out):
        super().__init__()
        #Create weight matrices
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)

    #   x = embedding vectors (inputs)
    def forward(self,x):
        #Make the queries, keys and values
        #Still just muktiplying the matrices
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v (x)
        #Compute scores, then normalize into weights
        scores = queries @ keys.T
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        #Compute context vector using weights and values
        context = weights @ values
        return context


In [37]:
simple = SimpleAttention(d_in = 8, d_out = 4)

In [38]:
context_vectors = simple(inputs)
context_vectors

tensor([[0.2320, 0.2654, 0.2398, 0.2627],
        [0.1815, 0.1899, 0.2234, 0.4052],
        [0.2624, 0.2063, 0.3141, 0.2171],
        [0.2308, 0.4676, 0.1540, 0.1475]], grad_fn=<SoftmaxBackward0>)

In [39]:
# the problem here is that each context vector uses information from ALL of the embedding vectors
#In practice, we should only use information about preceeding vectors
# To do so, implement causal attention AKA masked attention

### Simple Masked Attention (1.0)

In [None]:
#Hack to get some weights
#weights = simple(inputs)
weights

tensor([[0.2320, 0.2654, 0.2398, 0.2627],
        [0.1815, 0.1899, 0.2234, 0.4052],
        [0.2624, 0.2063, 0.3141, 0.2171],
        [0.2308, 0.4676, 0.1540, 0.1475]], grad_fn=<SoftmaxBackward0>)

In [42]:
#Note that these are already normalized
weights.sum(dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [43]:
#tril returns the lower triangular parts of the matrix (and the diagonal), turns the rest to zero
simple_mask = torch.tril(torch.ones(weights.shape[0], weights.shape[0]))
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [48]:
# * is a cell by cell product
masked_weights = weights*simple_mask
masked_weights

tensor([[0.2320, 0.0000, 0.0000, 0.0000],
        [0.1815, 0.1899, 0.0000, 0.0000],
        [0.2624, 0.2063, 0.3141, 0.0000],
        [0.2308, 0.4676, 0.1540, 0.1475]], grad_fn=<MulBackward0>)

In [49]:
#no longer normalized...
masked_weights.sum(dim=-1)

tensor([0.2320, 0.3715, 0.7829, 1.0000], grad_fn=<SumBackward1>)

In [50]:
#Need to normalize masked weights
row_sums = masked_weights.sum(dim=-1, keepdim=True)
row_sums

tensor([[0.2320],
        [0.3715],
        [0.7829],
        [1.0000]], grad_fn=<SumBackward1>)

In [51]:
#Dividing by row sums normalizes
masked_weights = masked_weights / row_sums
masked_weights.sum(dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

### More Efficient Masking Method (1.1)

In [52]:
# Masking method 2
# triu() returns the upper triangular part of the matrix, puts zeros on the diagonal
mask = torch.triu(torch.ones(weights.shape[0], weights.shape[0]), diagonal=1)
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [53]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [54]:
weights

tensor([[0.2320, 0.2654, 0.2398, 0.2627],
        [0.1815, 0.1899, 0.2234, 0.4052],
        [0.2624, 0.2063, 0.3141, 0.2171],
        [0.2308, 0.4676, 0.1540, 0.1475]], grad_fn=<SoftmaxBackward0>)

In [55]:
#Need a matrix of true/false to fill in, fills the true spots with the given value (-inf)
weights = weights.masked_fill(mask.bool(), -torch.inf, )
weights

tensor([[0.2320,   -inf,   -inf,   -inf],
        [0.1815, 0.1899,   -inf,   -inf],
        [0.2624, 0.2063, 0.3141,   -inf],
        [0.2308, 0.4676, 0.1540, 0.1475]], grad_fn=<MaskedFillBackward0>)

In [56]:
masked_weights = torch.softmax(weights, dim=-1)
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4979, 0.5021, 0.0000, 0.0000],
        [0.3335, 0.3153, 0.3512, 0.0000],
        [0.2431, 0.3081, 0.2251, 0.2237]], grad_fn=<SoftmaxBackward0>)

### Dropout Explanation

In [None]:
## Dropout helps us avoid overfitting during training 
# (don't want to overfit to training set bc then model cannot generalize)
# Randomly and uniformly ignoring certain data points
# we must set a dropout rate
dropout = nn.Dropout(0.5)

In [58]:
# It removes around half of the data points ...
dropout(masked_weights)

tensor([[2.0000, 0.0000, 0.0000, 0.0000],
        [0.9958, 1.0042, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.7024, 0.0000],
        [0.4862, 0.6161, 0.0000, 0.0000]], grad_fn=<MulBackward0>)

## Causal Attenion Class

In [59]:
# we need to be able to give our LLM batches of input
# For example:
# Concatenates two tensors
batches = torch.stack((inputs, inputs), dim=0)

In [62]:
batches

tensor([[[ 1.1235,  0.3848,  0.2176, -0.1915,  0.3218, -1.1199,  0.4628,
           1.0018],
         [-0.2260,  1.5758,  0.1162, -1.9782, -0.3457, -1.1353,  1.3096,
           0.4596],
         [-0.3055,  0.7803, -2.1933,  0.9615,  0.2858,  1.4231,  0.7304,
           0.7601],
         [ 0.1555, -2.5241,  1.0640,  0.1207,  0.7441, -1.2528,  1.6269,
          -1.0157]],

        [[ 1.1235,  0.3848,  0.2176, -0.1915,  0.3218, -1.1199,  0.4628,
           1.0018],
         [-0.2260,  1.5758,  0.1162, -1.9782, -0.3457, -1.1353,  1.3096,
           0.4596],
         [-0.3055,  0.7803, -2.1933,  0.9615,  0.2858,  1.4231,  0.7304,
           0.7601],
         [ 0.1555, -2.5241,  1.0640,  0.1207,  0.7441, -1.2528,  1.6269,
          -1.0157]]])

In [75]:
#this class needs to handle batches of input
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        #Create weight matrices
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)
        # include dropout
        self.dropout = nn.Dropout(dropout)
        #Use the following to manage memory efficiently:
        self.register_buffer(
            'mask', 
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
            )

    #   x = embedding vectors (inputs)
    def forward(self,x):
        b, num_tokens, d_in = x.shape
        #Make the queries, keys and values
        #Still just muktiplying the matrices
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v (x)
        #Compute scores, then normalize into weights (doing dot products of queries and keys)
        scores = queries @ keys.transpose(1,2)
        #masking process: turns into trues/falses then converts some into -inf which is normalized to 0 in next step
        scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        #Normalize, -inf turns into zeros
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        weights = self.dropout(weights)
        #Compute context vector using weights and values
        context = weights @ values
        return context

In [76]:
#Instantiate a causal attention mechanism
causal = CausalAttention(d_in=8, d_out=4, context_length= 4, dropout=0)

In [77]:
causal(batches)

tensor([[[-0.5388, -0.1939, -0.3264,  0.0441],
         [-0.1614, -0.1862, -0.0703,  0.0295],
         [ 0.1411,  0.0746, -0.0504, -0.3851],
         [-0.0500,  0.2988, -0.6009, -0.0190]],

        [[-0.5388, -0.1939, -0.3264,  0.0441],
         [-0.1614, -0.1862, -0.0703,  0.0295],
         [ 0.1411,  0.0746, -0.0504, -0.3851],
         [-0.0500,  0.2988, -0.6009, -0.0190]]], grad_fn=<UnsafeViewBackward0>)

### an example of how to transpose batches correctly

In [67]:
W_q = nn.Linear(d_in, d_out, bias=False)
W_k = nn.Linear(d_in, d_out, bias=False)
W_v = nn.Linear(d_in, d_out, bias=False)

In [68]:
queries = W_q(batches)
queries

tensor([[[ 0.2543, -0.1610,  0.0497, -0.9070, -0.5010,  0.3876],
         [-0.8582, -0.3220,  0.1259, -0.9665,  0.5995,  0.5388],
         [-1.2267, -0.5747,  0.4660,  0.6398, -0.0497,  0.2588],
         [ 0.5372,  0.8420, -0.9136,  0.1309,  0.2768, -0.7435]],

        [[ 0.2543, -0.1610,  0.0497, -0.9070, -0.5010,  0.3876],
         [-0.8582, -0.3220,  0.1259, -0.9665,  0.5995,  0.5388],
         [-1.2267, -0.5747,  0.4660,  0.6398, -0.0497,  0.2588],
         [ 0.5372,  0.8420, -0.9136,  0.1309,  0.2768, -0.7435]]],
       grad_fn=<UnsafeViewBackward0>)

In [69]:
keys = W_k(batches)
keys

tensor([[[-0.0905,  0.1475, -0.8634, -0.1281,  0.2515, -0.3039],
         [-0.1949,  0.4223, -0.0847, -1.2047,  1.0438, -0.7071],
         [ 0.1571,  0.9005, -1.0104, -0.1815, -0.5400,  0.4482],
         [ 0.0291, -0.0851, -0.3023,  0.4536, -0.6966, -0.2584]],

        [[-0.0905,  0.1475, -0.8634, -0.1281,  0.2515, -0.3039],
         [-0.1949,  0.4223, -0.0847, -1.2047,  1.0438, -0.7071],
         [ 0.1571,  0.9005, -1.0104, -0.1815, -0.5400,  0.4482],
         [ 0.0291, -0.0851, -0.3023,  0.4536, -0.6966, -0.2584]]],
       grad_fn=<UnsafeViewBackward0>)

In [None]:
#This transposes without mixing up the batches
keys.transpose(1,2)

tensor([[[-0.0905, -0.1949,  0.1571,  0.0291],
         [ 0.1475,  0.4223,  0.9005, -0.0851],
         [-0.8634, -0.0847, -1.0104, -0.3023],
         [-0.1281, -1.2047, -0.1815,  0.4536],
         [ 0.2515,  1.0438, -0.5400, -0.6966],
         [-0.3039, -0.7071,  0.4482, -0.2584]],

        [[-0.0905, -0.1949,  0.1571,  0.0291],
         [ 0.1475,  0.4223,  0.9005, -0.0851],
         [-0.8634, -0.0847, -1.0104, -0.3023],
         [-0.1281, -1.2047, -0.1815,  0.4536],
         [ 0.2515,  1.0438, -0.5400, -0.6966],
         [-0.3039, -0.7071,  0.4482, -0.2584]]], grad_fn=<TransposeBackward0>)