In [161]:
import torch
import torch.nn as nn

In [162]:
vocab_size = 4
output_dimension = 8

inputs = torch.nn.Embedding(vocab_size, output_dimension)

In [163]:
inputs = inputs.weight
inputs

Parameter containing:
tensor([[-1.0298, -1.3108, -0.7712,  0.0066,  0.3956,  0.2575,  1.1512, -1.7080],
        [ 0.9042, -0.2926, -0.6147, -1.3084, -0.5278, -0.1160, -1.7883,  1.4322],
        [-2.1895, -1.3615,  1.3255,  0.5632,  0.0529,  0.4835,  1.1977,  0.8252],
        [-1.6999, -0.2699, -0.9496, -1.0292,  0.3475, -0.3594, -0.5015, -0.3627]],
       requires_grad=True)

In [164]:
inputs = inputs.data # without 'requires_grad=True'
inputs

tensor([[-1.0298, -1.3108, -0.7712,  0.0066,  0.3956,  0.2575,  1.1512, -1.7080],
        [ 0.9042, -0.2926, -0.6147, -1.3084, -0.5278, -0.1160, -1.7883,  1.4322],
        [-2.1895, -1.3615,  1.3255,  0.5632,  0.0529,  0.4835,  1.1977,  0.8252],
        [-1.6999, -0.2699, -0.9496, -1.0292,  0.3475, -0.3594, -0.5015, -0.3627]])

In [165]:
# set dimensions
d_in = 8 # inputs.shape[i]
d_out = 6 # preferred output size

# create weight matrices
W_q = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_k = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
W_v = torch.nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

In [166]:
# choose an input vector and transform it into our query vector using W_q
# Note that the output has the preferred size
query = inputs @ W_q
query

tensor([[-0.7134, -0.7977, -1.7937, -2.0356, -0.7256,  0.0425],
        [-2.8993, -1.5911, -0.3089, -0.2292, -0.9948, -2.1418],
        [ 0.9768,  2.2162, -1.8151,  0.6648,  0.7611,  1.3483],
        [-3.3262, -2.6548, -2.0784, -3.3578, -2.4539, -2.2918]])

In [167]:
# calculate attention scores using the keys generated by W_k:
keys = inputs @ W_k
values = inputs @ W_v
print("Keys:", keys)
print("Values:", values )

Keys: tensor([[-0.4895, -1.3706, -1.4372, -1.4638, -1.1333, -0.6843],
        [-0.1225, -1.6693, -1.1195, -1.1018, -1.4315, -2.1682],
        [-0.0236,  0.4821,  0.4305,  1.5627,  0.4584,  2.7486],
        [-0.7373, -2.1763, -1.6892, -1.9651, -2.7166, -2.4873]])
Values: tensor([[-1.2902, -0.9184, -0.8885, -0.9777, -1.7442, -1.7779],
        [-0.7943, -0.7326, -2.6527, -0.8318, -0.9299, -1.4858],
        [ 0.8355,  0.1140,  0.0805, -1.3144,  1.7866,  0.2797],
        [-1.8394, -1.5204, -2.6693, -2.3475, -2.7289, -2.0297]])


In [168]:
# attention scores are how important a token is. Ex. when translating word by word,
# the word being translated gets the highest att score. the others get scores too but not as big as the meant word
attention_scores = query @ keys.T
attention_scores

tensor([[  7.7933,   6.6164,  -4.5368,  11.1575],
        [  6.9725,   9.6773,  -7.5327,  14.6023],
        [ -3.6652,  -6.5325,   5.3576,  -9.2048],
        [ 17.5183,  19.3472, -14.7673,  30.7058]])

Attention weights are different as they refer to the input and other inputs with respect to the highest attention (query) score input
Ex. a23 --> a is the name of the weight, 2 is the highest score, and 3 is the position of the input with respect to the highest, 2.

In [169]:
# dividing by the root of shape of a row in keys to keep them in a reasonable range
attention_weights = torch.softmax( attention_scores / keys.shape[-1]**0.5, dim = -1 )
attention_weights

tensor([[1.7941e-01, 1.1096e-01, 1.1687e-03, 7.0846e-01],
        [3.7665e-02, 1.1363e-01, 1.0097e-04, 8.4860e-01],
        [2.4271e-02, 7.5286e-03, 9.6567e-01, 2.5288e-03],
        [4.5259e-03, 9.5494e-03, 8.5402e-09, 9.8592e-01]])

