# 🎓 Final Project: Emotion Classification with BERT on GoEmotions
This notebook walks you through training a BERT model for multi-label emotion classification using the GoEmotions dataset.

In [None]:
# Install and upgrade required libraries
!pip install -U transformers datasets "scikit-learn<1.7"

In [None]:
# Import required libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, EarlyStoppingCallback
from datasets import load_dataset
from sklearn.metrics import f1_score, accuracy_score

In [None]:
# Load the GoEmotions dataset
# Link datasets: https://www.kaggle.com/datasets/debarshichanda/goemotions/
dataset = load_dataset("go_emotions")

In [None]:
# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [None]:
# Tokenize the dataset
def tokenize_function(example):
    return tokenizer(
        example["text"],
        padding="max_length",
        truncation=True,
        max_length=30
    )

tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [None]:
# Format multi-label outputs (force float32 tensors)
import torch

def format_labels(example):
    label_vector = [0.0] * 28
    for label in example["labels"]:
        label_vector[label] = 1.0
    example["labels"] = torch.tensor(label_vector, dtype=torch.float32)
    return example

tokenized_datasets = tokenized_datasets.map(format_labels)

# 💡 Force labels to be float32 for PyTorch
tokenized_datasets.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [None]:
# Load BERT for multi-label classification
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=28,
    problem_type="multi_label_classification"
)

In [None]:
# Define evaluation metrics
def compute_metrics(pred):
    logits, labels = pred
    # Ensure labels are float32 for metric calculation
    labels = labels.astype(np.float32)
    probs = torch.sigmoid(torch.from_numpy(logits))
    preds = (probs > 0.5).int().numpy()
    f1 = f1_score(labels, preds, average="micro")
    acc = accuracy_score(labels, preds)
    return {"f1": f1, "accuracy": acc}

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="steps",          # Evaluate after each epoch
    save_strategy="steps",                 # Save model after each epoch
    learning_rate=2e-5,                    # Works well for BERT fine-tuning
    per_device_train_batch_size=32,        # 16 for balance between speed and memory
    per_device_eval_batch_size=32,         # Larger for eval since no backprop
    num_train_epochs=3,                    # GoEmotions is small → more epochs help
    weight_decay=0.01,                      # Regularization
    lr_scheduler_type="linear",            # Linear decay after warmup
    warmup_ratio=0.1,                       # 10% of steps for LR warmup
    logging_dir="./logs",
    logging_strategy="steps",
    logging_steps=100,                       # Frequent logging
    eval_steps=100,
    load_best_model_at_end=True,            # Keep best checkpoint
    metric_for_best_model="eval_f1",       # Macro F1 for balanced emotion classes
    greater_is_better=True,                 # Higher F1 is better
    save_total_limit=2,                     # Keep only last 2 checkpoints
    fp16=True,                              # Mixed precision if GPU supports
    gradient_accumulation_steps=2,          # Effective batch size 32 without OOM
    report_to=["tensorboard"],              # For visualization
    seed=42                                 # Reproducibility
)

In [None]:
# Create a custom Trainer to handle float labels
class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.BCEWithLogitsLoss()
        loss = loss_fct(logits, labels.float())
        return (loss, outputs) if return_outputs else loss

In [None]:
# Initialize trainer and start training
trainer = MultilabelTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    compute_metrics=compute_metrics
)

trainer.train()

In [None]:
# Evaluate the model
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

In [None]:
# Predict emotions for custom input
from datasets import load_dataset
dataset = load_dataset("go_emotions")
emotion_labels = dataset["train"].features["labels"].feature.names

def predict_emotions(text):
    model.eval()
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=30)
    inputs = {k: v.to(model.device) for k, v in inputs.items()} # Move inputs to the same device as the model
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.sigmoid(logits)[0].cpu().numpy() # Move probabilities to CPU for numpy conversion
    threshold = 0.5
    predictions = [emotion_labels[i] for i, p in enumerate(probs) if p > threshold]
    return predictions

# Example usage
user_input = input("Enter a comment: ")
predicted_emotions = predict_emotions(user_input)
print("Predicted emotions:", predicted_emotions)

In [None]:
import matplotlib.pyplot as plt

# log_history comes from trainer.state.log_history after training
log_history = trainer.state.log_history

train_loss = []
eval_loss = []
epochs_train = []
epochs_eval = []

for entry in log_history:
    if "loss" in entry and "epoch" in entry:
        train_loss.append(entry["loss"])
        epochs_train.append(entry["epoch"])
    if "eval_loss" in entry and "epoch" in entry:
        eval_loss.append(entry["eval_loss"])
        epochs_eval.append(entry["epoch"])

# Plot the curves
plt.figure(figsize=(10, 6))
plt.plot(epochs_train, train_loss, label="Train Loss", marker='o')
plt.plot(epochs_eval, eval_loss, label="Eval Loss", marker='x')

plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training & Evaluation Loss")
plt.legend()
plt.grid(True)
plt.show()
