In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd drive/MyDrive

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaModel
import pandas as pd
import numpy as np

In [None]:
from torch import cuda
device = 'cuda' if cuda.is_available() else 'cpu'

print(device)

In [None]:
train_df = pd.read_json('data/EDiReF_train_data/MELD_train_efr.json')
train_df["triggers"] = train_df["triggers"].apply(lambda lst: [np.nan if x is None else x for x in lst])
train_df = train_df[train_df["triggers"].apply(lambda lst: not any(pd.isna(x) for x in lst))]

def flatten(xss):
    return [x for xs in xss for x in xs]

flattened_emotions = flatten(train_df['emotions'])
unique_emotions = set(flattened_emotions)

labels_to_ids = {k: v for v, k in enumerate(unique_emotions)}
ids_to_labels = {v: k for v, k in enumerate(unique_emotions)}

train_conversations = list(train_df['utterances'])
train_emotions = [[labels_to_ids[emotion] for emotion in conv] for conv in list(train_df['emotions'])]
train_triggers = list(train_df['triggers'])

In [None]:
val_df = pd.read_json('data/EDiReF_val_data/MELD_val_efr.json')
val_df["triggers"] = val_df["triggers"].apply(lambda lst: [np.nan if x is None else x for x in lst])
val_df = val_df[val_df["triggers"].apply(lambda lst: not any(pd.isna(x) for x in lst))]

val_conversations = list(val_df['utterances'])
val_emotions = [[labels_to_ids[emotion] for emotion in conv] for conv in list(val_df['emotions'])]
val_triggers = list(val_df['triggers'])

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

In [None]:
def tokenize_conversation(conversations, max_length = 128):
    input_ids = []
    attention_masks = []

    for conversation in conversations:
        dialogue = " [SEP] ".join(conversation)
        encoded = tokenizer(
            dialogue,
            truncation = True,
            padding = 'max_length',
            max_length = max_length,
            return_tensors = "pt"
        )
        input_ids.append(encoded["input_ids"].squeeze(0))
        attention_masks.append(encoded["attention_mask"].squeeze(0))

    return input_ids, attention_masks

In [None]:
def pad_labels(labels, max_length = 128):
    padded_labels = []
    for label_set in labels:
        label_tensor = torch.tensor(label_set, dtype = torch.float)
        # Pad with -1 to ignore padding tokens in the loss function
        padded_tensor = torch.cat(
            [label_tensor, torch.full((max_length - len(label_set),), -1)]
        )
        padded_labels.append(padded_tensor)
    return padded_labels

In [None]:
class ConversationDataset(Dataset):
    def __init__(self, input_ids, attention_masks, emotion_labels, trigger_labels):
        self.input_ids = input_ids
        self.attention_masks = attention_masks
        self.emotion_labels = emotion_labels
        self.trigger_labels = trigger_labels

    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return {
            "input_ids": self.input_ids[idx],
            "attention_mask": self.attention_masks[idx],
            "emotion_labels": self.emotion_labels[idx],
            "trigger_labels": self.trigger_labels[idx],
        }

In [None]:
MAX_LENGTH = 256    # @param [96, 128, 256] {type: 'raw'}
BATCH_SIZE = 8    # @param [16, 32, 64] {type: 'raw'}

In [None]:
train_input_ids, train_attention_masks = tokenize_conversation(train_conversations, max_length = MAX_LENGTH)

train_emotion_labels = pad_labels(train_emotions, max_length = MAX_LENGTH)
train_trigger_labels = pad_labels(train_triggers, max_length = MAX_LENGTH)

train_dataset = ConversationDataset(train_input_ids, train_attention_masks, train_emotion_labels, train_trigger_labels)

In [None]:
val_input_ids, val_attention_masks = tokenize_conversation(val_conversations, max_length = MAX_LENGTH)

val_emotion_labels = pad_labels(val_emotions, max_length = MAX_LENGTH)
val_trigger_labels = pad_labels(val_triggers, max_length = MAX_LENGTH)

val_dataset = ConversationDataset(val_input_ids, val_attention_masks, val_emotion_labels, val_trigger_labels)

