## Attention With Trainable Weights

In [1]:
import torch

In [2]:
inputs = torch.nn.Embedding(4,8)

In [3]:
inputs = inputs.weight.data
inputs

tensor([[-1.4949, -0.0219,  0.2485, -0.0637, -1.1915, -0.7459,  0.5395,  0.1476],
        [ 0.3478, -1.0369,  0.5293,  0.3848,  1.0521,  1.0733, -1.3499, -0.5940],
        [ 0.0870, -1.2924, -2.5694, -0.5932, -0.5718, -0.7813, -0.6482,  0.9455],
        [-1.2236, -1.0342, -0.4716, -0.5929, -1.4903,  0.9641,  0.4968,  2.9186]])

In [4]:
# Set dimensions
d_in = 8
d_out = 6
# Create Weight Matricies query, key, value with random entries
W_q = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_k = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_v = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [5]:
#choose an input vector (we're using 2) and transform it into our query vector using using W_q
query = inputs[2] @ W_q
query


tensor([-1.9723, -3.9607, -3.7837, -0.6660, -2.7144, -1.9260])

In [6]:
#Calculate keys and values matrices
keys = inputs @ W_k
values = inputs @ W_v
print(keys)
print(values)

tensor([[-2.6288, -1.2798, -1.2445, -1.5749, -1.2617, -0.9750],
        [ 1.5251, -0.1952, -0.3732,  0.5101, -0.5340, -0.6139],
        [-2.4367, -2.9807, -4.0509, -2.6120, -2.5843, -2.1185],
        [-0.4189,  0.6339, -0.4167, -0.0466, -1.4462,  0.3206]])
tensor([[-0.5259, -1.3863, -1.9953, -1.5142, -1.3107, -1.9609],
        [-0.2136,  0.2361, -0.2488, -0.4069,  0.7552, -0.0896],
        [-1.9484, -3.7887, -2.3165, -1.4110, -1.6432, -1.5987],
        [ 0.3877, -0.1307, -2.5233,  0.8071, -0.5758, -1.9153]])


In [7]:
attention_scores = query @ keys.T
attention_scores

tensor([21.3140,  1.4697, 44.7730,  3.2313])

In [8]:
#normalize the scores to get the weights
# denominator is just because that's what the researchers figured out worked best for training 
attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim= -1)
attention_weights

tensor([6.9292e-05, 2.1003e-08, 9.9993e-01, 4.3115e-08])

In [9]:
#Make sure you did it right
attention_weights.sum()

tensor(1.)

In [10]:
# make the context vector
context_vector = attention_weights @ values
context_vector

tensor([-1.9483, -3.7885, -2.3165, -1.4110, -1.6431, -1.5988])

## Making the Simple Attention Class Using Parameter 

In [11]:
import torch.nn as nn

In [12]:
#Create a simple attention class v1
class SimpleAttention(nn.Module):
    #Constructor, initialize those matrix dimensions
    def __init__(self, d_in, d_out):
        super().__init__()
        #Create weight matrices
        self.W_q = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_k = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
        self.W_v = nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

    #   x = embedding vectors (inputs)
    def forward(self,x):
        #Make the queries, keys and values
        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v
        #Compute scores, then normalize into weights
        scores = queries @ keys.T
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        #Compute context vector using weights and values
        context = weights @ values
        return context


In [13]:
#Using the class example
#instatiate an instance
simple = SimpleAttention(d_in = 8, d_out = 6)


In [14]:
#It has attributes!
simple.W_k

Parameter containing:
tensor([[0.4454, 0.6886, 0.3427, 0.4631, 0.9360, 0.8681],
        [0.0611, 0.6497, 0.2175, 0.5279, 0.8336, 0.2066],
        [0.2820, 0.4408, 0.9402, 0.3710, 0.4399, 0.9453],
        [0.3997, 0.1564, 0.7702, 0.1675, 0.5019, 0.8055],
        [0.7183, 0.2357, 0.1873, 0.8530, 0.2531, 0.7312],
        [0.4589, 0.5474, 0.2109, 0.6925, 0.1446, 0.4996],
        [0.2881, 0.0170, 0.1121, 0.7964, 0.3911, 0.4384],
        [0.0045, 0.3141, 0.2574, 0.3820, 0.7746, 0.6186]])

In [15]:
#Class returns the context vector
context_vectors = simple(inputs)
context_vectors