In [170]:
attention_weights.sum()

tensor(4.)

In [171]:
context_vector = attention_weights @ values
context_vector

tensor([[-1.6218, -1.3231, -2.3448, -1.9323, -2.3474, -1.9215],
        [-1.6997, -1.4081, -2.6001, -2.1236, -2.4869, -1.9582],
        [ 0.7648,  0.0784,  0.0295, -1.3052,  1.6690,  0.2106],
        [-1.8269, -1.5102, -2.6611, -2.3268, -2.7073, -2.0234]])

In [172]:
# here's a first version of a SimpleAttention class:

class SimpleAttention(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_k = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )
    self.W_v = nn.Parameter( torch.rand( d_in, d_out ), requires_grad=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = x @ self.W_q
    keys = x @ self.W_k
    values = x @ self.W_v

    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    
    return context

In [173]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttention( d_in = 8, d_out = 6 )

In [174]:
simple.W_v

Parameter containing:
tensor([[0.7882, 0.5257, 0.9291, 0.2225, 0.9376, 0.6388],
        [0.3784, 0.7168, 0.4756, 0.7102, 0.2398, 0.6359],
        [0.0596, 0.8810, 0.2340, 0.6542, 0.7770, 0.7621],
        [0.6871, 0.6001, 0.0829, 0.9675, 0.9467, 0.4399],
        [0.9656, 0.0739, 0.4295, 0.5311, 0.9109, 0.7623],
        [0.6701, 0.5710, 0.3946, 0.5935, 0.1420, 0.7334],
        [0.1673, 0.1421, 0.8042, 0.7996, 0.5478, 0.3332],
        [0.6309, 0.7429, 0.7442, 0.1275, 0.8110, 0.2621]])

In [175]:
context_vectors = simple( inputs )
context_vectors

tensor([[-2.3679, -3.0142, -2.6160, -2.6218, -3.5951, -2.6449],
        [-2.3093, -2.9916, -2.5485, -2.5144, -3.4955, -2.5765],
        [-0.8409, -0.2411, -0.7559,  0.1175, -0.2481, -0.5641],
        [-2.4182, -3.0579, -2.6747, -2.6560, -3.6653, -2.6890]])

In [176]:
# here's a second version of a SimpleAttention class ;
# it uses nn.Linear to do things more efficiently and gives better training results

class SimpleAttentionV2( nn.Module ):
  def __init__(self, d_in, d_out):
    super().__init__()
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )

    scores = queries @ keys.T
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    context = weights @ values
    
    return context

In [177]:
# here's how to use this class:
# instantiate an instance of it:
simple = SimpleAttentionV2( d_in = 8, d_out = 6 )

In [178]:
context_vectors = simple( inputs )
context_vectors

tensor([[ 0.3019, -0.2753,  0.2304,  0.3721, -0.3348,  0.3127],
        [ 0.3258, -0.2385,  0.2726,  0.3736, -0.2890,  0.2899],
        [ 0.2462, -0.1964,  0.2366,  0.4313, -0.2747,  0.2815],
        [ 0.0653, -0.0971,  0.0548,  0.4059, -0.2953,  0.2586]],
       grad_fn=<MmBackward0>)

- The problem with this is that each context vector uses information from ALL of the embedding vectors
- In practice, we should only use information about the preceding embedding vectors
- To accomplish this, we'll implement causal attention AKA masked attention
- It briefly means hiding future words

In [179]:
# this is a hack to get some example weights to work with!
# weights = simple( inputs ) # The hack didn't work :(

queries = simple.W_q(inputs)
keys = simple.W_k(inputs)
values = simple.W_v(inputs)

scores = queries @ keys.T
weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )

weights

tensor([[0.1752, 0.2578, 0.2785, 0.2885],
        [0.2506, 0.2582, 0.2851, 0.2061],
        [0.2414, 0.2247, 0.2271, 0.3068],
        [0.1990, 0.3362, 0.1744, 0.2905]], grad_fn=<SoftmaxBackward0>)

In [180]:
# note that these have already been normalized:
weights.sum( dim=-1 )

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [None]:
# masking method #1: we soft max first then mask
simple_mask = torch.tril( torch.ones( weights.shape[0], weights.shape[0] ) ) #Triangular mask: returns the lower triangular part
simple_mask

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [182]:
masked_weights = weights * simple_mask
masked_weights

