In [195]:
import numpy as np
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from copy import deepcopy

In [196]:
seq_len = 128
d_model = 32 # Embedding dim
n_head = 4
batch_size = 4

## MultiHead Self-Attention

Attention is all you need paper: https://arxiv.org/pdf/1706.03762.pdf

$MultiHead(Q, K, V) =  Concat(head_{1}, ..., head_{h})W^{O}$

$head_{i}=Attention(QW^{Q}_{i}, KW^{K}_{i}, VW^{V}_{i})$

$W^{Q}_{i} \in R^{(d_{model}, d_{k})}$

$W^{K}_{i} \in R^{(d_{model}, d_{k})}$

$W^{V}_{i} \in R^{(d_{model}, d_{v})}$

$d_{k} = d_{v} = d_{model} / h $

## 1. Make same with paper

In [234]:
class MultiHead_SelfAttention_(nn.Module):
    def __init__(self, d_model, n_head, bias=False):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_k = d_model // n_head
        self.d_v = d_model // n_head

        self.w_q = nn.Linear(d_model, d_model, bias=bias)
        self.w_k = nn.Linear(d_model, d_model, bias=bias)
        self.w_v = nn.Linear(d_model, d_model, bias=bias)
        self.w_o = nn.Linear(d_model, d_model, bias=bias)


    def forward(self, q, k, v, mask=None):
        B, T, C = q.size() # batch size, sequence length, embedding dimension

        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)

        q = q.view(B, T, self.n_head, self.d_k).transpose(1, 2) # (B, nh, T, hs)
        k = k.view(B, T, self.n_head, self.d_k).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, self.d_v).transpose(1, 2) # (B, nh, T, hs)

        # (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att_map = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_k))

        if mask is not None:
            pass
        
        # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        attention = F.softmax(att_map, dim=-1) @ v 
        
        output = attention.transpose(1, 2).contiguous().view(B, T, C)
        output = self.w_o(output)

        return output

In [235]:
test_layer = MultiHead_SelfAttention_(d_model=d_model, n_head=n_head)

for w in test_layer.parameters():
    print(w.size())

torch.Size([32, 32])
torch.Size([32, 32])
torch.Size([32, 32])
torch.Size([32, 32])


## 2. Make same with Pytorch

In [198]:
class MultiHead_SelfAttention(nn.Module):
    def __init__(self, d_model, n_head, bias=False):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.d_k = d_model // n_head
        self.d_v = d_model // n_head
        
        ### >>> Different with above class >>>
        self.c_attn = nn.Linear(d_model, 3*d_model, bias=bias) 
        ### >>>

        self.c_proj = nn.Linear(d_model, d_model, bias=bias) # same with above class


    def forward(self, q, k, v, mask=None):
        B, T, C = q.size() # batch size, sequence length, embedding dimension

        ### >>> Different with above class >>>
        q, k, v  = self.c_attn(q).split(self.d_model, dim=2)
        ### >>>

        q = q.view(B, T, self.n_head, self.d_k).transpose(1, 2) # (B, nh, T, hs)
        k = k.view(B, T, self.n_head, self.d_k).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, self.d_v).transpose(1, 2) # (B, nh, T, hs)

        # (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att_map = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.d_k))

        if mask is not None:
            pass
        
        # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        attention = F.softmax(att_map, dim=-1) @ v 
        
        output = attention.transpose(1, 2).contiguous().view(B, T, C)
        output = self.c_proj(output)

        return output


In [199]:
hand_att_layer = MultiHead_SelfAttention(d_model=d_model, n_head=n_head, bias=False)

In [236]:
for w in hand_att_layer.parameters():
    print(w.size())

torch.Size([96, 32])
torch.Size([32, 32])


# 3. Pytorch layer

In [200]:
torch_att_layer = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_head, bias=False, batch_first=True)

In [237]:
for w in torch_att_layer.parameters():
    print(w.size())

torch.Size([96, 32])
torch.Size([32, 32])


# 4. Test output

In [201]:
# Copying weight from torch_att_layer to hand_att_layer
with torch.no_grad():
    hand_att_layer.c_attn.weight = deepcopy(torch_att_layer.in_proj_weight)
    hand_att_layer.c_proj.weight = deepcopy(torch_att_layer.out_proj.weight)

In [202]:
batch_data = np.random.randint(low=0, high=90000, size=(batch_size, seq_len, d_model))
batch_data = torch.Tensor(batch_data)

In [203]:
output_1 = hand_att_layer(batch_data, batch_data, batch_data)

In [204]:
output_2, _ = torch_att_layer(batch_data, batch_data, batch_data)

In [238]:
# allow different with different from -0.1 to +0.1
torch.all(torch.lt(torch.abs(torch.add(output_1, - output_2)), 1e-1))

tensor(True)

# 5. Tensorflow (Keras) layer

In [241]:
import tensorflow as tf

x = tf.keras.Input(shape=[128, 32])
layer = tf.keras.layers.MultiHeadAttention(num_heads=1, key_dim=2, use_bias=False)

In [242]:
output_tensor = layer(x, x, x)

In [243]:
weights = layer.get_weights()
print(len(weights))

4


In [244]:
print(weights[0].shape)
print(weights[1].shape)
print(weights[2].shape)
print(weights[3].shape)

(32, 1, 2)
(32, 1, 2)
(32, 1, 2)
(1, 2, 32)