tensor([[-2.5240, -2.4606, -1.4712, -0.8811, -1.9217, -1.3459],
        [ 0.5158,  1.1164,  0.3386,  0.2991, -0.1793,  0.2796],
        [-2.5311, -2.4689, -1.4806, -0.8782, -1.9316, -1.3404],
        [-2.1427, -1.9348, -0.9703, -0.8859, -1.4233, -1.4814]])

## Making the Simple Attention Class Again using nn.Linear

In [24]:
#Create a simple attention class v2
#Using nn.Linear to be more efficient

class SimpleAttention(nn.Module):
    #Constructor, initialize those matrix dimensions
    def __init__(self, d_in, d_out):
        super().__init__()
        #Create weight matrices
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)

    #   x = embedding vectors (inputs)
    def forward(self,x):
        #Make the queries, keys and values
        #Still just muktiplying the matrices
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v (x)
        #Compute scores, then normalize into weights
        scores = queries @ keys.T
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        #Compute context vector using weights and values
        context = weights @ values
        return context
    


In [17]:
simple = SimpleAttention(d_in = 8, d_out = 4)

In [18]:
context_vectors = simple(inputs)
context_vectors

tensor([[ 0.0688, -0.1599, -0.1387,  0.4715],
        [ 0.3319, -0.1109,  0.1228,  0.6896],
        [ 0.3608, -0.1032,  0.1626,  0.7137],
        [ 0.1311, -0.1515, -0.0888,  0.5293]], grad_fn=<MmBackward0>)

In [19]:
# the problem here is that each context vector uses information from ALL of the embedding vectors
#In practice, we should only use information about preceeding vectors
# To do so, implement causal attention AKA masked attention

### Simple Masked Attention (1.0)

In [None]:
#Hack to get some weights
#weights = simple(inputs)
weights

tensor([[ 0.0688, -0.1599, -0.1387,  0.4715],
        [ 0.3319, -0.1109,  0.1228,  0.6896],
        [ 0.3608, -0.1032,  0.1626,  0.7137],
        [ 0.1311, -0.1515, -0.0888,  0.5293]], grad_fn=<MmBackward0>)

In [None]:
#Note that these are already normalized
weights.sum(dim=-1)

tensor([-0.8540, -0.3801, -0.6706, -0.7008], grad_fn=<SumBackward1>)

In [None]:
#tril returns the lower triangular parts of the matrix (and the diagonal), turns the rest to zero
simple_mask = torch.tril(torch.ones(weights.shape[0], weights.shape[0]))
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [None]:
# * is a cell by cell product
masked_weights = weights*simple_mask
masked_weights

tensor([[-0.2313, -0.0000, -0.0000,  0.0000],
        [-0.0081, -0.1016, -0.0000, -0.0000],
        [-0.1679, -0.3332, -0.2500,  0.0000],
        [-0.1755, -0.3499, -0.2763,  0.1009]], grad_fn=<MulBackward0>)

In [None]:
#no longer normalized...
masked_weights.sum(dim=-1)

tensor([-0.2313, -0.1098, -0.7512, -0.7008], grad_fn=<SumBackward1>)

In [None]:
#Need to normalize masked weights
row_sums = masked_weights.sum(dim=-1, keepdim=True)
row_sums

tensor([[-0.2313],
        [-0.1098],
        [-0.7512],
        [-0.7008]], grad_fn=<SumBackward1>)

In [None]:
#Dividing by row sums normalizes
masked_weights = masked_weights / row_sums
masked_weights.sum(dim=-1)

tensor([1., 1., 1., 1.], grad_fn=<SumBackward1>)

### More Efficient Masking Method (1.1)

In [21]:
# Masking method 2
# triu() returns the upper triangular part of the matrix, puts zeros on the diagonal
mask = torch.triu(torch.ones(weights.shape[0], weights.shape[0]), diagonal=1)
mask

NameError: name 'weights' is not defined

In [None]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [None]:
weights

tensor([[-0.2313, -0.4622, -0.3702,  0.2098],
        [-0.0081, -0.1016, -0.1492, -0.1211],
        [-0.1679, -0.3332, -0.2500,  0.0806],
        [-0.1755, -0.3499, -0.2763,  0.1009]], grad_fn=<MmBackward0>)

In [None]:
#Need a matrix of true/false to fill in, fills the true spots with the given value (-inf)
weights = weights.masked_fill(mask.bool(), -torch.inf, )
weights

tensor([[-0.2313,    -inf,    -inf,    -inf],
        [-0.0081, -0.1016,    -inf,    -inf],
        [-0.1679, -0.3332, -0.2500,    -inf],
        [-0.1755, -0.3499, -0.2763,  0.1009]], grad_fn=<MaskedFillBackward0>)

