# TACT-Net: Triple-Tier Attention-integrated Compact Transformer Network for COVID-19 Prediction in CT Scans

In [None]:
# Importing Required Dependecies
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers as L
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *

## Spatial Attention

In [None]:
class SpatialAttentionModule(tf.keras.layers.Layer):
    def __init__(self, kernel_size=3):
        '''
        paper: https://arxiv.org/abs/1807.06521
        code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
        '''
        super(SpatialAttentionModule, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(64, kernel_size=kernel_size, 
                                            use_bias=False, 
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.nn.relu)
        self.conv2 = tf.keras.layers.Conv2D(32, kernel_size=kernel_size, 
                                            use_bias=False, 
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.nn.relu)
        self.conv3 = tf.keras.layers.Conv2D(16, kernel_size=kernel_size, 
                                            use_bias=False, 
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.nn.relu)
        self.conv4 = tf.keras.layers.Conv2D(1, kernel_size=kernel_size,  
                                            use_bias=False,
                                            kernel_initializer='he_normal',
                                            strides=1, padding='same', 
                                            activation=tf.math.sigmoid)

    def call(self, inputs):
        avg_out = tf.reduce_mean(inputs, axis=3)
        max_out = tf.reduce_max(inputs,  axis=3)
        x = tf.stack([avg_out, max_out], axis=3) 
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        return self.conv4(x)
    

## Channel Attention

In [None]:
class ChannelAttentionModule(tf.keras.layers.Layer):
    def __init__(self, ratio=8):
        '''
        paper: https://arxiv.org/abs/1807.06521
        code: https://gist.github.com/innat/99888fa8065ecbf3ae2b297e5c10db70
        '''
        super(ChannelAttentionModule, self).__init__()
        self.ratio = ratio
        self.gapavg = tf.keras.layers.GlobalAveragePooling2D()
        self.gmpmax = tf.keras.layers.GlobalMaxPooling2D()
        
    def build(self, input_shape):
        self.conv1 = tf.keras.layers.Conv2D(input_shape[-1]//self.ratio, 
                                            kernel_size=1, 
                                            strides=1, padding='same',
                                            use_bias=True, activation=tf.nn.relu)
    
        self.conv2 = tf.keras.layers.Conv2D(input_shape[-1], 
                                            kernel_size=1, 
                                            strides=1, padding='same',
                                            use_bias=True, activation=tf.nn.relu)
        super(ChannelAttentionModule, self).build(input_shape)

    def call(self, inputs):
        # compute gap and gmp pooling 
        gapavg = self.gapavg(inputs)
        gmpmax = self.gmpmax(inputs)
        gapavg = tf.keras.layers.Reshape((1, 1, gapavg.shape[1]))(gapavg)   
        gmpmax = tf.keras.layers.Reshape((1, 1, gmpmax.shape[1]))(gmpmax)   
        # forward passing to the respected layers
        gapavg_out = self.conv2(self.conv1(gapavg))
        gmpmax_out = self.conv2(self.conv1(gmpmax))
        return tf.math.sigmoid(gapavg_out + gmpmax_out)
    
    def get_output_shape_for(self, input_shape):
        return self.compute_output_shape(input_shape)

    def compute_output_shape(self, input_shape):
        output_len = input_shape[3]
        return (input_shape[0], output_len)

## Pixel Attention

In [None]:
def pixel_attention(x, nf):
    # Apply convolution to capture spatial dependencies
    conv = Conv2D(nf, 3, padding='same', activation='relu')(x)
    
    # Apply convolution to obtain attention scores
    attention_scores = Conv2D(1, 1, padding='same', activation='sigmoid')(conv)
    
    # Multiply attention scores with input features
    weighted_features = Multiply()([x, attention_scores])
    
    return weighted_features

# Triple Tier Attention Module

In [None]:
def tta_module(x, nf):
    x_1by1 = Conv2D(nf, (1,1), activation='relu',  padding='same')(x)
    x_3by3 = Conv2D(nf, (3,3), activation='relu',  padding='same')(x)
    
    x_pa = pixel_attention(x_3by3, nf)
    
    x_ca = ChannelAttentionModule()(x_1by1)
    
    x_sa = SpatialAttentionModule()(x_1by1)
    
    x_casa = Multiply()([x_sa, x_ca])
    
    x_out = Concatenate()([x_casa, x_pa])
    return x_out

# Convolutional Tokenizer

In [None]:
#ConvolutionalTokenizer class that converts images into tokenized feature representations
class ConvolutionalTokenizer(keras.Model):
    def __init__(self, kernel_size=3, stride=1, padding=1, pooling_kernel_size=3, pooling_stride=2, 
                 conv_layers=3, num_output_channels=[64, 128, 128], **kwargs):
        super(ConvolutionalTokenizer, self).__init__(**kwargs)
        
        self.conv_model = keras.Sequential()  # Sequential container for the convolutional layers
        
        # Loop through the desired number of convolutional layers
        for i in range(conv_layers):
            self.conv_model.add(L.SeparableConv2D(num_output_channels[i], kernel_size=kernel_size, strides=stride, padding="same",
                                                 use_bias=False, activation="relu", depth_multiplier=1, 
                                                 depthwise_initializer="he_normal", pointwise_initializer="he_normal"))
                        
            self.conv_model.add(L.MaxPool2D(pool_size=pooling_kernel_size, strides=pooling_stride, padding="same"))
        
    def call(self, images):
       
        outputs = self.conv_model(images)
        
        # Flatten the output to produce tokens of shape (batch_size, sequence_length, channels)
        flattened = tf.reshape(outputs, (-1, tf.shape(outputs)[1] * tf.shape(outputs)[2], tf.shape(outputs)[3]))
        return flattened

# Positional Embedding

In [None]:
# PositionEmbedding class that adds positional encodings to the tokenized input
class PositionEmbedding(L.Layer):
    def __init__(self):
        super(PositionEmbedding, self).__init__()

    def build(self, input_shape):
        # Get the sequence length and projection dimension from the input shape
        sequence_length = input_shape[1]  
        projection_dim = input_shape[-1]  
        
        # Create an embedding layer to generate positional encodings
        self.embedding = L.Embedding(input_dim=sequence_length, output_dim=projection_dim)

    def call(self, inputs):
        # Generate the position indices and apply the embedding layer
        sequence_length = tf.shape(inputs)[1]  
        positions = tf.range(start=0, limit=sequence_length, delta=1)
        return self.embedding(positions)


# Transformer Encoder

In [None]:
def mlp(x, mlp_dim, dim, dropout_rate=0.1):
    x = L.Dense(mlp_dim, activation='swish')(x)
    x = L.Dropout(dropout_rate)(x)
    x = L.Dense(dim)(x)  # The output dimension is set to dim
    x = L.Dropout(dropout_rate)(x)
    return x

def transformer_encoder(x, num_heads, dim, mlp_dim):
    skip_1 = x
    x = L.LayerNormalization()(x)
    x = L.MultiHeadAttention(num_heads=num_heads, key_dim=dim)(x, x)
    x = L.Add()([x, skip_1])

    skip_2 = x
    x = L.LayerNormalization()(x)
    x = mlp(x, mlp_dim, dim)  # The output dimension of MLP is dim
    
    # Ensure skip connection has the same dimension as x
    if skip_2.shape[-1] != x.shape[-1]:
        skip_2 = L.Dense(x.shape[-1])(skip_2)
    
    x = L.Add()([x, skip_2])

    return x

# TTA-CT Block

In [None]:
def tta_ct_block(inputs, num_filters, dim, num_layers=1):
    
    # TTA Module applied to input features
    x_tta = tta_module(inputs, num_filters)
    
    tokenizer = ConvolutionalTokenizer() 
    tokenized = tokenizer(inputs)  
    
    position_embedding = PositionEmbedding()(tokenized)  
    x = tokenized + position_embedding  

    # Apply transformer layers to the tokenized inputs
    for _ in range(num_layers):
        x = transformer_encoder(x, num_heads=4, dim=dim, mlp_dim=dim*2)

    # Reshape the tokenized features back to spatial dimensions (height x width)
    B, P, N = x.shape  
    H = W = int(P**0.5)  # Calculate the height and width from the sequence length (assuming square)
    

    x = L.Reshape((H, W, N))(x)

    # Upsample to match the input dimensions
    x = L.UpSampling2D(size=(inputs.shape[1] // H, inputs.shape[2] // W), interpolation="bilinear")(x)

    
    x = L.SeparableConv2D(filters=inputs.shape[-1], kernel_size=1, padding='same', use_bias=False)(x)
    x = L.BatchNormalization()(x)  
    x = L.Activation('swish')(x)  

     
    x = L.Concatenate(name = "Vizualization_Map")([x, x_tta])  

    
    x = L.SeparableConv2D(filters=num_filters, kernel_size=3, padding='same', use_bias=False)(x)
    x = L.BatchNormalization()(x)  
    x = L.Activation('swish')(x)  

    return x




# TACT-Net Architecture

In [None]:
def build_model(input_shape):
    inputs = Input(shape=(224,224,3))
    
    
    x = SeparableConv2D(filters=64, kernel_size=3, padding='same', use_bias=False)(inputs)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    
    # TTA-CT Block
    x = tta_ct_block(x, num_filters=128, dim=128, num_layers=4)
    
    x = SeparableConv2D(filters=128, kernel_size=3, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    
    x = SeparableConv2D(filters=256, kernel_size=3, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    
    x = SeparableConv2D(filters=256, kernel_size=3, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    
    x = SeparableConv2D(filters=512, kernel_size=3, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    
    x = SeparableConv2D(filters=1024, kernel_size=3, padding='same', use_bias=False)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)
    
    # Flatten the output
    x = GlobalAveragePooling2D()(x)
    
    # Dense layers
    x = Dense(1024, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dense(512, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dense(256, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dense(128, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dense(64, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dense(16, activation='relu')(x)
    x = BatchNormalization()(x)
    
    # Final classification layer (binary output)
    outputs = Dense(1, activation='sigmoid')(x)
    model = keras.Model(inputs, outputs)
    return model