In [12]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import numpy as np
import math

In [13]:
def attention(query, key, value, mask = None):
    d_k = query.size(-1)
    scaled = torch.matmul(query, key.transpose(-2,-1))/math.sqrt(d_k)
    attention = F.softmax(scaled, dim = -1)
    return torch.matmul(attention, value), attention

In [18]:
def multi_head_self_attention(input, num_heads, attn_w, output_w):

    """
    TODO: Your implementation for the multihead attention function.
    We assume the input and the output have the same shape.
    input: the input with shape [L, N, E],
         L is the sequence length,
         N is the batch size,
         E is the embedding dimension.

    num_heads: number of the attention heads each with dimension E // num_heads.

    attn_w: the weight for the query, key, and value, with shape [3 * E, E].
    output_w: the additional linear layer with shape [E, E].
    """
    length = input.size(0)#L
    batch_size = input.size(1)# N
    emb_dim = input.size(-1) # E
    d_k = emb_dim//num_heads


    q_w = attn_w[:emb_dim, :emb_dim] # size [E,E]
    k_w = attn_w[emb_dim:(emb_dim*2), :emb_dim] # size [E,E]
    v_w = attn_w[(emb_dim*2):, :emb_dim] # size [E,E]

    q = torch.matmul(input, q_w.T) # [length,batch, dim]
    k = torch.matmul(input, k_w.T)
    v = torch.matmul(input, v_w.T)

    q = q.reshape(length, batch_size, num_heads, d_k) #[length, batch_num, num_head, 3*head_dim]
    k = k.reshape(length, batch_size, num_heads, d_k)
    v = v.reshape(length, batch_size, num_heads, d_k)

    q = q.permute(2,1,0,3)#[num_head, batch_num, length, head_dim]
    k = k.permute(2,1,0,3)
    v = v.permute(2,1,0,3)

    values, attenion = attention(q, k, v)
    values = values.permute(2,1,0,3)
    values = values.reshape(length, batch_size, num_heads * d_k)
    out = torch.matmul(values, output_w.T)
    return out

In [19]:
""" Reset random seed """
torch.manual_seed(0)

""" Configuration """
Ls = [4, 8, 16]
N = 1
Es = [4, 8, 16]
Heads = [1, 2, 4]

for L in Ls:
  for E in Es:
    for num_heads in Heads:
      """ Create weight and input """
      attn_layer = nn.MultiheadAttention(embed_dim=E,
                                         num_heads=num_heads,
                                         bias=False)
      attn_w, output_w = attn_layer.parameters()
      input = torch.randn([L, N, E])

      result_torch, _ = attn_layer(input, input, input)
      result_yours = multi_head_self_attention(input,
                                               num_heads,
                                               attn_w,
                                               output_w)
      assert torch.allclose(result_torch, result_yours, atol=1e-07)
      print('OK', L, E, num_heads)

OK 4 4 1
OK 4 4 2
OK 4 4 4
OK 4 8 1
OK 4 8 2
OK 4 8 4
OK 4 16 1
OK 4 16 2
OK 4 16 4
OK 8 4 1
OK 8 4 2
OK 8 4 4
OK 8 8 1
OK 8 8 2
OK 8 8 4
OK 8 16 1
OK 8 16 2
OK 8 16 4
OK 16 4 1
OK 16 4 2
OK 16 4 4
OK 16 8 1
OK 16 8 2
OK 16 8 4
OK 16 16 1
OK 16 16 2
OK 16 16 4
