In [1]:
import pandas as pd
import os
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score, classification_report
from tqdm import tqdm
os.environ["TOKENIZERS_PARALLELISM"] = "false"
TRAINING = False

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
empath_data = pd.read_csv("empatheticdialogues/train.csv", on_bad_lines="skip")
val_empath_data = pd.read_csv("empatheticdialogues/valid.csv", on_bad_lines="skip")
grouped = empath_data.groupby("conv_id")
val_grouped = val_empath_data.groupby("conv_id")
# speaker_counts = (
#     empath_data
#     .groupby("conv_id")["speaker_idx"]
#     .nunique()
# )
# speaker_counts.value_counts()

In [3]:
emotion_labels = empath_data["context"].unique().tolist()
emotion_to_id = {emotion: idx for idx, emotion in enumerate(emotion_labels)}

conversations = []

for conv_id, df_conv in grouped:
    texts = df_conv["utterance"].tolist()
    # texts = (df_conv["prompt"] + "[SEP]" + df_conv["utterance"]).tolist()
    labels = [emotion_to_id[x] for x in df_conv["context"]]
    timestamps = df_conv["utterance_idx"].tolist()
    speakers = (
        df_conv["speaker_idx"]
        .rank(method="dense")
        .astype(int)
        .sub(1)
        .tolist()
    )
    conversations.append({
        "texts": texts,
        "labels": labels,
        "timestamps": timestamps,
        "speakers": speakers
    })

val_conversations = []

for conv_id, df_conv in val_grouped:
    texts = df_conv["utterance"].tolist()
    # texts = (df_conv["prompt"] + "[SEP]" + df_conv["utterance"]).tolist()
    labels = [emotion_to_id[x] for x in df_conv["context"]]
    timestamps = df_conv["utterance_idx"].tolist()
    speakers = (
        df_conv["speaker_idx"]
        .rank(method="dense")
        .astype(int)
        .sub(1)
        .tolist()
    )
    val_conversations.append({
        "texts": texts,
        "labels": labels,
        "timestamps": timestamps,
        "speakers": speakers
    })


In [4]:

# texts = (empath_data["prompt"].astype(str) + "[SEP]" + empath_data["utterance"].astype(str)).tolist()
# raw_labels = empath_data["context"].values.tolist()
# labels = [emotion_to_id[label] for label in raw_labels]
# time_stamps = empath_data["utterance_idx"].values.tolist()

In [5]:
# class EmpatheticDialoguesDataset(Dataset):
#     def __init__(self, texts, labels, tokenizer, time_stamps, max_len=128):
#         self.texts = texts
#         self.labels = labels
#         self.tokenizer = tokenizer
#         self.max_len = max_len
#         self.time_stamps = time_stamps
#
#     def __len__(self):
#         return len(self.texts)
#
#     def __getitem__(self, idx):
#         text = self.texts[idx]
#         label = torch.tensor(self.labels[idx], dtype=torch.long)
#         time_stamp = torch.tensor(self.time_stamps[idx], dtype=torch.int) - 1
#         encoding = self.tokenizer(
#             text,
#             truncation=True,
#             padding="max_length",
#             max_length=self.max_len,
#             return_tensors="pt"
#         )
#
#         return {
#             "input_ids": encoding["input_ids"].squeeze(0),
#             "attention_mask": encoding["attention_mask"].squeeze(0),
#             "labels": label,
#             "timestamps": time_stamp,
#         }

class EmpatheticDialoguesDataset(Dataset):
    def __init__(self, conversations, tokenizer, max_len=128):
        self.conversations = conversations
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.conversations)

    def __getitem__(self, idx):
        conv = self.conversations[idx]

        return {
            "texts": conv["texts"],
            "labels": torch.tensor(conv["labels"], dtype=torch.long),
            "timestamps": torch.tensor(conv["timestamps"], dtype=torch.long),
            "speakers": torch.tensor(conv["speakers"], dtype=torch.long)
        }

In [6]:
def collate_conversations(batch, tokenizer, max_len=128):
    B = len(batch)
    T_max = max(len(item["texts"]) for item in batch)

    padded_texts = []
    padded_labels = []
    padded_timestamps = []
    padded_speakers = []
    utterance_mask = []
    for item in batch:
        texts = item["texts"]
        labels = item["labels"]
        timestamps = item["timestamps"]
        speakers = item["speakers"]

        pad_len = T_max - len(texts)

        padded_texts.extend(texts + [""] * pad_len)

        utterance_mask.append(
            torch.cat([torch.ones(len(texts)), torch.zeros(pad_len)])
        )

        padded_labels.append(
            torch.cat([labels, torch.full((pad_len,), -1, dtype=torch.long)])
        )

        padded_timestamps.append(
            torch.cat([timestamps, torch.zeros(pad_len, dtype=torch.long)])
        )

        padded_speakers.append(
            torch.cat([speakers, torch.zeros(pad_len, dtype=torch.long)])
        )


    encoding = tokenizer(
        padded_texts,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt"
    )

    return {
        "input_ids": encoding["input_ids"],
        "attention_mask": encoding["attention_mask"],
        "labels": torch.stack(padded_labels),
        "timestamps": torch.stack(padded_timestamps),
        "speakers": torch.stack(padded_speakers),
        "utterance_mask": torch.stack(utterance_mask)
    }

