# Attention with trainable weights

In [1]:
import torch

In [2]:
inputs = torch.nn.Embedding( 4, 8 )

In [3]:
inputs = inputs.weight.data
inputs

tensor([[-1.1409,  1.9904, -1.5029, -0.3719, -0.0094,  1.0202, -1.4527, -0.1686],
        [ 0.3157,  0.1487, -0.5505, -0.8596,  0.7156,  0.9594, -0.3960, -0.5231],
        [ 0.5807,  2.9491, -0.3287,  0.3263, -0.4029,  0.9054, -0.0462,  0.0200],
        [ 0.2366, -0.6871, -1.3611,  0.1971,  0.0978,  1.2356, -1.6434,  2.7711]])

In [4]:
# set dimension
d_in = 8
d_out = 6

# create weight matrices
w_query = torch.nn.Parameter( torch.randn( d_in, d_out ), requires_grad=False )
w_key = torch.nn.Parameter( torch.randn( d_in, d_out ), requires_grad=False )
w_value = torch.nn.Parameter( torch.randn( d_in, d_out ), requires_grad=False )


In [5]:
# choose and input vector and tranfrom it into our query vector using w_query
query = inputs[2] @ w_query
query

tensor([-3.3555,  0.5383,  6.5373,  3.3256, -4.0626, -3.2256])

In [6]:
# calculate attention scores using the keys generated by w_key
keys = inputs @ w_key
values = inputs @ w_value
print("Values: ", values, "\nKeys: ", keys)

Values:  tensor([[ 4.3321,  1.0042,  0.9976, -1.1313, -1.1402, -6.5772],
        [-0.5589, -1.1667, -2.6606,  0.4368, -0.1820, -0.7293],
        [ 5.1428,  3.3527,  4.1686, -3.2166, -4.1008, -8.9574],
        [-4.3656, -2.6636, -5.4837,  0.0429, -2.9598, -2.6025]]) 
Keys:  tensor([[-1.6257, -3.9317,  0.7758,  0.7769,  1.5306,  3.7090],
        [-3.3261, -0.4287,  1.5865, -1.3945,  0.7230,  0.8282],
        [-2.8306, -1.7201, -2.0632, -3.3845,  0.9093,  4.2520],
        [ 0.8098, -6.1397,  3.2592, -4.1501,  5.3712, -0.7537]])


In [7]:
attention_scores = query @ keys.T
attention_scores

tensor([ -7.1880,  11.0549, -33.5805, -17.9073])

In [8]:
# attention_weights = attention_scores.softmax(dim = -1)

attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
attention_weights

tensor([5.8249e-04, 9.9941e-01, 1.2187e-08, 7.3239e-06])

In [9]:
attention_weights.sum()

tensor(1.0000)

In [10]:
context_vector = attention_weights @ values
context_vector

tensor([-0.5561, -1.1655, -2.6585,  0.4359, -0.1826, -0.7327])

In [11]:
import torch.nn as nn

In [13]:
class SimpleAttention( nn.Module ):
  def __init__( self, d_in, d_out ):
    super().__init__()
    # create weight matrices
    self.W_query = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_key = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_value = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward(self, x):
    query = x @ self.W_query
    keys = x @ self.W_key
    values = x @ self.W_value
    attention_scores = query @ keys.T
    weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
    context_vector = weights @ values
    return context_vector



In [14]:
# how to use, instantiate an instance of it
simple = SimpleAttention( d_in=8, d_out=6 )

In [15]:
simple.W_value

Parameter containing:
tensor([[0.5402, 0.2507, 0.2656, 0.3683, 0.5517, 0.7879],
        [0.1108, 0.5651, 0.4232, 0.5728, 0.5119, 0.9222],
        [0.4943, 0.6898, 0.2489, 0.6047, 0.7045, 0.3921],
        [0.3669, 0.6006, 0.3672, 0.3818, 0.1152, 0.5886],
        [0.5593, 0.9466, 0.1252, 0.1099, 0.1236, 0.2779],
        [0.1003, 0.2716, 0.3923, 0.8284, 0.3606, 0.8240],
        [0.1995, 0.4309, 0.6025, 0.8863, 0.3609, 0.4383],
        [0.0184, 0.0788, 0.0637, 0.3206, 0.3667, 0.6317]])

In [16]:
context_vectors = simple(inputs)
context_vectors

tensor([[-1.3229, -0.7206, -0.4259, -0.7522, -0.8451,  0.2374],
        [-1.1290, -0.6293, -0.3639, -0.5985, -0.6762,  0.3977],
        [ 0.4543,  1.6270,  1.7179,  2.4994,  1.9029,  3.8663],
        [ 0.4268,  1.5556,  1.6526,  2.4177,  1.8440,  3.7889]])

In [23]:
class SimpleAttention2( nn.Module ):
  def __init__( self, d_in, d_out ):
    super().__init__()
    # create weight matrices
    self.W_query = nn.Linear( d_in, d_out, bias=False )
    self.W_key = nn.Linear( d_in, d_out, bias=False )
    self.W_value = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward(self, x):
    query = self.W_query( x )
    keys = self.W_key( x )
    values = self.W_value( x )
    attention_scores = query @ keys.T
    weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
    context_vector = weights @ values
    return context_vector

In [24]:
# how to use, instantiate an instance of it
simple = SimpleAttention2( d_in=8, d_out=6 )


In [25]:
simple.W_value

Linear(in_features=8, out_features=6, bias=False)

In [26]:
context_vectors = simple(inputs)
context_vectors

tensor([[-4.0908e-02,  6.2397e-01,  6.9920e-02,  1.1415e-01,  4.8267e-01,
          3.2548e-01],
        [-5.7384e-04,  7.1514e-01,  1.0783e-01,  1.4399e-01,  4.8921e-01,
          2.5816e-01],
        [-3.9691e-02,  6.3777e-01,  7.5006e-02,  1.2596e-01,  4.7954e-01,
          3.0081e-01],
        [-2.6508e-02,  6.4697e-01,  8.1854e-02,  1.0655e-01,  4.9552e-01,
          3.1444e-01]], grad_fn=<MmBackward0>)