In [257]:
import numpy as np
from scipy import special
np.set_printoptions(precision=4, suppress=True)

In [282]:
#set seed
np.random.seed(0)

batch_size = 2
seq_len = 4
d_k = 3

q = np.random.randint(1, 10, size=(batch_size, seq_len, d_k))
k = np.random.randint(1, 10, size=(batch_size, seq_len, d_k))
v = np.random.randint(1, 10, size=(batch_size, seq_len, d_k))
print('q')
print(q)



q
[[[6 1 4]
  [4 8 4]
  [6 3 5]
  [8 7 9]]

 [[9 2 7]
  [8 8 9]
  [2 6 9]
  [5 4 1]]]


In [269]:
def softmax(x):
    return special.softmax(x, axis=-1) # numerically stable
    e_x = np.exp(x)
    res = e_x / e_x.sum(axis=-1, keepdims=True)
    return res


scaled = np.random.randint(1, 10, size=(batch_size, seq_len, seq_len))
print(scaled)
attention = softmax(scaled)
print(attention)
print(attention.sum(axis=-1, keepdims=True)[0])

[[[7 8 4 9]
  [7 2 6 6]
  [9 6 3 1]
  [6 8 1 1]]

 [[1 7 7 2]
  [2 5 6 4]
  [4 8 6 1]
  [8 6 5 1]]]
[[[0.0896 0.2436 0.0045 0.6623]
  [0.5739 0.0039 0.2111 0.2111]
  [0.95   0.0473 0.0024 0.0003]
  [0.119  0.8794 0.0008 0.0008]]

 [[0.0012 0.4977 0.4977 0.0034]
  [0.012  0.2418 0.6572 0.0889]
  [0.0159 0.8661 0.1172 0.0008]
  [0.8431 0.1141 0.042  0.0008]]]
[[1.]
 [1.]
 [1.]
 [1.]]


In [270]:
def scaled_dot_product(q, k, v, mask=None):
    d_k = q.shape[-1] # 64 (because 512/8=64)
    # k.T would just reverse the shape, but we want to transpose the last two dimentions, so use .transpose(-1, -2)
    scaled = np.matmul(q, np.swapaxes(k, -1, -2), ) / np.sqrt(d_k) # scaling will keep the var somewhat close to var(q) (and var(k)) to keep gradient step stable
    if mask is not None: # if encoder, mask is None, if decoder, masking future tokens
        scaled += mask # elementwise addition
    attention = softmax(scaled)
    values = np.matmul(attention, v)
    return values, attention

def get_forward_mask(seq_len, eps=-1e9):
    # eps essentially negative infinity (after softmax turns to 0)
    mask = np.triu(np.ones((seq_len, seq_len)), k=1)*eps
    return mask.astype(int)

mask = get_forward_mask(seq_len)

values, attention = scaled_dot_product(q, k, v, mask)
attention.shape

(2, 4, 4)

In [276]:
def multihead_attention(x, mask, W_qkv, W_out, num_heads):
    batch_size, sequence_length, d_model = x.shape  # batch_size x max_sequence_length x 512
    head_dim = d_model // num_heads 
    qkv = x @ W_qkv                                 # batch_size x max_sequence_length x 1536
    qkv = qkv.reshape(batch_size, sequence_length, num_heads, 3*head_dim) # batch_size x max_sequence_length x 8 x 192 (because 3 * 64)
    qkv = qkv.transpose(0, 2, 1, 3) # batch_size x 8 x max_sequence_length x 192
    q, k, v = np.array_split(qkv, 3, axis=-1) # tuple of 3 tensors (each [batch_size x 8 x max_sequence_length x 64])
    values, attention = scaled_dot_product(q, k, v, mask)
    values = values.reshape(batch_size, sequence_length, num_heads*head_dim) # like a concatenation along the last axis + swapping the last two axes
    out = values @ W_out
    return out


In [283]:

x= np.random.randint(1, 10, size=(batch_size, seq_len, d_k))
mask = get_forward_mask(seq_len, -1000_000_000)
W_qkv = np.random.randn(*(d_k, 3*d_k))
W_out = np.random.randn(*(d_k, d_k))
out = multihead_attention(x, mask=mask, W_qkv=W_qkv, W_out=W_out, num_heads=3)
out.shape


(2, 4, 3)

In [286]:


def multihead_crossattention(x, y, mask, W_kv, W_q, W_out, num_heads):
    batch_size, sequence_length, d_model = x.shape # bs x max_seq_len x 512
    head_dim = d_model // num_heads
    kv = x @ W_kv                           # bs x max_seq_len x 1024
    q = y @ W_q                             # bs x max_seq_len x 512
    kv = kv.reshape(batch_size, sequence_length, num_heads, 2*head_dim) # bs x max_seq_len x 8 x 128
    q = q.reshape(batch_size, sequence_length, num_heads, head_dim)     # bs x max_seq_len x 8 x 64
    kv = kv.transpose(0, 2, 1, 3)                                       # bs x 8 x max_seq_len x 128
    q = q.transpose(0, 2, 1, 3)                                         # bs x 8 x max_seq_len x 64
    k, v = np.array_split(kv, 2, axis = -1)                             # tuple of 2 tensors of shape bs x 8 x max_seq_len x 64
    values, attention = scaled_dot_product(q, k, v)                     # values: bs x 8 x max_seq_len x 64. We don't need mask here
    values = values.reshape(batch_size, sequence_length, d_model)       # concatentate heads: bs x max_seq_len x 512
    out = values @ W_out # bs x max_seq_len x 512
    return out


x = np.random.randint(1, 10, size=(batch_size, seq_len, d_k))
y = np.random.randint(1, 10, size=(batch_size, seq_len, d_k))
mask = get_forward_mask(seq_len, -1000_000_000)
W_kv = np.random.randn(*(d_k, 2*d_k))
W_q = np.random.randn(*(d_k, d_k))
W_out = np.random.randn(*(d_k, d_k))
out = multihead_crossattention(x, y, mask=mask, W_kv=W_kv, W_q=W_q, W_out=W_out, num_heads=3)
out.shape

(2, 4, 3)

In [290]:
def encoder_block(x):
    W_qkv = np.random.randn(*(d_k, 3*d_k))
    W_out = np.random.randn(*(d_k, d_k))
    x = multihead_attention(x, mask=None, W_qkv=W_qkv, W_out=W_out, num_heads=3)
    W_ff1 = np.random.randn(*(d_k, 2048))
    W_ff2 = np.random.randn(*(2048, d_k))
    x = np.maximum(0, x @ W_ff1) @ W_ff2
    return x

x = np.random.randint(1, 10, size=(batch_size, seq_len, d_k))
N = 5
for layer_id in range(N):
    x = encoder_block(x)
x.shape

(2, 4, 3)