# **Mixture of Experts Model with Single Gate for Emotion & Trigger Classification**

This notebook developed as part of Master's thesis runs training and evaluation of a single-gated Mixture of Experts model for joint emotion and trigger classification using BERT-based features.

### **Imports & Setup**

This section loads all necessary libraries. The code is designed to run on Google Colab with GPU support.

In [None]:
# Standard libraries
from collections import defaultdict
from collections.abc import Callable
import os
import random
from typing import Any, Dict, List, Tuple, Union

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Data processing
import numpy as np
import pandas as pd

# PyTorch
import torch
import torch.nn as nn
from torch import cuda
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

# Transformers
from transformers import BertTokenizer, BertModel

# Scikit-learn
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    precision_score, 
    recall_score, 
    f1_score, 
    accuracy_score, 
    confusion_matrix,
)

In [None]:
from google.colab import drive
drive.mount("/content/drive")

In [None]:
%cd "drive/MyDrive/Colab Notebooks/EDiReF"

In [None]:
def set_seed(seed: int = 2024) -> None:
    """
    Set random seeds for reproducibility across NumPy, PyTorch, and CUDA.

    Args:
        seed (int): A seed value to ensure deterministic behaviour.
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

In [None]:
set_seed()

In [None]:
device = "cuda" if cuda.is_available() else "cpu"
print(device)

### **Data Loading & Preprocessing**

This section loads and preprocesses the conversation, emotion, and trigger label data.

#### Input Format

Each dataset (e.g. MELD, MaSaC) should be stored as a `.json` file in the `data/` directory.

Each file must contain:
- `"utterances"`: list of utterances per conversation (List[List[str]])
- `"emotions"`: corresponding emotion labels (List[List[str]])
- `"triggers"`: corresponding trigger indicators, binary or float (List[List[Union[int, float]]])

Files should be named like:
`data/EDiReF_train_data/MaSaC_train_efr.json`

In [None]:
dataset = "MELD"    # @param ["MELD", "MaSaC"]
max_length = 96    # @param [96, 128, 256] {type: "raw"}
batch_size = 8    # @param [8, 16, 32] {type: "raw"}

In [None]:
if dataset == "MELD":
    designated_model = "bert-base-cased"
elif dataset == "MaSaC":
    designated_model = "bert-base-multilingual-cased"

In [None]:
def get_data(
    dataset_name: str, 
    stage: str
) -> Tuple[List[List[str]], List[List[str]], List[List[float]]]:
    """
    Load and clean EDiReF JSON data for a given dataset and stage.

    The function loads a JSON file, fills missing values in the "triggers" column, 
    filters out invalid rows, and extracts "utterances", "emotions", and "triggers".

    Args:
        dataset_name (str): Name of the dataset ("MELD", or "MaSaC").
        stage (str): Subset name corresponding to the file suffix ("train" or "val").

    Returns:
        Tuple[List[List[str]], List[List[str]], List[List[float]]]: A tuple containing:
            - conversations: List of conversations (utterances).
            - emotions: List of corresponding emotion labels.
            - triggers: List of corresponding trigger values.
    """
    def to_float(x):
        try:
            return float(x)
        except ValueError:
            return 1.0

    df = pd.read_json(f"data/EDiReF_{stage}_data/{dataset_name}_{stage}_efr.json")
    df["triggers"] = df["triggers"].apply(
        lambda lst: [np.nan if x is None else x for x in lst]
    )
    df = df[df["triggers"].apply(lambda lst: not any(pd.isna(x) for x in lst))]
    df["triggers"] = df["triggers"].apply(
        lambda lst: [to_float(x) for x in lst]
    )

    conversations = list(df["utterances"])
    emotions = list(df["emotions"])
    triggers = list(df["triggers"])

    return conversations, emotions, triggers

In [None]:
train_conversations, train_emotions, train_triggers = get_data(dataset, "train")
val_conversations, val_emotions, val_triggers = get_data(dataset, "val")

In [None]:
conversations = train_conversations + val_conversations
emotions = train_emotions + val_emotions
triggers = train_triggers + val_triggers

In [None]:
flattened_emotions = [sent for conv in emotions for sent in conv]
unique_emotions = set(flattened_emotions)

labels_to_ids = {k: v for v, k in enumerate(unique_emotions)}
ids_to_labels = {v: k for v, k in enumerate(unique_emotions)}
emotions = [[labels_to_ids[emotion] for emotion in conv] for conv in emotions]

In [None]:
def train_val_test_split(
    X: List[List[str]], 
    y1: List[List[int]], 
    y2: List[List[float]], 
    val_size: float = 0.2, 
    test_size: float = 0.2, 
    random_state: int = None
) -> Tuple[
    List[List[str]], List[List[str]], List[List[str]],
    List[List[int]], List[List[int]], List[List[int]],
    List[List[float]], List[List[float]], List[List[float]]
]:
    """
    Split data into train, validation, and test sets with consistent label alignment.

    This function performs a two-step split:
    - First, it splits the dataset into training+validation and test sets.
    - Then it splits the training+validation set again to separate out a validation set.
    This ensures y1 and y2 (labels) stay aligned with X throughout the process.

    Args:
        X (List[List[str]]): The main input data (e.g., tokenized conversations).
        y1 (List[List[int]]): The first set of target labels (e.g., emotions).
        y2 (List[List[float]]): The second set of target labels (e.g., triggers).
        val_size (float, optional): Proportion of data to use for validation. 
            Defaults to 0.2.
        test_size (float, optional): Proportion of data to use for test set. 
            Defaults to 0.2.
        random_state (int, optional): Random seed for reproducibility. 
            Defaults to None.

    Returns:
        Tuple[
            List[List[str]], List[List[str]], List[List[str]],
            List[List[int]], List[List[int]], List[List[int]],
            List[List[float]], List[List[float]], List[List[float]]
        ]: A tuple containing:
            - X_train, X_val, X_test
            - y1_train, y1_val, y1_test
            - y2_train, y2_val, y2_test
    """
    X_train_val, X_test, y1_train_val, y1_test, y2_train_val, y2_test = train_test_split(
        X, y1, y2, test_size=test_size, random_state=random_state
    )

    val_relative_size = val_size / (1 - test_size)

    X_train, X_val, y1_train, y1_val, y2_train, y2_val = train_test_split(
        X_train_val, y1_train_val, y2_train_val, 
        test_size=val_relative_size, 
        random_state=random_state
    )

    return (
        X_train, X_val, X_test, 
        y1_train, y1_val, y1_test, 
        y2_train, y2_val, y2_test
    )

In [None]:
X_train, X_val, X_test, y1_train, y1_val, y1_test, y2_train, y2_val, y2_test = train_val_test_split(
    conversations, emotions, triggers, test_size=0.15, val_size=0.15, random_state=2024
)

### **Tokenization & Padding**

Utterances are tokenized using the HuggingFace `BertTokenizer`. Labels are padded to match max sequence length with `-1`, which is ignored during loss calculation.

In [None]:
tokenizer = BertTokenizer.from_pretrained(designated_model)

In [None]:
def tokenize_conversation(
    conversations: List[List[str]], 
    tokenizer: BertTokenizer, 
    max_length: int = 128
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
    """
    Tokenize and pad a list of conversations using pretrained BertTokenizer.

    Each conversation is flattened into a single string using [SEP] as a delimiter.
    Tokenization is performed with padding and truncation to a specified max length.

    Args:
        conversations (List[List[str]]): A list of conversations, where each
            conversation is a list of utterances (strings).
        tokenizer (BertTokenizer): A HuggingFace tokenizer used to tokenize
            the input conversations.
        max_length (int, optional): Maximum length (in tokens) for padding/truncation.
            Defaults to 128.

    Returns:
        Tuple[List[torch.Tensor], List[torch.Tensor]]: A tuple containing:
            - input_ids: Token ID tensors for each conversation.
            - attention_masks: Attention mask tensors for each conversation.
    """
    input_ids = []
    attention_masks = []

    for conversation in conversations:
        dialogue = f" {tokenizer.sep_token} ".join(conversation)
        encoded = tokenizer(
            dialogue,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_tensors="pt"
        )
        input_ids.append(encoded["input_ids"].squeeze(0))
        attention_masks.append(encoded["attention_mask"].squeeze(0))

    return input_ids, attention_masks

In [None]:
def pad_labels(
    labels: List[List[Union[int, float]]], 
    max_length: int = 128
) -> List[torch.Tensor]:
    """
    Pads each list of labels to a specified max length.

    Each list of labels is converted to a float tensor and padded with -1.0
    to a specified max length. Useful for masking loss during training on
    token-level tasks.

    Args:
        labels (List[List[Union[int, float]]]): A list of label sequences
            (e.g., emotions or triggers).
        max_length (int, optional): Maximum length to pad/truncate each
            sequence to. Defaults to 128.

    Returns:
        List[torch.Tensor]: A list of padded 1D tensors, one per input sequence.
    """
    padded_labels = []
    
    for label_set in labels:
        label_tensor = torch.tensor(label_set, dtype=torch.float)
        padding_tensor = torch.full((max_length - len(label_set),), -1.0)
        padded_tensor = torch.cat([label_tensor, padding_tensor])
        padded_labels.append(padded_tensor)

    return padded_labels

In [None]:
class ConversationDataset(Dataset):
    """
    A PyTorch-compatible dataset for emotion and trigger classification tasks.

    Stores tokenized conversations along with attention masks, and emotion and 
    trigger labels for each utterance in the input.
    """

    def __init__(
        self, 
        input_ids: List[torch.Tensor], 
        attention_masks: List[torch.Tensor], 
        emotion_labels: List[torch.Tensor], 
        trigger_labels: List[torch.Tensor]
    ) -> None:
        """
        Initialize the dataset with tokenized inputs and their corresponding labels.

        Args:
            input_ids (List[torch.Tensor]): Token IDs for each conversation.
            attention_masks (List[torch.Tensor]): Attention masks for each conversation.
            emotion_labels (List[torch.Tensor]): Label tensors for emotion classification.
            trigger_labels (List[torch.Tensor]): Label tensors for trigger classification.
        """
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.emotion_labels = emotion_labels
        self.trigger_labels = trigger_labels

    def __len__(self) -> int:
        """
        Return the number of samples in the dataset.

        Returns:
            int: The total number of data points.
        """
        return len(self.input_ids)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Retrieve a single item from the dataset.

        Args:
            idx (int): Index of the item to retrieve.

        Returns:
            Dict[str, torch.Tensor]: A dictionary containing:
                - "input_ids"
                - "attention_mask"
                - "emotion_labels"
                - "trigger_labels"
        """
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_masks[idx],
            "emotion_labels": self.emotion_labels[idx],
            "trigger_labels": self.trigger_labels[idx],
        }