In [None]:
train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE, shuffle = False)
val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE, shuffle = False)

In [None]:
NUM_EXPERTS = 8 # @param [2, 4, 8] {type: 'raw'}
TOP_K = 4 # @param [1, 2, 4, 8] {type: 'raw'}
LEARNING_RATE = 0.00001 # @param ["0.00001", "0.00002","0.00005","0.0001"] {"type":"raw"}
NUM_EPOCHS = 25 # @param ["5", "10", "15", "20", "25"] {"type":"raw"}

In [None]:
assert TOP_K <= NUM_EXPERTS, "Select different values for TOP_K and NUM_EXPERTS!"

In [None]:
class MoEForEmotionAndTriggerClassification(nn.Module):
    def __init__(self, num_experts, num_classes, k, model_name = 'roberta-base'):
        super(MoEForEmotionAndTriggerClassification, self).__init__()

        self.roberta = RobertaModel.from_pretrained(model_name)
        for param in self.roberta.parameters():
            param.requires_grad = True  # Set to True if you want to fine-tune RoBERTa
        hidden_size = self.roberta.config.hidden_size

        self.gating_network = nn.Linear(hidden_size, num_experts)
        self.experts = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(num_experts)])

        self.emotion_classifier = nn.Linear(hidden_size, num_classes)
        self.trigger_classifier = nn.Linear(hidden_size, 1)

        self.k = k
        self.dropout = nn.Dropout(p = 0.1)

    def forward(self, input_ids, attention_mask):
        roberta_outputs = self.roberta(input_ids = input_ids, attention_mask = attention_mask)
        embeddings = roberta_outputs.last_hidden_state  # (batch_size, seq_len, hidden_size)
        pooled_embeddings = embeddings.mean(dim = 1)    # (batch_size, hidden_size)
        pooled_embeddings = self.dropout(pooled_embeddings)

        expert_weights = self.gating_network(pooled_embeddings) # (batch_size, num_experts)
        expert_weights = torch.softmax(expert_weights, dim = -1)

        # top-k experts only are activated
        topk_weights, topk_indices = torch.topk(expert_weights, self.k, dim = -1)

        combined_output = torch.zeros_like(embeddings)  # (batch_size, seq_len, hidden_size)

        for i in range(self.k):
            expert_idx = topk_indices[:, i]
            weight = topk_weights[:, i].unsqueeze(-1).unsqueeze(-1)

            expert_outputs = torch.stack(
                [self.experts[expert_idx[j]](embeddings[j]) for j in range(expert_idx.size(0))]
            )
            combined_output += weight * expert_outputs  # (batch_size, hidden_size)

        combined_output = self.dropout(combined_output)
        emotion_logits = self.emotion_classifier(combined_output)   # (batch_size, seq_len, num_classes)
        trigger_logits = self.trigger_classifier(combined_output).squeeze(-1)   # (batch_size, seq_len)

        return emotion_logits, trigger_logits

In [None]:
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss

moe = MoEForEmotionAndTriggerClassification(num_experts = NUM_EXPERTS, num_classes = len(labels_to_ids), k = TOP_K)
optimizer = AdamW(moe.parameters(), lr = LEARNING_RATE)

emotion_loss_fn = CrossEntropyLoss()
trigger_loss_fn = BCEWithLogitsLoss()

In [None]:
moe.to(device)

