In [23]:
import torch.nn as nn
import torch
torch.manual_seed(1773) 
import math
import copy

### Pytorch's way

In [24]:
embed_dim = 2
num_heads = 1
query = torch.tensor([[[0.7204, 0.0731],
         [0.9699, 0.1078],
         [0.8829, 0.4132]]])
key = query
value = query

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads,batch_first=True,bias=True)
attn_output, attn_output_weights = multihead_attn(query, key, value)

### Getting the parameters

In [25]:
multihead_attn.state_dict()['in_proj_weight']

tensor([[-0.5774,  0.5295],
        [ 0.2826, -0.5617],
        [ 0.3290, -0.5050],
        [-0.0958,  0.8166],
        [ 0.0949, -0.7858],
        [-0.1903, -0.7558]])

In [26]:
multihead_attn.state_dict()['in_proj_weight'].chunk(3)[0]

tensor([[-0.5774,  0.5295],
        [ 0.2826, -0.5617]])

In [27]:
multihead_attn.state_dict()['in_proj_weight'].shape

torch.Size([6, 2])

In [28]:
multihead_attn.state_dict()['in_proj_bias'].shape

torch.Size([6])

In [29]:
multihead_attn.state_dict()['out_proj.weight'].shape

torch.Size([2, 2])

In [30]:
multihead_attn.state_dict()['out_proj.bias'].shape

torch.Size([2])

In [31]:
attn_output

tensor([[[ 0.2134, -0.0416],
         [ 0.2143, -0.0420],
         [ 0.2120, -0.0410]]], grad_fn=<TransposeBackward0>)

In [32]:
attn_output_weights

tensor([[[0.3282, 0.3227, 0.3491],
         [0.3265, 0.3193, 0.3542],
         [0.3317, 0.3273, 0.3410]]], grad_fn=<DivBackward0>)

### Replicating Pytorch

In [36]:
embed_dim = 2
num_heads = 1
query = torch.tensor([[[0.7204, 0.0731],
         [0.9699, 0.1078],
         [0.8829, 0.4132]]])
key = query
value = query
weight = multihead_attn.state_dict()['in_proj_weight']

q,k, v =  query.matmul(weight.t()).chunk(3, dim=-1)

v = v.reshape(3,2)
q = q.reshape(3,2)
k = k.reshape(3,2)
q = q / math.sqrt(embed_dim)
attn1 = torch.matmul(q,k.T)
m = nn.Softmax(dim=-1)
my_attn_output_weights = m(attn1)
output = torch.matmul(attn, v)
output = output.reshape((3,2))
out_proj = multihead_attn.state_dict()['out_proj.weight']
my_output = output.matmul(out_proj.t())

In [37]:
my_output

tensor([[ 0.2134, -0.0416],
        [ 0.2143, -0.0420],
        [ 0.2120, -0.0410]])

In [38]:
my_attn_output_weights

tensor([[0.3282, 0.3227, 0.3491],
        [0.3265, 0.3193, 0.3542],
        [0.3317, 0.3273, 0.3410]])