In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import torchaudio
import librosa
import os
from transformers import (
    Wav2Vec2Processor, 
    Wav2Vec2Model, 
    BertModel, 
    BertTokenizer
)
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

# Custom Dataset
import torch.nn.functional as F

class MultimodalEmotionDataset(Dataset):
    def __init__(self, csv_path, processor, tokenizer, max_length=128, target_audio_length=16000):
        # Read the CSV
        self.data = pd.read_csv(csv_path, sep='\t')
        
        # Add random text column if not exists
        if 'text' not in self.data.columns:
            print([f"Random text for {name}" for name in self.data['name']])
            self.data['text'] = [f"Random text for {name}" for name in self.data['name']]
        
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.target_audio_length = target_audio_length
        
        # Mapping emotions to indices
        self.emotion_to_idx = {emotion: idx for idx, emotion in enumerate(self.data['emotion'].unique())}
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # Load audio file
        speech_array, sampling_rate = torchaudio.load(row['path'])  # (Channels, Samples)
        
        # Convert to mono (if stereo, take the first channel)
        if speech_array.shape[0] > 1:
            speech_array = torch.mean(speech_array, dim=0, keepdim=True)
        
        speech_array = speech_array.squeeze().numpy()  # Convert to numpy array
        
        # Resample to target sampling rate
        speech_array = librosa.resample(y=speech_array, orig_sr=sampling_rate, target_sr=self.processor.feature_extractor.sampling_rate)
        
        # Ensure fixed length
        if len(speech_array) > self.target_audio_length:
            speech_array = speech_array[:self.target_audio_length]
        elif len(speech_array) < self.target_audio_length:
            padding = self.target_audio_length - len(speech_array)
            speech_array = np.pad(speech_array, (0, padding), mode='constant', constant_values=0)
    
        # Process audio with Wav2Vec2
        audio_inputs = self.processor(
            speech_array, 
            sampling_rate=self.processor.feature_extractor.sampling_rate, 
            return_tensors="pt"
        )
        
        # Process text
        text_inputs = self.tokenizer(
            row['text'], 
            max_length=self.max_length, 
            padding='max_length', 
            truncation=True, 
            return_tensors="pt"
        )
        
        # Get emotion label
        label = self.emotion_to_idx[row['emotion']]
        
        return {
            'audio_input': audio_inputs.input_values.squeeze(),  # Ensure correct shape
            'audio_mask': audio_inputs.attention_mask.squeeze(),
            'text_input_ids': text_inputs['input_ids'].squeeze(),
            'text_attention_mask': text_inputs['attention_mask'].squeeze(),
            'label': label
        }


    # def __getitem__(self, idx):
    #     row = self.data.iloc[idx]
        
    #     # Process Audio
    #     speech_array, sampling_rate = torchaudio.load(row['path'])
    #     speech_array = speech_array.squeeze().numpy()
    #     # speech_array = librosa.resample(np.asarray(speech_array), sampling_rate, self.processor.feature_extractor.sampling_rate)
    #     speech_array = librosa.resample(y=np.asarray(speech_array), orig_sr=sampling_rate, target_sr=self.processor.feature_extractor.sampling_rate)
    #     # Pad or truncate audio
    #     if len(speech_array) > self.target_audio_length:
    #         speech_array = speech_array[:self.target_audio_length]
    #     elif len(speech_array) < self.target_audio_length:
    #         padding = self.target_audio_length - len(speech_array)
    #         speech_array = np.pad(speech_array, (0, padding), mode='constant', constant_values=0)
        
    #     audio_inputs = self.processor(
    #         speech_array, 
    #         sampling_rate=self.processor.feature_extractor.sampling_rate, 
    #         return_tensors="pt"
    #     )
        
    #     # Process Text
    #     text_inputs = self.tokenizer(
    #         row['text'], 
    #         max_length=self.max_length, 
    #         padding='max_length', 
    #         truncation=True, 
    #         return_tensors="pt"
    #     )
        
    #     # Get emotion label
    #     label = self.emotion_to_idx[row['emotion']]
        
    #     return {
    #         'audio_input': audio_inputs.input_values.squeeze(),
    #         'audio_mask': audio_inputs.attention_mask.squeeze(),
    #         'text_input_ids': text_inputs['input_ids'].squeeze(),
    #         'text_attention_mask': text_inputs['attention_mask'].squeeze(),
    #         'label': label
    #     }