In [None]:
def create_dataloader(
    conversations: List[List[str]], 
    emotions: List[List[int]], 
    triggers: List[List[float]],
    tokenizer: BertTokenizer,  
    batch_size: int,
    max_length: int = 128, 
    shuffle: bool = False
) -> DataLoader:
    """
    Create a DataLoader from conversations and corresponding labels.

    Conversations are first tokenized using the `tokenize_conversation()` function,
    and the emotion and trigger labels are padded using `pad_labels()`. These are then
    wrapped in a `ConversationDataset` and returned as a PyTorch DataLoader.

    Args:
        conversations (List[List[str]]): A list of conversations, where each conversation
            is a list of utterances (strings).
        emotions (List[List[int]]): A list of emotion labels.
        triggers (List[List[float]]): A list of trigger labels.
        tokenizer (BertTokenizer): A HuggingFace tokenizer used to tokenize
            the input conversations.
        batch_size (int): Number of samples per batch.
        max_length (int, optional): Maximum length to pad/truncate each sequence to.
            Defaults to 128.
        shuffle (bool, optional): Whether to shuffle the data at every epoch.
            Defaults to False.

    Returns:
        DataLoader: An iterable over the constructed ConversationDataset.
    """
    input_ids, attention_masks = tokenize_conversation(
        conversations, tokenizer, max_length=max_length
    )
    emotions_labels = pad_labels(emotions, max_length=max_length)
    triggers_labels = pad_labels(triggers, max_length=max_length)

    dataset = ConversationDataset(
        input_ids, attention_masks, emotions_labels, triggers_labels
    )
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return loader