tensor([[0.1752, 0.0000, 0.0000, 0.0000],
        [0.2506, 0.2582, 0.0000, 0.0000],
        [0.2414, 0.2247, 0.2271, 0.0000],
        [0.1990, 0.3362, 0.1744, 0.2905]], grad_fn=<MulBackward0>)

In [183]:
masked_weights.sum( dim=-1 )

tensor([0.1752, 0.5088, 0.6932, 1.0000], grad_fn=<SumBackward1>)

In [184]:
# now, we need to normalize the masked_weights so that each row has sum 1 as it is good for optimization
# What this code does -> simple_mask / row_sums
row_sums = masked_weights.sum( dim=-1, keepdim=True)
row_sums

tensor([[0.1752],
        [0.5088],
        [0.6932],
        [1.0000]], grad_fn=<SumBackward1>)

In [185]:
masked_weights = masked_weights / row_sums
masked_weights.sum( dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [186]:
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4925, 0.5075, 0.0000, 0.0000],
        [0.3483, 0.3242, 0.3276, 0.0000],
        [0.1990, 0.3362, 0.1744, 0.2905]], grad_fn=<DivBackward0>)

In [None]:
# masking method #2
# This way scores -> mask -> soft max
mask = torch.triu( torch.ones(weights.shape[0], weights.shape[0]), diagonal = 1 )  #returns the upper triangular part
mask

tensor([[0., 1., 1., 1.],
        [0., 0., 1., 1.],
        [0., 0., 0., 1.],
        [0., 0., 0., 0.]])

In [188]:
mask.bool()

tensor([[False,  True,  True,  True],
        [False, False,  True,  True],
        [False, False, False,  True],
        [False, False, False, False]])

In [189]:
weights

tensor([[0.1752, 0.2578, 0.2785, 0.2885],
        [0.2506, 0.2582, 0.2851, 0.2061],
        [0.2414, 0.2247, 0.2271, 0.3068],
        [0.1990, 0.3362, 0.1744, 0.2905]], grad_fn=<SoftmaxBackward0>)

In [190]:
# We masked the values first by hiding future values with -infinity
weights = weights.masked_fill( mask.bool(), -torch.inf )
weights

tensor([[0.1752,   -inf,   -inf,   -inf],
        [0.2506, 0.2582,   -inf,   -inf],
        [0.2414, 0.2247, 0.2271,   -inf],
        [0.1990, 0.3362, 0.1744, 0.2905]], grad_fn=<MaskedFillBackward0>)

In [191]:
# Now, every row sums up to 1
masked_weights = torch.softmax( weights, dim=-1 )
masked_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.4981, 0.5019, 0.0000, 0.0000],
        [0.3368, 0.3312, 0.3320, 0.0000],
        [0.2370, 0.2719, 0.2313, 0.2598]], grad_fn=<SoftmaxBackward0>)

In [192]:
## Dropout Mask
# idea: randomly select some data to leave out to avoid overfitting
dropout = nn.Dropout( 0.5 ) # 50%

In [193]:
dropout( torch.ones(6,6) )

tensor([[0., 0., 2., 2., 2., 0.],
        [2., 0., 0., 0., 2., 0.],
        [0., 0., 2., 0., 2., 0.],
        [0., 2., 0., 2., 2., 0.],
        [2., 0., 2., 2., 2., 2.],
        [2., 2., 0., 2., 2., 2.]])

In [194]:
# we need to be able to give our LLM batches of input
# for example:
batches = torch.stack( (inputs, inputs), dim=0) #stack 2 inputs on top of eachother

In [195]:
batches

tensor([[[-1.0298, -1.3108, -0.7712,  0.0066,  0.3956,  0.2575,  1.1512,
          -1.7080],
         [ 0.9042, -0.2926, -0.6147, -1.3084, -0.5278, -0.1160, -1.7883,
           1.4322],
         [-2.1895, -1.3615,  1.3255,  0.5632,  0.0529,  0.4835,  1.1977,
           0.8252],
         [-1.6999, -0.2699, -0.9496, -1.0292,  0.3475, -0.3594, -0.5015,
          -0.3627]],

        [[-1.0298, -1.3108, -0.7712,  0.0066,  0.3956,  0.2575,  1.1512,
          -1.7080],
         [ 0.9042, -0.2926, -0.6147, -1.3084, -0.5278, -0.1160, -1.7883,
           1.4322],
         [-2.1895, -1.3615,  1.3255,  0.5632,  0.0529,  0.4835,  1.1977,
           0.8252],
         [-1.6999, -0.2699, -0.9496, -1.0292,  0.3475, -0.3594, -0.5015,
          -0.3627]]])

