# Multimodal Depression Detection Model

In [None]:
from typing import Dict, List, Tuple

import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.model_selection import ParameterGrid, TimeSeriesSplit
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

from models.audio_rnn import AudioRNN
from models.face_strnn import FaceSTRNN
from models.multimodal_fusion import MultimodalFusion
from preprocessing.loader_audio import AudioLoader
from preprocessing.loader_face import FaceLoader
from preprocessing.loader_results import ResultsLoader
from preprocessing.loader_text import TextLoader
from training.trainer_multimodal_fusion import MultimodalFusionTrainer

# Constants
RANDOM_STATE = 42
DATA_PERCENTAGE = 0.05  # Percentage of total data to use
BATCH_SIZE = 32
N_EPOCHS = 50
FIGURE_SIZE = (15, 8)

# Hyperparameter grid for model tuning
PARAM_GRID = {
    'learning_rate': [0.001, 0.0001],
    'weight_decay': [0.01, 0.001],
    'dropout': [0.2, 0.3]
}

# Device configuration
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {DEVICE}")

## Data Preparation

In [None]:
from typing import Any
from utils.pca_utils import load_and_transform_pca


def prepare_data(
    percentage: float = DATA_PERCENTAGE, random_state: int = RANDOM_STATE
) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    # Initialize loaders
    results_loader = ResultsLoader()
    text_loader = TextLoader()
    audio_loader = AudioLoader()
    face_loader = FaceLoader()

    # Load data
    df_result = results_loader.get_data(
        percentage=percentage, random_state=random_state
    )
    df_text = text_loader.get_data(percentage=percentage, random_state=random_state)
    df_audio = audio_loader.get_data(
        percentage=percentage, random_state=random_state, ds_freq="10s", rw_size="10s"
    )
    df_face = face_loader.get_data(
        percentage=percentage, random_state=random_state, ds_freq="10s", rw_size="10s"
    )

    # text features preprocessing
    # ...

    # Audio features PCA
    df_audio_pca = load_and_transform_pca(df_audio, ["models/pca_audio.pkl"])

    # Face features PCA
    df_face_pca = load_and_transform_pca(
        df_face,
        [
            "models/pca_face_action_units.pkl",
            "models/pca_face_gaze.pkl",
            "models/pca_face_pose.pkl",
        ],
    )

    return df_text, df_audio_pca, df_face_pca, df_result


def load_models() -> Tuple[Any, nn.Module, nn.Module, StandardScaler, StandardScaler]:
    # Load individual models and their preprocessors.
    # Load text model
    text_model = joblib.load("text_model.joblib")

    import training.trainer as train

    # Load audio and face models
    audio_model, audio_scaler = train.load_model(AudioRNN, "audio_model.pth", DEVICE)

    face_model, face_scaler = train.load_model(FaceSTRNN, "face_model.pth",DEVICE)

    return text_model, audio_model, face_model, audio_scaler, face_scaler


# Load and prepare data
df_text, df_audio, df_face, df_result = prepare_data()

# Load models
# text_model, audio_model, face_model, audio_scaler, face_scaler = load_models()

# Display data overview
print("Text Data:")
display(df_text.head())

print("\nAudio Data:")
display(df_audio.head())

print("\nFace Data:")
display(df_face.head())

print("\nResults Data:")
display(df_result.head())

## Data Splitting

In [None]:
# This function aligns and merges the three modalities (text, audio, face) by a common set of keys (ID and time window).
# Audio and face are both time series data, so they are expected to have features extracted per time window (e.g., every 10s).
# Text is non-time series, but for fusion, we align each text sample to the same time window as audio/face (e.g., by transcript segment or by aggregating text features per window).
# The merge ensures that each row in the final dataset corresponds to a single sample with all three modalities for the same subject and time window.
# After merging, the function performs a stratified train/val/test split, so that all splits are aligned across modalities.
# This ensures that each sample in the split contains the correct text, audio, and face features for the same instance.

from sklearn.model_selection import train_test_split