In [None]:
train_loader = create_dataloader(
    X_train, y1_train, y2_train, 
    tokenizer=tokenizer, batch_size=batch_size, max_length=max_length, 
    shuffle=True
)
val_loader = create_dataloader(
    X_val, y1_val, y2_val, 
    tokenizer=tokenizer, batch_size=batch_size, max_length=max_length, 
    shuffle=False
)
test_loader = create_dataloader(
    X_test, y1_test, y2_test, 
    tokenizer=tokenizer, batch_size=batch_size, max_length=max_length, 
    shuffle=False
)

### **Model Definition**

Defines a Mixture of Experts model with single gating mechanism, which combines expert outputs.

In [None]:
gate_type = "linear"  # @param ["linear", "mlp"]
expert_type = "linear" # @param ["linear", "mlp", "rnn"]
num_experts = 2 # @param {type: "slider", min: 1, max: 8, step: 1}
top_k = 2 # @param {type: "slider", min: 1, max: 8, step: 1}

In [None]:
assert top_k <= num_experts, "Select different values for top_k and num_experts!"

In [None]:
class MoEForEmotionAndTriggerClassification(nn.Module):
    """
    A Mixture of Experts model for emotion and trigger classification.

    This architecture uses a single gating network to aggregate predictions from 
    expert modules for two separate tasks: emotion classification and trigger detection.

    Attributes:
        model (transformers.BertModel): Pretrained BERT model for feature extraction.
        gating_network_emotion (Union[nn.Linear, nn.Sequential]): Gating mechanism 
            for emotion classification.
        gating_network_trigger (Union[nn.Linear, nn.Sequential]): Gating 
            mechanism for trigger classification.
        experts (nn.ModuleList): A list of expert modules (Linear, MLP, or RNN).
        emotion_classifier (nn.Linear): Output layer for emotion classification.
        trigger_classifier (nn.Linear): Output layer for trigger classification.
        k (int): Number of top experts to activate.
        dropout (nn.Dropout): Dropout layer for regularization.
    """

    def __init__(
        self, 
        num_experts: int, 
        k: int, 
        num_classes: int, 
        gate_type: str, 
        expert_type: str, 
        model_name: str = "bert-base-uncased", 
        train_bert: bool = True
    ) -> None:
        """
        Initialize the MoE model.

        Args:
            num_experts (int): Number of expert modules.
            k (int): Number of top experts to use per forward pass.
            num_classes (int): Number of emotion classes.
            gate_type (str): Type of gating mechanism ("linear" or "mlp").
            expert_type (str): Type of expert layer ("linear", "mlp", or "rnn").
            model_name (str): Name of the pretrained BERT model. Defaults to "bert-base-uncased".
            train_bert (bool): Whether to fine-tune the BERT model. Defaults to True.
        """
        super(MoEForEmotionAndTriggerClassification, self).__init__()

        self.model = BertModel.from_pretrained(model_name)
        for param in self.model.parameters():
            param.requires_grad = train_bert  # Set to True if you want to fine-tune model
        
        hidden_size = self.model.config.hidden_size

        gate_setup = {
            "linear": nn.Linear(hidden_size, num_experts),
            "mlp": nn.Sequential(
                nn.Linear(hidden_size, 512), 
                nn.ReLU(), 
                nn.Linear(512, num_experts)
            ),
        }

        expert_setup = {
            "linear": nn.Linear(hidden_size, hidden_size),
            "mlp": nn.Sequential(
                nn.Linear(hidden_size, 512), 
                nn.ReLU(), 
                nn.Linear(512, hidden_size)
            ),
            "rnn": nn.LSTM(hidden_size, hidden_size),
        }

        self.gating_network = gate_setup[gate_type]
        self.experts = nn.ModuleList(
            [expert_setup[expert_type] for _ in range(num_experts)]
        )

        self.emotion_classifier = nn.Linear(hidden_size, num_classes)
        self.trigger_classifier = nn.Linear(hidden_size, 1)

        self.k = k
        self.dropout = nn.Dropout(p=0.1)

    def forward(
        self, 
        input_ids: torch.Tensor, 
        attention_mask: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward pass through the model.

        Args:
            input_ids (torch.Tensor): Input IDs of shape (batch_size, seq_len).
            attention_mask (torch.Tensor): Attention mask of shape (batch_size, seq_len).

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
                - emotion_logits: Tensor of shape (batch_size, seq_len, num_classes).
                - trigger_logits: Tensor of shape (batch_size, seq_len).
        """
        model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = model_outputs.last_hidden_state
        pooled_embeddings = embeddings.mean(dim=1)
        pooled_embeddings = self.dropout(pooled_embeddings)

        expert_weights = torch.softmax(
            self.gating_network(pooled_embeddings), dim=-1
        )

        combined_output = self._compute_expert_output(embeddings, expert_weights)
        
        combined_output = self.dropout(combined_output)

        emotion_logits = self.emotion_classifier(combined_output)
        trigger_logits = self.trigger_classifier(combined_output).squeeze(-1)

        return emotion_logits, trigger_logits

    def _compute_expert_output(
        self, 
        embeddings: torch.Tensor, 
        expert_weights: torch.Tensor
    ) -> torch.Tensor:
        """
        Combine outputs from the top-k selected experts.

        Args:
            embeddings (torch.Tensor): BERT embeddings of shape (batch_size, seq_len, hidden_dim).
            expert_weights (torch.Tensor): Gating weights of shape (batch_size, num_experts).

        Returns:
            torch.Tensor: Combined expert output of shape (batch_size, seq_len, hidden_dim).
        """
        combined_output = torch.zeros_like(embeddings)
        topk_weights, topk_indices = torch.topk(expert_weights, self.k, dim=-1)

        for i in range(self.k):
            expert_idx = topk_indices[:, i]
            # Add dimensions for broadcasting across sequence length and hidden size
            weight = topk_weights[:, i].unsqueeze(-1).unsqueeze(-1)

            expert_outputs = []
            for j in range(expert_idx.size(0)):
                # Variant suggested for improved clarity and shape safety (ChatGPT)
                expert = self.experts[expert_idx[j]]
                x = embeddings[j].unsqueeze(0)

                if isinstance(expert, nn.LSTM):
                    output, _ = expert(x)
                    expert_outputs.append(output.squeeze(0))
                else:
                    expert_outputs.append(expert(x.squeeze(0)))

                """
                # TODO: Re-evaluate original version below. Retained temporarily for testing.
                expert = self.experts[expert_idx[j]]

                if isinstance(expert, nn.LSTM):
                    embedding_input = embeddings[j].unsqueeze(0)
                    output, _ = expert(embedding_input)
                    expert_outputs.append(output.squeeze(0))

                elif isinstance(expert, nn.Linear) or isinstance(expert, nn.Sequential):
                    output = expert(embeddings[j])
                    expert_outputs.append(output)
                """

            expert_outputs = torch.stack(expert_outputs)
            combined_output += weight * expert_outputs

        return combined_output

### **Training**

Trains the MoE model and evaluates on validation set after each epoch. Logs loss and accuracy for both tasks.

In [None]:
train_bert = True   # @param {type: "boolean"}
learning_rate = 0.00002  # @param {type: "slider", min: 1E-5, max: 5E-5, step: 1E-5}
num_epochs = 5  # @param {type: "slider", min: 3, max: 15, step: 1}
weight_triggers = True # @param {type: "boolean"}

In [None]:
# Instantiate the Mixture of Experts model
moe = MoEForEmotionAndTriggerClassification(
    num_experts=num_experts, 
    k=top_k, 
    num_classes=len(labels_to_ids), 
    gate_type=gate_type, 
    expert_type=expert_type, 
    model_name=designated_model, 
    train_bert=train_bert
)

# Initialize optimizer
optimizer = AdamW(moe.parameters(), lr=learning_rate)

# Compute positive class weight for trigger classification (if applicable)
pos_weight = None
if weight_triggers and dataset in ["MaSaC", "MaSaC_translated"]:
    flattened_triggers = [x for sequence in y2_train for x in sequence]
    num_negative_samples = len([x for x in flattened_triggers if x == 0])*0.8
    num_positive_samples = len(flattened_triggers) - num_negative_samples
    pos_weight_value = num_negative_samples / num_positive_samples
    pos_weight = torch.tensor([pos_weight_value], device=device)

# Define loss functions
emotion_loss_fn = CrossEntropyLoss()
trigger_loss_fn = BCEWithLogitsLoss(pos_weight=pos_weight)

In [None]:
moe.to(device)

In [None]:
def remove_padding(
    logits: torch.Tensor, 
    labels: torch.Tensor, 
    task: str
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Remove masked padding from logits and labels for loss computation.

    This function flattens both the model outputs and target labels, then removes 
    entries where labels are set to -1 (which represent padding tokens).

    Args:
        logits (torch.Tensor): Model outputs of shape (batch_size, seq_len, num_classes)
            for "emotion", or (batch_size, seq_len) for "trigger".
        labels (torch.Tensor): Ground-truth labels of shape (batch_size, seq_len).
        task (str): Classification task type, either "emotion" or "trigger".

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: A tuple containing a pair of 1d tensors:
            - logits: Flattened logits excluding padding.
            - labels: Flattened ground-truth labels excluding padding.
    """
    assert task in {"emotion", "trigger"}, "task must be 'emotion' or 'trigger'"

    valid_positions = labels != -1

    logits_flat = (
        logits.view(-1, logits.size(-1)) 
        if task == "emotion" 
        else logits.view(-1)
    )
    labels_flat = labels.view(-1)

    logits = logits_flat[valid_positions.view(-1)]
    labels = labels_flat[valid_positions.view(-1)]

    return logits, labels

In [None]:
def evaluate(
    model: MoEForEmotionAndTriggerClassification, 
    val_loader: DataLoader, 
    device: torch.device, 
    emotion_loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 
    trigger_loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 
    verbose: bool = True
) -> Tuple[float, float, List[str]]:
    """
    Evaluates the Mixture-of-Experts model on a validation dataset.

    Computes total loss, emotion classification accuracy, and trigger classification
    accuracy. Logs validation loss every 100 steps.

    Args:
        model (MoEForEmotionAndTriggerClassification): The trained model to evaluate.
        val_loader (DataLoader): Dataloader for the validation set.
        device (torch.device): Represents the device on which a torch.Tensor is 
            or will be allocated.
        emotion_loss_fn (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): 
            Loss function used for emotion classification (e.g., CrossEntropyLoss).
        trigger_loss_fn (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): 
            Loss function used for trigger classification (e.g., BCEWithLogitsLoss).
        verbose (bool): Controls the verbosity. Set False to hide messages. 
            Defaults to True.

    Returns:
        Tuple[float, float, List[str]]: A tuple containing:
            - avg_val_loss: Mean validation loss across all batches.
            - avg_val_accuracy: Mean of emotion and trigger accuracy.
            - val_logs: Log messages recorded during validation.
    """
    model.eval()
    val_loss, nb_steps = 0.0, 0
    total_emotion_preds, correct_emotion_preds = 0, 0
    total_trigger_preds, correct_trigger_preds = 0, 0
    val_logs = []

    with torch.no_grad():
        for idx, batch in enumerate(val_loader):
            # forward pass
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            emotion_labels = batch["emotion_labels"].to(device)
            trigger_labels = batch["trigger_labels"].to(device)

            emotion_logits, trigger_logits = model(input_ids, attention_mask)

            # remove padding
            emotion_logits, emotion_labels = remove_padding(
                emotion_logits, emotion_labels, "emotion"
            )
            trigger_logits, trigger_labels = remove_padding(
                trigger_logits, trigger_labels, "trigger"
            )

            # compute loss
            emotion_loss = emotion_loss_fn(emotion_logits, emotion_labels.long())
            trigger_loss = trigger_loss_fn(trigger_logits, trigger_labels)

            loss = emotion_loss + trigger_loss
            val_loss += loss.item()

            # compute accuracy
            emotion_preds = torch.argmax(emotion_logits, dim=-1)
            trigger_preds = (torch.sigmoid(trigger_logits).squeeze(-1) > 0.5).long()

            correct_emotion_preds += torch.sum(emotion_preds == emotion_labels).item()
            correct_trigger_preds += torch.sum(trigger_preds == trigger_labels).item()

            total_emotion_preds += emotion_labels.numel()
            total_trigger_preds += trigger_labels.numel()

            nb_steps += 1

            # logging
            if verbose and idx % 100 == 0:
                loss_step = val_loss / nb_steps
                print(f"      Validation loss per 100 training steps: {loss_step:.4f}")
                val_logs.append(f"      Validation loss per 100 training steps: {loss_step:.4f}\n")

        avg_val_loss = val_loss / max(len(val_loader), 1)
        emotion_accuracy = correct_emotion_preds / max(total_emotion_preds, 1)
        trigger_accuracy = correct_trigger_preds / max(total_trigger_preds, 1)
        avg_val_accuracy = (emotion_accuracy + trigger_accuracy)/2

    return avg_val_loss, avg_val_accuracy, val_logs

In [None]:
def train_and_validate(
    model: MoEForEmotionAndTriggerClassification, 
    train_loader: DataLoader, 
    val_loader: DataLoader, 
    optimizer: torch.optim.Optimizer, 
    device: torch.device, 
    emotion_loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 
    trigger_loss_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], 
    num_epochs: int = 3, 
    verbose: bool = True
) -> List[str]:
    """
    Train and evaluate the Mixture-of-Experts model for emotion and trigger classification.

    For each epoch, computes training loss and accuracy for both tasks. Logs 
    training loss every 100 steps and evaluates the model on the validation set 
    at the end of each epoch.

    Args:
        model (MoEForEmotionAndTriggerClassification): The model to train and evaluate.
        train_loader (DataLoader): Dataloader for the train set.
        val_loader (DataLoader): Dataloader for the validation set.
        optimizer (torch.optim.Optimizer): PyTorch optimizer.
        device (torch.device): Represents the device on which a torch.Tensor is 
            or will be allocated.
        emotion_loss_fn (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): 
            Loss function used for emotion classification (e.g., CrossEntropyLoss).
        trigger_loss_fn (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): 
            Loss function used for trigger classification (e.g., BCEWithLogitsLoss).
        num_epochs (int, optional): Number of training epochs. Defaults to 3.
        verbose (bool): Controls the verbosity. Set False to hide messages. 
            Defaults to True.

    Returns:
        List[str]: A list of formatted strings logging training and validation 
        progress for each epoch and key steps within.
    """
    train_logs = []

    for epoch in range(num_epochs):
        print(f"\nEpoch [{epoch + 1}/{num_epochs}]")
        train_logs.append(f"Epoch [{epoch + 1}/{num_epochs}]\n")
        model.train()

        train_loss, nb_steps = 0.0, 0
        total_emotion_preds, correct_emotion_preds = 0, 0
        total_trigger_preds, correct_trigger_preds = 0, 0

        for idx, batch in enumerate(train_loader):
            # forward pass
            optimizer.zero_grad()

            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            emotion_labels = batch["emotion_labels"].to(device)
            trigger_labels = batch["trigger_labels"].to(device)

            emotion_logits, trigger_logits = model(input_ids, attention_mask)

            # remove padding
            emotion_logits, emotion_labels = remove_padding(
                emotion_logits, emotion_labels, "emotion"
            )
            trigger_logits, trigger_labels = remove_padding(
                trigger_logits, trigger_labels, "trigger"
            )

            # compute loss
            emotion_loss = emotion_loss_fn(emotion_logits, emotion_labels.long())
            trigger_loss = trigger_loss_fn(trigger_logits, trigger_labels)

            loss = emotion_loss + trigger_loss
            train_loss += loss.item()

            loss.backward()
            optimizer.step()

            # compute accuracy
            emotion_preds = torch.argmax(emotion_logits, dim=-1)
            trigger_preds = (torch.sigmoid(trigger_logits).squeeze(-1) > 0.5).long()

            correct_emotion_preds += torch.sum(emotion_preds == emotion_labels).item()
            correct_trigger_preds += torch.sum(trigger_preds == trigger_labels).item()

            total_emotion_preds += emotion_labels.numel()
            total_trigger_preds += trigger_labels.numel()

            nb_steps += 1

            # logging
            if verbose and idx % 100 == 0:
                loss_step = train_loss / nb_steps
                print(f"      Training loss per 100 training steps: {loss_step:.4f}")
                train_logs.append(f"      Training loss per 100 training steps: {loss_step:.4f}\n")

        avg_train_loss = train_loss / max(len(train_loader), 1)
        emotion_accuracy = correct_emotion_preds / max(total_emotion_preds, 1)
        trigger_accuracy = correct_trigger_preds / max(total_trigger_preds, 1)
        avg_train_accuracy = (emotion_accuracy + trigger_accuracy)/2

        val_loss, val_accuracy, val_logs = evaluate(
            model, 
            val_loader, 
            device=device, 
            emotion_loss_fn=emotion_loss_fn, 
            trigger_loss_fn=trigger_loss_fn, 
            verbose=verbose
        )
        train_logs.extend(val_logs)
        train_logs.append(f"   Training Loss: {avg_train_loss:.4f}, Training Accuracy: {avg_train_accuracy:.4f}\n")
        train_logs.append(f"   Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n\n")

        if verbose:
            print(f"   Training Loss: {avg_train_loss:.4f}, Training Accuracy: {avg_train_accuracy:.4f}")
            print(f"   Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n")

    return train_logs

