In [None]:
import os

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.init as init
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
    

In [None]:
# Class for Gated Multimodal Unit of Arevalo et al. (2017)
class GatedMultimodalUnit(torch.nn.Module):
    def __init__(self, text_dim, audio_dim, video_dim, output_dim):
        super(GatedMultimodalUnit, self).__init__()
        
        # Linear transformation for text features
        self.text_linear = nn.Linear(text_dim, output_dim)
        
        # Convolutional layers for audio and video features
        self.audio_conv = nn.Conv1d(audio_dim, output_dim, kernel_size=1)
        self.video_conv = nn.Conv1d(video_dim, output_dim, kernel_size=1)
        
        # Activation functions
        self.activation = nn.Tanh()
        self.gate_activation = nn.Sigmoid() # Logistic()
        
        # Weight matrices for each modality
        self.W1 = nn.Parameter(torch.Tensor(output_dim, output_dim))
        self.W2 = nn.Parameter(torch.Tensor(output_dim, output_dim))
        self.W3 = nn.Parameter(torch.Tensor(output_dim, output_dim))
        
        # Gating matrices
        self.Y1 = nn.Parameter(torch.Tensor(output_dim, output_dim * 3))
        self.Y2 = nn.Parameter(torch.Tensor(output_dim, output_dim * 3))
        self.Y3 = nn.Parameter(torch.Tensor(output_dim, output_dim * 3))
        
        # Initialize weights
        self._initialize_weights()
        
    def _initialize_weights(self):
        
        # Initialize weight matrices
        init.xavier_uniform_(self.W1)
        init.xavier_uniform_(self.W2)
        init.xavier_uniform_(self.W3)
        
        # Initialize gating matrices
        init.xavier_uniform_(self.Y1)
        init.xavier_uniform_(self.Y2)
        init.xavier_uniform_(self.Y3)
        
        
    def forward(self, text_features, audio_features, video_features):
        
        # Process text features
        x_t = self.text_linear(text_features)  # Shape: [batch_size, output_dim]
        h1 = self.activation(torch.matmul(x_t, self.W1))  # Shape: [batch_size, output_dim]

        # Process audio features
        audio_features = audio_features.squeeze(1).permute(0, 2, 1)  # Shape: [batch_size, audio_dim, sequence_length]
        x_a = self.audio_conv(audio_features).mean(dim=-1)  # Shape: [batch_size, output_dim]
        h2 = self.activation(torch.matmul(x_a, self.W2))  # Shape: [batch_size, output_dim]

        # Process video features
        video_features = video_features.permute(0, 2, 1)  # Shape: [batch_size, video_dim, sequence_length]
        x_v = self.video_conv(video_features).mean(dim=-1)  # Shape: [batch_size, output_dim]
        h3 = self.activation(torch.matmul(x_v, self.W3))  # Shape: [batch_size, output_dim]

        # Combine processed features
        x = torch.cat((x_t, x_a, x_v), dim=1)
               
        # Compute gating weights for each modality
        z1 = self.gate_activation(torch.matmul(x, self.Y1.t()))
        z2 = self.gate_activation(torch.matmul(x, self.Y2.t()))
        z3 = self.gate_activation(torch.matmul(x, self.Y3.t()))

        # Calculate final output
        h = z1 * h1 + z2 * h2 + z3 * h3

        return h