# Transformer Varients

In [1]:
import pandas as pd 
import numpy as np 

from einops import rearrange

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

import matplotlib.pyplot as plt 
import seaborn as sns 
sns.set()

## Multi Head Attention

In [2]:
def scaled_dot_product_attention(q, k, v, mask):
    scale_factor =  tf.math.sqrt(tf.cast(tf.shape(k)[-1], tf.float32))
    scaled_dot_prod = tf.einsum('... i d , ... j d -> ... i j', q, k) / scale_factor
    attention_weights = tf.nn.softmax(scaled_dot_prod, axis=-1) 
    return tf.einsum('... i j , ... j d -> ... i d', attention_weights, v), attention_weights

In [3]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)
        
        print("q shape: {}".format(q.shape))
        print("k shape: {}".format(k.shape))
        print("v shape: {}".format(v.shape))

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)
        
        print("q shape: {}".format(q.shape))
        print("k shape: {}".format(k.shape))
        print("v shape: {}".format(v.shape))

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask
        )

        scaled_attention = tf.transpose(
            scaled_attention, perm=[0, 2, 1, 3]
        )  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(
            scaled_attention, (batch_size, -1, self.d_model)
        )  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output, attention_weights


In [4]:
mha = MultiHeadAttention(50,10)

In [5]:
x = np.random.uniform(0,1,(1,100,50))
y, _ = mha(x,x,x,mask=None)
y.shape

q shape: (1, 100, 50)
k shape: (1, 100, 50)
v shape: (1, 100, 50)
q shape: (1, 10, 100, 5)
k shape: (1, 10, 100, 5)
v shape: (1, 10, 100, 5)


TensorShape([1, 100, 50])

## Linformer Attention

In [6]:
def compute_mhsa(q, k, v, scale_factor=1):
    # resulted shape will be: [batch, heads, tokens, tokens]
    scaled_dot_prod = tf.einsum('... i d , ... j d -> ... i j', q, k) * scale_factor
    attention = tf.nn.softmax(scaled_dot_prod, axis=-1)
    # calc result per head
    return tf.einsum('... i j , ... j d -> ... i d', attention, v)

In [7]:
def project_vk_linformer(v, k, E):
    # project k,v
    v = tf.einsum('b h j d , j k -> b h k d', v, E)
    k = tf.einsum('b h j d , j k -> b h k d', k, E)
    return v, k

In [8]:
class LinformerAttention(tf.keras.layers.Layer):
    def __init__(self, dim, heads=8, dim_head=None, proj_shape=None, trainable_proj=True):
        """
        Based on the Linformer paper
        Link: https://arxiv.org/pdf/2006.04768.pdf
        Args:
            dim: token's dimension, i.e. word embedding vector size
            heads: the number of distinct representations to learn
            dim_head: the dim of the head.
            shared_projection: if the projection matrix will be shared among layers
            (it will have to be passed in the forward that way)
            trainable_proj: if the projection matrix E matrix is not shared,
            you can enable this option to make it trainable (non trainable in the paper)
            proj_shape: 2-tuple (tokens,k), where k is the projection dimension of the linformer
            """
        super(LinformerAttention, self).__init__()
        self.dim_head = (int(dim / heads)) if dim_head is None else dim_head
        _dim = self.dim_head * heads
        self.heads = heads
        
        to_qvk_init = tf.random_normal_initializer()
        self.to_qvk = tf.Variable(
            initial_value= to_qvk_init(shape=(dim, _dim * 3), dtype="float32"),
            trainable=True,
        )
        
        W_0_init = tf.random_normal_initializer()
        self.W_0 = tf.Variable(
            initial_value=W_0_init(shape=(_dim, dim), dtype="float32"),
            trainable=True,
        )

        self.scale_factor = self.dim_head ** -0.5

        E_init = tf.random_normal_initializer()
        self.E = tf.Variable(initial_value=E_init(shape=(proj_shape), dtype="float32"),
                trainable=trainable_proj,) 
        self.k = proj_shape[1]

    def call(self, x):
        qkv = x@self.to_qvk # [batch, tokens, dim*3*heads ]

        q, k, v = tuple(rearrange(qkv, 'b t (d k h ) -> k b h t d ', k=3, h=self.heads))

        v, k = project_vk_linformer(v, k, self.E)

        out = compute_mhsa(q, k, v, scale_factor=self.scale_factor)
        # re-compose: merge heads with dim_head

        out = rearrange(out, "b h i d -> b i (h d)")

        # Apply final linear transformation layer
        return out@self.W_0

