# Attention with trainable weights

In [138]:
import torch

In [139]:
inputs = torch.nn.Embedding( 4, 8 )

In [140]:
inputs = inputs.weight.data
inputs

tensor([[ 0.5546,  0.2966,  2.0589, -0.9608,  1.1342, -0.3602,  0.3127, -0.3441],
        [-1.8143, -1.9086,  1.5761, -0.0748,  1.0284,  0.8431, -0.3166,  0.4879],
        [-2.3368, -0.8830, -0.5530,  1.1507,  0.5392,  1.2049,  1.1838, -0.0752],
        [-1.7607, -0.9615,  0.3583, -2.0808, -1.0154, -0.7016, -0.2094, -0.8293]])

In [141]:
# set dimension
d_in = 8
d_out = 6

# create weight matrices
w_query = torch.nn.Parameter( torch.randn( d_in, d_out ), requires_grad=False )
w_key = torch.nn.Parameter( torch.randn( d_in, d_out ), requires_grad=False )
w_value = torch.nn.Parameter( torch.randn( d_in, d_out ), requires_grad=False )


In [142]:
# choose and input vector and tranfrom it into our query vector using w_query
query = inputs[2] @ w_query
query

tensor([-1.5475, -5.5059, -4.9801, -8.9749,  1.7650, -2.8583])

In [143]:
# calculate attention scores using the keys generated by w_key
keys = inputs @ w_key
values = inputs @ w_value
print("Values: ", values, "\nKeys: ", keys)

Values:  tensor([[ 4.2913, -0.9399,  2.7514,  1.3161,  2.1229, -5.5799],
        [-4.2806, -4.3206,  6.4974,  4.4086,  0.1625, -4.3899],
        [-0.6404, -1.2269,  1.1270,  3.6036,  0.9804, -0.5700],
        [-4.0006, -2.8639,  6.1219,  1.1201, -6.6592,  4.3211]]) 
Keys:  tensor([[-0.9820, -4.2075, -0.2588, -0.8410,  0.8486, -1.2789],
        [-2.2700,  3.0331, -1.1149,  1.8968,  3.6503, -3.3202],
        [-1.9354,  4.8751, -3.1151,  5.3309,  2.1982, -0.5466],
        [ 2.0072,  3.4482,  6.0506,  1.9756,  0.8990, -3.9215]])


In [144]:
attention_scores = query @ keys.T
attention_scores

tensor([ 38.6757,  -8.7250, -50.7352, -57.1592])

In [145]:
# attention_weights = attention_scores.softmax(dim = -1)

attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
attention_weights

tensor([1.0000e+00, 3.9433e-09, 1.4043e-16, 1.0197e-17])

In [146]:
attention_weights.sum()

tensor(1.)

In [147]:
context_vector = attention_weights @ values
context_vector

tensor([ 4.2913, -0.9399,  2.7514,  1.3161,  2.1229, -5.5799])

In [148]:
import torch.nn as nn

In [149]:
class SimpleAttention( nn.Module ):
  def __init__( self, d_in, d_out ):
    super().__init__()
    # create weight matrices
    self.W_query = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_key = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_value = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward(self, x):
    query = x @ self.W_query
    keys = x @ self.W_key
    values = x @ self.W_value
    attention_scores = query @ keys.T
    weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
    context_vector = weights @ values
    return context_vector



In [150]:
# how to use, instantiate an instance of it
simple = SimpleAttention( d_in=8, d_out=6 )

In [151]:
simple.W_value

Parameter containing:
tensor([[0.3606, 0.4550, 0.8046, 0.3757, 0.7438, 0.2470],
        [0.7580, 0.7419, 0.1634, 0.2725, 0.3278, 0.0818],
        [0.4613, 0.2523, 0.7435, 0.0907, 0.2169, 0.4255],
        [0.3549, 0.2882, 0.0194, 0.5397, 0.7751, 0.6892],
        [0.9210, 0.0218, 0.9647, 0.4730, 0.6370, 0.3097],
        [0.9146, 0.2590, 0.0981, 0.7211, 0.5695, 0.6908],
        [0.4844, 0.3496, 0.7356, 0.2102, 0.4970, 0.6801],
        [0.5970, 0.4330, 0.3571, 0.5969, 0.0214, 0.2182]])