In [None]:
logs = train_and_validate(
    moe, 
    train_loader, 
    val_loader, 
    optimizer=optimizer, 
    device=device, 
    emotion_loss_fn=emotion_loss_fn, 
    trigger_loss_fn=trigger_loss_fn, 
    num_epochs=num_epochs, 
    verbose=True
)

### **Evaluation**

Runs on test set to calculate accuracy, precision, recall, and F1 for emotion and trigger tasks. Also generates confusion matrices.

In [None]:
def get_metrics(
    model: MoEForEmotionAndTriggerClassification, 
    data_loader: DataLoader, 
    device: torch.device, 
    labels_to_ids: Dict[str, int], 
    ids_to_labels: Dict[int, str]
) -> Tuple[Dict[str, Any], np.ndarray, np.ndarray]:
    """
    Evaluate a trained Mixture-of-Experts model on a test set and compute metrics.

    Calculates accuracy, precision, recall, and F1 scores for both emotion and
    trigger classification tasks, as well as confusion matrices.

    Args:
        model (MoEForEmotionAndTriggerClassification): The trained model.
        data_loader (DataLoader): DataLoader for the test set.
        device (torch.device): The device to use for computation.
        labels_to_ids (Dict[str, int]): Dictionary mapping strings into integers.
        ids_to_labels (Dict[int, str]): Dictionary mapping integers to labels.

    Returns:
        Tuple[Dict[str, Any], np.ndarray, np.ndarray]: A tuple containing:
            - metrics (dict): Dictionary with aggregated performance metrics.
            - emotion_cm (np.ndarray): Confusion matrix for emotion classification.
            - trigger_cm (np.ndarray): Confusion matrix for trigger classification.
    """
    model.eval()

    # Initialize metrics
    emotion_accuracy, emotion_precision, emotion_recall, emotion_f1 = 0.0, 0.0, 0.0, 0.0
    trigger_accuracy, trigger_precision, trigger_recall, trigger_f1 = 0.0, 0.0, 0.0, 0.0

    emotion_cm = None
    trigger_cm = None

    unique_emotions_f1 = dict.fromkeys(labels_to_ids.keys(), 0.0)
    nb_steps = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            emotion_labels = batch["emotion_labels"].to(device)
            trigger_labels = batch["trigger_labels"].to(device)

            emotion_logits, trigger_logits = model(input_ids, attention_mask)

            # Emotion classification
            emotion_logits, emotion_labels = remove_padding(
                emotion_logits, emotion_labels, "emotion"
            )

            emotion_preds = torch.argmax(emotion_logits, dim=-1)

            emotion_preds_np = emotion_preds.cpu().numpy()
            emotion_labels_np = emotion_labels.cpu().numpy()

            emotion_accuracy += accuracy_score(emotion_labels_np, emotion_preds_np)
            emotion_precision += precision_score(
                emotion_labels_np, emotion_preds_np, average="weighted", zero_division=0
            )
            emotion_recall += recall_score(
                emotion_labels_np, emotion_preds_np, average="weighted", zero_division=0
            )
            emotion_f1 += f1_score(
                emotion_labels_np, emotion_preds_np, average="weighted", zero_division=0
            )

            for idx, score in enumerate(
                f1_score(emotion_labels_np, emotion_preds_np, average=None, zero_division=0)
            ):
                unique_emotions_f1[ids_to_labels[idx]] += score

            emotion_cm_batch = confusion_matrix(
                emotion_labels_np, emotion_preds_np, labels=list(range(len(labels_to_ids)))
            )
            emotion_cm = (
                emotion_cm_batch 
                if emotion_cm is None 
                else emotion_cm + emotion_cm_batch
            )

            # Trigger classification
            trigger_logits, trigger_labels = remove_padding(
                trigger_logits, trigger_labels, "trigger"
            )

            trigger_preds = (torch.sigmoid(trigger_logits).squeeze(-1) > 0.5).long()

            trigger_preds_np = trigger_preds.cpu().numpy()
            trigger_labels_np = trigger_labels.cpu().numpy()

            trigger_accuracy += accuracy_score(trigger_labels_np, trigger_preds_np)
            trigger_precision += precision_score(
                trigger_labels_np, trigger_preds_np, average="weighted", zero_division=0
            )
            trigger_recall += recall_score(
                trigger_labels_np, trigger_preds_np, average="weighted", zero_division=0
            )
            trigger_f1 += f1_score(
                trigger_labels_np, trigger_preds_np, average="weighted", zero_division=0
            )

            trigger_cm_batch = confusion_matrix(
                trigger_labels_np, trigger_preds_np, labels=[0, 1]
            )
            trigger_cm = (
                trigger_cm_batch 
                if trigger_cm is None 
                else trigger_cm + trigger_cm_batch
            )

            nb_steps += 1

    # Aggregate metrics
    metrics = defaultdict(lambda: {})
    nb_steps = max(nb_steps, 1)

    metrics["emotion_classification"] = {
        "accuracy": emotion_accuracy / nb_steps, 
        "precision": emotion_precision / nb_steps, 
        "recall": emotion_recall / nb_steps, 
        "f1": {
            "avg": emotion_f1 / nb_steps
        },
    }

    for key in unique_emotions_f1:
        metrics["emotion_classification"]["f1"][key] = unique_emotions_f1[key] / nb_steps
    
    metrics["trigger_classification"] = {
        "accuracy": trigger_accuracy / nb_steps, 
        "precision": trigger_precision / nb_steps, 
        "recall": trigger_recall / nb_steps, 
        "f1": trigger_f1 / nb_steps,
    }

    return metrics, emotion_cm, trigger_cm