In [9]:
# embed_dim: token's dimension, i.e. word embedding vector size
lha = LinformerAttention(50, heads=8, dim_head=None,
                 proj_shape=(100,20), # proj_shape=(seq_len,k_dim)
                 trainable_proj=False)

In [10]:
x = np.random.uniform(0,1,(1,100,50)) # (batch,seq_len,embed_dim)
y = lha(x)
y.shape

TensorShape([1, 100, 50])

## Linformer Transformer Block

In [11]:
class LinformerTransformerBlock(layers.Layer):
    def __init__(self, embed_dim, seq_len, project_dim, num_heads, ff_dim, rate=0.1, trainable_proj=False):
        
        super(LinformerTransformerBlock, self).__init__()

        self.lha = LinformerAttention(embed_dim, heads=num_heads, dim_head=None,
                 proj_shape=(seq_len,project_dim), # proj_shape=(seq_len,k_dim)
                 trainable_proj=trainable_proj)
        
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        attn_output = self.lha(inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

In [12]:
transformer = LinformerTransformerBlock(50,100, 20, 8, 50, trainable_proj=False)

In [13]:
x = np.random.uniform(0,1,(1,100,50)) # (batch,seq_len,embed_dim)
y = transformer(x)
y.shape

TensorShape([1, 100, 50])

## Linear Transformer using Kernel Trick

In [14]:
def compute_linear_att(q, k, v):
    q = tf.nn.elu(q)+1
    k = tf.nn.elu(k)+1
    kv = tf.einsum('... s d, ... s m  -> ... m d',k,v)
    k_sum = tf.math.reduce_sum(k,axis=1)
    z = 1/ (tf.einsum('... l d, ... d -> ... l',q ,k_sum)+1e-4)
    Vhat = tf.einsum('... l d, ... m d, ... l -> ... l m',q,kv,z)
    return Vhat

# x ~ (seq_len,embed_dim) ~ (s,f)
# Wq ~ (f,d)
# Wk ~ (f,d)
# Wv ~ (f,m)
# q ~ (s,d)
# k ~ (s,d)
# v ~ (s,m)

s = 100 # seq len
f = 10 # embed dim
d = 30 # query length
m = 10 # key length

BATCH_SIZE = 2

q = np.ones((BATCH_SIZE,s,d))
k = np.ones((BATCH_SIZE,s,d))
v = np.ones((BATCH_SIZE,s,m))

Vhat = compute_linear_att(q,k,v)
print(Vhat.shape)

(2, 100, 10)


In [15]:
def compute_linear_mhsa(q, k, v):
    q = tf.nn.elu(q)+1
    k = tf.nn.elu(k)+1
    kv = tf.einsum('... h s d, ...  h s m  -> ... h m d',k,v)
    k_sum = tf.math.reduce_sum(k,axis=2)
    z = 1/ (tf.einsum('... h l d, ... h d -> ... h l',q ,k_sum)+1e-4)
    Vhat = tf.einsum('... h l d, ... h m d, ... h l -> ... h l m',q,kv,z)
    return Vhat

class LinearAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, heads=8):
        super(LinearAttention, self).__init__()
        self.num_heads = heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention = compute_linear_mhsa(q, k, v)

        scaled_attention = tf.transpose(
            scaled_attention, perm=[0, 2, 1, 3]
        )  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(
            scaled_attention, (batch_size, -1, self.d_model)
        )  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output
    
class LinearTransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        
        super(LinearTransformerBlock, self).__init__()

        self.lha = LinearAttention(embed_dim,num_heads)
        
        self.ffn = keras.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs):
        attn_output = self.lha(inputs,inputs,inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output)

In [16]:
x = np.random.uniform(0,1,(1,100,256)) # (batch,seq_len,embed_dim)
ltb = LinearTransformerBlock(256,2,40)

y = ltb(x)
print(y.shape)

(1, 100, 256)
(1, 100, 256)
