In [1]:
import numpy as np

In [2]:
def positional_encoding(seq_len, d_model):
    '''
        Generate positional encoding for input sequences
    '''
    pe = np.zeros((seq_len, d_model)) # Positional encoding matrix
    position = np.arange(0, seq_len)[:, np.newaxis] 
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return pe

$$

    Attention(Q, K, V) = softmax(\frac{QK^{T}}{\sqrt{d_k}})V

$$

In [3]:
def scaled_dot_product_attention(query, key, value, mask=None):
    '''
        Compute the scaled dot-product attention
    '''
    d_k = query.shape[-1]
    scores = np.matmul(query, key.transpose(0, 1, 3, 2)) / np.sqrt(d_k)
    if mask is not None:
        scores = scores + (mask * -1e9)
    attention_weights = np.exp(scores) / np.sum(np.exp(scores), axis=-1, keepdims=True)
    attention_output = np.matmul(attention_weights, value)
    return attention_output, attention_weights

In [4]:
class MultiHeadAttention:
    def __init__(self, d_model, num_heads):
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = np.random.randn(d_model, d_model)
        self.W_k = np.random.randn(d_model, d_model)
        self.W_v = np.random.randn(d_model, d_model)
        self.W_o = np.random.randn(d_model, d_model)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape
        Q = np.matmul(x, self.W_q).reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
        K = np.matmul(x, self.W_k).reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
        V = np.matmul(x, self.W_v).reshape(batch_size, seq_len, self.num_heads, self.d_k).transpose(0, 2, 1, 3)
        output, weights = scaled_dot_product_attention(Q, K, V, mask)
        output = output.transpose(0, 2, 1, 3).reshape(batch_size, seq_len, d_model)
        output = np.matmul(output, self.W_o)
        return output, weights