#NOTE: double check if this function is properly splitting the time series data according to the session IDs and that there are no leaks
def prepare_aligned_data_splits(
    df_text: pd.DataFrame,
    df_audio: pd.DataFrame,
    df_face: pd.DataFrame,
    df_result: pd.DataFrame,
    test_size: float = 0.2,
    val_size: float = 0.1,
    random_state: int = RANDOM_STATE
):
    # Merge on ID and time window (adjust 'window' to your actual time window column if needed)
    merge_keys = ['ID', 'window'] if 'window' in df_audio.columns else ['ID']
    df = df_result.copy()
    df_all = df_text.merge(df_audio, on=merge_keys, suffixes=('_text', '_audio'))
    df_all = df_all.merge(df_face, on=merge_keys, suffixes=('', '_face'))
    df_all = df_all.merge(df_result, on='ID')

    # Drop rows with missing values (optional, or handle differently)
    df_all = df_all.dropna()

    # Prepare features and target
    text_features = df_all['TRANSCRIPT_text']  # or your text feature columns
    audio_features = df_all[[col for col in df_audio.columns if col not in merge_keys]]
    face_features = df_all[[col for col in df_face.columns if col not in merge_keys]]
    y = df_all['PHQ_Binary']

    # Train/val/test split (stratified if possible)
    X = pd.DataFrame({
        'text': text_features,
        'audio': list(audio_features.values),
        'face': list(face_features.values)
    })
    y = y.reset_index(drop=True)

    # First split into train+val and test
    X_trainval, X_test, y_trainval, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state, stratify=y
    )
    # Then split train+val into train and val
    val_relative_size = val_size / (1 - test_size)
    X_train, X_val, y_train, y_val = train_test_split(
        X_trainval, y_trainval, test_size=val_relative_size, random_state=random_state, stratify=y_trainval
    )

    # Return splits as tuples of (text, audio, face, y)
    def unpack_split(X_split, y_split):
        return {
            'text': list(X_split['text']),
            'audio': np.stack(X_split['audio']),
            'face': np.stack(X_split['face']),
            'label': y_split.values
        }

    return {
        'train': unpack_split(X_train, y_train),
        'val': unpack_split(X_val, y_val),
        'test': unpack_split(X_test, y_test)
    }
splits = prepare_aligned_data_splits(df_text, df_audio, df_face, df_result)


## Model Training

In [None]:
from training.trainer import save_model

def create_data_loaders(
    splits: dict,
    text_vectorizer,
    batch_size: int = BATCH_SIZE
) -> Tuple[DataLoader, DataLoader, DataLoader]:
    def to_tensor(arr, dtype=torch.float32):
        return torch.tensor(np.array(arr), dtype=dtype)

    # Convert to tensors
    # The textual data is transformed into a numeric format using the first step from the pipeline
    X_train_text = to_tensor(text_vectorizer.transform(splits['train']['text']).toarray())
    X_train_audio = to_tensor(splits['train']['audio'])
    X_train_face = to_tensor(splits['train']['face'])
    y_train = torch.tensor(splits['train']['label'], dtype=torch.long)

    X_val_text = to_tensor(text_vectorizer.transform(splits['val']['text']).toarray())
    X_val_audio = to_tensor(splits['val']['audio'])
    X_val_face = to_tensor(splits['val']['face'])
    y_val = torch.tensor(splits['val']['label'], dtype=torch.long)

    X_test_text = to_tensor(text_vectorizer.transform(splits['test']['text']).toarray())
    X_test_audio = to_tensor(splits['test']['audio'])
    X_test_face = to_tensor(splits['test']['face'])
    y_test = torch.tensor(splits['test']['label'], dtype=torch.long)

    # Ensure audio and face are 3D for LSTM: [batch, seq_len, input_dim]
    if X_train_audio.dim() == 2:
        X_train_audio = X_train_audio.unsqueeze(1)
    if X_val_audio.dim() == 2:
        X_val_audio = X_val_audio.unsqueeze(1)
    if X_test_audio.dim() == 2:
        X_test_audio = X_test_audio.unsqueeze(1)

    if X_train_face.dim() == 2:
        X_train_face = X_train_face.unsqueeze(1)
    if X_val_face.dim() == 2:
        X_val_face = X_val_face.unsqueeze(1)
    if X_test_face.dim() == 2:
        X_test_face = X_test_face.unsqueeze(1)

    # Create datasets
    train_dataset = TensorDataset(X_train_text, X_train_audio, X_train_face, y_train)
    val_dataset = TensorDataset(X_val_text, X_val_audio, X_val_face, y_val)
    test_dataset = TensorDataset(X_test_text, X_test_audio, X_test_face, y_test)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)

    return train_loader, val_loader, test_loader


def train_model_with_grid_search(splits: dict, param_grid: Dict, n_epochs: int = N_EPOCHS) -> Tuple[Dict, List[Dict]]:
    #Perform grid search to find optimal hyperparameters.
    # Load individual models
    text_model, audio_model, face_model, audio_scaler, face_scaler = load_models()
    text_vectorizer = text_model.named_steps["tfidf"]
    text_feature_dim = len(text_vectorizer.get_feature_names_out())

    # Create data loaders
    train_loader, val_loader, _ = create_data_loaders(splits, text_vectorizer, batch_size=BATCH_SIZE)

    # Initialize tracking variables
    best_val_loss = float('inf')
    best_params = None
    results = []

    # Grid search
    for params in tqdm(ParameterGrid(param_grid)):
        # Create multimodal model
        model = MultimodalFusion(
            text_feature_dim,
            audio_model,
            face_model
        ).to(DEVICE)

        # Training setup
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=params['learning_rate'],
            weight_decay=params['weight_decay']
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='min',
            factor=0.1,
            patience=3,
            verbose=True
        )

        # Initialize trainer
        trainer = MultimodalFusionTrainer(
            model=model,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=DEVICE
        )

        # Train model
        train_losses, val_losses = trainer.train(
            train_loader=train_loader,
            val_loader=val_loader,
            n_epochs=n_epochs
        )

        # Record results
        final_val_loss = val_losses[-1]
        results.append({
            'params': params,
            'final_val_loss': final_val_loss,
            'train_losses': train_losses,
            'val_losses': val_losses
        })

        # Update best parameters
        if final_val_loss < best_val_loss:
            best_val_loss = final_val_loss
            best_params = params
            best_model_state = model.state_dict()
        
            # Save the best model using the save_model function
            save_model(
                model=model,
                scaler=scaler,
                input_size={'audio': X_train['audio'].shape[1], 'face': X_train['face'].shape[1]},
                best_params=params,
                save_path='checkpoints/multimodal_best_model.pth'
            )

    return best_params, results


