### Import libraries

In [None]:
import pandas as pd
import numpy as np

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Input, LayerNormalization
from tensorflow.keras.layers import MultiHeadAttention, GlobalAveragePooling1D, Layer, Add

from tqdm import tqdm
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback
from sklearn.metrics import classification_report

### labels (Train, Validation, Test)

In [None]:
train_labels = pd.read_csv('train-data-annotation-v1.csv')['Empathic-Level']
val_labels = pd.read_csv('val-data-annotation-v1.csv')['Empathic-Level']
test_labels =pd.read_csv('test-data-annotation-v1.csv')['Empathic-Level']

### Load the visual features

In [None]:
#FaceNet
train_visual_facenet = np.load('visual_train_features_facenet.npy')
val_visual_facenet = np.load('visual_val_features_facenet.npy')
test_visual_facenet = np.load('visual_test_features_facenet.npy')

train_visual_facenet.shape, val_visual_facenet.shape, test_visual_facenet.shape

In [None]:
#FER
train_visual_fer = np.load('train_visual_features_fer.npy')
val_visual_fer = np.load('val_visual_features_fer.npy')
test_visual_fer = np.load('test_visual_features_fer.npy')

train_visual_fer.shape, val_visual_fer.shape, test_visual_fer.shape

In [None]:
#pose
train_visual_pose = np.load('visual_train_features_pose.npy')
val_visual_pose = np.load('visual_val_features_pose.npy')
test_visual_pose = np.load('visual_test_features_pose.npy')

train_visual_pose.shape, val_visual_pose.shape, test_visual_pose.shape

In [None]:
#gaze
train_visual_gaze = np.load('visual_train_features_gaze.npy')
val_visual_gaze = np.load('visual_val_features_gaze.npy')
test_visual_gaze = np.load('visual_test_features_gaze.npy')

train_visual_gaze.shape, val_visual_gaze.shape, test_visual_gaze.shape

### Load the audio features

In [None]:
#Wav2Vec 2.0

train_audio_features_wav = np.load('train_listener_features_wv2.npy')
val_audio_features_wav = np.load('val_listener_features_wv2.npy')
test_audio_features_wav = np.load('test_listener_features_wv2.npy')

train_audio_features_wav.shape, val_audio_features_wav.shape, test_audio_features_wav.shape

In [None]:
#HuBERT

train_audio_features_hb = np.load('train_listener_features_hb.npy')
val_audio_features_hb = np.load('val_listener_features_hb.npy')
test_audio_features_hb = np.load('test_listener_features_hb.npy')

train_audio_features_hb.shape, val_audio_features_hb.shape, test_audio_features_hb.shape

### Load the text features

In [None]:
train_text_features = np.load('train_text_features_dikobert.npy')
val_text_features = np.load('val_text_features_dikobert.npy')
test_text_features = np.load('test_text_features_dikobert.npy')

train_text_features.shape, val_text_features.shape, test_text_features.shape

### Load the bio features

In [None]:
train_bio_features = np.load('./extracted-features/train_listener_bio_features.npy')
val_bio_features = np.load('./extracted-features/val_listener_bio_features.npy')
test_bio_features = np.load('./extracted-features/test_listener_bio_features.npy')

train_bio_features.shape, val_bio_features.shape, test_bio_features.shape

### Transformers

In [None]:
class FeedForward(Layer):
    def __init__(self, hidden_size, intermediate_size, dropout_rate):
        super(FeedForward, self).__init__()
        # First dense layer with GELU activation
        self.linear_1 = Dense(intermediate_size, activation='gelu')
        # Second dense layer to project back to hidden_size
        self.linear_2 = Dense(hidden_size)
        # Dropout layer for regularization
        self.dropout = Dropout(dropout_rate)

    def call(self, x):
        # Apply the feed-forward network
        x = self.linear_1(x)
        x = self.linear_2(x)
        x = self.dropout(x)
        return x

