In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
def self_attention(query, key, value, dropout=None, mask=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
    if mask is not None:
        mask.cuda()
        scores = scores.masked_fill(mask == 0, -1e9)
    self_attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        self_attn = dropout(self_attn)
    return torch.matmul(self_attn, value), self_attn

def multi_head_self_attention(input, num_heads, attn_w, output_w):
    
    """
    TODO: Your implementation for the multihead attention function.
    We assume the input and the output have the same shape.
    input: the input with shape [L, N, E], where L is the sequence length,
         N is the batch size, E is the embedding dimension.
    num_heads: number of the attention heads each with dimension E // num_heads.
    attn_w: the weight for the query, key, and value, with shape [3 * E, E].
    output_w: the additional linear layer with shape [E, E].
    """
    L,N,E = input.shape
    assert (E % num_heads == 0)
    d_k = E // num_heads
    head = num_heads
    d_model = E
    
    linear_query = nn.Linear(d_model, d_model)
    linear_key = nn.Linear(d_model, d_model)
    linear_value = nn.Linear(d_model, d_model)
    linear_out = nn.Linear(d_model, d_model)
    
    n_batch = N
    
    query = linear_query(input).view(n_batch, -1, head, d_k).transpose(1, 2)  # [b, 8, 32, 64]
    key = linear_key(input).view(n_batch, -1, head, d_k).transpose(1, 2)  # [b, 8, 32, 64]
    value = linear_value(input).view(n_batch, -1, head, d_k).transpose(1, 2)  # [b, 8, 32, 64]

    x, attn = self_attention(query, key, value)
    x = x.transpose(1, 2).contiguous().view(n_batch, -1, head * d_k)

    return linear_out(x)
    

   

In [3]:
""" Reset random seed """
torch.manual_seed(0)

""" Configuration """
Ls = [4, 8, 16]
N = 1
Es = [4, 8, 16]
Heads = [1, 2, 4]

for L in Ls:
  for E in Es:
    for num_heads in Heads:
      """ Create weight and input """
      attn_layer = nn.MultiheadAttention(embed_dim=E, 
                                         num_heads=num_heads, 
                                         bias=False)
      attn_w, output_w = attn_layer.parameters()
      input = torch.randn([L, N, E])
      
      result_torch, _ = attn_layer(input, input, input)
      result_yours = multi_head_self_attention(input, 
                                               num_heads,
                                               attn_w,
                                               output_w)
      assert torch.allclose(result_torch, result_yours, atol=1e-07)
      print('OK', L, E, num_heads)

NameError: name 'math' is not defined