In [1]:
import pandas as pd
import torch
import torch.nn as nn
import joblib
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from preprocessing.loader import ResultsLoader, TextLoader, AudioLoader, FaceLoader
from models.audio_rnn import AudioRNN
from models.face_strnn import FaceSTRNN, SpatialAttention, TemporalAttention
from utils import training as train


# Load and preprocess data
def prepare_data(percentage, random_state):
    # Initialize loaders
    results_loader = ResultsLoader()
    text_loader = TextLoader()
    audio_loader = AudioLoader()
    face_loader = FaceLoader()

    # Load data
    df_result = results_loader.get_data(percentage=percentage, random_state=random_state)
    df_text = text_loader.get_data(percentage=percentage, random_state=random_state)
    df_audio = audio_loader.get_data(percentage=percentage, random_state=random_state)
    df_face = face_loader.get_data(percentage=percentage, random_state=random_state)

    # Reset index for time series data to make ID and timestamp regular columns
    df_audio = df_audio.reset_index()
    df_face = df_face.reset_index()

    # Merge the time series modalities (audio and face) on both ID and timestamp
    df_timeseries = pd.merge(df_audio, df_face, on=['ID', 'timestamp'])

    # Group the time series data by ID to get sequence-level features
    df_timeseries_grouped = df_timeseries.groupby('ID').agg({
        col: 'mean' for col in df_timeseries.columns if col not in ['ID', 'timestamp']
        # NOTE: using mean as the aggregation function (maybe try something else?)
    }).reset_index()

    # Merge with non-time series data (text and results)
    df = pd.merge(df_text, df_timeseries_grouped, on='ID')
    df = pd.merge(df, df_result, on='ID')

    return df


# Load individual models and their preprocessors
def load_models():
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load text model
    text_model = joblib.load('text_model.joblib')

    # Load audio model 
    audio_model, audio_scaler = train.load_model(AudioRNN, "audio_model.pth", device)
    # audio_checkpoint = torch.load('audio_model.pth')
    # audio_model = AudioRNN(
    #     input_size=audio_checkpoint['input_size'],
    #     **audio_checkpoint['best_params']
    # )
    # audio_model.load_state_dict(audio_checkpoint['model_state_dict'])
    # audio_scaler = audio_checkpoint['scaler_state_dict']

    # Load face model 
    face_model, face_scaler= train.load_model(FaceSTRNN, "face_model.pth", device)

    # face_checkpoint = torch.load('face_model.pth')
    # face_model = FaceSTRNN(
    #     input_size=face_checkpoint['input_size'],
    #     **face_checkpoint['best_params']
    # )
    # face_model.load_state_dict(face_checkpoint['model_state_dict'])
    # face_scaler = face_checkpoint['scaler_state_dict']

    return text_model, audio_model, face_model, audio_scaler, face_scaler


In [None]:

# Main execution
if __name__ == "__main__":
    # Load individual models
    text_model, audio_model, face_model, audio_scaler, face_scaler = load_models()

    # Create multimodal model
    multimodal_model = MultimodalFusion(text_model, audio_model, face_model)

    # Prepare data
    percentage = 0.02
    random_state = 42
    df = prepare_data(percentage=percentage, random_state=random_state)

    # Create data loaders
    # TODO: Implementation here...

    # Train multimodal model
    # TODO: Implementation here...

In [None]:
# EXAMPLE
# an example of the main block for the training implementation 

# Example usage in main execution block:
if __name__ == "__main__":
    # ... existing code ...

    # Create data loaders
    batch_size = 32
    train_dataset = MultimodalDataset(X_train_text, X_train_audio, X_train_face, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = MultimodalDataset(X_val_text, X_val_audio, X_val_face, y_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Initialize model, criterion, optimizer and scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        multimodal_model.parameters(),
        lr=0.001,
        weight_decay=0.01
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=3,
        verbose=True
    )

    # Train multimodal model
    train_losses, val_losses = train_multimodal(
        model=multimodal_model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        n_epochs=50,
        device=device
    )

    # Plot training curves
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()