In [1]:
import pandas as pd
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from models.emotimebert import EmotionalTimeBert
from sklearn.metrics import f1_score
from tqdm import tqdm
import matplotlib.pyplot as plt
from data_handling.datasets.empathetic_dialogues import EmpatheticDialoguesDataset
from data_handling.data_cleaning.empathetic_dialogues import load_empath_conversations, load_empath_test_conversations
from utils.utils import train_model, validate_model, test_model, collate_conversations, create_new_run_dir, save_table_as_csv
from pathlib import Path
import re
os.environ["TOKENIZERS_PARALLELISM"] = "false"
TRAINING = False


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_save_location = create_new_run_dir(new_dir=TRAINING)

In [3]:
conversations, val_conversations, emotion_labels, emotion_to_id = load_empath_conversations()


In [4]:
# (Moved Empathetic Dialogues preprocessing to data_handling/empathetic_dialogues.py)


In [5]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = EmpatheticDialoguesDataset(conversations, tokenizer)
val_dataset = EmpatheticDialoguesDataset(val_conversations, tokenizer)
loader = DataLoader(
    dataset,
    batch_size=40,
    shuffle=True,
    num_workers=8,
    # pin_memory=True,
    # persistent_workers=True,
    collate_fn=lambda x: collate_conversations(x, tokenizer),
)

val_loader = DataLoader(
    val_dataset,
    batch_size=40,
    shuffle=False, # future me, keep it false, helps reproduce results
    num_workers=8,
    #pin_memory=True,
    # persistent_workers=True,
    collate_fn=lambda x: collate_conversations(x, tokenizer),
)


In [6]:
device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
    print("Device:", torch.cuda.get_device_name(0))
    print("CUDA Enabled!")

model = EmotionalTimeBert("./medbert_4_epochs", num_labels=len(emotion_labels)).to(device)

if not TRAINING:
    model.load_state_dict(torch.load(f"{base_save_location}/model.pt")) #  emotional_time_bert_GRU_To_Show5.pt"

criterion = nn.CrossEntropyLoss(ignore_index=-1)

optimizer = torch.optim.AdamW([
    {"params": model.encoder.encoder.layer[-2:].parameters(), "lr": 1e-5},
    {"params": model.temporal_transformer.parameters(), "lr": 3e-4},
    {"params": model.time_embed.parameters(), "lr": 3e-4},
    {"params": model.speakers_embed.parameters(), "lr": 3e-4},
    {"params": model.head_emotions.parameters(), "lr": 3e-4},
])

Device: NVIDIA GeForce RTX 5080
CUDA Enabled!


In [7]:
train_step_losses = []
val_step_losses = []
global_steps = []
val_steps = []


if TRAINING:
    num_of_epochs = 5
    step = 0
    for epoch in range(num_of_epochs):
        progress_bar = tqdm(loader, total=len(loader))
        avg_loss, step = train_model(model, optimizer, device, criterion=criterion, bar=progress_bar, train_step_losses=train_step_losses, global_steps=global_steps, start_step=step)
        validate_progress = tqdm(val_loader, total=len(val_loader))
        val_loss, val_f1, step = validate_model(model, device, criterion=criterion, bar=validate_progress, val_step_losses=val_step_losses, val_steps=val_steps, start_step=step)

        print(f"Epoch {epoch+1}: train loss = {avg_loss:.4f}")
        print(f"Epoch {epoch+1}: val loss   = {val_loss:.4f}")
        print(f"Epoch {epoch+1}: val F1     = {val_f1:.4f}")

    torch.save(model.state_dict(), f"{base_save_location}/model.pt")
    save_table_as_csv(global_steps, f"{base_save_location}/global_steps.csv")
    save_table_as_csv(train_step_losses, f"{base_save_location}/train_step_losses.csv")
    save_table_as_csv(val_step_losses, f"{base_save_location}/val_step_losses.csv")
    save_table_as_csv(val_steps, f"{base_save_location}/val_steps.csv")



In [8]:
if TRAINING:
    plt.figure(figsize=(10, 5))
    plt.plot(global_steps, train_step_losses, alpha=0.4, label="Train (per batch)")
    plt.plot(val_steps, val_step_losses, alpha=0.8, label="Val (per batch)")

    plt.xlabel("Training Step")
    plt.ylabel("Loss")
    plt.title("Per-batch Training & Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"{base_save_location}/Unsmoothed.png")

In [9]:
def ema(values, beta=0.98):
    smoothed = []
    avg = values[0]
    for v in values:
        avg = beta * avg + (1 - beta) * v
        smoothed.append(avg)
    return smoothed

if TRAINING:
    plt.figure(figsize=(10, 5))
    plt.plot(global_steps, ema(train_step_losses), label="Train")
    plt.plot(val_steps, ema(val_step_losses), label="Val")
    plt.xlabel("Training Step")
    plt.ylabel("Loss")
    plt.title("Per-step Loss")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"{base_save_location}/Smoothed.png")

In [10]:
test_conversations = load_empath_test_conversations(emotion_to_id)

In [11]:
test_dataset = EmpatheticDialoguesDataset(test_conversations, tokenizer)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=8,
    collate_fn=lambda x: collate_conversations(x, tokenizer)
)

In [12]:
test_f1 = test_model(
    model=model,
    dataloader=test_loader,
    device=device,
    emotion_labels=emotion_labels,
    save_results=True,
    save_dir=f"{base_save_location}/classification_report.csv",
)
# also look into bert's special keywords




In [13]:

def ablation(model, dataloader, device, ablate):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            utterance_mask = batch["utterance_mask"].to(device)

            timestamps = batch["timestamps"].to(device)
            speakers = batch["speakers"].to(device)

            if ablate == "time_zero":
                timestamps = torch.zeros_like(timestamps)
            elif ablate == "time_shuffle":
                timestamps = timestamps[torch.randperm(timestamps.size(0))]
            elif ablate == "speaker_zero":
                speakers = torch.zeros_like(speakers)

            logits = model(
                input_ids, attention_mask,
                timestamps, speakers,
                labels, utterance_mask
            )

            preds = logits.argmax(-1)
            mask = labels != -1

            all_preds.append(preds[mask].cpu())
            all_labels.append(labels[mask].cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    return f1_score(all_labels, all_preds, average="macro")

In [14]:
print("Normal:", ablation(model, val_loader, device, "none"))

Normal: 0.45721292987638296


In [15]:
print("Time zero:", ablation(model, val_loader, device, "time_zero"))

Time zero: 0.30866959758379153


In [16]:
print("Time shuffle:", ablation(model, val_loader, device, "time_shuffle"))

Time shuffle: 0.4597815262329828


In [17]:
print("Speaker zero:", ablation(model, val_loader, device, "speaker_zero"))

Speaker zero: 0.45293499501246337