# Multimodal Fusion Model
class MultimodalEmotionClassifier(nn.Module):
    def __init__(self, num_labels, audio_model_path, text_model_path):
        super().__init__()
        
        # Audio Encoder (Wav2Vec2)
        self.audio_encoder = Wav2Vec2Model.from_pretrained(audio_model_path)
        
        # Text Encoder (BERT)
        self.text_encoder = BertModel.from_pretrained(text_model_path)
        
        # Freeze pretrained encoders (optional)
        for param in self.audio_encoder.parameters():
            param.requires_grad = False
        for param in self.text_encoder.parameters():
            param.requires_grad = False
        
        # Fusion Layer
        audio_feature_dim = self.audio_encoder.config.hidden_size
        text_feature_dim = self.text_encoder.config.hidden_size
        fusion_dim = audio_feature_dim + text_feature_dim
        
        self.fusion_layers = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(fusion_dim // 2, num_labels)
        )
        self.audio_norm = nn.LayerNorm(audio_feature_dim)
        self.text_norm = nn.LayerNorm(text_feature_dim)

    def forward(self, audio_input, audio_mask, text_input_ids, text_attention_mask):
        # Extract audio features
        audio_outputs = self.audio_encoder(audio_input, attention_mask=audio_mask)
        audio_features = torch.mean(audio_outputs.last_hidden_state, dim=1)  # Mean pooling
        audio_features = self.audio_norm(audio_features)  # Apply LayerNorm
    
        # Extract text features
        text_outputs = self.text_encoder(text_input_ids, attention_mask=text_attention_mask)
        text_features = text_outputs.pooler_output  # [CLS] token representation
        text_features = self.text_norm(text_features)  # Apply LayerNorm
    
        # Concatenate features
        combined_features = torch.cat([audio_features, text_features], dim=1)
    
        # Classification
        logits = self.fusion_layers(combined_features)
        
        return logits

    # def forward(self, audio_input, audio_mask, text_input_ids, text_attention_mask):
    #     # Extract audio features
    #     audio_outputs = self.audio_encoder(
    #         audio_input, 
    #         attention_mask=audio_mask
    #     )
    #     audio_features = torch.mean(audio_outputs.last_hidden_state, dim=1)
        
    #     # Extract text features
    #     text_outputs = self.text_encoder(
    #         text_input_ids, 
    #         attention_mask=text_attention_mask
    #     )
    #     text_features = text_outputs.pooler_output
        
    #     # Concatenate features
    #     combined_features = torch.cat([audio_features, text_features], dim=1)
        
    #     # Classification
    #     logits = self.fusion_layers(combined_features)
        
    #     return logits

