In [121]:
import torch
import torch.nn as nn

In [122]:
vocab_size = 4
output_dimension = 8

inputs = torch.nn.Embedding(vocab_size, output_dimension)

In [123]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[ 0.2026,  0.6488, -0.1370, -1.1637,  0.2137,  0.7658,  1.0017,  0.8374],
        [ 0.9614, -1.6557,  1.0424,  1.2569,  0.5858, -1.1519, -0.8093,  1.0280],
        [ 1.2628,  0.9452,  0.1253,  0.5801,  1.0466,  0.9004,  0.0573,  0.1743],
        [-2.6966, -1.2928,  0.0785,  0.3935,  0.2669, -0.5604, -0.5411,  1.2409]],
       requires_grad=True)

In [124]:
inputs = inputs.data # without 'requires_grad=True'
inputs

tensor([[ 0.2026,  0.6488, -0.1370, -1.1637,  0.2137,  0.7658,  1.0017,  0.8374],
        [ 0.9614, -1.6557,  1.0424,  1.2569,  0.5858, -1.1519, -0.8093,  1.0280],
        [ 1.2628,  0.9452,  0.1253,  0.5801,  1.0466,  0.9004,  0.0573,  0.1743],
        [-2.6966, -1.2928,  0.0785,  0.3935,  0.2669, -0.5604, -0.5411,  1.2409]])

In [None]:
# set dimensions
d_in = 8 # inputs.shape[i]
d_out = 6 # preferred output size

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [146]:
# choose an input vector and transform it into our query vector using W_q
# Note that the output has the preferred size
query = inputs @ W_q
query

tensor([[ 1.0851,  1.5374,  1.3586,  2.0521,  1.4855,  0.9387],
        [ 0.4557, -1.5418,  0.3188, -0.5352, -0.7327,  2.3348],
        [ 2.5883,  2.8919,  1.0787,  2.9465,  2.2346,  2.2001],
        [-0.9973, -1.9330, -0.2969, -3.5545, -1.6691, -0.5671]])

In [None]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[ 1.5925,  1.7724,  0.5706,  1.6434,  0.9033,  0.3941],
        [ 0.5997, -0.8042,  1.0577,  1.7515,  0.8616,  0.8988],
        [ 1.8331,  3.2341,  2.8132,  3.5662,  2.3614,  2.5543],
        [-0.0117, -2.0766, -2.5004, -2.0864, -2.1412, -1.2068]])
Values: tensor([[ 2.4172,  0.6776,  0.2720,  2.0063,  0.8952,  1.5085],
        [-0.8929,  0.8504,  1.5221,  0.2792,  1.1410, -0.2061],
        [ 3.2835,  3.7520,  1.7643,  1.8684,  1.3785,  3.6897],
        [-3.6619, -2.4133, -0.0285, -1.2269,  0.2540, -2.1437]])


In [None]:
# attention scores are how important a token is. Ex. when translating word by word,
# the word being translated gets the highest att score. the others get scores too but not as big as the meant word
attention_scores = query @ keys.T
attention_scores

tensor([ 17.5908,   9.4310,  38.5359, -22.3202])

Attention weights are different as they refer to the input and other inputs with respect to the highest attention (query) score input
Ex. a23 --> a is the name of the weight, 2 is the highest score, and 3 is the position of the input with respect to the highest, 2.

In [147]:
# dividing by the root of shape of a row in keys to keep them in a reasonable range
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([1.9335e-04, 6.9123e-06, 9.9980e-01, 1.6223e-11])

In [148]:
attention_weights.sum()

tensor(1.)

In [131]:
context_vector = attention_weights @ values
context_vector

tensor([3.2833, 3.7514, 1.7641, 1.8684, 1.3784, 3.6893])

In [None]:
# here's a first version of a SimpleAttention class:

class SimpleAttention(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v

    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    
    return context

In [133]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [134]:
simple.W_v

Parameter containing:
tensor([[0.3949, 0.1490, 0.5330, 0.4385, 0.1507, 0.6974],
        [0.0810, 0.6689, 0.7093, 0.0489, 0.9008, 0.6195],
        [0.2792, 0.3700, 0.3167, 0.5520, 0.5143, 0.1974],
        [0.2238, 0.9672, 0.4168, 0.9716, 0.1322, 0.2362],
        [0.3618, 0.5748, 0.1897, 0.6266, 0.5694, 0.9321],
        [0.5822, 0.4235, 0.4769, 0.4325, 0.1058, 0.6787],
        [0.1098, 0.5989, 0.8953, 0.7697, 0.6388, 0.4676],
        [0.4718, 0.9047, 0.1847, 0.8678, 0.2331, 0.6411]])

In [135]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 1.5701,  2.3257,  1.9721,  2.3300,  1.6557,  2.9421],
        [ 1.7132,  2.5713,  2.3164,  2.4417,  1.9381,  3.3222],
        [ 1.7311,  2.6024,  2.3361,  2.4728,  1.9512,  3.3526],
        [-0.7629, -0.1420, -2.6364, -0.2343, -1.4419, -2.1611]])

In [149]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently and gives better training results

class SimpleAttentionV2( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )

    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    
    return context

In [150]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttentionV2( d_in = 8, d_out = 6 )

In [151]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.1999, -0.0027,  0.3246,  0.2052,  0.2975, -0.1162],
        [ 0.0551,  0.0104,  0.1952,  0.0808,  0.2693, -0.1608],
        [ 0.1449,  0.0077,  0.2736,  0.1402,  0.2850, -0.1151],
        [ 0.0479,  0.0077,  0.1920,  0.0908,  0.2710, -0.1757]],
       grad_fn=<MmBackward0>)

- The problem with this is that each context vector uses information from ALL of the embedding vectors
- In practice, we should only use information about the preceding embedding vectors
- To accomplish this, we'll implement causal attention AKA masked attention
- It briefly means hiding future words

In [None]:
# this is a hack to get some example weights to work with!
# weights = simple( inputs ) # The hack didn't work :(

queries = simple.W_q(inputs)
keys = simple.W_k(inputs)
values = simple.W_v(inputs)

scores = queries @ keys.T
weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )

weights

tensor([[0.2021, 0.2944, 0.2535, 0.2500],
        [0.3036, 0.2168, 0.2889, 0.1907],
        [0.2446, 0.2528, 0.2562, 0.2464],
        [0.3020, 0.2217, 0.3018, 0.1745]], grad_fn=<SoftmaxBackward0>)

In [154]:
# note that these have already been normalized:
weights.sum( dim=-1 )

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [155]:
# masking method #1
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) ) #Triangular mask
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [156]:
masked_weights = weights * simple_mask
masked_weights

tensor([[0.2021, 0.0000, 0.0000, 0.0000],
        [0.3036, 0.2168, 0.0000, 0.0000],
        [0.2446, 0.2528, 0.2562, 0.0000],
        [0.3020, 0.2217, 0.3018, 0.1745]], grad_fn=<MulBackward0>)

In [157]:
masked_weights.sum( dim=-1 )

tensor([0.2021, 0.5204, 0.7536, 1.0000], grad_fn=<SumBackward1>)

In [158]:
# now, we need to normalize the masked_weights so that each row has sum 1 as it is good for optimization
# What this code does -> simple_mask / row_sums
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

tensor([[0.2021],
        [0.5204],
        [0.7536],
        [1.0000]], grad_fn=<SumBackward1>)

In [159]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [160]:
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5834, 0.4166, 0.0000, 0.0000],
        [0.3245, 0.3355, 0.3400, 0.0000],
        [0.3020, 0.2217, 0.3018, 0.1745]], grad_fn=<DivBackward0>)

In [161]:
# masking method #2
# This way scores -> mask -> soft max
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 ) 
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [162]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [163]:
weights

tensor([[0.2021, 0.2944, 0.2535, 0.2500],
        [0.3036, 0.2168, 0.2889, 0.1907],
        [0.2446, 0.2528, 0.2562, 0.2464],
        [0.3020, 0.2217, 0.3018, 0.1745]], grad_fn=<SoftmaxBackward0>)

In [None]:
# We masked the values first by hiding future values with -infinity
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

