In [35]:
import numpy as np
import torch
import torch.nn as nn

In [36]:
def softmax(Z):
    Z = np.exp(Z - Z.max(axis=-1, keepdims=True))
    return Z / Z.sum(axis=-1, keepdims=True)
    
def self_attention(X, mask, W_KQV, W_out):
    K,Q,V = np.split(X@W_KQV, 3, axis=-1)
    attn = softmax(K@Q.swapaxes(-1,-2) / np.sqrt(X.shape[-1]) + mask)
    return attn@V@W_out, attn

In [37]:
T = 5
M = torch.triu(-float("inf")*torch.ones(T,T),1)
M

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [38]:
T, d = 100, 64
attn = nn.MultiheadAttention(d, 1, bias=False, batch_first=True)
# TODO: attn mask, but why this?
M = torch.triu(-float("inf") * torch.ones(T, T), 1)
X = torch.randn(1, T, d)
Y_, A_ = attn(X, X, X, attn_mask=M)

In [39]:
Y, A = self_attention(X[0].numpy(), M.numpy(), 
                      attn.in_proj_weight.detach().numpy().T, 
                      attn.out_proj.weight.detach().numpy().T) 

In [40]:
print(np.linalg.norm(A - A_[0].detach().numpy()))
print(np.linalg.norm(Y - Y_[0].detach().numpy()))

1.8144881e-07
1.4463458e-06


## Minibatching with BMM

In [41]:
# illustration of batch matmul
B = np.random.randn(10,3,5,4)
C = np.random.randn(10,3,4,3)
(B@C).shape

(10, 3, 5, 3)

In [42]:
N = 10
M = torch.triu(-float("inf") * torch.ones(T, T), 1)
X = torch.randn(N, T, d)
Y_, A_ = attn(X, X, X, attn_mask=M)

In [43]:
Y, A = self_attention(X.numpy(), M.numpy(), 
                      attn.in_proj_weight.detach().numpy().T, 
                      attn.out_proj.weight.detach().numpy().T)

In [44]:
print(np.linalg.norm(A - A_.detach().numpy()))
print(np.linalg.norm(Y - Y_.detach().numpy()))

5.503326e-07
4.5866805e-06


## Multihead Attention

In [45]:
def multihead_attention(X, mask, heads, W_KQV, W_out):
    N, T, d = X.shape
    K, Q, V = np.split(X@W_KQV, 3, axis=-1)
    K, Q, V = [a.reshape(N, T, heads, d//heads).swapaxes(1, 2) for a in (K, Q, V)]

    attn = softmax(K@Q.swapaxes(-1, -2) / np.sqrt(d//heads) + mask)
    return (attn@V).swapaxes(1, 2).reshape(N, T, d) @ W_out, attn 

In [46]:
heads = 4
attn = nn.MultiheadAttention(d, heads, bias=False, batch_first=True)
Y_, A_ = attn(X, X, X, attn_mask=M)

In [47]:
Y, A = multihead_attention(X.numpy(), M.numpy(), 4, 
                           attn.in_proj_weight.detach().numpy().T,
                           attn.out_proj.weight.detach().numpy().T)

In [48]:
A_.shape

torch.Size([10, 100, 100])

In [49]:
A.shape

(10, 4, 100, 100)

In [50]:
print(np.linalg.norm(Y - Y_.detach().numpy()))
print(np.linalg.norm(A.mean(1) - A_.detach().numpy()))

4.314429e-06
3.8336384e-07


## Transformer block

In [51]:
def layer_norm(Z, eps):
    return (Z - Z.mean(axis=-1, keepdims=True)) / np.sqrt(Z.var(axis=-1, keepdims=True) + eps)

def relu(Z):
    return np.maximum(Z, 0)

def transformer(X, mask, heads, W_KQV, W_out, W_ff1, W_ff2, eps):
    Z = layer_norm(multihead_attention(X, mask, heads, W_KQV, W_out)[0] + X, eps)
    return layer_norm(Z + relu(Z@W_ff1)@W_ff2, eps)

In [52]:
trans = nn.TransformerEncoderLayer(d, heads, dim_feedforward=128, dropout=0.0, batch_first=True)
trans.linear1.bias.data.zero_()
trans.linear2.bias.data.zero_()
Y_ = trans(X, M)

In [53]:
Y = transformer(X.numpy(), M.numpy(), heads, 
                trans.self_attn.in_proj_weight.detach().numpy().T,
                trans.self_attn.out_proj.weight.detach().numpy().T,
                trans.linear1.weight.detach().numpy().T,
                trans.linear2.weight.detach().numpy().T,
                trans.norm1.eps)

In [54]:
print(np.linalg.norm(Y - Y_.detach().numpy()))

2.807236e-05
