In [5]:
import tensorflow as tf
import math
from keras.layers import LayerNormalization, MultiHeadAttention, Add, Dropout, Reshape

In [6]:
class MultiHeadInvertedAttention(tf.keras.layers.MultiHeadAttention):
    """Multi-head inverted attention block for iTransformer architecture."""

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, query, value, attention_mask=None, training=None):
        # Attention mechanism
        attention_output = super().call(query, value, attention_mask=attention_mask, training=training)
        return attention_output


def imha_block(input_feature, key_dim=8, num_heads=2, dropout=0.5):
    """iTransformer Multi-Head Inverted Attention block."""
    # Embedding layer (no embedding is required since the input features are already embedded)
    x = input_feature
    
    # Multi-head inverted attention layer
    x = MultiHeadInvertedAttention(key_dim=key_dim, num_heads=num_heads)(x, x)
    
    # Layer normalization across timestamps
    x = LayerNormalization(epsilon=1e-6)(x)
    
    # Feed-forward network (FFN)
    x = tf.keras.layers.Dense(units=x.shape[-1] * 4, activation='relu')(x)
    x = tf.keras.layers.Dense(units=x.shape[-1])(x)
    
    # Layer normalization across timestamps
    x = LayerNormalization(epsilon=1e-6)(x)
    
    # Skip connection
    imha_feature = Add()([input_feature, x])
    
    return imha_feature


In [7]:
class MultiHeadAttention_LSA(tf.keras.layers.MultiHeadAttention):
    """local multi-head self attention block
     
     Locality Self Attention as described in https://arxiv.org/abs/2112.13492v1
     This implementation is taken from  https://keras.io/examples/vision/vit_small_ds/ 
    """    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # The trainable temperature term. The initial value is the square 
        # root of the key dimension.
        self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)

    def _compute_attention(self, query, key, value, attention_mask=None, training=None):
        query = tf.multiply(query, 1.0 / self.tau)
        attention_scores = tf.einsum(self._dot_product_equation, key, query)
        attention_scores = self._masked_softmax(attention_scores, attention_mask)
        attention_scores_dropout = self._dropout_layer(
            attention_scores, training=training
        )
        attention_output = tf.einsum(
            self._combine_equation, attention_scores_dropout, value
        )
        return attention_output, attention_scores



def mha_block(input_feature, key_dim=8, num_heads=2, dropout = 0.5, vanilla = True):
    """Multi Head self Attention (MHA) block.     
       
    Here we include two types of MHA blocks: 
            The original multi-head self-attention as described in https://arxiv.org/abs/1706.03762
            The multi-head local self attention as described in https://arxiv.org/abs/2112.13492v1
    """    
    # Layer normalization
    x = LayerNormalization(epsilon=1e-6)(input_feature)
    
    if vanilla:
        # Create a multi-head attention layer as described in 
        # 'Attention Is All You Need' https://arxiv.org/abs/1706.03762
        x = MultiHeadAttention(key_dim = key_dim, num_heads = num_heads, dropout = dropout)(x, x)
    else:
        # Create a multi-head local self-attention layer as described in 
        # 'Vision Transformer for Small-Size Datasets' https://arxiv.org/abs/2112.13492v1
        
        # Build the diagonal attention mask
        NUM_PATCHES = input_feature.shape[1]
        diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
        diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)
        
        # Create a multi-head local self attention layer.
        # x = MultiHeadAttention_LSA(key_dim = key_dim, num_heads = num_heads, dropout = dropout)(
        #     x, x, attention_mask = diag_attn_mask)
        x = MultiHeadAttention_LSA(key_dim = key_dim, num_heads = num_heads, dropout = dropout)(
            x, x, attention_mask = diag_attn_mask)
    x = Dropout(0.3)(x)
    # Skip connection
    mha_feature = Add()([input_feature, x])
    
    return mha_feature


def attention_block(in_layer, attention_model, ratio=8, residual = False, apply_to_input=True): 
    in_sh = in_layer.shape # dimensions of the input tensor
    in_len = len(in_sh) 
    expanded_axis = 2 # defualt = 2
    
    if attention_model == 'mha':   # Multi-head self attention layer 
        if(in_len > 3):
            in_layer = Reshape((in_sh[1],-1))(in_layer)
        out_layer = mha_block(in_layer)
    else:
        raise Exception("'{}' is not supported attention module!".format(attention_model))

        
    if (in_len == 3 and len(out_layer.shape) == 4):
        out_layer = tf.squeeze(out_layer, expanded_axis)
    elif (in_len == 4 and len(out_layer.shape) == 3):
        out_layer = Reshape((in_sh[1], in_sh[2], in_sh[3]))(out_layer)
    return out_layer