In [None]:
def plot_confusion_matrix(
    cm: np.ndarray, 
    labels: List[str], 
    title: str = "Confusion Matrix",
    save: bool = False
) -> None:
    """
    Plot a labeled confusion matrix using Seaborn heatmap.

    Args:
        cm (np.ndarray): Confusion matrix of shape (n_classes, n_classes).
        labels (List[str]): Class labels to use on axes.
        title (str, optional): Title of the plot. Defaults to "Confusion Matrix".
        save (bool, optional): If True, saves the plot to disk. Defaults to False.
    """
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm, 
        annot=True, 
        fmt="d", 
        cmap="Blues", 
        xticklabels=labels, 
        yticklabels=labels, 
        cbar=False
    )
    plt.xlabel("Predicted Labels", fontsize=14)
    plt.ylabel("True Labels", fontsize=14)
    plt.title(title, fontsize=16)
    plt.xticks(rotation=45, ha="right")
    plt.yticks(rotation=0)
    plt.tight_layout()

    if save:
        filename = title.lower().replace(" ", "_") + ".png"
        plt.savefig(filename, dpi=300)
        
    plt.show()

In [None]:
metrics, emotion_cm, trigger_cm = get_metrics(
    moe, 
    test_loader, 
    device, 
    labels_to_ids, 
    ids_to_labels
)