class TransformerEncoderLayer(Layer):
    def __init__(self, hidden_size, intermediate_size, num_heads, dropout_rate):
        super(TransformerEncoderLayer, self).__init__()
        # Layer normalization before attention and feed-forward
        self.layer_norm_1 = LayerNormalization(epsilon=1e-6)
        self.layer_norm_2 = LayerNormalization(epsilon=1e-6)
        # Multi-head attention layer
        self.attention = MultiHeadAttention(num_heads=num_heads, key_dim=hidden_size)
        # Feed-forward network
        self.feed_forward = FeedForward(hidden_size, intermediate_size, dropout_rate)

    def call(self, x):
        # Apply layer normalization and self-attention
        hidden_state = self.layer_norm_1(x)
        attn_output = self.attention(hidden_state, hidden_state, hidden_state)
        # Add residual connection
        x = x + attn_output
        # Apply layer normalization and feed-forward network
        x = self.layer_norm_2(x)
        x = x + self.feed_forward(x)
        return x

In [None]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, max_position_embeddings, feature_dim, dropout_rate=0.1):
        super(PositionalEmbedding, self).__init__()
        self.position_embeddings = self.add_weight(
            name="pos_embeddings",
            shape=[max_position_embeddings, feature_dim],
            initializer="glorot_uniform")
        self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x):
        # Get the sequence length and batch size from the input shape
        seq_length = tf.shape(x)[1]
        batch_size = tf.shape(x)[0]
        
        # Generate position indices and gather the relevant embeddings
        position_ids = tf.range(start=0, limit=seq_length, delta=1)
        position_embeddings = tf.gather(self.position_embeddings, position_ids)
        
        # Broadcast the position embeddings to match the input shape
        position_embeddings = tf.broadcast_to(position_embeddings, [batch_size, seq_length, tf.shape(x)[-1]])
        
        # Add the position embeddings to the input
        embeddings = x + position_embeddings
        embeddings = self.layer_norm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

In [None]:
class TransformerEncoder(tf.keras.Model):
    def __init__(self, num_hidden_layers, hidden_size, intermediate_size, num_heads, dropout_rate, max_position_embeddings):
        super(TransformerEncoder, self).__init__()
        self.positional_embedding = PositionalEmbedding(max_position_embeddings, hidden_size, dropout_rate)
        
        # Transformer Encoder Layers
        self.encoder_layers = [TransformerEncoderLayer(hidden_size, intermediate_size, num_heads, dropout_rate) 
                               for _ in range(num_hidden_layers)]

    def call(self, x):
        x = self.positional_embedding(x)
        for layer in self.encoder_layers:
            x = layer(x)
        return x

### Bi-directional MLP-Mixer Layer

In [None]:
class MixerLayer(Layer):
    def __init__(self, tokens_mlp_dim, channels_mlp_dim, dropout_rate, return_sequences=True):
        super().__init__()
        self.tokens_mlp_dim = tokens_mlp_dim
        self.channels_mlp_dim = channels_mlp_dim
        self.dropout_rate = dropout_rate
        self.return_sequences = return_sequences

    def build(self, input_shape):
        self.layer_norm1 = LayerNormalization(epsilon=1e-6)
        self.layer_norm2 = LayerNormalization(epsilon=1e-6)
        self.dense1 = Dense(self.tokens_mlp_dim, activation=tf.nn.gelu)
        self.dense2 = Dense(input_shape[1], activation=tf.nn.gelu)
        self.dense3 = Dense(self.channels_mlp_dim, activation=tf.nn.gelu)
        self.dense4 = Dense(input_shape[2], activation=tf.nn.gelu)
        self.dropout = Dropout(self.dropout_rate)
        
    def call(self, inputs):
        # Token mixing
        x = self.layer_norm1(inputs)
        x_t = tf.transpose(x, perm=[0, 2, 1])
        x_t = self.dense1(x_t)
        x_t = self.dense2(x_t)
        x_t = tf.transpose(x_t, perm=[0, 2, 1])
        x = Add()([x, x_t])

        # Channel mixing
        y = self.layer_norm2(x)
        y = self.dense3(y)
        y = self.dense4(y)
        y = Add()([x, y])
        y = self.dropout(y)

        if not self.return_sequences:
            y = tf.reduce_mean(y, axis=1)

        return y

