# Linear Transformer with SPE

## Import modules 

In [1]:
import numpy as np 
import math

from einops import rearrange

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers 

## Linear Multi Head Attention with SineSPE

In [2]:
class SineSPE(layers.Layer):
    def __init__(self, 
                 num_heads: int = 8,
                 in_features: int = 64,
                 num_realizations: int = 256,
                 num_sines: int = 1):
        super(SineSPE, self).__init__()
        
        self.num_heads = num_heads
        self.in_features = in_features 
        self.num_sines = num_sines 
        self.num_realizations = num_realizations
        
        freqs_init = tf.random_normal_initializer()
        self.freqs = tf.Variable(
            initial_value=freqs_init(shape=(num_heads, in_features, num_sines), dtype="float32"),
            trainable=True,
        )
        
        offsets_init = tf.random_normal_initializer()
        self.offsets = tf.Variable(
            initial_value=offsets_init(shape=(num_heads, in_features, num_sines), dtype="float32"),
            trainable=True,
        )
        
        gains_init = tf.random_normal_initializer()
        self.gains = tf.Variable(
            initial_value=gains_init(shape=(num_heads, in_features, num_sines), dtype="float32"),
            trainable=True,
        )
        
        # Normalize gains 
        self.gains = self.gains/(tf.math.sqrt(tf.norm(self.gains,axis=-1,keepdims=True))/2)
        
        # Bias intial freqs
        self.freqs = self.freqs-4
        
        self.code_shape = (num_heads,in_features)

    def call(self, shape):
        """
        Generate the code, composed of a random QBar and Kbar,
        depending on the parameters, and return them for use with a
        SPE module to actually encode queries and keys.
        Args:
            shape: The outer shape of the inputs: (batchsize, *size)
            num_realizations: if provided, overrides self.num_realizations
        """
        
        if len(shape) != 2:
            raise ValueError('Only 1D inputs are supported by SineSPE')
        
        max_len = shape[1]
        
        # build omega_q and omega_k
        # with shape (num_heads,keys_dim,length,2*num_sines)
        indices = tf.linspace(0,max_len-1,max_len)
        indices = tf.cast(indices, dtype=tf.float32)

        # make sure freqs are in [0,.5]
        freqs = tf.nn.sigmoid(self.freqs[:,:,None,:])/2
        
        phases_q = 2*math.pi*freqs*indices[None,None,:,None]*self.offsets[:,:,None,:]
        omega_q = tf.stack([tf.math.cos(phases_q),tf.math.sin(phases_q)],axis=-1)
        omega_q = tf.reshape(omega_q,[1,self.num_heads,self.in_features,max_len,2*self.num_sines] )
        
        phases_k = 2*math.pi*freqs*indices[None,None,:,None]
        omega_k = tf.stack([tf.math.cos(phases_k),tf.math.sin(phases_k)],axis=-1)
        omega_k = tf.reshape(omega_k,[1,self.num_heads,self.in_features,max_len,2*self.num_sines] )
        
        # Gains is (num_heads,keys_dim,num_sines), make nonnegative with softplut
        gains = tf.math.softplus(self.gains)
        
        # Upsample
        gains = tf.stack([gains,gains],axis=-1)
        gains = tf.reshape(gains, [self.num_heads,self.in_features,2*self.num_sines])
        
        # Draw noise
        z = tf.random.normal((1,self.num_heads,self.in_features,2*self.num_sines,self.num_realizations))
        z = z/tf.math.sqrt(tf.cast(self.num_sines*2, dtype=tf.float32))
        
        # Scale each of the 2*num_sines by the appropriate gain
        z = z*gains[None, ..., None]
    
        # Compute sums over sines
        qbar = tf.linalg.matmul(omega_q,z)
        kbar = tf.linalg.matmul(omega_k,z)
        
        # Pemute to (1,length,num_heads,key_dim,num_realization)
        qbar = tf.transpose(qbar, perm=[0,3,1,2,4])
        kbar = tf.transpose(kbar, perm=[0,3,1,2,4])

        # scale
        scale = (self.num_realizations*self.in_features)**.25
        return (qbar/scale,kbar/scale)