In [None]:
masked_weights = torch.softmax(weights, dim=-1)
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5234, 0.4766, 0.0000, 0.0000],
        [0.3612, 0.3061, 0.3327, 0.0000],
        [0.2462, 0.2068, 0.2226, 0.3245]], grad_fn=<SoftmaxBackward0>)

### Dropout Explanation

In [None]:
## Dropout helps us avoid overfitting during training 
# (don't want to overfit to training set bc then model cannot generalize)
# Randomly and uniformly ignoring certain data points
# we must set a dropout rate
dropout = nn.Dropout(0.5)

In [None]:
# It removes around half of the data points ...
dropout(masked_weights)

tensor([[0.0000, 0.0000, 0.0000, 0.0000],
        [1.0467, 0.9533, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4135, 0.0000, 0.6490]], grad_fn=<MulBackward0>)

## Causal Attenion Class

In [None]:
# we need to be able to give our LLM batches of input
# For example:
# Concatenates two tensors
batches = torch.stack((inputs, inputs), dim=0)

In [None]:
batches

tensor([[[ 0.3476,  1.0008,  1.4503,  0.3622,  0.7486, -0.0103, -0.4972,
          -2.9958],
         [-0.4068, -0.5580, -0.1595, -0.7773, -0.5482, -0.2031,  0.2543,
           1.1040],
         [-0.8071,  0.7434,  0.9972, -0.2365,  0.6343,  0.2737, -0.2480,
           2.2735],
         [ 0.8415, -1.1536,  1.2855,  0.5882,  0.3148, -0.2748, -1.1443,
           2.0168]],

        [[ 0.3476,  1.0008,  1.4503,  0.3622,  0.7486, -0.0103, -0.4972,
          -2.9958],
         [-0.4068, -0.5580, -0.1595, -0.7773, -0.5482, -0.2031,  0.2543,
           1.1040],
         [-0.8071,  0.7434,  0.9972, -0.2365,  0.6343,  0.2737, -0.2480,
           2.2735],
         [ 0.8415, -1.1536,  1.2855,  0.5882,  0.3148, -0.2748, -1.1443,
           2.0168]]])

In [None]:
#this class needs to handle batches of input
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        #Create weight matrices
        self.W_q = nn.Linear(d_in, d_out, bias=False)
        self.W_k = nn.Linear(d_in, d_out, bias=False)
        self.W_v = nn.Linear(d_in, d_out, bias=False)
        # include dropout
        self.dropout = nn.Dropout(dropout)
        #Use the following to manage memory efficiently:
        self.register_buffer(
            'mask', 
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
            )

    #   x = embedding vectors (inputs)
    def forward(self,x):
        b, num_tokens, d_in = x.shape
        #Make the queries, keys and values
        #Still just muktiplying the matrices
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v (x)
        #Compute scores, then normalize into weights (doing dot products of queries and keys)
        scores = queries @ keys.transpose(1,2)
        #masking process: turns into trues/falses then converts some into -inf which is normalized to 0 in next step
        scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        #Normalize, -inf turns into zeros
        weights = torch.softmax(scores / keys.shape[-1]**0.5, dim= -1)
        weights = self.dropout(weights)
        #Compute context vector using weights and values
        context = weights @ values
        return context

In [None]:
#Instantiate a causal attention mechanism
causal = CausalAttention(d_in=8, d_out=4, context_length= 4, dropout=0)

In [None]:
causal(batches)

tensor([[[ 0.1315, -0.0068, -0.5639, -1.1360],
         [ 0.0835, -0.0535, -0.2836, -0.3171],
         [-0.0576, -0.1803, -0.1704, -0.1978],
         [-0.1245, -0.3059,  0.0398,  0.0690]],

        [[ 0.1315, -0.0068, -0.5639, -1.1360],
         [ 0.0835, -0.0535, -0.2836, -0.3171],
         [-0.0576, -0.1803, -0.1704, -0.1978],
         [-0.1245, -0.3059,  0.0398,  0.0690]]], grad_fn=<UnsafeViewBackward0>)

### an example of how to transpose batches correctly

In [None]:
W_q = nn.Linear(d_in, d_out, bias=False)
W_k = nn.Linear(d_in, d_out, bias=False)
W_v = nn.Linear(d_in, d_out, bias=False)

In [None]:
queries = W_q(batches)
queries

