### Attention

RNNs will favour more recent inputs as the proportion of the output that an earlier recurred output is responsible for will exponentially decay passed through a balanced set of weights repeatedly. Attention solves this problem by directly considering every symbol in a sequence.

Attention is modeled as a database of key value pairs. Each query is matched to every key and a similarity score computed. The similarity scores are used as softmax input to generate a weight for a each key and the output is the weighted sum of all values for each query. The keys, queries and values have their own learned latent spaces.

Attention also is Parallelisable, unlike RNNs where the output of a pass is used in the next.

In [65]:
import numpy as np
from scipy.special import softmax

class Attention:
    # d_K: dimension of key and query input
    # d_V: dimension of value input
    # d_L: dimension of latent spaces
    def __init__(self, d_K, d_V, d_L):
        self.d_K = d_K
        self.d_V = d_V
        self.d_L = d_L

        # glorot initialization limits
        limit = np.sqrt(6 / (d_K + d_L))

        # weight matrices for query, key, and value
        self.W_Q = np.random.uniform(-limit, limit, size=(d_L, d_K))
        self.W_K = np.random.uniform(-limit, limit, size=(d_L, d_K))
        self.W_V = np.random.uniform(-limit, limit, size=(d_L, d_V))

    # Q: query input
    # K: key input
    # V: value input
    # for self-attention Q K and V are the same 
    def calc_attention(self, Q, K, V):
        # calculate latent space representations
        self.l_Q = Q @ self.W_Q.T
        self.l_K = K @ self.W_K.T
        self.l_V = V @ self.W_V.T

        # calculate similarity by matching each query to every key
        similarity = self.l_Q @ self.l_K.T

        # scale similarity
        similarity = similarity / np.sqrt(self.d_L)
        
        # calculate weights
        weights = softmax(similarity, axis=1)

        # apply weights to values and sum results
        return np.sum(weights[:, :, np.newaxis] * self.l_V[np.newaxis, :, :], axis=1)