In [None]:
def evaluate(model, val_loader):
    model.eval()
    val_loss, nb_steps = 0.0, 0
    total_emotion_preds, correct_emotion_preds = 0, 0
    total_trigger_preds, correct_trigger_preds = 0, 0

    with torch.no_grad():
        for idx, batch in enumerate(val_loader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            emotion_labels = batch['emotion_labels'].to(device)
            trigger_labels = batch['trigger_labels'].to(device)

            emotion_logits, trigger_logits = model(input_ids, attention_mask)

            # removing padding
            emotion_mask = emotion_labels != -1

            emotion_logits_flat = emotion_logits.view(-1, emotion_logits.size(-1))
            emotion_labels_flat = emotion_labels.view(-1)

            emotion_logits = emotion_logits_flat[emotion_mask.view(-1)]
            emotion_labels = emotion_labels_flat[emotion_mask.view(-1)]

            trigger_mask = trigger_labels != -1

            trigger_logits_flat = trigger_logits.view(-1)
            trigger_labels_flat = trigger_labels.view(-1)

            trigger_logits = trigger_logits_flat[trigger_mask.view(-1)]
            trigger_labels = trigger_labels_flat[trigger_mask.view(-1)]

            # calculating loss
            emotion_loss = emotion_loss_fn(emotion_logits, emotion_labels.long())
            trigger_loss = trigger_loss_fn(trigger_logits, trigger_labels)

            loss = emotion_loss + trigger_loss
            val_loss += loss.item()

            # calculating accuracy
            emotion_preds = torch.argmax(emotion_logits, dim=-1)
            trigger_preds = (torch.sigmoid(trigger_logits).squeeze(-1) > 0.5).long()

            correct_emotion_preds += torch.sum(emotion_preds == emotion_labels).item()
            correct_trigger_preds += torch.sum(trigger_preds == trigger_labels).item()

            total_emotion_preds += emotion_labels.numel()
            total_trigger_preds += trigger_labels.numel()

            nb_steps += 1

            if idx % 100 == 0:
                loss_step = val_loss / nb_steps
                print(f'      Validation loss per 100 training steps: {loss_step}')

        avg_val_loss = val_loss / len(val_loader)
        emotion_accuracy = correct_emotion_preds / total_emotion_preds
        trigger_accuracy = correct_trigger_preds / total_trigger_preds
        avg_val_accuracy = (emotion_accuracy + trigger_accuracy)/2

    return avg_val_loss, avg_val_accuracy

In [None]:
def train_and_validate(model, train_loader, val_loader, num_epochs = 3):
    for epoch in range(num_epochs):
        print(f"Epoch [{epoch + 1}/{num_epochs}]")
        model.train()
        train_loss, nb_steps = 0.0, 0
        total_emotion_preds, correct_emotion_preds = 0, 0
        total_trigger_preds, correct_trigger_preds = 0, 0

        for idx, batch in enumerate(train_loader):
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            emotion_labels = batch['emotion_labels'].to(device)
            trigger_labels = batch['trigger_labels'].to(device)

            emotion_logits, trigger_logits = model(input_ids, attention_mask)

            # removing padding
            emotion_mask = emotion_labels != -1

            emotion_logits_flat = emotion_logits.view(-1, emotion_logits.size(-1))
            emotion_labels_flat = emotion_labels.view(-1)

            emotion_logits = emotion_logits_flat[emotion_mask.view(-1)]
            emotion_labels = emotion_labels_flat[emotion_mask.view(-1)]

            trigger_mask = trigger_labels != -1

            trigger_logits_flat = trigger_logits.view(-1)
            trigger_labels_flat = trigger_labels.view(-1)

            trigger_logits = trigger_logits_flat[trigger_mask.view(-1)]
            trigger_labels = trigger_labels_flat[trigger_mask.view(-1)]

            # calculating loss
            emotion_loss = emotion_loss_fn(emotion_logits, emotion_labels.long())
            trigger_loss = trigger_loss_fn(trigger_logits, trigger_labels)

            loss = emotion_loss + trigger_loss
            train_loss += loss.item()

            loss.backward()
            optimizer.step()

            # calculating accuracy
            emotion_preds = torch.argmax(emotion_logits, dim=-1)
            trigger_preds = (torch.sigmoid(trigger_logits).squeeze(-1) > 0.5).long()

            correct_emotion_preds += torch.sum(emotion_preds == emotion_labels).item()
            correct_trigger_preds += torch.sum(trigger_preds == trigger_labels).item()

            total_emotion_preds += emotion_labels.numel()
            total_trigger_preds += trigger_labels.numel()
            nb_steps += 1

            if idx % 100 == 0:
                loss_step = train_loss / nb_steps
                print(f'      Training loss per 100 training steps: {loss_step}')

        avg_train_loss = train_loss / len(train_loader)
        emotion_accuracy = correct_emotion_preds / total_emotion_preds
        trigger_accuracy = correct_trigger_preds / total_trigger_preds
        avg_train_accuracy = (emotion_accuracy + trigger_accuracy)/2

        val_loss, val_accuracy = evaluate(model, val_loader)

        print(f"   Training Loss: {avg_train_loss:.3f}, Training Accuracy: {avg_train_accuracy:.3f}")
        print(f"   Validation Loss: {val_loss:.3f}, Validation Accuracy: {val_accuracy:.3f}\n")

In [None]:
train_and_validate(moe, train_loader, val_loader, num_epochs = NUM_EPOCHS)

In [None]:
torch.save(moe.state_dict(), f'trained_models/moe_model_{NUM_EXPERTS}_experts_{TOP_K}_active_{LEARNING_RATE}_lr_{NUM_EPOCHS}_epochs.pth')

In [None]:
moe_loaded = MoEForEmotionAndTriggerClassification(num_experts = NUM_EXPERTS, num_classes = len(labels_to_ids), k = TOP_K)
moe_loaded.load_state_dict(torch.load(f'trained_models/moe_model_{NUM_EXPERTS}_experts_{TOP_K}_active_{LEARNING_RATE}_lr_{NUM_EPOCHS}_epochs.pth', map_location=torch.device("cpu")))

In [None]:
test_df = pd.read_json('/content/drive/MyDrive/MELD_test_efr.json')

test_conversations = list(test_df['utterances'])
test_emotions = [[labels_to_ids[emotion] for emotion in conv] for conv in list(test_df['emotions'])]
test_triggers = [[-1.0 for sent in conv] for conv in test_emotions]

In [None]:
test_input_ids, test_attention_masks = tokenize_conversation(test_conversations, max_length = MAX_LENGTH)

test_emotion_labels = pad_labels(test_emotions, max_length = MAX_LENGTH)
test_trigger_labels = pad_labels(test_triggers, max_length = MAX_LENGTH)

test_dataset = ConversationDataset(test_input_ids, test_attention_masks, test_emotion_labels, test_trigger_labels)
test_loader = DataLoader(test_dataset, batch_size = BATCH_SIZE, shuffle = False)

In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

moe_loaded.to(device)
moe_loaded.eval()

total_emotion_preds, correct_emotion_preds = 0, 0
test_accuracy = 0.0
test_precision = 0.0
test_recall = 0.0
test_f1 = 0.0
num_samples, nb_steps = 0, 0

for batch in test_loader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    emotion_labels = batch['emotion_labels'].to(device)

    with torch.no_grad():
        # Forward pass
        emotion_logits, _ = moe_loaded(input_ids, attention_mask)

        # Compute predictions
        emotion_mask = emotion_labels != -1

        emotion_logits_flat = emotion_logits.view(-1, emotion_logits.size(-1))
        emotion_labels_flat = emotion_labels.view(-1)

        emotion_logits = emotion_logits_flat[emotion_mask.view(-1)]
        emotion_labels = emotion_labels_flat[emotion_mask.view(-1)]

        emotion_preds = torch.argmax(emotion_logits, dim = -1)

        # Calculate metrics for emotion classification
        emotion_preds_flat = emotion_preds.cpu().numpy()
        emotion_labels_flat = emotion_labels.cpu().numpy()

        test_accuracy += torch.sum(emotion_preds == emotion_labels).item()

        precision, recall, f1, _ = precision_recall_fscore_support(
            emotion_labels_flat, emotion_preds_flat, average='weighted', zero_division = 0
        )

        test_precision += precision
        test_recall += recall
        test_f1 += f1
        num_samples += len(emotion_labels_flat)
        nb_steps += 1

# Calculate average metrics
avg_accuracy = test_accuracy / num_samples
avg_precision = test_precision / nb_steps
avg_recall = test_recall / nb_steps
avg_f1 = test_f1 / nb_steps

# Output results
print("Emotion classification:")
print(f"   Test Accuracy: {avg_accuracy:.3f}")
print(f"   Test Precision: {avg_precision:.3f}")
print(f"   Test Recall: {avg_recall:.3f}")
print(f"   Test F1-score: {avg_f1:.3f}")