In [152]:
context_vectors = simple(inputs)
context_vectors

tensor([[ 1.2834, -0.0511,  1.6461,  0.3132,  0.5517,  0.9596],
        [ 0.7401, -0.8125, -0.3155,  0.5519,  0.1722,  1.3695],
        [ 0.9074, -0.5208,  0.9765,  0.2446,  0.1340,  0.9078],
        [-4.1104, -2.6600, -2.8465, -3.5393, -4.3281, -2.9177]])

In [153]:
class SimpleAttention2( nn.Module ):
  def __init__( self, d_in, d_out ):
    super().__init__()
    # create weight matrices
    self.W_query = nn.Linear( d_in, d_out, bias=False )
    self.W_key = nn.Linear( d_in, d_out, bias=False )
    self.W_value = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward(self, x):
    query = self.W_query( x )
    keys = self.W_key( x )
    values = self.W_value( x )
    attention_scores = query @ keys.T
    weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
    context_vector = weights @ values
    return context_vector

In [154]:
# how to use, instantiate an instance of it
simple = SimpleAttention2( d_in=8, d_out=6 )


In [155]:
simple.W_value

Linear(in_features=8, out_features=6, bias=False)

In [156]:
context_vectors = simple(inputs)
context_vectors

tensor([[-0.3151, -0.0453,  0.8171, -0.5821,  0.0013, -0.3779],
        [-0.4395,  0.1396,  0.6292, -0.2068,  0.1724, -0.3179],
        [-0.4181,  0.1646,  0.7088, -0.2527,  0.1108, -0.2131],
        [-0.3831,  0.0844,  0.8050, -0.4328,  0.0461, -0.2671]],
       grad_fn=<MmBackward0>)

In [157]:
# the problem wit h this is that each context vector uses infortmation from ALL of the embedding vectors
# om practice, wer should only use the information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [158]:
class SimpleAttention2( nn.Module ):
  def __init__( self, d_in, d_out ):
    super().__init__()
    # create weight matrices
    self.W_query = nn.Linear( d_in, d_out, bias=False )
    self.W_key = nn.Linear( d_in, d_out, bias=False )
    self.W_value = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward(self, x):
    query = self.W_query( x )
    keys = self.W_key( x )
    values = self.W_value( x )
    attention_scores = query @ keys.T
    weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
    context_vector = weights @ values
    return weights

In [159]:
simple = SimpleAttention2( d_in=8, d_out=6 )
weights = simple( inputs )
weights

tensor([[0.2307, 0.2837, 0.2586, 0.2270],
        [0.2975, 0.2660, 0.2691, 0.1674],
        [0.3651, 0.2201, 0.2732, 0.1417],
        [0.3673, 0.2139, 0.2436, 0.1753]], grad_fn=<SoftmaxBackward0>)

In [160]:
#already normalized
weights.sum(dim= -1)

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [161]:
# Masking Method 1
simple_mask = torch.tril( torch.ones(weights.shape[0],weights.shape[0]))
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [162]:
# multiply to get the coordinate by coordinate product, NOT dot product
masked_weights = weights * simple_mask
masked_weights

tensor([[0.2307, 0.0000, 0.0000, 0.0000],
        [0.2975, 0.2660, 0.0000, 0.0000],
        [0.3651, 0.2201, 0.2732, 0.0000],
        [0.3673, 0.2139, 0.2436, 0.1753]], grad_fn=<MulBackward0>)

In [163]:
# now we need to normaliz the masked_weights so that each row has a sum 1
row_sums = masked_weights.sum(dim=-1, keepdim=True)
row_sums

