In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np

class MultimodalFusionDataset(Dataset):
    def __init__(self, text_features, audio_features, face_features, labels):
        self.text_features = text_features
        self.audio_features = audio_features 
        self.face_features = face_features
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'text': torch.FloatTensor(self.text_features[idx]),
            'audio': torch.FloatTensor(self.audio_features[idx]),
            'face': torch.FloatTensor(self.face_features[idx]),
            'label': torch.FloatTensor([self.labels[idx]])
        }

class MultimodalFusion(nn.Module):
    def __init__(self, text_dim, audio_dim, face_dim, hidden_dim=128):
        super(MultimodalFusion, self).__init__()
        
        # Individual modality encoders
        self.text_encoder = nn.Sequential(
            nn.Linear(text_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        self.audio_encoder = nn.Sequential(
            nn.Linear(audio_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        self.face_encoder = nn.Sequential(
            nn.Linear(face_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Attention mechanism for fusion
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads=4)
        
        # Final classification layers
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, text, audio, face):
        # Encode each modality
        text_encoded = self.text_encoder(text)
        audio_encoded = self.audio_encoder(audio)
        face_encoded = self.face_encoder(face)
        
        # Stack encodings for attention
        stacked = torch.stack([text_encoded, audio_encoded, face_encoded])
        
        # Apply self-attention
        attended, _ = self.attention(stacked, stacked, stacked)
        
        # Concatenate attended features
        fused = torch.cat([
            attended[0], attended[1], attended[2]
        ], dim=1)
        
        # Final classification
        output = self.classifier(fused)
        return output

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=50, learning_rate=0.001):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5)
    
    best_val_loss = float('inf')
    best_model = None
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0
        for batch in train_loader:
            text = batch['text'].to(device)
            audio = batch['audio'].to(device)
            face = batch['face'].to(device)
            labels = batch['label'].to(device)
            
            optimizer.zero_grad()
            outputs = model(text, audio, face)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
        # Validation phase
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                text = batch['text'].to(device)
                audio = batch['audio'].to(device)
                face = batch['face'].to(device)
                labels = batch['label'].to(device)
                
                outputs = model(text, audio, face)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
        
        # Print progress
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Training Loss: {train_loss/len(train_loader):.4f}')
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}')
        
        # Learning rate scheduling
        scheduler.step(val_loss)
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model = model.state_dict()
            
    return best_model