# Implementation Transformer

## import

In [114]:
import numpy as np
import torch
import torch.nn as nn

$y = softmax({QK^T \over \sqrt d})v $

$y = (softmax({XW_Q W_K^TX \over \sqrt d})XW_V)W_{out} $

$K = XW_K$

$Q = XW_Q$  => [K,Q,V] = X[W_K,W_Q,W_V] = XW_kqv

$V = XW_V$

In [115]:
def softmax(z):
    z = np.exp(z - z.max(axis= -1, keepdims = True))
    return z/z.sum(axis = -1, keepdims = True)

def self_attention(X, mask, W_KQV, W_out):
    K,Q,V = np.split(X@W_KQV,3,axis = 1)
    attn = softmax(K@Q.T / np.sqrt(X.shape[1]) + mask)
    return attn@V@W_out, attn

In [116]:
T,d = 100,64
attn = nn.MultiheadAttention(d,1,bias=False,batch_first=True)
M = torch.triu(-float('inf')*torch.ones(T,T),1)
X = torch.randn(1,T,d)
Y_, A_ = attn(X, X, X, attn_mask = M)

In [117]:
print(M)

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [118]:
attn.out_proj.weight.shape

torch.Size([64, 64])

In [119]:
Y, A = self_attention(X[0].numpy(), M.numpy(),
                      attn.in_proj_weight.detach().numpy().T,attn.out_proj.weight.detach().numpy().T)

In [120]:
np.linalg.norm(A - A_[0].detach().numpy())

1.2434688e-07

In [121]:
np.linalg.norm(Y - Y_[0].detach().numpy())

1.1917103e-06

## Minibatching

$K\in R^{T \times d} $

- $T \times B \times D\quad(for\;RNNs)$
- $B \times T \times D\quad(for\;Transformers)$

In [96]:
C = np.random.randn(5,4,10,3)
D = np.random.randn(3,6)
(C@D).shape

(5, 4, 10, 6)

In [97]:
(C.reshape(-1,3)@D).reshape(5,4,10,6) - C@D

array([[[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]],


       [[[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0.

In [98]:
def self_attention(X, mask, W_KQV, W_out):
    K,Q,V = np.split(X@W_KQV,3,axis = 1)
    attn = softmax(K@Q.swapaxes(-1,-2) / np.sqrt(X.shape[1]) + mask)
    return attn@V@W_out, attn

In [99]:
B, T, d = 50,100,64
X = torch.randn(B,T,d)
M = torch.triu(-float('inf')*torch.ones(T,T),1)
Y_, A_ = attn(X,X,X, attn_mask = M)

In [100]:
Y, A = self_attention(X[0].numpy(), M.numpy(),
                      attn.in_proj_weight.detach().numpy().T,attn.out_proj.weight.detach().numpy().T)

In [101]:
np.linalg.norm(A - A_.detach().numpy())

9.538563

## Multihead attention

In [102]:
def multihead_attention(X, mask, heads, W_KQV, W_out):
    B,T,d = X.shape
    K,Q,V = np.split(X@W_KQV, 3, axis = -1)
# B x T x d => B x heads x T x d / heads
    K,Q,V = [a.reshape(B,T,heads,d//heads).swapaxes(1,2) for a in (K,Q,V)]
    attn = softmax(K@Q.swapaxes(-1,-2) / np.sqrt(d//heads) + mask)

    return (attn@V).swapaxes(1,2).reshape(B,T,d) @ W_out, attn

In [103]:
heads = 4
attn = nn.MultiheadAttention(d,heads,bias=False,batch_first=True)
Y_,A_ = attn(X,X,X,attn_mask = M)

In [104]:
Y,A = multihead_attention(X.numpy(),M.numpy(),heads,
                          attn.in_proj_weight.detach().numpy().T,
                          attn.out_proj.weight.detach().numpy().T)

In [105]:
np.linalg.norm(Y-Y_.detach().numpy())

8.57245e-06

## Transformer

In [110]:
def layer_norm(z, eps):
    return (z- z.mean(axis=-1, keepdims = True)) / np.sqrt(z.var(axis = -1, keepdims = True)+eps)
def relu(z):
    return np.maximum(z,0)
def transformer(X, mask, heads, W_KQV, W_out, W_ff1, W_ff2, eps):
    z = layer_norm(X + multihead_attention(X, mask, heads, W_KQV, W_out)[0], eps)
    return layer_norm(z + relu(z@W_ff1)@W_ff2, eps)

In [111]:
trans = nn.TransformerEncoderLayer(d, heads, dim_feedforward= 128, dropout = 0.0, batch_first = True)
trans.linear1.bias.data.zero_()
trans.linear2.bias.data.zero_()
Y_ = trans(X,M)

In [112]:
Y = transformer(X.numpy(),M.numpy(),heads,
                trans.self_attn.in_proj_weight.detach().numpy().T,
                trans.self_attn.out_proj.weight.detach().numpy().T,
                trans.linear1.weight.detach().numpy().T,
                trans.linear2.weight.detach().numpy().T,
                trans.norm1.eps)

In [113]:
np.linalg.norm(Y-Y_.detach().numpy())

6.2397885e-05