def train_model(model, train_loader, val_loader, device, epochs=10, learning_rate=1e-4):
    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    # Create models directory if it doesn't exist
    os.makedirs('saved_models', exist_ok=True)
    
    best_val_accuracy = 0.0
    
    for epoch in range(epochs):
        # Training Phase
        model.train()
        train_losses = []
        train_preds = []
        train_true = []
        
        for batch in train_loader:
            # Move data to device
            audio_input = batch['audio_input'].to(device)
            audio_mask = batch['audio_mask'].to(device)
            text_input_ids = batch['text_input_ids'].to(device)
            text_attention_mask = batch['text_attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(
                audio_input, 
                audio_mask, 
                text_input_ids, 
                text_attention_mask
            )
            
            # Compute loss
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Track training metrics
            train_losses.append(loss.item())
            preds = torch.argmax(outputs, dim=1)
            train_preds.extend(preds.cpu().numpy())
            train_true.extend(labels.cpu().numpy())
        
        # Validation Phase
        model.eval()
        val_losses = []
        val_preds = []
        val_true = []
        
        with torch.no_grad():
            for batch in val_loader:
                # Move data to device
                audio_input = batch['audio_input'].to(device)
                audio_mask = batch['audio_mask'].to(device)
                text_input_ids = batch['text_input_ids'].to(device)
                text_attention_mask = batch['text_attention_mask'].to(device)
                labels = batch['label'].to(device)
                
                # Forward pass
                outputs = model(
                    audio_input, 
                    audio_mask, 
                    text_input_ids, 
                    text_attention_mask
                )
                
                # Compute loss
                loss = criterion(outputs, labels)
                val_losses.append(loss.item())
                
                # Track validation metrics
                preds = torch.argmax(outputs, dim=1)
                val_preds.extend(preds.cpu().numpy())
                val_true.extend(labels.cpu().numpy())
        
        # Compute metrics
        train_accuracy = accuracy_score(train_true, train_preds)
        val_accuracy = accuracy_score(val_true, val_preds)
        
        # Print epoch summary
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {np.mean(train_losses):.4f}, Train Accuracy: {train_accuracy:.4f}")
        print(f"Val Loss: {np.mean(val_losses):.4f}, Val Accuracy: {val_accuracy:.4f}")
        
        # Save best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_accuracy': best_val_accuracy,
                'epoch': epoch
            }, f'saved_models/best_multimodal_model.pth')
            
            print(f"Saved new best model with validation accuracy: {best_val_accuracy:.4f}")
        
        # Optional: Print classification report for validation set
        if epoch % 5 == 0:
            print("\nValidation Classification Report:")
            print(classification_report(val_true, val_preds))
    
    return model

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Paths and Configurations
    audio_model_path = "./emotion_recognition_model"
    text_model_path = "bert-base-uncased"
    
    # Initialize Processors
    audio_processor = Wav2Vec2Processor.from_pretrained(audio_model_path)
    text_tokenizer = BertTokenizer.from_pretrained(text_model_path)
    
    # Create Datasets
    train_dataset = MultimodalEmotionDataset("dataset/train.csv", audio_processor, text_tokenizer)
    test_dataset = MultimodalEmotionDataset("dataset/test.csv", audio_processor, text_tokenizer)
    
    # Split train into train and validation
    train_dataset, val_dataset = torch.utils.data.random_split(
        train_dataset, 
        [int(len(train_dataset)*0.8), len(train_dataset)-int(len(train_dataset)*0.8)]
    )
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)
    
    # Get number of labels
    num_labels = len(train_dataset.dataset.emotion_to_idx)
    
    # Initialize Model
    model = MultimodalEmotionClassifier(
        num_labels=num_labels, 
        audio_model_path=audio_model_path, 
        text_model_path=text_model_path
    ).to(device)
    print(model)
    print(f"Model Initialized with {num_labels} emotion classes")
    print("Emotion to Index mapping:", train_dataset.dataset.emotion_to_idx)
    
    # Train Model
    trained_model = train_model(model, train_loader, val_loader, device, epochs=50)
    
    # Optional: Load and evaluate best saved model
    best_model_path = 'saved_models/best_multimodal_model.pth'
    checkpoint = torch.load(best_model_path)
    
    # Reinitialize model and load state dict
    best_model = MultimodalEmotionClassifier(
        num_labels=num_labels, 
        audio_model_path=audio_model_path, 
        text_model_path=text_model_path
    ).to(device)
    best_model.load_state_dict(checkpoint['model_state_dict'])
    
    # Test model
    best_model.eval()
    test_preds = []
    test_true = []
    
    with torch.no_grad():
        for batch in test_loader:
            # Move data to device
            audio_input = batch['audio_input'].to(device)
            audio_mask = batch['audio_mask'].to(device)
            text_input_ids = batch['text_input_ids'].to(device)
            text_attention_mask = batch['text_attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            # Forward pass
            outputs = best_model(
                audio_input, 
                audio_mask, 
                text_input_ids, 
                text_attention_mask
            )
            
            # Track test metrics
            preds = torch.argmax(outputs, dim=1)
            test_preds.extend(preds.cpu().numpy())
            test_true.extend(labels.cpu().numpy())
    
    # Print test classification report
    print("\nTest Classification Report:")
    print(classification_report(test_true, test_preds, 
        target_names=list(train_dataset.dataset.emotion_to_idx.keys())))

if __name__ == "__main__":
    main()

MultimodalEmotionClassifier(
  (audio_encoder): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1-4): 4 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (5-6): 2 x Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(2,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
      )
    )
    (feature_projection): Wav2Vec2FeatureProjection(
      (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=T

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Epoch 2/50
Train Loss: 1.6258, Train Accuracy: 0.2378
Val Loss: 1.6276, Val Accuracy: 0.2532
Saved new best model with validation accuracy: 0.2532
Epoch 3/50
Train Loss: 1.6210, Train Accuracy: 0.2280
Val Loss: 1.6837, Val Accuracy: 0.2013
Epoch 4/50
Train Loss: 1.6138, Train Accuracy: 0.2476
Val Loss: 1.6024, Val Accuracy: 0.2792
Saved new best model with validation accuracy: 0.2792
Epoch 5/50
Train Loss: 1.6076, Train Accuracy: 0.2590
Val Loss: 1.5930, Val Accuracy: 0.2468
Epoch 6/50
Train Loss: 1.5867, Train Accuracy: 0.2704
Val Loss: 1.6044, Val Accuracy: 0.2403

Validation Classification Report:
              precision    recall  f1-score   support

           0       0.07      0.03      0.04        33
           1       0.31      0.67      0.43        30
           2       0.29      0.07      0.11        29
           3       0.40      0.23      0.29        35
           4       0.12      0.22      0.16        27

    accuracy                           0.24       154
   macro avg