In [None]:
import torch
import torch.nn as nn
import joblib
from sklearn.preprocessing import StandardScaler
from preprocessing.loader import ResultsLoader, TextLoader, AudioLoader, FaceLoader

# Load individual models and their preprocessors
def load_models():
    # Load text model
    text_model = joblib.load('text_model.joblib')
    
    # Load audio model
    audio_checkpoint = torch.load('audio_model.pth')
    audio_model = AudioRNN(
        input_size=audio_checkpoint['input_size'],
        **audio_checkpoint['best_params']
    )
    audio_model.load_state_dict(audio_checkpoint['model_state_dict'])
    audio_scaler = audio_checkpoint['scaler_state_dict']
    
    # Load face model
    face_checkpoint = torch.load('face_model.pth')
    face_model = STRNN(
        input_size=face_checkpoint['input_size'],
        **face_checkpoint['best_params']
    )
    face_model.load_state_dict(face_checkpoint['model_state_dict'])
    face_scaler = face_checkpoint['scaler_state_dict']
    
    return text_model, audio_model, face_model, audio_scaler, face_scaler

# Define multimodal fusion model
class MultimodalFusion(nn.Module):
    def __init__(self, text_model, audio_model, face_model):
        super(MultimodalFusion, self).__init__()
        self.text_model = text_model
        self.audio_model = audio_model
        self.face_model = face_model
        
        # Freeze individual models
        for model in [self.audio_model, self.face_model]:
            for param in model.parameters():
                param.requires_grad = False
        
        # Fusion layers
        self.fusion_input_size = 256  # TODO: Adjust based on individual model output sizes
        self.fusion_layers = nn.Sequential(
            nn.Linear(self.fusion_input_size * 3, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 2)
        )
        
    def forward(self, text_input, audio_input, face_input):
        # Get embeddings from individual models
        text_output = self.text_model.transform(text_input)  # For text, use the TF-IDF transformation
        audio_output, _, _ = self.audio_model(audio_input)
        face_output, _, _ = self.face_model(face_input)
        
        # Concatenate embeddings
        combined = torch.cat((text_output, audio_output, face_output), dim=1)
        
        # Pass through fusion layers
        output = self.fusion_layers(combined)
        return output

# Load and preprocess data
def prepare_data(percentage, random_state):
    # Initialize loaders
    results_loader = ResultsLoader()
    text_loader = TextLoader()
    audio_loader = AudioLoader()
    face_loader = FaceLoader()
    
    # Load data
    df_result = results_loader.get_data(percentage=percentage, random_state=random_state)
    df_text = text_loader.get_data(percentage=percentage, random_state=random_state)
    df_audio = audio_loader.get_data(percentage=percentage, random_state=random_state)
    df_face = face_loader.get_data(percentage=percentage, random_state=random_state)
    
    # Reset index for time series data to make ID and timestamp regular columns
    df_audio = df_audio.reset_index()
    df_face = df_face.reset_index()
    
    # Merge the time series modalities (audio and face) on both ID and timestamp
    df_timeseries = pd.merge(df_audio, df_face, on=['ID', 'timestamp'])
    
    # Group the time series data by ID to get sequence-level features
    df_timeseries_grouped = df_timeseries.groupby('ID').agg({
        col: 'mean' for col in df_timeseries.columns if col not in ['ID', 'timestamp'] # NOTE: using mean as the aggregation function (maybe try something else?)
    }).reset_index()
    
    # Merge with non-time series data (text and results)
    df = pd.merge(df_text, df_timeseries_grouped, on='ID')
    df = pd.merge(df, df_result, on='ID')
    
    return df

# Training function
def train_multimodal(model, train_loader, val_loader, criterion, optimizer, n_epochs, device):
    # Similar to your existing training functions but handling multiple inputs
    # TODO: Implementation here...
    pass

# Main execution
if __name__ == "__main__":
    # Load individual models
    text_model, audio_model, face_model, audio_scaler, face_scaler = load_models()
    
    # Create multimodal model
    multimodal_model = MultimodalFusion(text_model, audio_model, face_model)
    
    # Prepare data
    percentage = 0.02
    random_state = 42
    df = prepare_data(percentage=percentage, random_state=random_state)
    
    # Create data loaders
    # TODO: Implementation here...
    
    # Train multimodal model
    # TODO: Implementation here...