# Train model with grid search
best_params, results = train_model_with_grid_search(splits, PARAM_GRID)

# Print best parameters
print("\nBest parameters:")
for param, value in best_params.items():
    print(f"{param}: {value}")
# print(f"Best validation loss: {best_val_loss:.4f}")

# Plot training curves for best model
plt.figure(figsize=FIGURE_SIZE)
best_result = min(results, key=lambda x: x['final_val_loss'])
plt.plot(best_result['train_losses'], label='Training Loss')
plt.plot(best_result['val_losses'], label='Validation Loss')
plt.title('Training and Validation Loss (Best Model)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()

## Model Evaluation

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay


# def evaluate_model(model: nn.Module, test_loader: DataLoader, device: torch.device) -> Tuple[np.ndarray, np.ndarray]:
#     #Evaluate the model on the test set.
#     model.eval()
#     all_preds = []
#     all_labels = []

#     with torch.no_grad():
#         for X, y in test_loader:
#             X, y = X.to(device), y.to(device)
#             outputs = model(X)
#             _, preds = torch.max(outputs, 1)
#             all_preds.extend(preds.cpu().numpy())
#             all_labels.extend(y.cpu().numpy())

#     return np.array(all_labels), np.array(all_preds)


# # Create test loader
# _, _, test_loader = create_data_loaders(X_train, X_val, X_test, y_train, y_val, y_test)

# # Load individual models
# text_model, audio_model, face_model, audio_scaler, face_scaler = load_models()

# # Initialize best model
# best_model = MultimodalFusion(
#     text_model,
#     audio_model,
#     face_model
# ).to(DEVICE)

# # Evaluate model
# y_true, y_pred = evaluate_model(best_model, test_loader, DEVICE)

# # Print classification report
# print("Classification Report:")
# print(classification_report(y_true, y_pred))

# # Plot confusion matrix
# plt.figure(figsize=FIGURE_SIZE)
# class_labels = [0, 1]
# cm = confusion_matrix(y_true, y_pred, labels=class_labels)
# disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_labels)
# disp.plot(cmap=plt.cm.Blues)
# plt.title('Confusion Matrix')
# plt.show()

from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from training.trainer import load_model

def evaluate_model(model: nn.Module, test_loader: DataLoader, device: torch.device) -> Tuple[np.ndarray, np.ndarray]:
    """Evaluate the model on the test set."""
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for text_features, audio_features, face_features, labels in test_loader:
            # Move all inputs to device
            text_features = text_features.to(device)
            audio_features = audio_features.to(device)
            face_features = face_features.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(text_features, audio_features, face_features)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)
            
            # Collect predictions and labels
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # Print prediction statistics
    probs_array = np.array(all_probs)
    print("\nPrediction Statistics:")
    print(f"Mean probability for class 0: {probs_array[:, 0].mean():.3f}")
    print(f"Mean probability for class 1: {probs_array[:, 1].mean():.3f}")
    print(f"Number of predictions for class 0: {(np.array(all_preds) == 0).sum()}")
    print(f"Number of predictions for class 1: {(np.array(all_preds) == 1).sum()}")

    return np.array(all_labels), np.array(all_preds)

# Load individual models first (needed for the multimodal model)
text_model, audio_model, face_model, audio_scaler, face_scaler = load_models()
text_vectorizer = text_model.named_steps["tfidf"]

# Create test loader with all modalities
_, _, test_loader = create_data_loaders(splits, text_vectorizer, batch_size=BATCH_SIZE)

# Load the best multimodal model
best_model, _ = load_model(
    model_class=MultimodalFusion,
    load_path='checkpoints/multimodal_best_model.pth',
    device=DEVICE
)

# Make sure the model is in eval mode
best_model.eval()

# Evaluate model
y_true, y_pred = evaluate_model(best_model, test_loader, DEVICE)

# Print classification report with zero_division=1 to avoid warnings
print("\nClassification Report:")
print(classification_report(y_true, y_pred, zero_division=1))

# Plot confusion matrix
plt.figure(figsize=FIGURE_SIZE)
class_labels = [0, 1]
cm = confusion_matrix(y_true, y_pred, labels=class_labels)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_labels)
disp.plot(cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.show()


## Save Model

In [None]:
# Save the model
torch.save(best_model.state_dict(), 'multimodal_model.pth')
print("Model saved successfully!")