# Attention with trainable weights

In [591]:
import torch

In [592]:
inputs = torch.nn.Embedding( 4, 8 )

In [593]:
inputs = inputs.weight.data
inputs

tensor([[ 1.0533,  0.1388, -0.2044, -2.2685, -0.9133, -0.4204,  1.3111, -0.2199],
        [ 0.1838,  0.2293,  0.6177, -0.2876,  0.8218,  0.1512,  0.1036, -2.1996],
        [-2.3229,  1.0878, -0.0635, -0.4486, -1.2785, -1.1440,  0.2436, -0.0567],
        [ 0.4403, -1.4465, -0.5581, -0.0517, -0.9083,  0.3507,  1.5434,  0.1406]])

In [594]:
# set dimension
d_in = 8
d_out = 6

# create weight matrices
w_query = torch.nn.Parameter( torch.randn( d_in, d_out ), requires_grad=False )
w_key = torch.nn.Parameter( torch.randn( d_in, d_out ), requires_grad=False )
w_value = torch.nn.Parameter( torch.randn( d_in, d_out ), requires_grad=False )


In [595]:
# choose and input vector and tranfrom it into our query vector using w_query
query = inputs[2] @ w_query
query

tensor([ 1.4048,  5.5231, -2.0989, -1.6392, -2.1866,  2.1391])

In [596]:
# calculate attention scores using the keys generated by w_key
keys = inputs @ w_key
values = inputs @ w_value
print("Values: ", values, "\nKeys: ", keys)

Values:  tensor([[ 4.0310, -0.1499,  1.7580,  3.0791, -1.6066, -0.5806],
        [ 3.8501,  1.1053, -2.9978, -2.0425,  0.4336, -0.8436],
        [-3.7869, -1.9328,  1.8063,  0.6338, -3.7921, -1.2053],
        [ 1.7406,  1.1961, -0.9290,  0.2997, -0.7902,  2.6447]]) 
Keys:  tensor([[ 2.6992, -3.4495, -4.9359, -5.9429, -1.5600,  1.5451],
        [ 0.7882, -5.2803,  0.1594,  0.2066,  1.0121, -0.5606],
        [ 2.9669,  0.1861,  1.8968, -1.0446, -5.5319,  6.1505],
        [-0.7256, -0.5183, -0.9423, -2.1783,  2.1761,  2.1374]])


In [597]:
attention_scores = query @ keys.T
attention_scores

tensor([ 11.5579, -32.1414,  28.1797,   1.4804])

In [598]:
# attention_weights = attention_scores.softmax(dim = -1)

attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
attention_weights

tensor([1.1283e-03, 2.0163e-11, 9.9885e-01, 1.8438e-05])

In [599]:
attention_weights.sum()

tensor(1.)

In [600]:
context_vector = attention_weights @ values
context_vector

tensor([-3.7779, -1.9307,  1.8062,  0.6366, -3.7895, -1.2045])

In [601]:
import torch.nn as nn

In [602]:
class SimpleAttention( nn.Module ):
  def __init__( self, d_in, d_out ):
    super().__init__()
    # create weight matrices
    self.W_query = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_key = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_value = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward(self, x):
    query = x @ self.W_query
    keys = x @ self.W_key
    values = x @ self.W_value
    attention_scores = query @ keys.T
    weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
    context_vector = weights @ values
    return context_vector



In [603]:
# how to use, instantiate an instance of it
simple = SimpleAttention( d_in=8, d_out=6 )

In [604]:
simple.W_value

Parameter containing:
tensor([[0.7355, 0.6248, 0.1638, 0.5158, 0.6000, 0.2299],
        [0.2890, 0.9078, 0.4596, 0.4947, 0.1836, 0.2010],
        [0.9603, 0.6861, 0.4209, 0.8046, 0.2621, 0.0638],
        [0.0036, 0.7032, 0.3051, 0.8070, 0.9271, 0.6647],
        [0.9296, 0.3848, 0.9357, 0.2616, 0.4344, 0.8323],
        [0.2410, 0.8815, 0.6226, 0.4902, 0.9279, 0.8751],
        [0.2943, 0.5485, 0.5583, 0.9096, 0.7810, 0.9049],
        [0.8048, 0.0649, 0.8322, 0.3672, 0.9012, 0.8146]])

In [605]:
context_vectors = simple(inputs)
context_vectors

