In [77]:
import numpy as np
import torch

In [78]:
class MultiHeadAttention:
    def __init__(self, d_model, num_heads):
        assert d_model % num_heads == 0 # d_model must be divisible by num heads
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = d_model // num_heads

        # initialize weight for Q, K, V
        self.WQ = np.random.randn(d_model, d_model) * 0.01
        self.WK = np.random.randn(d_model, d_model) * 0.01
        self.WV = np.random.randn(d_model, d_model) * 0.01
        self.dense = np.random.randn(d_model, d_model) * 0.01

    def split_heads(self, x, batch_size):
        """
        split the last dimension into (num_heads, depth)
        Transpose the result to shape (batch_size, num_heads, seq_len, depth)
        """
        x = x.reshape(batch_size, -1, self.num_heads, self.depth)
        return np.transpose(x, (0, 2, 1, 3))

    def softmax(self, x):
        exp_x = torch.exp(x - torch.max(x, dim=-1, keepdim=True)[0])
        return exp_x / torch.sum(exp_x, dim=-1, keepdim=True)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        d_k = Q.shape[-1]
        scores = np.matmul(Q, K.transpose(-2, -1)) / np.sqrt(d_k)

        if mask is not None:
            scores = np.where(mask == 0, -1e9, scores)

        # softmax
        attention_weights = self.softmax(scores)
        output = np.matmul(attention_weights, V)

        return output, attention_weights

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.shape[0]

        # linear projections
        Q = np.matmul(Q, self.WQ)
        K = np.matmul(K, self.WK)
        V = np.matmul(V, self.WV)

        # split projections into multiple heads
        Q = self.split_heads(Q, batch_size)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)

        # scaled dot product attention
        attention_output, attention_weights = self.scaled_dot_product_attention(Q, K, V, mask)

        # concat heads
        attention_output = attention_output.permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)
        
        # final layer
        output = np.matmul(attention_output, self.dense)

        return output, attention_weights  

In [79]:
combined_embedding = torch.tensor([[-1.0888,  0.7176,  0.4090,  1.6687],
        [ 1.8317,  0.6463, -0.1360,  1.9677],
        [ 2.0548,  0.1155,  0.0074,  1.1889],
        [-0.2718, -1.5583,  0.2442,  0.3019],
        [-1.9627,  0.0190, -0.5172,  1.1168],
        [-0.2445, -2.1576,  0.4894,  2.0258],
        [-1.8888,  0.4360,  0.5125,  0.8771],
        [ 0.4977,  0.5503,  1.0210,  0.5724],
        [ 1.7046, -0.7652,  0.8749,  2.1788],
        [ 0.4044, -2.5807, -1.2181,  1.3923],
        [-1.4248, -0.7295,  1.2019,  3.3161],
        [-2.5365,  1.3118,  2.5363, -0.3911],
        [-0.9186,  2.9010,  0.6690,  0.4206],
        [ 2.2390,  0.9207, -0.2267,  0.7386],
        [ 2.5905,  0.7127,  0.5017, -0.4269],
        [ 1.7419, -1.8513, -0.3948,  0.5540],
        [ 1.6389, -1.0596, -0.0133,  0.0694],
        [-0.6884, -0.0960, -0.3400,  0.8601],
        [-1.8131,  1.4727, -1.0340, -0.8144],
        [ 0.6698,  1.4205, -0.4873,  0.7800],
        [-0.2251, -0.3391,  0.8395,  1.0639],
        [-0.1970, -1.8507,  0.6383, -0.6770],
        [ 0.3017, -1.7349, -0.0116,  1.7366],
        [-1.1640, -0.6178,  0.0421,  2.2698],
        [-0.3800,  2.1826,  0.9099,  2.8828],
        [ 0.8038,  2.6766, -0.0650,  0.7205],
        [ 2.5170,  1.1129,  0.1357,  1.8460],
        [ 0.9675,  2.4028, -0.6554,  0.6674],
        [ 0.6507, -1.4472, -0.3987,  0.1019],
        [-0.6414, -0.7548,  1.1913,  1.8270],
        [ 0.1130,  0.2623,  1.0054,  2.1297],
        [ 0.5632, -1.0677,  1.8818,  1.4106],
        [ 1.5025,  2.3895,  1.6297,  0.5900],
        [ 0.8548, -0.7370,  0.9209, -0.1616],
        [ 0.8174, -0.1504,  0.2955,  0.9260]])

In [84]:
np.random.seed(42)
batch_size = 1
num_heads = 2
d_model = 4

In [85]:
Q = combined_embedding
K = combined_embedding
V = combined_embedding

In [86]:
multihead_attention = MultiHeadAttention(d_model, num_heads)

In [87]:
output, attention_weights = multihead_attention.forward(Q, K, V)

In [88]:
output

tensor([[[ 9.9593e-05,  2.9258e-04, -6.4020e-04, -6.8415e-04]],

        [[ 2.3448e-04,  2.3441e-04, -5.8210e-04, -3.4590e-04]],

        [[ 1.3601e-04,  1.0283e-04, -2.2347e-04,  8.2476e-06]],

        [[-3.8836e-04,  2.4034e-04,  3.3301e-04,  4.4886e-04]],

        [[-1.4968e-04,  4.1858e-04, -4.2946e-04, -4.6737e-04]],

        [[-5.4662e-04,  6.4713e-04,  6.9862e-05,  3.4845e-04]],

        [[-1.9356e-06,  1.9957e-04, -4.0025e-04, -5.4507e-04]],

        [[ 1.9391e-04, -8.9977e-05, -1.9038e-04, -2.1663e-04]],

        [[-7.9263e-05,  3.4389e-04, -2.0217e-04,  9.6955e-05]],

        [[-6.6544e-04,  7.1029e-04,  2.5435e-04,  6.7115e-04]],

        [[-2.6390e-04,  7.4743e-04, -6.4283e-04, -5.1921e-04]],

        [[ 2.6520e-04, -3.5438e-04, -1.9668e-04, -6.7417e-04]],

        [[ 6.6541e-04, -2.7134e-04, -8.6526e-04, -1.1830e-03]],

        [[ 3.3939e-04, -7.4453e-05, -3.2604e-04, -1.6598e-04]],

        [[ 3.5440e-04, -3.9926e-04,  9.5039e-05,  1.7106e-04]],

        [[-3.6495e-04,  2

In [89]:
output.shape

torch.Size([35, 1, 4])

In [90]:
attention_weights

tensor([[[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1.]],

         [[1.]]],


        [[[1

In [91]:
attention_weights.shape

torch.Size([35, 2, 1, 1])