tensor([[[ 0.0670,  0.2508, -0.4446, -0.4527, -0.3071,  1.5719],
         [ 0.0319, -0.0357,  0.3191,  0.1176,  0.1593, -0.5529],
         [-0.7918, -0.4235,  0.8298,  0.1500,  0.4593,  0.1080],
         [-0.9608, -0.2707,  0.7450, -0.3383,  0.9127, -0.4438]],

        [[ 0.0670,  0.2508, -0.4446, -0.4527, -0.3071,  1.5719],
         [ 0.0319, -0.0357,  0.3191,  0.1176,  0.1593, -0.5529],
         [-0.7918, -0.4235,  0.8298,  0.1500,  0.4593,  0.1080],
         [-0.9608, -0.2707,  0.7450, -0.3383,  0.9127, -0.4438]]],
       grad_fn=<UnsafeViewBackward0>)

In [None]:
keys = W_k(batches)
keys

tensor([[[-0.4463, -0.8833, -0.9793,  0.5838, -0.1625,  1.1621],
         [ 0.0057,  0.2712,  0.3750, -0.0250,  0.2763, -0.4203],
         [ 0.2456, -0.4015,  0.7905, -0.4507,  0.7702, -0.4926],
         [ 0.3099, -0.4570,  0.4964, -0.6015,  0.0530, -0.3242]],

        [[-0.4463, -0.8833, -0.9793,  0.5838, -0.1625,  1.1621],
         [ 0.0057,  0.2712,  0.3750, -0.0250,  0.2763, -0.4203],
         [ 0.2456, -0.4015,  0.7905, -0.4507,  0.7702, -0.4926],
         [ 0.3099, -0.4570,  0.4964, -0.6015,  0.0530, -0.3242]]],
       grad_fn=<UnsafeViewBackward0>)

In [None]:
#This transposes without mixing up the batches
keys.transpose(1,2)

tensor([[[-0.4463,  0.0057,  0.2456,  0.3099],
         [-0.8833,  0.2712, -0.4015, -0.4570],
         [-0.9793,  0.3750,  0.7905,  0.4964],
         [ 0.5838, -0.0250, -0.4507, -0.6015],
         [-0.1625,  0.2763,  0.7702,  0.0530],
         [ 1.1621, -0.4203, -0.4926, -0.3242]],

        [[-0.4463,  0.0057,  0.2456,  0.3099],
         [-0.8833,  0.2712, -0.4015, -0.4570],
         [-0.9793,  0.3750,  0.7905,  0.4964],
         [ 0.5838, -0.0250, -0.4507, -0.6015],
         [-0.1625,  0.2763,  0.7702,  0.0530],
         [ 1.1621, -0.4203, -0.4926, -0.3242]]], grad_fn=<TransposeBackward0>)

##  Multi-Head Attention 1.0 (not efficient)

In [None]:
# Different attention heads (above is one attention head) pay attention to different things
# A bunch of causal attention instances together

class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


In [None]:
mha = MultiHeadAttention(d_in=8, d_out= 4, context_length= 4, dropout= 0, num_heads=3)

In [None]:
mha_out = mha(batches)
mha_out

