In [7]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import math

In [8]:
def multi_head_self_attention(input, num_heads, attn_w, output_w):
    
    """
    TODO: Your implementation for the multihead attention function.
    We assume the input and the output have the same shape.
    input: the input with shape [L, N, E], where 
    
    assume sequence is a sentence
    L is the sequence length,      (number of words)
    N is the batch size, 
    E is the embedding dimension.  (number of numbers used for representing the semantic information of each word)
    
    num_heads: number of the attention heads each with dimension E // num_heads.
    
    attn_w: the weight for the query, key, and value, with shape [3 * E, E].
    
    output_w: the additional linear layer with shape [E, E].
    """
    
    L, N, E = input.shape
    d_k = E // num_heads   ## query size
    
    ## split the weight
    weights = attn_w[:E, :E], \
               attn_w[E:(E*2), :E], \
               attn_w[(E*2):, :E]
    q_w, k_w, v_w = weights
    
    ## no bias
    q_proj = torch.matmul(input, q_w.T)
    k_proj = torch.matmul(input, k_w.T)
    v_proj = torch.matmul(input, v_w.T)
    
    def reshape_projection(proj):
        return proj.reshape(L, N, num_heads, d_k).permute(2,1,0,3)
    
    query = reshape_projection(q_proj)
    key = reshape_projection(k_proj)
    value = reshape_projection(v_proj)
    
    att = F.softmax(torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) , 
                    dim = -1)
    score = torch.matmul(att, value)
    
    attention = score.permute(2,1,0,3).reshape(L, N, num_heads * d_k)
    output_w = output_w.T
    output = torch.matmul(attention, output_w)
    
    return output


In [9]:
""" Reset random seed """
torch.manual_seed(0)

""" Configuration """
Ls = [4, 8, 16]
N = 1
Es = [4, 8, 16]
Heads = [1, 2, 4]

for L in Ls:
    for E in Es:
        for num_heads in Heads:
            """ Create weight and input """
            attn_layer = nn.MultiheadAttention(embed_dim=E, 
                                         num_heads=num_heads, 
                                         bias=False)
            attn_w, output_w = attn_layer.parameters()
            input = torch.randn([L, N, E])   ## input: 3D tensor
      
            result_torch, _ = attn_layer(input, input, input)
            result_yours = multi_head_self_attention(input, 
                                               num_heads,
                                               attn_w,
                                               output_w)
            assert torch.allclose(result_torch, result_yours, atol=1e-07)
            print('OK', L, E, num_heads)

OK 4 4 1
OK 4 4 2
OK 4 4 4
OK 4 8 1
OK 4 8 2
OK 4 8 4
OK 4 16 1
OK 4 16 2
OK 4 16 4
OK 8 4 1
OK 8 4 2
OK 8 4 4
OK 8 8 1
OK 8 8 2
OK 8 8 4
OK 8 16 1
OK 8 16 2
OK 8 16 4
OK 16 4 1
OK 16 4 2
OK 16 4 4
OK 16 8 1
OK 16 8 2
OK 16 8 4
OK 16 16 1
OK 16 16 2
OK 16 16 4