# Output results
for task, results in metrics.items():
    print(f"Task: {task.upper()}")
    for metric, score in results.items():
        if metric == "f1" and isinstance(score, dict):
            print(f"      f1: ")
            max_label_len = max(len(str(k)) for k in score.keys())
            for x, y in score.items():
                print(f"          {x:<{max_label_len}}: {y:.4f}")
        else:
            print(f"      {metric}: {score:.4f}")

In [None]:
plot_confusion_matrix(
    emotion_cm, 
    list(labels_to_ids), 
    "Emotion Classification"
)

In [None]:
plot_confusion_matrix(
    trigger_cm, 
    ["No trigger", "Trigger"], 
    "Trigger Classification"
    )

### **Save Experiment**

Saves experiment configuration, performance metrics, and confusion matrices to the `results/<dataset>` folder. To save state_dict of the trained model, uncomment the cell below and the state_dict will be saved to `trained_model/<dataset>` folder.

#### Output Format

Each experiment result is saved to a `.txt` file in the `results/<dataset>/` folder.

Each file includes:
- Experiment config summary (batch size, epochs, gating type, etc.)
- Training and validation loss/accuracy per epoch
- Final test set metrics (accuracy, precision, recall, F1)
- Per-class F1 scores for emotion labels
- Confusion matrices for emotion and trigger classification