In [3]:
class SPEFilter(layers.Layer):
    """Stochastic positional encoding filter
    Applies a positional code provided by a SPE module on actual queries and keys.
    Implements gating, i.e. some "dry" parameter, that lets original queries and keys through if activated.
    Args:
    gated: whether to use the gated version, which learns to balance
        positional and positionless features.
    code_shape: the inner shape of the codes, i.e. (num_heads, key_dim),
        as given by `spe.code_shape`
    """
    def __init__(self,gated,code_shape):
        super(SPEFilter, self).__init__()

        self.gated = gated
        self.code_shape = code_shape

        # create the gating parameters if required
        if gated:
            if code_shape is None:
                raise RuntimeError('code_shape has to be provided if gated is True.')

            gate_init = tf.random_normal_initializer()
            self.gate = tf.Variable(
                initial_value=gate_init(shape=(code_shape), dtype="float32"),
                trainable=True,
            )  

    def call(self,queries,keys,code):
        """
        Apply SPE on keys with a given code.
        Expects keys and queries of shape `(batch_size, ..., num_heads,
        key_dim)` and outputs keys and queries of shape `(batch_size,
        ..., num_heads, num_realizations)`. code is the tuple
        of the 2 tensors provided by the code instance, each one of
        shape (1, ..., num_heads, key_dim, num_realizations)
        """
        assert (queries.shape == keys.shape), \
            "As of current implementation, queries and keys must have the same shape. "\
            "got queries: {} and keys: {}".format(queries.shape, keys.shape)

        # qbar and kbar are (1, *shape, num_heads, keys_dim, num_realizations)
        (qbar, kbar) = code

        # check that codes have the shape we are expecting
        if self.code_shape is not None and qbar.shape[-3:-1] != self.code_shape:
            raise ValueError(
                f'The inner shape of codes is {qbar.shape[-3:-1]}, '
                f'but expected {self.code_shape}')

        # check shapes: size of codes should be bigger than queries, keys
        code_size = qbar.shape[1:-3]
        query_size = queries.shape[1:-2]
        if (len(code_size) != len(query_size)
            or tf.reduce_any(
                tf.Variable(code_size) < tf.Variable(query_size)
            )):
                raise ValueError(f'Keys/queries have length {query_size}, '
                                 f'but expected at most {code_size}')
        if qbar.shape[-3:-1] != queries.shape[-2:]:
            raise ValueError(f'shape mismatch. codes have shape {qbar.shape}, '
                             f'but queries are {queries.shape}')

        # truncate qbar and kbar for matching current queries and keys,
        # but only if we need to
        for dim in range(len(query_size)):
            if code_size[dim] > query_size[dim]:
                indices = [slice(1), *[slice(qbar.shape[1+k]) for k in range(dim)],
                           slice(query_size[dim])]
                qbar = qbar[indices]
                kbar = kbar[indices]

        # apply gate if required
        if self.gated:
            # incorporate the constant bias for Pd if required. First draw noise
            # such that noise noise^T = 1, for each head, feature, realization.
            # qbar is : (1, *shape, num_heads, keys_dim, num_realizations)
            in_features = qbar.shape[-2]
            num_realizations = qbar.shape[-1]
            gating_noise = tf.random.normal(self.code_shape+\
                            (num_realizations,))/(in_features*num_realizations)**.25
            
            
            # normalize it so that it's an additive 1 to Pd
            #gating_noise = gating_noise / gating_noise.norm(dim=2, keepdim=True)

            # constrain the gate parameter to be in [0 1]
            gate = tf.math.sigmoid(self.gate[..., None])

            # qbar is (1, *shape, num_heads, keys_dim, num_realizations)
            # gating noise is (num_heads, keys_dim, num_realizations)
            # gate is (num_heads, keys_dim, 1)
            #import ipdb; ipdb.set_trace()
            qbar = tf.math.sqrt(1.-gate) * qbar  + tf.math.sqrt(gate) * gating_noise
            kbar = tf.math.sqrt(1.-gate) * kbar  + tf.math.sqrt(gate) * gating_noise

        # sum over d after multiplying by queries and keys
        # qbar/kbar are (1, *shape, num_heads, keys_dim, num_realizations)
        # queries/keys  (batchsize, *shape, num_heads, keys_dim)
        qhat = tf.math.reduce_sum(qbar * queries[..., None],axis=-2)
        khat = tf.math.reduce_sum(kbar * keys[..., None],axis=-2)

        # result is (batchsize, ..., num_heads, num_realizations)
        return (qhat, khat)

