In [None]:
import torch
import torch.nn as nn
import joblib
from torch.utils.data import Dataset
from sklearn.preprocessing import StandardScaler
from preprocessing.loader import ResultsLoader, TextLoader, AudioLoader, FaceLoader

# Load individual models and their preprocessors
def load_models():
    # Load text model
    text_model = joblib.load('text_model.joblib')
    
    # Load audio model
    audio_checkpoint = torch.load('audio_model.pth')
    audio_model = AudioRNN(
        input_size=audio_checkpoint['input_size'],
        **audio_checkpoint['best_params']
    )
    audio_model.load_state_dict(audio_checkpoint['model_state_dict'])
    audio_scaler = audio_checkpoint['scaler_state_dict']
    
    # Load face model
    face_checkpoint = torch.load('face_model.pth')
    face_model = STRNN(
        input_size=face_checkpoint['input_size'],
        **face_checkpoint['best_params']
    )
    face_model.load_state_dict(face_checkpoint['model_state_dict'])
    face_scaler = face_checkpoint['scaler_state_dict']
    
    return text_model, audio_model, face_model, audio_scaler, face_scaler

# Define multimodal fusion model
class MultimodalFusion(nn.Module):
    def __init__(self, text_model, audio_model, face_model):
        super(MultimodalFusion, self).__init__()
        self.text_model = text_model
        self.audio_model = audio_model
        self.face_model = face_model
        
        # Freeze individual models
        for model in [self.audio_model, self.face_model]:
            for param in model.parameters():
                param.requires_grad = False
        
        # Get output sizes from each model
        self.text_output_size = len(self.text_model.named_steps['tfidf'].get_feature_names_out())
        self.audio_output_size = self.audio_model.hidden_size * 2  # *2 for bidirectional
        self.face_output_size = self.face_model.hidden_size * 2    # *2 for bidirectional
        
        # Projection layers to standardize dimensions
        self.text_projection = nn.Linear(self.text_output_size, 256)
        self.audio_projection = nn.Linear(self.audio_output_size, 256)
        self.face_projection = nn.Linear(self.face_output_size, 256)
        
        # Fusion layers
        self.fusion_layers = nn.Sequential(
            nn.Linear(256 * 3, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 2)
        )
        
    def forward(self, text_input, audio_input, face_input):
        # Get embeddings from individual models
        # For text, convert sparse matrix to dense tensor
        text_output = torch.from_numpy(
            self.text_model.transform(text_input).toarray()
        ).float().to(text_input.device)
        text_output = self.text_projection(text_output)
        
        # For audio and face, get the outputs and discard attention weights
        audio_output, _, _ = self.audio_model(audio_input)
        audio_output = self.audio_projection(audio_output)
        
        face_output, _, _ = self.face_model(face_input)
        face_output = self.face_projection(face_output)
        
        # Concatenate embeddings
        combined = torch.cat((text_output, audio_output, face_output), dim=1)
        
        # Pass through fusion layers
        output = self.fusion_layers(combined)
        return output

# Load and preprocess data
def prepare_data(percentage, random_state):
    # Initialize loaders
    results_loader = ResultsLoader()
    text_loader = TextLoader()
    audio_loader = AudioLoader()
    face_loader = FaceLoader()
    
    # Load data
    df_result = results_loader.get_data(percentage=percentage, random_state=random_state)
    df_text = text_loader.get_data(percentage=percentage, random_state=random_state)
    df_audio = audio_loader.get_data(percentage=percentage, random_state=random_state)
    df_face = face_loader.get_data(percentage=percentage, random_state=random_state)
    
    # Reset index for time series data to make ID and timestamp regular columns
    df_audio = df_audio.reset_index()
    df_face = df_face.reset_index()
    
    # Merge the time series modalities (audio and face) on both ID and timestamp
    df_timeseries = pd.merge(df_audio, df_face, on=['ID', 'timestamp'])
    
    # Group the time series data by ID to get sequence-level features
    df_timeseries_grouped = df_timeseries.groupby('ID').agg({
        col: 'mean' for col in df_timeseries.columns if col not in ['ID', 'timestamp'] # NOTE: using mean as the aggregation function (maybe try something else?)
    }).reset_index()
    
    # Merge with non-time series data (text and results)
    df = pd.merge(df_text, df_timeseries_grouped, on='ID')
    df = pd.merge(df, df_result, on='ID')
    
    return df

# Training function
def train_multimodal(model, train_loader, val_loader, criterion, optimizer, n_epochs, device):
    model = model.to(device)
    best_val_loss = float('inf')
    early_stopping_counter = 0
    early_stopping_patience = 7
    
    train_losses = []
    val_losses = []
    
    for epoch in range(n_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        
        for batch_text, batch_audio, batch_face, batch_y in train_loader:
            # Move all inputs to device
            batch_text = batch_text.to(device)
            batch_audio = batch_audio.to(device)
            batch_face = batch_face.to(device)
            batch_y = batch_y.to(device)
            
            optimizer.zero_grad()
            outputs = model(batch_text, batch_audio, batch_face)
            loss = criterion(outputs, batch_y)
            
            loss.backward()
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_train_loss += loss.item()
        
        # Validation phase
        model.eval()
        total_val_loss = 0
        
        with torch.no_grad():
            for batch_text, batch_audio, batch_face, batch_y in val_loader:
                # Move all inputs to device
                batch_text = batch_text.to(device)
                batch_audio = batch_audio.to(device)
                batch_face = batch_face.to(device)
                batch_y = batch_y.to(device)
                
                outputs = model(batch_text, batch_audio, batch_face)
                val_loss = criterion(outputs, batch_y)
                total_val_loss += val_loss.item()
        
        # Calculate average losses
        avg_train_loss = total_train_loss / len(train_loader)
        avg_val_loss = total_val_loss / len(val_loader)
        
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
        
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        
        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            early_stopping_counter = 0
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'epoch': epoch,
            }, 'best_multimodal_model.pth')
        else:
            early_stopping_counter += 1
            
        if early_stopping_counter >= early_stopping_patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
            
        if (epoch + 1) % 5 == 0:
            print(f'Epoch [{epoch+1}/{n_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}')
    
    return train_losses, val_losses

# Main execution
if __name__ == "__main__":
    # Load individual models
    text_model, audio_model, face_model, audio_scaler, face_scaler = load_models()
    
    # Create multimodal model
    multimodal_model = MultimodalFusion(text_model, audio_model, face_model)
    
    # Prepare data
    percentage = 0.02
    random_state = 42
    df = prepare_data(percentage=percentage, random_state=random_state)
    
    # Create data loaders
    # TODO: Implementation here...
    
    # Train multimodal model
    # TODO: Implementation here...

In [None]:
# EXAMPLE
# an example of the main block for the training implementation 

# Example usage in main execution block:
if __name__ == "__main__":
    # ... existing code ...
    
    # Create data loaders
    batch_size = 32
    train_dataset = MultimodalDataset(X_train_text, X_train_audio, X_train_face, y_train)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = MultimodalDataset(X_val_text, X_val_audio, X_val_face, y_val)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    # Initialize model, criterion, optimizer and scheduler
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(
        multimodal_model.parameters(),
        lr=0.001,
        weight_decay=0.01
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.1,
        patience=3,
        verbose=True
    )
    
    # Train multimodal model
    train_losses, val_losses = train_multimodal(
        model=multimodal_model,
        train_loader=train_loader,
        val_loader=val_loader,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        n_epochs=50,
        device=device
    )
    
    # Plot training curves
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()