Example file path:

`results/MELD/moe_model_True_train_bert_linear_single_gate_4_mlp_experts_4_active_3e-5_lr_10_epochs.txt`

In [None]:
# # Optional: Save the trained model
# # --------------------------------
# # Uncomment this section to save the trained model to disk.

# model_dir = os.path.join("trained_models", dataset)
# os.makedirs(model_dir, exist_ok=True)

# model_name = (
#     f"moe_model_{train_bert}_train_bert_{gate_type}_single_gate_"
#     f"{num_experts}_{expert_type}_experts_{top_k}_active_"
#     f"{learning_rate}_lr_{num_epochs}_epochs.pth"
# )
# model_path = os.path.join(model_dir, model_name)

# torch.save(moe.state_dict(), model_path)

In [None]:
def write_confusion_matrix(
    title: str, 
    cm: np.ndarray, 
    labels: List[str]
) -> List[str]:
    """
    Format a confusion matrix as a list of strings for saving to a .txt file.

    Args:
        title (str): Title to be displayed above the matrix.
        cm (np.ndarray): Confusion matrix of shape (n_classes, n_classes).
        labels (List[str]): Class labels corresponding to matrix rows/columns.

    Returns:
        List[str]: A list of strings representing the formatted confusion matrix,
            ready to be written to a plain text file.
    """
    cm2txt = []
    col_width = max(len(str(label)) for label in labels) + 1

    cm2txt.append(f"\n{title}\n")

    # Header row
    header = f"{'':<{col_width}}" + "".join(
        f"{label:<{col_width}}" for label in labels
    )
    cm2txt.append(header + "\n")

    # Matrix rows
    for i, label in enumerate(labels):
        row = f"{label:<{col_width}}" + "".join(
            f"{value:<{col_width}}" for value in cm[i]
        )
        cm2txt.append(row + "\n")

    return cm2txt