In [4]:
def compute_linear_mhsa(q, k, v):
    q = tf.nn.elu(q)+1
    k = tf.nn.elu(k)+1
    kv = tf.einsum('... h s d, ...  h s m  -> ... h m d',k,v)
    k_sum = tf.math.reduce_sum(k,axis=2)
    z = 1/ (tf.einsum('... h l d, ... h d -> ... h l',q ,k_sum)+1e-4)
    Vhat = tf.einsum('... h l d, ... h m d, ... h l -> ... h l m',q,kv,z)
    return Vhat

class LinearAttentionSineSPE(tf.keras.layers.Layer):
    def __init__(self, d_model, heads=8, num_sines=5):
        super(LinearAttentionSineSPE, self).__init__()
        self.num_heads = heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)
        
        self.spe_encoder = SineSPE(num_heads=heads,          # Number of attention heads
                          in_features=self.depth,       # Dimension of keys and queries
                          num_realizations=self.depth,  # New dimension of keys and queries
                          num_sines=num_sines)          # Number of sinusoidal components
        self.spe_filter = SPEFilter(gated=True, code_shape=self.spe_encoder.code_shape)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)
        
        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)   

        q = tf.transpose(q,perm=[0,2,1,3])
        k = tf.transpose(k,perm=[0,2,1,3])
        
        pos_codes = self.spe_encoder(q.shape[:2])  # pos_codes is a tuple (qbar, kbar)
        q, k = self.spe_filter(q, k, pos_codes)
        q = tf.transpose(q,perm=[0,2,1,3])
        k = tf.transpose(k,perm=[0,2,1,3])
        
        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention = compute_linear_mhsa(q, k, v)

        scaled_attention = tf.transpose(
            scaled_attention, perm=[0, 2, 1, 3]
        )  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(
            scaled_attention, (batch_size, -1, self.d_model)
        )  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output

In [5]:
BATCH_SIZE = 32
seq_len = 100
n_heads = 8
embedding_dim = 64 

x = np.random.uniform(0,1,(BATCH_SIZE,seq_len,embedding_dim))
lha = LinearAttentionSineSPE(embedding_dim ,n_heads)

y = lha(x,x,x)
print(y.shape)

(32, 100, 64)


## Linear Multi Head Attention with SineSPE Transformer 


In [6]:
class LinearSineSPETransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        
        super(LinearSineSPETransformerBlock, self).__init__()

        self.lha = LinearAttentionSineSPE(embed_dim,num_heads)
        
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        attn_output = self.lha(inputs,inputs,inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

In [7]:
BATCH_SIZE = 32
seq_len = 100
n_heads = 8
embedding_dim = 64 
ff_dim = 40

x = np.random.uniform(0,1,(BATCH_SIZE,seq_len,embedding_dim))
ltb = LinearSineSPETransformerBlock(embedding_dim, n_heads ,ff_dim)

y = ltb(x)
print(y.shape)

(32, 100, 64)
