In [1]:
import math 
import torch
import torch.nn as nn
import numpy as np

In [2]:
torch.manual_seed(42)

<torch._C.Generator at 0x1f44d712b50>

In [3]:
batch_size = 2
seq_len = 3
num_heads = 2
d_model = 4
d_k = d_model // num_heads

# Generate query, key and value matrices in the format of (batch_size, seq_len, d_model)
query = torch.rand(batch_size, seq_len, d_model)
key = torch.rand(batch_size, seq_len, d_model)
value = torch.rand(batch_size, seq_len, d_model)

# Current query, key and value matrixes
print(query)
print(key)
print(value)

tensor([[[0.8823, 0.9150, 0.3829, 0.9593],
         [0.3904, 0.6009, 0.2566, 0.7936],
         [0.9408, 0.1332, 0.9346, 0.5936]],

        [[0.8694, 0.5677, 0.7411, 0.4294],
         [0.8854, 0.5739, 0.2666, 0.6274],
         [0.2696, 0.4414, 0.2969, 0.8317]]])
tensor([[[0.1053, 0.2695, 0.3588, 0.1994],
         [0.5472, 0.0062, 0.9516, 0.0753],
         [0.8860, 0.5832, 0.3376, 0.8090]],

        [[0.5779, 0.9040, 0.5547, 0.3423],
         [0.6343, 0.3644, 0.7104, 0.9464],
         [0.7890, 0.2814, 0.7886, 0.5895]]])
tensor([[[0.7539, 0.1952, 0.0050, 0.3068],
         [0.1165, 0.9103, 0.6440, 0.7071],
         [0.6581, 0.4913, 0.8913, 0.1447]],

        [[0.5315, 0.1587, 0.6542, 0.3278],
         [0.6532, 0.3958, 0.9147, 0.2036],
         [0.2018, 0.2018, 0.9497, 0.6666]]])


In [4]:
# Add num heads and d_k into the matrixes: (batch_size, seq_len, d_model) --> (batch_size, seq_len, h, d_k)
query = query.view(query.shape[0], query.shape[1], num_heads, d_k)
key = key.view(key.shape[0], key.shape[1], num_heads, d_k)
value = value.view(value.shape[0], value.shape[1], num_heads, d_k)

print(query)
print(key)
print(value)

tensor([[[[0.8823, 0.9150],
          [0.3829, 0.9593]],

         [[0.3904, 0.6009],
          [0.2566, 0.7936]],

         [[0.9408, 0.1332],
          [0.9346, 0.5936]]],


        [[[0.8694, 0.5677],
          [0.7411, 0.4294]],

         [[0.8854, 0.5739],
          [0.2666, 0.6274]],

         [[0.2696, 0.4414],
          [0.2969, 0.8317]]]])
tensor([[[[0.1053, 0.2695],
          [0.3588, 0.1994]],

         [[0.5472, 0.0062],
          [0.9516, 0.0753]],

         [[0.8860, 0.5832],
          [0.3376, 0.8090]]],


        [[[0.5779, 0.9040],
          [0.5547, 0.3423]],

         [[0.6343, 0.3644],
          [0.7104, 0.9464]],

         [[0.7890, 0.2814],
          [0.7886, 0.5895]]]])
tensor([[[[0.7539, 0.1952],
          [0.0050, 0.3068]],

         [[0.1165, 0.9103],
          [0.6440, 0.7071]],

         [[0.6581, 0.4913],
          [0.8913, 0.1447]]],


        [[[0.5315, 0.1587],
          [0.6542, 0.3278]],

         [[0.6532, 0.3958],
          [0.9147, 0.2036]],

      

In [5]:
# Transpose the query, key and value matrixes from: (batch_size, seq_len, h, d_k) --> (batch_size, h, seq_len, d_k)
# This is because attention mechanisms processes each head individually
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)

print(query)
print(key)
print(value)

tensor([[[[0.8823, 0.9150],
          [0.3904, 0.6009],
          [0.9408, 0.1332]],

         [[0.3829, 0.9593],
          [0.2566, 0.7936],
          [0.9346, 0.5936]]],


        [[[0.8694, 0.5677],
          [0.8854, 0.5739],
          [0.2696, 0.4414]],

         [[0.7411, 0.4294],
          [0.2666, 0.6274],
          [0.2969, 0.8317]]]])
tensor([[[[0.1053, 0.2695],
          [0.5472, 0.0062],
          [0.8860, 0.5832]],

         [[0.3588, 0.1994],
          [0.9516, 0.0753],
          [0.3376, 0.8090]]],


        [[[0.5779, 0.9040],
          [0.6343, 0.3644],
          [0.7890, 0.2814]],

         [[0.5547, 0.3423],
          [0.7104, 0.9464],
          [0.7886, 0.5895]]]])
tensor([[[[0.7539, 0.1952],
          [0.1165, 0.9103],
          [0.6581, 0.4913]],

         [[0.0050, 0.3068],
          [0.6440, 0.7071],
          [0.8913, 0.1447]]],


        [[[0.5315, 0.1587],
          [0.6532, 0.3958],
          [0.2018, 0.2018]],

         [[0.6542, 0.3278],
          [0.9147,