In [196]:
batches.shape
# The output means 2 inputs, each tensor has 4 tokens, and each token is 8-dim vector

torch.Size([2, 4, 8])

In [197]:
# this class needs to handle batches of input!

class CausalAttention( nn.Module ):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    # create weight matrices:
    self.W_q = nn.Linear( d_in, d_out, bias=False )
    self.W_k = nn.Linear( d_in, d_out, bias=False )
    self.W_v = nn.Linear( d_in, d_out, bias=False )
    
    # include dropout:
    self.dropout = nn.Dropout( dropout )
    
    # use the following to manage memory efficiently
    # When passing this to a GPU, it's better because GPUs dont read tensors
    self.register_buffer(
        'mask',
        torch.triu( torch.ones(context_length, context_length), diagonal = 1 )
    )

  # x = embedding vectors (inputs)
  def forward( self, x ):
    b, num_tokens, d_in = x.shape # b is batch size (num inputs)
    queries = self.W_q( x )
    keys = self.W_k( x )
    values = self.W_v( x )

    scores = queries @ keys.transpose(1,2)
    scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # inplace operation for better efficiency
    weights = torch.softmax( scores / keys.shape[-1]**0.5, dim = -1 )
    weights = self.dropout( weights )
    
    context = weights @ values
    return context

In [198]:
# instantiate a causal attention mechanism:
causal = CausalAttention( d_in=8, d_out=6, context_length=4, dropout=0 )

In [199]:
causal( batches )

tensor([[[-0.4548,  0.0840,  0.1201,  0.2948,  0.1458,  0.0185],
         [ 0.0907,  0.3469,  0.2073,  0.3515, -0.3509, -0.2936],
         [-0.0654,  0.1142,  0.2091,  0.1274, -0.0935, -0.0815],
         [-0.0398,  0.1362,  0.0924,  0.1151,  0.0739, -0.1267]],

        [[-0.4548,  0.0840,  0.1201,  0.2948,  0.1458,  0.0185],
         [ 0.0907,  0.3469,  0.2073,  0.3515, -0.3509, -0.2936],
         [-0.0654,  0.1142,  0.2091,  0.1274, -0.0935, -0.0815],
         [-0.0398,  0.1362,  0.0924,  0.1151,  0.0739, -0.1267]]],
       grad_fn=<UnsafeViewBackward0>)

In [200]:
# everything below is just to show what happens with batches

queries = W_q( batches )
queries

TypeError: 'Parameter' object is not callable

In [None]:
keys = W_k( batches )
keys

TypeError: 'Parameter' object is not callable

In [None]:
keys.transpose(1,2)

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

In [201]:
# here's a first pass at multi-head attention
class MultiHeadAttention( nn.Module ):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList( 
            [ CausalAttention( d_in, d_out, context_length, dropout, qkv_bias ) for _ in range(num_heads) ]
        )

    def forward( self, x ):
        return torch.cat( [ head(x) for head in self.heads ], dim=-1 )

In [202]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length= 4, dropout=0, num_heads = 3 )

In [203]:
mha_out = mha( batches )

In [204]:
mha_out