In [7]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
dataset = EmpatheticDialoguesDataset(conversations, tokenizer)
val_dataset = EmpatheticDialoguesDataset(val_conversations, tokenizer)
loader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=8,
    # pin_memory=True,
    # persistent_workers=True,
    collate_fn=lambda x: collate_conversations(x, tokenizer),
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False, # future me, keep it false, helps reproduce results
    num_workers=8,
    #pin_memory=True,
    # persistent_workers=True,
    collate_fn=lambda x: collate_conversations(x, tokenizer),
)


In [8]:
class TemporalTransformer(nn.Module):
    def __init__(self, hidden_size, num_layers, num_heads, dropout):
        super().__init__()

        layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dropout=dropout,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            layer,
            num_layers=num_layers
        )

    def forward(self, x, padding_mask):
        assert padding_mask.shape[:2] == x.shape[:2]

        return self.encoder(x, src_key_padding_mask=padding_mask)


In [9]:
class EmotionalTimeBert(nn.Module):
    def __init__(self, encoder_name, num_labels, max_time = 8, max_speakers=2):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(encoder_name)
        hidden = self.encoder.config.hidden_size

        self.head_emotions = nn.Linear(hidden, num_labels)
        self.time_embed = nn.Embedding(max_time + 1, hidden)
        self.speakers_embed = nn.Embedding(max_speakers + 1, hidden)
        self.temporal_transformer = TemporalTransformer(hidden, 2, 8, 0.1)# num_labels, hidden, False)

        # pause training bert
        # for p in self.encoder.parameters():
        #     p.requires_grad = False

        # for layer in self.encoder.encoder.layer[-2:]:
        #     for p in layer.parameters():
        #         p.requires_grad = True

        for p in self.encoder.parameters():
            p.requires_grad = False

        for layer in self.encoder.encoder.layer[-4:]:
            for p in layer.parameters():
                p.requires_grad = True



    def forward(self, input_ids, attention_mask, timestamps=None, speakers=None, labels=None, utterance_mask=None):
        # bert_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)

        flat_mask = utterance_mask.view(-1).bool()
        bert_output = self.encoder(
            input_ids=input_ids[flat_mask],
            attention_mask=attention_mask[flat_mask]
        )

        h = bert_output.last_hidden_state[:, 0, :]   # (B*T, H)
        # B, T = timestamps.shape
        # H = h.size(-1)

        B, T = timestamps.shape
        H = h.size(-1)

        h_all = torch.zeros(
            (B * T, H),
            device=h.device,
            dtype=h.dtype
        )

        h_all[flat_mask] = h

        speakers = speakers + 1

        h_t = h_all.view(B, T, H)
        time_vec = self.time_embed(timestamps)
        speakers_vec = self.speakers_embed(speakers)
        Z = h_t + time_vec + speakers_vec
        padding_mask = (labels == -1)# (timestamps == 0) # & (speakers == 0)
        U = self.temporal_transformer(Z, padding_mask)

        logits = self.head_emotions(U)
        return logits

In [10]:
device = "cpu"

if torch.cuda.is_available():
    device = "cuda"
    print("Device:", torch.cuda.get_device_name(0))
    print("CUDA Enabled!")

model = EmotionalTimeBert("./medbert_4_epochs", num_labels=len(emotion_labels)).to(device)

if not TRAINING:
    model.load_state_dict(torch.load("emotional_time_bert_5_to_show.pt"))

criterion = nn.CrossEntropyLoss(ignore_index=-1)

optimizer = torch.optim.AdamW([
    {"params": model.encoder.encoder.layer[-2:].parameters(), "lr": 1e-5},
    {"params": model.temporal_transformer.parameters(), "lr": 3e-4},
    {"params": model.time_embed.parameters(), "lr": 3e-4},
    {"params": model.speakers_embed.parameters(), "lr": 3e-4},
    {"params": model.head_emotions.parameters(), "lr": 3e-4},
])

Some weights of BertModel were not initialized from the model checkpoint at ./medbert_4_epochs and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Device: NVIDIA GeForce RTX 5080
CUDA Enabled!


In [11]:
def train_model(bar):
    model.train()
    total_loss = 0.0
    num_batches = 0

    for batch in progress_bar:
        optimizer.zero_grad()
        postfix = {}

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)
        timestamps = batch["timestamps"].to(device)
        speakers = batch["speakers"].to(device)
        utterance_mask = batch["utterance_mask"].to(device)

        logits = model(input_ids, attention_mask, timestamps, speakers, labels, utterance_mask)
        loss = criterion(
            logits.view(-1, logits.size(-1)),
            labels.view(-1)
        )

        postfix["ED Loss"] = loss.item()

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        num_batches += 1

        progress_bar.set_postfix(step = "Training", loss=loss.item(), average=total_loss / num_batches)

    avg_loss = total_loss / num_batches

    return avg_loss

