In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [22]:
ref_MHA = nn.MultiheadAttention(8, 2, batch_first=True)

In [23]:
batch_size = 2
seq_size = 8
dimension = 8
h = 2
x = torch.randn(batch_size, seq_size, dimension)

In [24]:
ref_MHA(x, x, x)

(tensor([[[-0.2981, -0.1443,  0.2256, -0.1200, -0.0094,  0.1198, -0.2828,
           -0.1166],
          [-0.3639, -0.1603,  0.2851, -0.1088,  0.1658, -0.0218, -0.2217,
           -0.0960],
          [-0.3159, -0.2360,  0.3608, -0.0896,  0.1391,  0.0792, -0.1101,
           -0.0736],
          [-0.4521, -0.2679,  0.4190, -0.1395,  0.2153, -0.0358, -0.1985,
           -0.1083],
          [-0.4160, -0.2677,  0.4059, -0.1558,  0.1903,  0.0625, -0.0398,
           -0.1208],
          [-0.3387, -0.1959,  0.3296, -0.1553,  0.0607,  0.0791, -0.2709,
           -0.1384],
          [-0.2129, -0.1457,  0.2581, -0.0703,  0.0280,  0.1399, -0.1811,
           -0.0857],
          [-0.4141, -0.3248,  0.4855, -0.1528,  0.1461,  0.0779, -0.1661,
           -0.1330]],
 
         [[ 0.0408,  0.3094, -0.5754, -0.2493, -0.2430, -0.3085, -0.1636,
           -0.1119],
          [ 0.0700,  0.2905, -0.5372, -0.2187, -0.2034, -0.3521, -0.1580,
           -0.0876],
          [ 0.0338,  0.3593, -0.5925, -0.2440, 

In [25]:
import math

class MyMHA(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super().__init__()
    self.embed_dim = embed_dim
    self.num_heads = num_heads
    assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
    self.weight_q = nn.Linear(in_features=embed_dim, out_features=embed_dim)
    self.weight_k = nn.Linear(in_features=embed_dim, out_features=embed_dim)
    self.weight_v = nn.Linear(in_features=embed_dim, out_features=embed_dim)
    self.weight_o = nn.Linear(in_features=embed_dim, out_features=embed_dim)
    # self.weight_o = nn.Parameter(torch.randn(embed_dim, embed_dim))
    # self.bias_o = nn.Parameter(torch.randn(embed_dim))
    self.emb_dim_h = self.embed_dim//self.num_heads

  def forward(self, q, k, v):
    batch_size, seq_size, dimension = q.shape
    Q = self.weight_q(q)
    K = self.weight_k(k)
    V = self.weight_v(v) # (b, s, d) -> (b, s, h , embed_dim/h)
    Q_h = Q.view(batch_size, seq_size, self.num_heads, self.emb_dim_h)
    K_h = K.view(batch_size, seq_size, self.num_heads, self.emb_dim_h)
    V_h = V.view(batch_size, seq_size, self.num_heads, self.emb_dim_h)
    Q_h = Q_h.transpose(1, 2)
    K_h = K_h.transpose(1, 2)
    V_h = V_h.transpose(1, 2)
    # F.softmax(Q@K.T,dim=-1)VW_o
    scalar = math.sqrt(self.emb_dim_h)

    Attention_score = F.softmax((Q_h @ K_h.transpose(-1, -2))/scalar,dim=-1)
    tmp_out = Attention_score @ V_h
    tmp_out = tmp_out.transpose(1, 2).contiguous().view(
        batch_size, seq_size, dimension
        )
    return self.weight_o(tmp_out)



In [26]:
my_mha = MyMHA(embed_dim=8, num_heads=2)

In [27]:
my_mha(x, x, x)

tensor([[[ 0.0508,  0.1309, -0.2686, -0.4265, -0.0218, -0.5444, -0.2060,
          -0.2166],
         [ 0.1651,  0.0321, -0.1619, -0.4330,  0.1050, -0.4894, -0.0453,
          -0.2427],
         [ 0.1700,  0.0630, -0.1057, -0.4782,  0.1156, -0.4420, -0.1329,
          -0.2356],
         [ 0.1641,  0.0803, -0.1178, -0.4632,  0.1065, -0.4361, -0.1212,
          -0.2562],
         [-0.0569,  0.1254, -0.4866, -0.3621, -0.1282, -0.7312, -0.1489,
          -0.2096],
         [ 0.1099,  0.1206, -0.1766, -0.4572,  0.0527, -0.4962, -0.1977,
          -0.2440],
         [ 0.0681,  0.0847, -0.2812, -0.4309,  0.0148, -0.5930, -0.1301,
          -0.2308],
         [ 0.1390,  0.1433, -0.1067, -0.5000,  0.1028, -0.4292, -0.2347,
          -0.2613]],

        [[ 0.4342, -0.1298,  0.2493, -0.0350, -0.0274, -0.3431,  0.0299,
           0.0101],
         [ 0.4005, -0.0942,  0.2325, -0.0865, -0.0338, -0.3194,  0.0137,
          -0.0153],
         [ 0.4222, -0.1426,  0.2213, -0.0562, -0.0008, -0.3646,  0.0

In [28]:
ref_MHA.out_proj

NonDynamicallyQuantizableLinear(in_features=8, out_features=8, bias=True)

In [29]:
with torch.no_grad():
  my_mha.weight_q.weight.copy_(ref_MHA.in_proj_weight[:dimension, :])
  my_mha.weight_k.weight.copy_(ref_MHA.in_proj_weight[dimension:dimension*2, :])
  my_mha.weight_v.weight.copy_(ref_MHA.in_proj_weight[dimension*2:3*dimension, :])
  my_mha.weight_o.weight.copy_(ref_MHA.out_proj.weight)

  my_mha.weight_o.bias.copy_(ref_MHA.out_proj.bias[:dimension])
  my_mha.weight_q.bias.copy_(ref_MHA.in_proj_bias[:dimension])
  my_mha.weight_k.bias.copy_(ref_MHA.in_proj_bias[dimension:2*dimension])
  my_mha.weight_v.bias.copy_(ref_MHA.in_proj_bias[2*dimension:3*dimension])

In [30]:
my_mha.eval()
ref_MHA.eval()
torch.allclose(my_mha(x,x,x), ref_MHA(x,x,x)[0])

True

In [31]:
my_mha(x,x,x).shape

torch.Size([2, 8, 8])

In [32]:
ref_MHA(x,x,x)[0].shape


torch.Size([2, 8, 8])

In [33]:
ref_MHA.in_proj_bias[:dimension]

tensor([0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SliceBackward0>)

In [34]:
ref_MHA.in_proj_bias[dimension:2*dimension]

tensor([0., 0., 0., 0., 0., 0., 0., 0.], grad_fn=<SliceBackward0>)

In [15]:
ref_MHA.out_proj.bias

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

In [35]:
my_mha(x,x,x)

tensor([[[-0.2981, -0.1443,  0.2256, -0.1200, -0.0094,  0.1198, -0.2828,
          -0.1166],
         [-0.3639, -0.1603,  0.2851, -0.1088,  0.1658, -0.0218, -0.2217,
          -0.0960],
         [-0.3159, -0.2360,  0.3608, -0.0896,  0.1391,  0.0792, -0.1101,
          -0.0736],
         [-0.4521, -0.2679,  0.4190, -0.1395,  0.2153, -0.0358, -0.1985,
          -0.1083],
         [-0.4160, -0.2677,  0.4059, -0.1558,  0.1903,  0.0625, -0.0398,
          -0.1208],
         [-0.3387, -0.1959,  0.3296, -0.1553,  0.0607,  0.0791, -0.2709,
          -0.1384],
         [-0.2129, -0.1457,  0.2581, -0.0703,  0.0280,  0.1399, -0.1811,
          -0.0857],
         [-0.4141, -0.3248,  0.4855, -0.1528,  0.1461,  0.0779, -0.1661,
          -0.1330]],

        [[ 0.0408,  0.3094, -0.5754, -0.2493, -0.2430, -0.3085, -0.1636,
          -0.1119],
         [ 0.0700,  0.2905, -0.5372, -0.2187, -0.2034, -0.3521, -0.1580,
          -0.0876],
         [ 0.0338,  0.3593, -0.5925, -0.2440, -0.2011, -0.3368, -0.2

In [36]:
ref_MHA(x,x,x)[0]

tensor([[[-0.2981, -0.1443,  0.2256, -0.1200, -0.0094,  0.1198, -0.2828,
          -0.1166],
         [-0.3639, -0.1603,  0.2851, -0.1088,  0.1658, -0.0218, -0.2217,
          -0.0960],
         [-0.3159, -0.2360,  0.3608, -0.0896,  0.1391,  0.0792, -0.1101,
          -0.0736],
         [-0.4521, -0.2679,  0.4190, -0.1395,  0.2153, -0.0358, -0.1985,
          -0.1083],
         [-0.4160, -0.2677,  0.4059, -0.1558,  0.1903,  0.0625, -0.0398,
          -0.1208],
         [-0.3387, -0.1959,  0.3296, -0.1553,  0.0607,  0.0791, -0.2709,
          -0.1384],
         [-0.2129, -0.1457,  0.2581, -0.0703,  0.0280,  0.1399, -0.1811,
          -0.0857],
         [-0.4141, -0.3248,  0.4855, -0.1528,  0.1461,  0.0779, -0.1661,
          -0.1330]],

        [[ 0.0408,  0.3094, -0.5754, -0.2493, -0.2430, -0.3085, -0.1636,
          -0.1119],
         [ 0.0700,  0.2905, -0.5372, -0.2187, -0.2034, -0.3521, -0.1580,
          -0.0876],
         [ 0.0338,  0.3593, -0.5925, -0.2440, -0.2011, -0.3368, -0.2