In [13]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [14]:
def multi_head_self_attention(input, num_heads, attn_w, output_w):
    
    """
    TODO: Your implementation for the multihead attention function.
    We assume the input and the output have the same shape.
    input: the input with shape [L, N, E], where L is the sequence length,
         N is the batch size, E is the embedding dimension.
    num_heads: number of the attention heads each with dimension E // num_heads.
    attn_w: the weight for the query, key, and value, with shape [3 * E, E].
    output_w: the additional linear layer with shape [E, E].
    """
    L,N,E = input.shape
    assert (E % num_heads == 0)
    d_k = E // num_heads
    
    q_w = attn_w[:E,:E]
    k_w = attn_w[E:(E*2),:E]
    v_w = attn_w[(E*2):,:E] 
    
    linear_query = torch.matmul(input,q_w)
    linear_key = torch.matmul(input,k_w)
    linear_value = torch.matmul(input,v_w)

    
   # print("batch,head,sequence legth,dk:",N,num_heads,L,d_k)
    
    query = linear_query.view(N, L, num_heads, d_k).transpose(1, 2)  # [batch, head, L , dk]
    key = linear_key.view(N, L, num_heads, d_k).transpose(1, 2)  
    value = linear_value.view(N, L, num_heads, d_k).transpose(1, 2)  
    
   # print("reshape:",query.shape)
    
    #self attention
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
    self_attn = F.softmax(scores, dim=-1)
    x = torch.matmul(self_attn, value)
    
    x = x.transpose(1, 2).reshape(N, L, E)

    return torch.matmul(x,output_w)
    

   

In [15]:
""" Reset random seed """
torch.manual_seed(0)

""" Configuration """
Ls = [4, 8, 16]
N = 1
Es = [4, 8, 16]
Heads = [1, 2, 4]

for L in Ls:
  for E in Es:
    for num_heads in Heads:
      """ Create weight and input """
      attn_layer = nn.MultiheadAttention(embed_dim=E, 
                                         num_heads=num_heads, 
                                         bias=False)
      attn_w, output_w = attn_layer.parameters()
      input = torch.randn([L, N, E])
      
      result_torch, _ = attn_layer(input, input, input)
      result_yours = multi_head_self_attention(input, 
                                               num_heads,
                                               attn_w,
                                               output_w)
      assert torch.allclose(result_torch, result_yours, atol=1e-07)
      print('OK', L, E, num_heads)

AssertionError: 