In [1]:
from transformers import pipeline
from datasets import load_dataset
from sklearn.metrics import classification_report
import numpy as np
import torch

# Load the GoEmotions dataset
dataset = load_dataset("go_emotions")

# GoEmotions emotion labels
candidate_labels = [
    "admiration", "amusement", "anger", "annoyance", "approval", "caring",
    "confusion", "curiosity", "desire", "disappointment", "disapproval",
    "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
    "joy", "love", "nervousness", "optimism", "pride", "realization", "relief",
    "remorse", "sadness", "surprise", "neutral"
]

device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else -1)
classifier = pipeline("zero-shot-classification",
                      model="facebook/bart-large-mnli",
                      device=device)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def classify_sample(examples):
    results = classifier(examples["text"], candidate_labels=candidate_labels, multi_label=True, batch_size=10)
    predicted_labels = [[label for label, score in zip(result["labels"], result["scores"]) if score > 0.5] for result in results]
    return {"predicted_labels": predicted_labels}

classified_samples = dataset["test"].select(range(100)).map(classify_sample, batched=True, batch_size=10)

In [3]:
def binarize_labels(sample):
    binary_ground_truth = [1 if i in sample["labels"] else 0 for i in range(len(candidate_labels))]
    binary_predictions = [1 if label in sample["predicted_labels"] else 0 for label in candidate_labels]
    return {"binary_ground_truth": binary_ground_truth, "binary_predictions": binary_predictions}

# Apply binarization to classified samples
binary_data = classified_samples.map(binarize_labels)

# Extract binary ground truth and predictions
ground_truth = np.array([sample["binary_ground_truth"] for sample in binary_data])
predictions = np.array([sample["binary_predictions"] for sample in binary_data])

# Calculate precision, recall, F1-score for each label
report = classification_report(ground_truth[:100], predictions[:100], target_names=candidate_labels, zero_division=0)

print(report)

                precision    recall  f1-score   support

    admiration       0.23      0.75      0.35         8
     amusement       0.43      0.86      0.57         7
         anger       0.12      1.00      0.22         1
     annoyance       0.10      0.67      0.17         6
      approval       0.08      0.67      0.15         3
        caring       0.05      0.67      0.09         3
     confusion       0.04      0.67      0.07         3
     curiosity       0.10      0.50      0.17         4
        desire       0.06      0.50      0.10         2
disappointment       0.00      0.00      0.00         1
   disapproval       0.14      0.88      0.24         8
       disgust       0.00      0.00      0.00         0
 embarrassment       0.00      0.00      0.00         0
    excitement       0.12      1.00      0.22         2
          fear       0.50      0.83      0.62         6
     gratitude       0.67      0.89      0.76         9
         grief       0.00      0.00      0.00  