tensor([[0.2021,   -inf,   -inf,   -inf],
        [0.3036, 0.2168,   -inf,   -inf],
        [0.2446, 0.2528, 0.2562,   -inf],
        [0.3020, 0.2217, 0.3018, 0.1745]], grad_fn=<MaskedFillBackward0>)

In [None]:
# Now, every row sums up to 1
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5217, 0.4783, 0.0000, 0.0000],
        [0.3311, 0.3339, 0.3350, 0.0000],
        [0.2630, 0.2427, 0.2629, 0.2315]], grad_fn=<SoftmaxBackward0>)

In [None]:
## Dropout Mask
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 ) # 50%

In [169]:
dropout( torch.ones(6,6) )

tensor([[0., 0., 2., 2., 2., 0.],
        [2., 0., 2., 0., 2., 2.],
        [2., 2., 2., 0., 2., 2.],
        [2., 0., 0., 2., 0., 0.],
        [2., 2., 0., 0., 2., 2.],
        [2., 0., 2., 0., 2., 2.]])

In [None]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0) #stack 2 inputs on top of eachother

In [171]:
batches

tensor([[[ 0.2026,  0.6488, -0.1370, -1.1637,  0.2137,  0.7658,  1.0017,
           0.8374],
         [ 0.9614, -1.6557,  1.0424,  1.2569,  0.5858, -1.1519, -0.8093,
           1.0280],
         [ 1.2628,  0.9452,  0.1253,  0.5801,  1.0466,  0.9004,  0.0573,
           0.1743],
         [-2.6966, -1.2928,  0.0785,  0.3935,  0.2669, -0.5604, -0.5411,
           1.2409]],

        [[ 0.2026,  0.6488, -0.1370, -1.1637,  0.2137,  0.7658,  1.0017,
           0.8374],
         [ 0.9614, -1.6557,  1.0424,  1.2569,  0.5858, -1.1519, -0.8093,
           1.0280],
         [ 1.2628,  0.9452,  0.1253,  0.5801,  1.0466,  0.9004,  0.0573,
           0.1743],
         [-2.6966, -1.2928,  0.0785,  0.3935,  0.2669, -0.5604, -0.5411,
           1.2409]]])

In [None]:
batches.shape
# The output means 2 inputs, each tensor has 4 tokens, and each token is 8-dim vector

torch.Size([2, 4, 8])

In [174]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    
    # include dropout:
    self.dropout = nn.Dropout( dropout )
    
    # use the following to manage memory efficiently
    # When passing this to a GPU, it's better because GPUs dont read tensors
    self.register_buffer(
        'mask',
        torch.triu( torch.ones(context_length, context_length), diagonal = 1 )
    )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    b, num_tokens, d_in = x.shape # b is batch size (num inputs)
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )

    scores = queries @ keys.transpose(1,2)
    scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # inplace operation for better efficiency
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    weights = self.dropout( weights )
    
    context = weights @ values
    return context

In [175]:
# instantiate a causal attention mechanism:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0 )

In [176]:
causal( batches )

tensor([[[ 0.4950, -0.2449,  0.1706, -0.5732,  0.5883, -0.2651],
         [-0.5732,  0.2960, -0.0670, -0.2901,  0.6681, -0.1605],
         [-0.4589,  0.3969, -0.0690, -0.2954,  0.4442, -0.0536],
         [-0.5587,  0.3625, -0.0361, -0.1037,  0.2258, -0.0034]],

        [[ 0.4950, -0.2449,  0.1706, -0.5732,  0.5883, -0.2651],
         [-0.5732,  0.2960, -0.0670, -0.2901,  0.6681, -0.1605],
         [-0.4589,  0.3969, -0.0690, -0.2954,  0.4442, -0.0536],
         [-0.5587,  0.3625, -0.0361, -0.1037,  0.2258, -0.0034]]],
       grad_fn=<UnsafeViewBackward0>)

In [177]:
# everything below is just to show what happens with batches

queries = W_q( batches )
queries

TypeError: 'Parameter' object is not callable

In [143]:
keys = W_k( batches )
keys

TypeError: 'Parameter' object is not callable

In [144]:
keys.transpose(1,2)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)