def validate_model(bar):
    model.eval()
    total_loss = 0.0
    num_batches = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            timestamps = batch["timestamps"].to(device)
            speakers = batch["speakers"].to(device)
            labels = batch["labels"].to(device)   # (B, T)
            utterance_mask = batch["utterance_mask"].to(device)

            logits = model(
                input_ids,
                attention_mask,
                timestamps,
                speakers,
                labels,
                utterance_mask
            )

            loss = criterion(
                logits.view(-1, logits.size(-1)),
                labels.view(-1)
            )

            total_loss += loss.item()
            num_batches += 1

            preds = logits.argmax(dim=-1)
            mask = labels != -1

            all_preds.append(preds[mask].cpu())
            all_labels.append(labels[mask].cpu())

            bar.set_postfix(step="Validating", loss=loss.item())

    avg_val_loss = total_loss / num_batches

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    macro_f1 = f1_score(
        all_labels.numpy(),
        all_preds.numpy(),
        average="macro"
    )

    return avg_val_loss, macro_f1

if TRAINING:
    num_of_epochs = 5
    for epoch in range(num_of_epochs):
        progress_bar = tqdm(loader, total=len(loader))
        avg_loss = train_model(progress_bar)
        validate_progress = tqdm(val_loader, total=len(val_loader))
        val_loss, val_f1 = validate_model(validate_progress)
        print(f"Epoch {epoch+1}: train loss = {avg_loss:.4f}")
        print(f"Epoch {epoch+1}: val loss = {val_loss:.4f}")
        print(f"Epoch {epoch+1}: val F1 = {val_f1:.4f}")

    torch.save(model.state_dict(), f"emotional_time_bert_{num_of_epochs}.pt")



In [12]:
test_empath_data = pd.read_csv("empatheticdialogues/test.csv", on_bad_lines="skip")
test_grouped = test_empath_data.groupby("conv_id")
test_conversations = []

for conv_id, df_conv in test_grouped:
    texts = df_conv["utterance"].tolist()
    # texts = (df_conv["prompt"] + "[SEP]" + df_conv["utterance"]).tolist()
    labels = [emotion_to_id[x] for x in df_conv["context"]]
    timestamps = df_conv["utterance_idx"].tolist()
    speakers = (
        df_conv["speaker_idx"]
        .rank(method="dense")
        .astype(int)
        .sub(1)
        .tolist()
    )
    test_conversations.append({
        "texts": texts,
        "labels": labels,
        "timestamps": timestamps,
        "speakers": speakers
    })

In [13]:
def test_model(model, dataloader, device, emotion_labels=None):
    model.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Testing"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            timestamps = batch["timestamps"].to(device)
            speakers = batch["speakers"].to(device)
            labels = batch["labels"].to(device)
            utterance_mask = batch["utterance_mask"].to(device)

            logits = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                timestamps=timestamps,
                speakers=speakers,
                labels=labels,
                utterance_mask=utterance_mask
            )

            preds = logits.argmax(dim=-1)

            mask = labels != -1  # ignore padded utterances

            all_preds.append(preds[mask].cpu())
            all_labels.append(labels[mask].cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    macro_f1 = f1_score(
        all_labels.numpy(),
        all_preds.numpy(),
        average="macro"
    )

    print(f"Test Macro F1: {macro_f1:.4f}")

    if emotion_labels is not None:
        print("\nPer-emotion results:")
        print(classification_report(
            all_labels.numpy(),
            all_preds.numpy(),
            target_names=emotion_labels,
            digits=3
        ))

    return macro_f1, all_preds, all_labels


In [14]:
test_dataset = EmpatheticDialoguesDataset(test_conversations, tokenizer)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=8,
    collate_fn=lambda x: collate_conversations(x, tokenizer)
)

In [15]:
test_f1 = test_model(
    model=model,
    dataloader=test_loader,
    device=device,
    emotion_labels=emotion_labels
)

  output = torch._nested_tensor_from_mask(
Testing: 100%|██████████| 80/80 [00:06<00:00, 13.12it/s]

Test Macro F1: 0.4205

Per-emotion results:
              precision    recall  f1-score   support

 sentimental      0.403     0.415     0.409       205
      afraid      0.309     0.262     0.284       164
       proud      0.630     0.308     0.413       221
    faithful      0.733     0.282     0.407       117
   terrified      0.326     0.548     0.409       155
      joyful      0.311     0.171     0.221       187
       angry      0.252     0.149     0.188       181
         sad      0.356     0.477     0.408       195
     jealous      0.564     0.579     0.571       183
    grateful      0.583     0.380     0.460       221
    prepared      0.450     0.595     0.512       173
 embarrassed      0.584     0.642     0.612       179
     excited      0.389     0.624     0.479       202
     annoyed      0.385     0.616     0.474       198
      lonely      0.545     0.772     0.639       171
     ashamed      0.315     0.161     0.213       143
      guilty      0.385     0.638    


