In [None]:
import torch
import torch.nn as nn
import numpy as np
from transformers import BartTokenizer, BartModel
from datasets import load_dataset
from sklearn.metrics import classification_report


# Load the GoEmotions dataset
dataset = load_dataset("go_emotions")

print(dataset)

# Emotion labels
candidate_labels = [
    "admiration", "amusement", "anger", "annoyance", "approval", "caring",
    "confusion", "curiosity", "desire", "disappointment", "disapproval",
    "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
    "joy", "love", "nervousness", "optimism", "pride", "realization", "relief",
    "remorse", "sadness", "surprise", "neutral"
]

In [None]:
# Load BART base model and tokenizer
model_name = "facebook/bart-large"
tokenizer = BartTokenizer.from_pretrained(model_name)
base_model = BartModel.from_pretrained(model_name)

# Define BART with a classification head
class BartWithClassificationHead(nn.Module):
    def __init__(self, base_model, num_labels):
        super().__init__()
        self.bart = base_model
        self.classification_head = nn.Sequential(
            nn.Linear(self.bart.config.d_model, num_labels),
            nn.Sigmoid()
        )
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bart(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        logits = self.classification_head(cls_output)
        return logits

# Initialize the model
num_labels = 28
model = BartWithClassificationHead(base_model, num_labels)

In [None]:
# Tokenize the test dataset
def preprocess_function(examples):
    return tokenizer(
        examples["text"], 
        truncation=True, 
        padding="max_length", 
        max_length=128,
        return_tensors="pt"
    )

tokenized_test = dataset["test"].map(preprocess_function, batched=True)

In [None]:
test_texts = tokenized_test["text"]
test_inputs = tokenizer(test_texts, truncation=True, padding=True, max_length=128, return_tensors="pt")

model.eval()
with torch.no_grad():
    probs = model(test_inputs["input_ids"], test_inputs["attention_mask"])

# Apply threshold to probs
threshold = 0.5
predictions = (probs > threshold).int()

predicted_labels = []
for i in range(len(test_texts)):
    labels = [candidate_labels[j] for j, pred in enumerate(predictions[i]) if pred == 1]
    predicted_labels.append(labels)

In [None]:
# Convert ground truth to binary format
def binarize_labels(example):
    binary_labels = np.zeros(num_labels, dtype=int)
    for label in example["labels"]:
        binary_labels[label] = 1
    return binary_labels

binary_ground_truth = np.array([binarize_labels(example) for example in dataset["test"]])

In [None]:
# Generate classification report
report = classification_report(binary_ground_truth, predictions.numpy(), target_names=candidate_labels, zero_division=0)
print(report)