class BackwardMixerLayer(Layer):
    def __init__(self, tokens_mlp_dim, channels_mlp_dim, dropout_rate, return_sequences=True):
        super().__init__()
        self.tokens_mlp_dim = tokens_mlp_dim
        self.channels_mlp_dim = channels_mlp_dim
        self.dropout_rate = dropout_rate
        self.return_sequences = return_sequences

    def build(self, input_shape):
        self.layer_norm1 = LayerNormalization(epsilon=1e-6)
        self.layer_norm2 = LayerNormalization(epsilon=1e-6)
        self.dense1 = Dense(self.tokens_mlp_dim, activation=tf.nn.gelu)
        self.dense2 = Dense(input_shape[1], activation=tf.nn.gelu)
        self.dense3 = Dense(self.channels_mlp_dim, activation=tf.nn.gelu)
        self.dense4 = Dense(input_shape[2], activation=tf.nn.gelu)
        self.dropout = Dropout(self.dropout_rate)
        
    def call(self, inputs):
        # Token mixing (backward)
        x = self.layer_norm1(inputs)
        x_t = tf.reverse(tf.transpose(x, perm=[0, 2, 1]), axis=[1])
        x_t = self.dense1(x_t)
        x_t = self.dense2(x_t)
        x_t = tf.reverse(tf.transpose(x_t, perm=[0, 2, 1]), axis=[1])
        x = Add()([x, x_t])

        # Channel mixing (backward)
        y = self.layer_norm2(x)
        y = tf.reverse(self.dense3(y), axis=[1])
        y = self.dense4(y)
        y = tf.reverse(y, axis=[1])
        y = Add()([x, y])
        y = self.dropout(y)

        if not self.return_sequences:
            y = tf.reduce_mean(y, axis=1)

        return y

class BidirectionalMixerLayer(Layer):
    def __init__(self, tokens_mlp_dim, channels_mlp_dim, dropout_rate, merge_mode='sum', return_sequences=True):
        super().__init__()
        self.forward_layer = MixerLayer(tokens_mlp_dim, channels_mlp_dim, dropout_rate, return_sequences)
        self.backward_layer = BackwardMixerLayer(tokens_mlp_dim, channels_mlp_dim, dropout_rate, return_sequences)
        self.merge_mode = merge_mode
        self.return_sequences = return_sequences

    def call(self, inputs):
        # Forward pass
        forward_output = self.forward_layer(inputs)
        
        # Backward pass
        backward_output = self.backward_layer(inputs)
        
        # Merge outputs
        if self.merge_mode == 'concat':
            combined_output = tf.concat([forward_output, backward_output], axis=-1)
        elif self.merge_mode == 'sum':
            combined_output = forward_output + backward_output
        elif self.merge_mode == 'mul':
            combined_output = forward_output * backward_output
        elif self.merge_mode == 'ave':
            combined_output = (forward_output + backward_output) / 2
        else:
            combined_output = [forward_output, backward_output]

        return combined_output

### Model Definition (Visual, Audio)

In [None]:
visual_input_shape_facenet = (train_visual_facenet.shape[1], train_visual_facenet.shape[2])
visual_input_shape_fer = (train_visual_fer.shape[1], train_visual_fer.shape[2])
visual_input_shape_pose = (train_visual_pose.shape[1], train_visual_pose.shape[2])
visual_input_shape_gaze = (train_visual_gaze.shape[1], train_visual_gaze.shape[2])

In [None]:
audio_input_shape_wav = (train_audio_features_wav.shape[1], train_audio_features_wav.shape[2])
audio_input_shape_hb = (train_audio_features_hb.shape[1], train_audio_features_hb.shape[2])

In [None]:
visual_model_facenet = Sequential([
    Input(shape=visual_input_shape_facenet),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    GlobalAveragePooling1D()
])

visual_model_fer = Sequential([
    Input(shape=visual_input_shape_fer),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    GlobalAveragePooling1D()
])

visual_model_pose = Sequential([
    Input(shape=visual_input_shape_pose),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    GlobalAveragePooling1D()
])

visual_model_gaze = Sequential([
    Input(shape=visual_input_shape_gaze),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    GlobalAveragePooling1D()
])

In [None]:
audio_model_wav = Sequential([
    Input(shape=audio_input_shape_wav),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    GlobalAveragePooling1D()
])

audio_model_hb = Sequential([
    Input(shape=audio_input_shape_hb),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    BidirectionalMixerLayer(tokens_mlp_dim=128, channels_mlp_dim=128, dropout_rate=0.4, return_sequences=True),
    GlobalAveragePooling1D()
])

### Text Model