tensor([[-2.8343, -2.1591, -1.8431, -1.7378, -3.0583, -2.4742],
        [-0.8974, -0.6558, -0.8224, -0.1684, -0.4084, -0.3660],
        [-2.8581, -2.1767, -1.8535, -1.7533, -3.0820, -2.4929],
        [-1.7793, -1.3494, -1.4022, -1.0888, -2.1340, -1.7414]])

In [606]:
class SimpleAttention2( nn.Module ):
  def __init__( self, d_in, d_out ):
    super().__init__()
    # create weight matrices
    self.W_query = nn.Linear( d_in, d_out, bias=False )
    self.W_key = nn.Linear( d_in, d_out, bias=False )
    self.W_value = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward(self, x):
    query = self.W_query( x )
    keys = self.W_key( x )
    values = self.W_value( x )
    attention_scores = query @ keys.T
    weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
    context_vector = weights @ values
    return context_vector

In [607]:
# how to use, instantiate an instance of it
simple = SimpleAttention2( d_in=8, d_out=6 )


In [608]:
simple.W_value

Linear(in_features=8, out_features=6, bias=False)

In [609]:
context_vectors = simple(inputs)
context_vectors

tensor([[ 5.6137e-01,  2.3243e-01,  1.2150e-01, -1.9756e-01,  4.7015e-01,
          1.5979e-01],
        [ 6.0625e-01,  2.6822e-01,  1.9601e-01, -2.0487e-01,  3.3184e-01,
          7.4284e-02],
        [ 4.7718e-01,  1.8690e-01, -3.9953e-04, -1.8236e-01,  5.9303e-01,
          2.5653e-01],
        [ 4.2967e-01,  1.5835e-01,  3.3162e-03, -1.5923e-01,  4.1212e-01,
          1.4546e-01]], grad_fn=<MmBackward0>)

In [610]:
# the problem wit h this is that each context vector uses infortmation from ALL of the embedding vectors
# om practice, wer should only use the information about the preceding embedding vectors
# to accomplish this, we'll implement causal attention AKA masked attention

In [611]:
class SimpleAttention2( nn.Module ):
  def __init__( self, d_in, d_out ):
    super().__init__()
    # create weight matrices
    self.W_query = nn.Linear( d_in, d_out, bias=False )
    self.W_key = nn.Linear( d_in, d_out, bias=False )
    self.W_value = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward(self, x):
    query = self.W_query( x )
    keys = self.W_key( x )
    values = self.W_value( x )
    attention_scores = query @ keys.T
    weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
    context_vector = weights @ values
    return weights

In [612]:
simple = SimpleAttention2( d_in=8, d_out=6 )
weights = simple( inputs )
weights

tensor([[0.2708, 0.2100, 0.2373, 0.2819],
        [0.2025, 0.2152, 0.3724, 0.2099],
        [0.2983, 0.2595, 0.1685, 0.2737],
        [0.2575, 0.1819, 0.3239, 0.2367]], grad_fn=<SoftmaxBackward0>)

In [613]:
#already normalized
weights.sum(dim= -1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

In [614]:
# Masking Method 1
simple_mask = torch.tril( torch.ones(weights.shape[0],weights.shape[0]))
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [615]:
# multiply to get the coordinate by coordinate product, NOT dot product
masked_weights = weights * simple_mask
masked_weights

tensor([[0.2708, 0.0000, 0.0000, 0.0000],
        [0.2025, 0.2152, 0.0000, 0.0000],
        [0.2983, 0.2595, 0.1685, 0.0000],
        [0.2575, 0.1819, 0.3239, 0.2367]], grad_fn=<MulBackward0>)

In [616]:
# now we need to normaliz the masked_weights so that each row has a sum 1
row_sums = masked_weights.sum(dim=-1, keepdim=True)
row_sums

tensor([[0.2708],
        [0.4177],
        [0.7263],
        [1.0000]], grad_fn=<SumBackward1>)

In [617]:
masked_weights_norm = masked_weights / row_sums
print(masked_weights_norm)
print(masked_weights_norm.sum(dim=-1))

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4848, 0.5152, 0.0000, 0.0000],
        [0.4107, 0.3573, 0.2320, 0.0000],
        [0.2575, 0.1819, 0.3239, 0.2367]], grad_fn=<DivBackward0>)
tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)


In [618]:
# Masking method 2
mask = torch.triu(torch.ones(weights.shape[0], weights.shape[0]), diagonal =1)
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [619]:
# can use mask == 1 or mask.bool()
weights_masked = weights.masked_fill(mask == 1, -torch.inf)
weights_masked

