In [40]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [38]:
def attend(query, key, value):
    d_k = query.shape[-1]
    x = torch.matmul(query, key.transpose(-2, -1))
    x /= d_k ** (1/2)
    x = torch.softmax(x, dim=-1)
    return torch.matmul(x, value)


In [48]:
def get_qkv(input, attn_w):
    E = attn_w.shape[1]
    # qkv = torch.matmul(input, attn_w)
    w_q = attn_w[:E, :]
    w_k = attn_w[E:2*E, :]
    w_v = attn_w[2*E:, :]
    q = torch.matmul(input, w_q.T)
    k = torch.matmul(input, w_k.T)
    v = torch.matmul(input, w_v.T)
    return q, k, v

In [49]:
def multi_head_self_attention(input, num_heads, attn_w, output_w):
    
    """
    TODO: Your implementation for the multihead attention function.
    We assume the input and the output have the same shape.
    input: the input with shape [L, N, E], where L is the sequence length,
         N is the batch size, E is the embedding dimension.
    num_heads: number of the attention heads each with dimension E // num_heads.
    attn_w: the weight for the query, key, and value, with shape [3 * E, E].
    output_w: the additional linear layer with shape [E, E].
    """
    L, N, E = input.shape
    d_k = E // num_heads

    q, k, v = get_qkv(input, attn_w)
    def split_heads(mat):
       mat = mat.reshape(L, N, num_heads, d_k)
       return mat.permute(2, 1, 0, 3)
    q, k, v = split_heads(q), split_heads(k), split_heads(v)
    attention = attend(q, k, v)
    attention = attention.permute(2, 1, 0, 3)
    attention = attention.reshape(L, N, num_heads * d_k)
    output = torch.matmul(attention, output_w.T)
    return output
    

In [50]:
""" Reset random seed """
torch.manual_seed(0)

""" Configuration """
Ls = [4, 8, 16]
N = 1
Es = [4, 8, 16]
Heads = [1, 2, 4]

for L in Ls:
  for E in Es:
    for num_heads in Heads:
      """ Create weight and input """
      attn_layer = nn.MultiheadAttention(embed_dim=E, 
                                         num_heads=num_heads, 
                                         bias=False)
      attn_w, output_w = attn_layer.parameters()
      input = torch.randn([L, N, E])

      # print(input, num_heads, attn_w, output_w)
      
      result_torch, _ = attn_layer(input, input, input)
      result_yours = multi_head_self_attention(input, 
                                               num_heads,
                                               attn_w,
                                               output_w)

      # print(result_torch)
      # print(result_yours)
      assert torch.allclose(result_torch, result_yours, atol=1e-07)
      print('OK', L, E, num_heads)

OK 4 4 1
OK 4 4 2
OK 4 4 4
OK 4 8 1
OK 4 8 2
OK 4 8 4
OK 4 16 1
OK 4 16 2
OK 4 16 4
OK 8 4 1
OK 8 4 2
OK 8 4 4
OK 8 8 1
OK 8 8 2
OK 8 8 4
OK 8 16 1
OK 8 16 2
OK 8 16 4
OK 16 4 1
OK 16 4 2
OK 16 4 4
OK 16 8 1
OK 16 8 2
OK 16 8 4
OK 16 16 1
OK 16 16 2
OK 16 16 4