In [None]:
class TransformerForText(tf.keras.Model):
    def __init__(self, num_hidden_layers, hidden_size, intermediate_size, num_heads, dropout_rate, max_position_embeddings, num_classes):
        super(TransformerForText, self).__init__()
        self.encoder = TransformerEncoder(num_hidden_layers, hidden_size, intermediate_size, num_heads, dropout_rate, max_position_embeddings)
        self.pooling = GlobalAveragePooling1D()
        self.dense = Dense(256)
        
    def call(self, inputs):
        x = self.encoder(inputs)  
        x = self.pooling(x)  
        x = self.dense(x)
        
        return x

num_hidden_layers = 4
hidden_size = 768  
intermediate_size = 2048  
num_heads = 6  
dropout_rate = 0.1
max_position_embeddings = 210
num_classes = 7
feature_dim = hidden_size

text_model = TransformerForText(num_hidden_layers, hidden_size, intermediate_size, num_heads, dropout_rate, max_position_embeddings, num_classes)

### Bio Model

In [None]:
class TransformerForBio(tf.keras.Model):
    def __init__(self, num_hidden_layers, hidden_size, intermediate_size, num_heads, dropout_rate, max_position_embeddings):
        super(TransformerForBio, self).__init__()
        self.encoder = TransformerEncoder(num_hidden_layers, hidden_size, intermediate_size, num_heads, dropout_rate, max_position_embeddings)
        self.pooling = GlobalAveragePooling1D()
        self.dense = Dense(256)

    def call(self, inputs):
        x = self.encoder(inputs)
        x = self.pooling(x)
        x = self.dense(x)
        
        return x

num_hidden_layers = 4
hidden_size = 4 
intermediate_size = 32  
num_heads = 2 
dropout_rate = 0.1
max_position_embeddings = 3

bio_model = TransformerForBio(num_hidden_layers, hidden_size, intermediate_size, num_heads, dropout_rate, max_position_embeddings)

### Cross-Attention

In [None]:
class CrossAttention(Layer):
    def __init__(self, units):
        super(CrossAttention, self).__init__()
        self.units = units
        # Dense layers for transforming query, key, and value
        self.dense_query = Dense(units)
        self.dense_key = Dense(units)
        self.dense_value = Dense(units)
        # Scaling factor for dot product attention
        self.scale = tf.sqrt(tf.cast(units, tf.float32))
        # Softmax layer for attention weights
        self.softmax = tf.keras.layers.Softmax(axis=-1)
        
    def call(self, inputs):
        query, key = inputs
        # Transform query, key, and value
        query = self.dense_query(query)
        key = self.dense_key(key)
        value = self.dense_value(key)
        
        # Compute scaled dot-product attention
        score = tf.matmul(query, key, transpose_b=True) / self.scale
        # Apply softmax to get attention weights
        alignment = self.softmax(score)
        # Compute the context vector
        context = tf.matmul(alignment, value)
        return context

In [None]:
class WeightedSum(tf.keras.layers.Layer):
    def __init__(self, num_inputs):
        super(WeightedSum, self).__init__()
        # Initialize trainable weights for attention
        self.attention_weights = self.add_weight(shape=(num_inputs,),
                                                 initializer='random_normal',
                                                 trainable=True)

    def call(self, inputs):
        # Convert all inputs to float32 tensors
        tensor_inputs = [tf.convert_to_tensor(input, dtype=tf.float32) for input in inputs]
        
        # Stack inputs along the last axis
        stacked = tf.stack(tensor_inputs, axis=-1)
        
        # Apply softmax to attention weights and compute weighted sum
        weighted_sum = tf.reduce_sum(stacked * tf.nn.softmax(self.attention_weights), axis=-1)
        
        return weighted_sum

### Multi-modal Model