tensor([[0.2708,   -inf,   -inf,   -inf],
        [0.2025, 0.2152,   -inf,   -inf],
        [0.2983, 0.2595, 0.1685,   -inf],
        [0.2575, 0.1819, 0.3239, 0.2367]], grad_fn=<MaskedFillBackward0>)

In [620]:
weights_masked_norm = torch.softmax(weights_masked, dim=-1)
weights_masked_norm
print(weights_masked_norm.sum(dim=-1))

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)


In [621]:
## DROPOUT - avoiding overfitting by randomly leaving out data
# idea : randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout(p=0.5) # 50% dropout weight
dropout( inputs )

tensor([[ 0.0000,  0.2776, -0.0000, -4.5371, -0.0000, -0.8407,  0.0000, -0.4399],
        [ 0.3677,  0.4587,  0.0000, -0.5752,  0.0000,  0.0000,  0.2073, -0.0000],
        [-0.0000,  0.0000, -0.1271, -0.8973, -0.0000, -2.2880,  0.0000, -0.1135],
        [ 0.0000, -0.0000, -0.0000, -0.0000, -0.0000,  0.7014,  0.0000,  0.2812]])

In [622]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack((inputs, inputs), dim=0)
print(batches)
# torch.stack?

tensor([[[ 1.0533,  0.1388, -0.2044, -2.2685, -0.9133, -0.4204,  1.3111,
          -0.2199],
         [ 0.1838,  0.2293,  0.6177, -0.2876,  0.8218,  0.1512,  0.1036,
          -2.1996],
         [-2.3229,  1.0878, -0.0635, -0.4486, -1.2785, -1.1440,  0.2436,
          -0.0567],
         [ 0.4403, -1.4465, -0.5581, -0.0517, -0.9083,  0.3507,  1.5434,
           0.1406]],

        [[ 1.0533,  0.1388, -0.2044, -2.2685, -0.9133, -0.4204,  1.3111,
          -0.2199],
         [ 0.1838,  0.2293,  0.6177, -0.2876,  0.8218,  0.1512,  0.1036,
          -2.1996],
         [-2.3229,  1.0878, -0.0635, -0.4486, -1.2785, -1.1440,  0.2436,
          -0.0567],
         [ 0.4403, -1.4465, -0.5581, -0.0517, -0.9083,  0.3507,  1.5434,
           0.1406]]])


In [623]:
batches.shape

torch.Size([2, 4, 8])

In [624]:
# # this class needs to hand batches of input

# class CausalAttention( nn.Module ) :
#   def __init__( self, d_in, d_out, context_length, dropout, qky_bias=False ):
#     super().__init__()
#     # create weight matrices
#     self.W_query = nn.Linear( d_in, d_out, bias=False )
#     self.W_key = nn.Linear( d_in, d_out, bias=False )
#     self.W_value = nn.Linear( d_in, d_out, bias=False )
#     self.dropout = nn.Dropout( dropout )

#   # x = embedding vectors (inputs)
#   def forward(self, x):
#     query = self.W_query( x )
#     keys = self.W_key( x )
#     values = self.W_value( x )
#     attention_scores = query @ keys.T
#     weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1 )
#     context_vector = weights @ values
#     return context_vector

In [625]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    # include dropout:
    self.dropout = nn.Dropout( dropout )
    # use the following to manage memory efficiently:
    self.register_buffer(
        'mask',
        torch.triu( torch.ones(context_length, context_length), diagonal = 1 )
    )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    b, num_tokens, d_in = x.shape
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )
    scores = queries @ keys.transpose(1,2)
    scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    weights = self.dropout( weights )
    context = weights @ values
    return context

In [626]:
# instantiate a causal attention mechanism:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0 )

In [627]:
causal( batches )

tensor([[[ 0.1295,  0.7283, -1.0537,  0.5340, -0.4069, -1.0032],
         [-0.4305,  0.0564, -0.2973,  0.0053, -0.0448,  0.0306],
         [ 0.3387,  0.0275, -0.2099, -0.1538, -0.3687, -0.1400],
         [ 0.0866,  0.2247, -0.1328,  0.0500, -0.2527, -0.2858]],

        [[ 0.1295,  0.7283, -1.0537,  0.5340, -0.4069, -1.0032],
         [-0.4305,  0.0564, -0.2973,  0.0053, -0.0448,  0.0306],
         [ 0.3387,  0.0275, -0.2099, -0.1538, -0.3687, -0.1400],
         [ 0.0866,  0.2247, -0.1328,  0.0500, -0.2527, -0.2858]]],
       grad_fn=<UnsafeViewBackward0>)