In [None]:
# Construct file path
model_name = (
    f"weighted_moe_model" if weight_triggers and dataset == "MaSaC"
    else "moe_model"
)
file_name = (
    f"{model_name}_{train_bert}_train_bert_{gate_type}_single_gate_"
    f"{num_experts}_{expert_type}_experts_{top_k}_active_"
    f"{learning_rate}_lr_{num_epochs}_epochs.txt"
)
file_path = os.path.join("results", dataset, file_name)

# Ensure the directory exists
os.makedirs(os.path.dirname(file_path), exist_ok=True)

with open(file_path, "w") as f:
    # Experiment setup metadata
    experiment_setup = [
        f"max_length = {max_length}\n",
        f"batch_size = {batch_size}\n\n",
        f"gate_type = {gate_type}\n",
        f"expert_type = {expert_type}\n",
        f"num_experts = {num_experts}\n",
        f"top_k = {top_k}\n",
        f"learning_rate = {learning_rate}\n",
        f"num_epochs = {num_epochs}\n",
        f"train_bert = {train_bert}\n",
        f"weight_triggers = {weight_triggers}\n\n"
    ]
    f.writelines(experiment_setup)

    # Training logs
    f.writelines(logs)
    f.write("\n")

    # Final metrics
    experiment_results = []
    for task, results in metrics.items():
        experiment_results.append(f"\nTask: {task.upper()}\n")
        for metric, score in results.items():
            if metric == "f1" and isinstance(score, dict):
                experiment_results.append("      f1:\n")
                max_label_len = max(len(str(k)) for k in score.keys())
                for x, y in score.items():
                    experiment_results.append(f"          {x:<{max_label_len}}: {y:.4f}\n")
            else:
                experiment_results.append(f"      {metric}: {score:.4f}\n")
    f.writelines(experiment_results)

    # Confusion matrices
    writeable_emotion_cm = write_confusion_matrix(
        "Emotion confusion matrix", emotion_cm, [i for i in labels_to_ids.keys()]
    )
    f.writelines(writeable_emotion_cm)
    
    f.write("\n")
    writeable_trigger_cm = write_confusion_matrix(
        "Trigger confusion matrix", trigger_cm, ["No trigger", "Trigger"]
    )
    f.writelines(writeable_trigger_cm)

### Load and test trained model

Loads state_dict from `trained_models/<dataset>/` folder, runs inference on test set, calculates metrics and plots confusion matrices for both tasks.

In [None]:
# # Instantiate the Mixture of Experts model
# moe_loaded = MoEForEmotionAndTriggerClassification(
#     num_experts=num_experts,
#     k=top_k,
#     num_classes=len(labels_to_ids),
#     gate_type=gate_type,
#     expert_type=expert_type,
#     model_name=designated_model,
#     train_bert=train_bert,
# )

# # Load state dict from file (onto CPU in this case)
# model_path = (
#     f"trained_models/{dataset}/moe_model_{train_bert}_train_bert_"
#     f"{gate_type}_single_gate_{num_experts}_{expert_type}_experts_"
#     f"{top_k}_active_{learning_rate}_lr_{num_epochs}_epochs.pth"
# )
# moe_loaded.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))

In [None]:
# metrics, emotion_cm, trigger_cm = get_metrics(
#     moe_loaded, 
#     test_loader, 
#     torch.device("cpu"), 
#     labels_to_ids, 
#     ids_to_labels
# )

# # Output results
# for task, results in metrics.items():
#     print(f"Task: {task.upper()}")
#     for metric, score in results.items():
#         if metric == "f1" and isinstance(score, dict):
#             print(f"      f1: ")
#             max_label_len = max(len(str(k)) for k in score.keys())
#             for x, y in score.items():
#                 print(f"          {x:<{max_label_len}}: {y:.4f}")
#         else:
#             print(f"      {metric}: {score:.4f}")

In [None]:
# plot_confusion_matrix(
#     emotion_cm, 
#     list(labels_to_ids), 
#     "Emotion Classification"
# )

In [None]:
# plot_confusion_matrix(
#     trigger_cm, 
#     ["No trigger", "Trigger"], 
#     "Trigger Classification"
#     )