tensor([[[-0.2118, -0.1884, -0.3523, -0.1649,  0.2711,  0.0540, -0.1980,
           0.5261, -0.2650,  0.0665,  0.7665, -0.0167,  0.4012, -0.2532,
          -0.0161, -0.0168, -0.3253, -0.2619],
         [-0.1033,  0.0851, -0.1470, -0.0452,  0.2015,  0.0220, -0.0505,
           0.2666, -0.2875,  0.1883,  0.2228,  0.1200, -0.0183, -0.2629,
           0.1170, -0.1338, -0.3425, -0.1860],
         [-0.1038, -0.0539, -0.1180,  0.0188,  0.2275, -0.0454, -0.0199,
           0.1025, -0.0832,  0.2892, -0.1231,  0.4481, -0.3181,  0.0932,
           0.0616, -0.0186, -0.2017, -0.3157],
         [-0.0654,  0.0051, -0.1084,  0.0742,  0.0912, -0.0942,  0.0107,
           0.1322, -0.0120,  0.1973, -0.1906,  0.4651, -0.5089,  0.2139,
           0.1281, -0.0158, -0.1028, -0.3456]],

        [[-0.2118, -0.1884, -0.3523, -0.1649,  0.2711,  0.0540, -0.1980,
           0.5261, -0.2650,  0.0665,  0.7665, -0.0167,  0.4012, -0.2532,
          -0.0161, -0.0168, -0.3253, -0.2619],
         [-0.1033,  0.0851, -0.14

## Real Multi-Head Attention Mechanism

In [None]:
# Much more efficient, only one matrix multiplicaiton step (all multiplication for all attention heads once)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec


In [None]:
batches.shape

torch.Size([2, 4, 8])

In [None]:
batches

tensor([[[ 0.3476,  1.0008,  1.4503,  0.3622,  0.7486, -0.0103, -0.4972,
          -2.9958],
         [-0.4068, -0.5580, -0.1595, -0.7773, -0.5482, -0.2031,  0.2543,
           1.1040],
         [-0.8071,  0.7434,  0.9972, -0.2365,  0.6343,  0.2737, -0.2480,
           2.2735],
         [ 0.8415, -1.1536,  1.2855,  0.5882,  0.3148, -0.2748, -1.1443,
           2.0168]],

        [[ 0.3476,  1.0008,  1.4503,  0.3622,  0.7486, -0.0103, -0.4972,
          -2.9958],
         [-0.4068, -0.5580, -0.1595, -0.7773, -0.5482, -0.2031,  0.2543,
           1.1040],
         [-0.8071,  0.7434,  0.9972, -0.2365,  0.6343,  0.2737, -0.2480,
           2.2735],
         [ 0.8415, -1.1536,  1.2855,  0.5882,  0.3148, -0.2748, -1.1443,
           2.0168]]])

In [None]:
#View reshapes a tensor last two dimensions should multiply to be the last dimension of OG
batches.view(2,4,2,4)

tensor([[[[ 0.3476,  1.0008,  1.4503,  0.3622],
          [ 0.7486, -0.0103, -0.4972, -2.9958]],

         [[-0.4068, -0.5580, -0.1595, -0.7773],
          [-0.5482, -0.2031,  0.2543,  1.1040]],

         [[-0.8071,  0.7434,  0.9972, -0.2365],
          [ 0.6343,  0.2737, -0.2480,  2.2735]],

         [[ 0.8415, -1.1536,  1.2855,  0.5882],
          [ 0.3148, -0.2748, -1.1443,  2.0168]]],


        [[[ 0.3476,  1.0008,  1.4503,  0.3622],
          [ 0.7486, -0.0103, -0.4972, -2.9958]],

         [[-0.4068, -0.5580, -0.1595, -0.7773],
          [-0.5482, -0.2031,  0.2543,  1.1040]],

         [[-0.8071,  0.7434,  0.9972, -0.2365],
          [ 0.6343,  0.2737, -0.2480,  2.2735]],

         [[ 0.8415, -1.1536,  1.2855,  0.5882],
          [ 0.3148, -0.2748, -1.1443,  2.0168]]]])

In [None]:
mha = MultiHeadAttention(d_in=8, d_out= 6, context_length= 4, dropout= 0, num_heads= 3)

In [None]:
mha_out = mha(batches)
mha_out

tensor([[[-0.2118, -0.1884, -0.3523, -0.1649,  0.2711,  0.0540, -0.1980,
           0.5261, -0.2650,  0.0665,  0.7665, -0.0167,  0.4012, -0.2532,
          -0.0161, -0.0168, -0.3253, -0.2619],
         [-0.1033,  0.0851, -0.1470, -0.0452,  0.2015,  0.0220, -0.0505,
           0.2666, -0.2875,  0.1883,  0.2228,  0.1200, -0.0183, -0.2629,
           0.1170, -0.1338, -0.3425, -0.1860],
         [-0.1038, -0.0539, -0.1180,  0.0188,  0.2275, -0.0454, -0.0199,
           0.1025, -0.0832,  0.2892, -0.1231,  0.4481, -0.3181,  0.0932,
           0.0616, -0.0186, -0.2017, -0.3157],
         [-0.0654,  0.0051, -0.1084,  0.0742,  0.0912, -0.0942,  0.0107,
           0.1322, -0.0120,  0.1973, -0.1906,  0.4651, -0.5089,  0.2139,
           0.1281, -0.0158, -0.1028, -0.3456]],

        [[-0.2118, -0.1884, -0.3523, -0.1649,  0.2711,  0.0540, -0.1980,
           0.5261, -0.2650,  0.0665,  0.7665, -0.0167,  0.4012, -0.2532,
          -0.0161, -0.0168, -0.3253, -0.2619],
         [-0.1033,  0.0851, -0.14