In [628]:
# heres the first pass of multiheaded attention
class MultiHeadAtttention ( nn.Module ):
  def __init__( self, d_in, d_out, num_heads, context_length, dropout, qkv_bias=False ):
    super().__init__()
    self.heads = nn.ModuleList(
      [ CausalAttention( d_in, d_out, context_length, dropout, qkv_bias ) for _ in range( num_heads ) ]
    )
  def forward( self, x ):
    return torch.cat( [head(x) for head in self.heads] , dim=-1 )

In [629]:
mha = MultiHeadAtttention( d_in=8, d_out=6, num_heads=4, context_length=4, dropout=0 )

In [630]:
mha_out = mha( batches )
mha_out

tensor([[[ 5.5415e-02,  6.9044e-01, -5.0947e-01, -1.0197e-01,  1.0220e+00,
           2.3957e-01, -6.1575e-01, -5.1908e-01,  8.1423e-01,  2.5077e-01,
           8.6917e-01, -1.4233e-01, -2.9188e-01, -4.5756e-01, -6.3211e-01,
           1.0346e-01, -1.0247e-01, -1.1738e+00, -1.2562e-01,  1.1179e+00,
          -4.0048e-01,  1.8521e-01,  4.7765e-01,  9.9035e-01],
         [-4.4370e-01,  5.9249e-02, -6.1015e-01, -1.6997e-01,  5.8622e-01,
           2.2276e-01, -3.8214e-01, -1.0437e-01,  6.9650e-01,  3.9591e-02,
           7.8057e-01,  2.3940e-01,  1.9300e-01,  5.0624e-02, -6.4491e-01,
          -1.1939e-01, -1.3995e-01, -8.7192e-01,  7.0109e-02,  4.4353e-01,
          -4.2574e-03, -4.8048e-01,  1.8570e-01,  1.6159e-01],
         [-2.0398e-01,  3.1655e-01, -7.3651e-01,  6.3688e-02,  6.5674e-01,
           2.3183e-01, -1.5374e-01, -1.7352e-01,  2.4260e-01, -1.6717e-01,
           6.4879e-01,  2.7125e-02, -1.8632e-01,  2.0006e-01, -5.4526e-01,
           1.2717e-01,  1.0821e-02, -2.6455e-01, 

In [631]:
mha_out.shape

torch.Size([2, 4, 24])

In [632]:
# copited from LLMs-from-scratch/ch03/
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`,
        # this will result in errors in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batches.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batches)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.0942,  0.1862],
         [-0.0618,  0.1937],
         [ 0.0984,  0.2526],
         [ 0.0100,  0.2169]],

        [[-0.0942,  0.1862],
         [-0.0618,  0.1937],
         [ 0.0984,  0.2526],
         [ 0.0100,  0.2169]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 4, 2])


In [634]:
mha = MultiHeadAtttention( d_in=8, d_out=6, num_heads=4, context_length=4, dropout=0 )
mha_out = mha( batches )
mha_out

tensor([[[-1.1008e-01,  9.3279e-01, -2.2153e-01,  1.2646e-01, -1.0283e-01,
           6.7520e-01, -9.1997e-02, -1.0719e+00, -2.6806e-02,  9.1233e-01,
           8.4830e-01,  9.2786e-01, -9.3980e-01,  1.2979e+00, -6.6498e-01,
           9.9515e-02, -6.2438e-01,  9.0697e-01, -1.1210e-01,  6.1818e-02,
          -7.8646e-01, -8.2610e-03, -4.8232e-01, -4.9239e-01],
         [ 1.9144e-01,  8.6248e-01, -4.4032e-01, -5.1212e-02, -4.1459e-01,
           4.7685e-01,  1.2396e-01, -8.5839e-01, -3.5714e-01,  8.1635e-01,
           1.3785e-01,  7.1400e-02, -8.3076e-01,  7.3520e-01, -3.1956e-01,
           3.1299e-01, -3.1264e-01,  9.7956e-01, -2.3090e-02, -3.4840e-01,
          -5.3236e-01,  1.4260e-01, -2.9180e-01, -7.3608e-01],
         [ 2.5967e-01,  2.7055e-01, -5.1644e-01,  2.4873e-01, -3.0062e-01,
          -1.2209e-02,  3.3210e-01, -5.3176e-01, -3.5665e-01,  5.2701e-01,
           4.6658e-01,  5.9160e-02, -6.5873e-01,  4.9429e-01, -2.2862e-03,
           3.1935e-01,  7.0749e-04,  5.8875e-01, 

In [635]:
mha_out.shape

torch.Size([2, 4, 24])