tensor([[0.2307],
        [0.5635],
        [0.8583],
        [1.0000]], grad_fn=<SumBackward1>)

In [164]:
masked_weights_norm = masked_weights / row_sums
print(masked_weights_norm)
print(masked_weights_norm.sum(dim=-1))

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5280, 0.4720, 0.0000, 0.0000],
        [0.4254, 0.2564, 0.3182, 0.0000],
        [0.3673, 0.2139, 0.2436, 0.1753]], grad_fn=<DivBackward0>)
tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)


In [165]:
# Masking method 2
mask = torch.triu(torch.ones(weights.shape[0], weights.shape[0]), diagonal =1)
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [166]:
# can use mask == 1 or mask.bool()
weights_masked = weights.masked_fill(mask == 1, -torch.inf)
weights_masked

tensor([[0.2307,   -inf,   -inf,   -inf],
        [0.2975, 0.2660,   -inf,   -inf],
        [0.3651, 0.2201, 0.2732,   -inf],
        [0.3673, 0.2139, 0.2436, 0.1753]], grad_fn=<MaskedFillBackward0>)

In [170]:
weights_masked_norm = torch.softmax(weights_masked, dim=-1)
weights_masked_norm
print(weights_masked_norm.sum(dim=-1))

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)


In [180]:
## DROPOUT - avoiding overfitting by randomly leaving out data
# idea : randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout(p=0.5) # 50% dropout weight
dropout( inputs )

tensor([[ 0.0000,  0.5933,  0.0000, -0.0000,  2.2683, -0.0000,  0.6253, -0.6882],
        [-0.0000, -3.8171,  0.0000, -0.1497,  2.0568,  0.0000, -0.0000,  0.9758],
        [-4.6736, -0.0000, -0.0000,  2.3013,  1.0784,  2.4098,  2.3676, -0.1504],
        [-3.5214, -0.0000,  0.7165, -4.1616, -0.0000, -1.4032, -0.4187, -1.6586]])

In [178]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack((inputs, inputs), dim=0)
print(batches)
torch.stack?

tensor([[[ 0.5546,  0.2966,  2.0589, -0.9608,  1.1342, -0.3602,  0.3127,
          -0.3441],
         [-1.8143, -1.9086,  1.5761, -0.0748,  1.0284,  0.8431, -0.3166,
           0.4879],
         [-2.3368, -0.8830, -0.5530,  1.1507,  0.5392,  1.2049,  1.1838,
          -0.0752],
         [-1.7607, -0.9615,  0.3583, -2.0808, -1.0154, -0.7016, -0.2094,
          -0.8293]],

        [[ 0.5546,  0.2966,  2.0589, -0.9608,  1.1342, -0.3602,  0.3127,
          -0.3441],
         [-1.8143, -1.9086,  1.5761, -0.0748,  1.0284,  0.8431, -0.3166,
           0.4879],
         [-2.3368, -0.8830, -0.5530,  1.1507,  0.5392,  1.2049,  1.1838,
          -0.0752],
         [-1.7607, -0.9615,  0.3583, -2.0808, -1.0154, -0.7016, -0.2094,
          -0.8293]]])


In [174]:
batches.shape

torch.Size([2, 4, 8])

In [167]:
# this class needs to hand batches of input

class CausalAttention( nn.Module )
  def __init__( self, d_in, d_out, context_length, dropout, qky_bias=False ):
    super().__init__()
    # create weight matrices
    self.W_query = nn.Linear( d_in, d_out, bias=False )
    self.W_key = nn.Linear( d_in, d_out, bias=False )
    self.W_value = nn.Linear( d_in, d_out, bias=False )
    self.dropout = nn.Dropout( dropout )

  # x = embedding vectors (inputs)
  def forward(self, x):
    query = self.W_query( x )
    keys = self.W_key( x )
    values = self.W_value( x )
    attention_scores = query @ keys.T
    weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
    context_vector = weights @ values
    return context_vector