In [None]:
class MultiModalModel(tf.keras.Model):
    def __init__(self, visual_model_facenet, visual_model_fer, visual_model_pose, visual_model_gaze, 
                 audio_model_wav, audio_model_hb, text_model, bio_model):
        super(MultiModalModel, self).__init__()
        # Initialize individual models for each modality
        self.visual_model_facenet = visual_model_facenet
        self.visual_model_fer = visual_model_fer
        self.visual_model_pose = visual_model_pose
        self.visual_model_gaze = visual_model_gaze
        self.audio_model_wav = audio_model_wav
        self.audio_model_hb = audio_model_hb
        self.text_model = text_model
        self.bio_model = bio_model  
        
        # Cross-attention layers for visual and audio modalities
        self.cross_attention_visual_1 = CrossAttention(256)
        self.cross_attention_visual_2 = CrossAttention(256)
        self.cross_attention_audio = CrossAttention(256)
        
        # Weighted sum layer for fusion of all modalities
        self.weighted_sum = WeightedSum(num_inputs=5)
        
        # Final classification layer
        self.classifier = Dense(7, activation='softmax')
        
    def call(self, inputs):
        # Unpack inputs for each modality
        visual_input_facenet, visual_input_fer, visual_input_pose, visual_input_gaze, \
        audio_input_wav, audio_input_hb, text_input, bio_input = inputs
        
        # Process inputs through respective models
        visual_output_facenet = self.visual_model_facenet(visual_input_facenet)
        visual_output_fer = self.visual_model_fer(visual_input_fer)
        visual_output_pose = self.visual_model_pose(visual_input_pose)
        visual_output_gaze = self.visual_model_gaze(visual_input_gaze)
        audio_output_wav = self.audio_model_wav(audio_input_wav)
        audio_output_hb = self.audio_model_hb(audio_input_hb)
        text_output = self.text_model(text_input)
        bio_output = self.bio_model(bio_input)  
        
        # Apply cross-attention to visual outputs
        cross_attention_visual_output_1 = self.cross_attention_visual_1([visual_output_facenet, visual_output_fer])
        cross_attention_visual_output_2 = self.cross_attention_visual_2([visual_output_pose, visual_output_gaze])
    
        # Apply cross-attention to audio outputs
        cross_attention_audio_output = self.cross_attention_audio([audio_output_wav, audio_output_hb])
    
        # Fuse all modalities using weighted sum
        fusion = self.weighted_sum([
            cross_attention_visual_output_1, 
            cross_attention_visual_output_2, 
            cross_attention_audio_output, 
            text_output, 
            bio_output
        ])
        
        # Final classification
        x = self.classifier(fusion)
        return x

In [None]:
baseline_model = MultiModalModel(visual_model_facenet, visual_model_fer, visual_model_pose, visual_model_gaze, audio_model_wav, audio_model_hb, text_model, bio_model)  

### Customized F1-Callback

In [None]:
class F1ScoreCallback(Callback):
    def __init__(self, validation_data, patience=300, checkpoint_filepath='file_path', batch_size=2):
        super(F1ScoreCallback, self).__init__()
        self.validation_data = validation_data
        self.patience = patience
        self.best_weights = None
        self.best_f1 = -np.inf
        self.wait = 0
        self.checkpoint_filepath = checkpoint_filepath
        self.batch_size = batch_size
        
        if not os.path.exists(self.checkpoint_filepath):
            os.makedirs(self.checkpoint_filepath)

    def on_epoch_end(self, epoch, logs=None):
        val_predict = np.argmax(self.model.predict(self.validation_data[0], batch_size=self.batch_size), axis=1)
        val_targ = self.validation_data[1]
        _val_f1 = classification_report(val_targ, val_predict, output_dict=True)['weighted avg']['f1-score']
        logs['val_f1'] = _val_f1

        if _val_f1 > self.best_f1:
            self.best_f1 = _val_f1
            self.best_weights = self.model.get_weights()
            self.model.save_weights(f"{self.checkpoint_filepath}epoch_{epoch:02d}_val_f1_{_val_f1:.5f}.h5")
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.model.stop_training = True
                self.model.set_weights(self.best_weights)

### Model Training

In [None]:
def train_and_evaluate(model):
    
    learning_rate = 2e-5
    optimizer = Adam(lr=learning_rate)
    
    model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])


    f1_score_callback = F1ScoreCallback(validation_data=((val_visual_facenet, val_visual_fer, val_visual_pose, val_visual_gaze, 
                                                          val_audio_features_wav, val_audio_features_hb, val_text_features, val_bio_features), val_labels))

    model.fit((train_visual_facenet, train_visual_fer, train_visual_pose, train_visual_gaze, train_audio_features_wav, 
               train_audio_features_hb, train_text_features, train_bio_features), train_labels, epochs=1000, 
              validation_data=((val_visual_facenet, val_visual_fer, val_visual_pose, val_visual_gaze, 
                                val_audio_features_wav, val_audio_features_hb, val_text_features, val_bio_features), val_labels), 
              batch_size=2)

train_and_evaluate(baseline_model)