tensor([[[-0.8212, -0.8456, -0.6401, -0.6974,  0.1105, -0.1379, -0.1625,
           0.0825,  0.8190, -0.5085,  0.3563,  0.7604,  0.4894,  0.7208,
           0.3837, -0.2336,  0.0953,  0.2158],
         [-0.2423, -0.2805, -0.1383,  0.3809,  0.2915,  0.0933, -0.1245,
           0.1647,  0.1735, -0.5619,  0.2576,  0.1994,  0.4498,  0.3860,
           0.0192,  0.2289,  0.1943,  0.0930],
         [-0.3217, -0.5457, -0.0944, -0.2841,  0.1091, -0.5424, -0.1049,
          -0.0095, -0.2659,  0.1045,  0.2647,  0.4789,  0.4424,  0.0798,
           0.3823,  0.1615,  0.2073,  0.5586],
         [-0.4849, -0.5947, -0.2791, -0.2977, -0.0066, -0.4671, -0.2279,
          -0.0983, -0.0127,  0.1463,  0.2048,  0.6803,  0.5510,  0.0346,
           0.1713,  0.1727,  0.2000,  0.3906]],

        [[-0.8212, -0.8456, -0.6401, -0.6974,  0.1105, -0.1379, -0.1625,
           0.0825,  0.8190, -0.5085,  0.3563,  0.7604,  0.4894,  0.7208,
           0.3837, -0.2336,  0.0953,  0.2158],
         [-0.2423, -0.2805, -0.13

In [205]:
mha_out.shape

torch.Size([2, 4, 18])

In [206]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec


In [207]:
batches.shape

torch.Size([2, 4, 8])

In [208]:
batches

tensor([[[-1.0298, -1.3108, -0.7712,  0.0066,  0.3956,  0.2575,  1.1512,
          -1.7080],
         [ 0.9042, -0.2926, -0.6147, -1.3084, -0.5278, -0.1160, -1.7883,
           1.4322],
         [-2.1895, -1.3615,  1.3255,  0.5632,  0.0529,  0.4835,  1.1977,
           0.8252],
         [-1.6999, -0.2699, -0.9496, -1.0292,  0.3475, -0.3594, -0.5015,
          -0.3627]],

        [[-1.0298, -1.3108, -0.7712,  0.0066,  0.3956,  0.2575,  1.1512,
          -1.7080],
         [ 0.9042, -0.2926, -0.6147, -1.3084, -0.5278, -0.1160, -1.7883,
           1.4322],
         [-2.1895, -1.3615,  1.3255,  0.5632,  0.0529,  0.4835,  1.1977,
           0.8252],
         [-1.6999, -0.2699, -0.9496, -1.0292,  0.3475, -0.3594, -0.5015,
          -0.3627]]])

In [209]:
batches.view( 2, 4, 2, 4 )

tensor([[[[-1.0298, -1.3108, -0.7712,  0.0066],
          [ 0.3956,  0.2575,  1.1512, -1.7080]],

         [[ 0.9042, -0.2926, -0.6147, -1.3084],
          [-0.5278, -0.1160, -1.7883,  1.4322]],

         [[-2.1895, -1.3615,  1.3255,  0.5632],
          [ 0.0529,  0.4835,  1.1977,  0.8252]],

         [[-1.6999, -0.2699, -0.9496, -1.0292],
          [ 0.3475, -0.3594, -0.5015, -0.3627]]],


        [[[-1.0298, -1.3108, -0.7712,  0.0066],
          [ 0.3956,  0.2575,  1.1512, -1.7080]],

         [[ 0.9042, -0.2926, -0.6147, -1.3084],
          [-0.5278, -0.1160, -1.7883,  1.4322]],

         [[-2.1895, -1.3615,  1.3255,  0.5632],
          [ 0.0529,  0.4835,  1.1977,  0.8252]],

         [[-1.6999, -0.2699, -0.9496, -1.0292],
          [ 0.3475, -0.3594, -0.5015, -0.3627]]]])

In [210]:
mha = MultiHeadAttention( d_in = 8, d_out = 6, context_length=4, dropout=0, num_heads=3 )

In [211]:
mha_out = mha( batches )

In [212]:
mha_out

tensor([[[ 0.0985,  0.0660,  0.2198,  0.0374, -0.1643,  0.0888],
         [ 0.2072, -0.1460,  0.2532,  0.0242,  0.4572, -0.3227],
         [ 0.2035, -0.1053,  0.2840, -0.1345,  0.2946, -0.2536],
         [ 0.1382, -0.1410,  0.2790,  0.0616,  0.2923, -0.1300]],

        [[ 0.0985,  0.0660,  0.2198,  0.0374, -0.1643,  0.0888],
         [ 0.2072, -0.1460,  0.2532,  0.0242,  0.4572, -0.3227],
         [ 0.2035, -0.1053,  0.2840, -0.1345,  0.2946, -0.2536],
         [ 0.1382, -0.1410,  0.2790,  0.0616,  0.2923, -0.1300]]],
       grad_fn=<ViewBackward0>)

In [213]:
mha_out.shape

torch.Size([2, 4, 6])