### Attention

RNNs will favour more recent inputs as the proportion of the output that an earlier recurred output is responsible for will exponentially decay passed through a balanced set of weights repeatedly. Attention solves this problem by directly considering every symbol in a sequence.

Attention is modeled as a database of key value pairs. Each query is matched to every key and a similarity score computed. The similarity scores are used as softmax input to generate a weight for a each key and the output is the weighted sum of all values for each query. The keys, queries and values have their own learned latent spaces.

Attention also is Parallelisable, unlike RNNs where the output of a pass is used in the next.

In [59]:
import numpy as np
from scipy.special import softmax

class Attention:
    # d_K: dimension of key and query input
    # d_V: dimension of value input
    # d_L: dimension of latent spaces
    def __init__(self, d_K, d_V, d_L, mask=None):
        self.d_K = d_K
        self.d_V = d_V
        self.d_L = d_L

        self.mask = mask

        # glorot initialization limits
        limit = np.sqrt(6 / (d_K + d_L))

        # weight matrices for query, key, and value latent spaces
        self.w_Q = np.random.uniform(-limit, limit, size=(d_K, d_L))
        self.w_K = np.random.uniform(-limit, limit, size=(d_K, d_L))
        self.w_V = np.random.uniform(-limit, limit, size=(d_V, d_L))

    # Q: query input
    # K: key input
    # V: value input
    # for self-attention Q K and V are the same 
    def calc_attention(self, Q, K, V):
        # hold inputs for gradient calculations
        self.Q = Q
        self.K = K
        self.V = V

        # calculate latent space representations
        self.lQ = Q @ self.w_Q
        self.lK = K @ self.w_K
        self.lV = V @ self.w_V

        # calculate similarity by matching each query to every key
        similarity = self.lQ @ self.lK.T

        # scale similarity
        similarity = similarity / np.sqrt(self.d_L)

        if self.mask is not None:
            similarity = similarity * self.mask
        
        # calculate weights
        self.w_A = softmax(similarity, axis=1)

        # apply weights to value vectors and sum results
        return np.sum(self.w_A[:, :, np.newaxis] * self.lV[np.newaxis, :, :], axis=1)
    
    # g_A: gradient at attention output
    def calc_gradients(self, g_A):
        # gradient at values in latent space
        g_lV = self.w_A.T @ g_A
        # gradient at final weights, post softmax
        g_wA = g_A @ self.lV.T

        # gradient of values input pre latent space transformation
        g_V = g_lV @ self.w_V.T 
        # gradient of latent space transformation
        self.g_w_lV = self.V.T @ g_lV

        # gradient of similarity scores, pre softmax 

        # softmax derivative is jacobian matrix since all elems of input
        # vector effect all elems of output vector. each attention output
        # has its own softmax output, so multiple jacobian matrices
        jacobs = np.array([np.outer(row, row) for row in self.w_A])
        jacobs *= -1
        jacobs += np.array([np.diag(row) for row in self.w_A])
        
        g_sim = jacobs @ g_wA[:, :, np.newaxis]
        g_sim = g_sim.squeeze()
        
        # mask is its own gradient
        if self.mask is not None:
            g_sim = g_sim * self.mask

        # gradient of scaling
        g_sim = g_sim / np.sqrt(self.d_L)
        
        # gradient at latent space
        g_lQ = g_sim @ self.lK
        g_lK = g_sim.T @ self.lQ

        # gradients pre latent space transformation
        g_Q = g_lQ @ self.w_Q.T
        g_K = g_lK @ self.w_K.T

        # gradients of latent space transformations
        self.g_w_Q = self.Q.T @ g_lQ
        self.g_w_K = self.K.T @ g_lK

        return (g_Q, g_K, g_V)



[[ 2.40355304 -0.66996238  1.89421245]
 [ 7.64042136 -1.92736712  5.0570968 ]]


(array([[ 6.69904830e-02, -9.31449150e-01],
        [ 1.56768043e-26, -2.24831539e-25]]),
 array([[-7.30964179e-01, -5.32471819e-01],
        [-1.17104623e-25, -8.46764949e-26]]),
 array([[ 1.27416367,  3.89954714],
        [ 4.55091222, 11.70034358]]))