In [1]:
import os

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from transformers import get_linear_schedule_with_warmup
from torch.optim import AdamW

from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
from transformers.modeling_outputs import SequenceClassifierOutput

from datasets import load_dataset
from sklearn import metrics

In [2]:
go_emotions = load_dataset("go_emotions", "simplified")
labels = go_emotions["test"].features["labels"].feature.names
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [3]:
class GoEmotionDataset(Dataset):
    def __init__(self, data):
        self.data_frame = data.to_pandas()

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        text = self.data_frame.iloc[idx, 0]
        label = self.data_frame.iloc[idx, 1]
        hot_label = np.eye(28)[label].sum(axis=0)  # todo calculate max_length
        encoded_input = tokenizer(text, padding='max_length', truncation=True, max_length=64, return_tensors='pt')
        return {"input_ids": encoded_input['input_ids'].squeeze(0), "attention_mask": encoded_input['attention_mask'].squeeze(0), "labels": hot_label}

class BertForMultilabelSequenceClassification(BertForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None,
                head_mask=None,
                inputs_embeds=None,
                labels=None,
                output_attentions=None,
                output_hidden_states=None,
                return_dict=None):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_state = outputs[0]
        pooled_output = hidden_state[:, 0]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = torch.nn.BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels),
                            labels.float().view(-1, self.num_labels))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )


In [4]:
train_dataset = GoEmotionDataset(go_emotions['train'])
test_dataset = GoEmotionDataset(go_emotions['test'])
eval_dataset = GoEmotionDataset(go_emotions['validation'])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2048, shuffle=True)
eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=2048, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2048, shuffle=False)
# %%
train_losses = []
eval_losses = []
train_accuracies = []
eval_accuracies = []

In [5]:
num_epochs = 1
model = BertForMultilabelSequenceClassification.from_pretrained("bert-base-cased", num_labels=28)
# for param in model.bert.parameters():
#     param.requires_grad = False

# model_path = "best_model.pth"
# if os.path.exists(model_path):
#     model.load_state_dict(torch.load(model_path))
#     print("Loaded the pre-trained model.")

Some weights of BertForMultilabelSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
batch_size = 64
# configure logging so we see training loss
logging_steps = len(train_dataset) // batch_size

args = TrainingArguments(
    output_dir="emotion",
    evaluation_strategy="epoch",
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=logging_steps,
)

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)
# trainer.train(resume_from_checkpoint=True)
trainer.train()

Epoch,Training Loss,Validation Loss
1,0.1373,0.097027
2,0.0919,0.088449
3,0.0831,0.086264


TrainOutput(global_step=2037, training_loss=0.10406537126488234, metrics={'train_runtime': 1389.9808, 'train_samples_per_second': 93.692, 'train_steps_per_second': 1.465, 'total_flos': 4284118958791680.0, 'train_loss': 0.10406537126488234, 'epoch': 3.0})

In [12]:
# model.load_state_dict(torch.load('emotion/checkpoint-/rng_state.pth'))
device = "cuda"
# model = BertForMultilabelSequenceClassification.from_pretrained("emotion/checkpoint-66", num_labels=28)
# model.to(device)
def calc_label_metrics(label, y_targets, y_preds, threshold):
    return {
        "label": label,
        "accuracy": metrics.accuracy_score(y_targets, y_preds),
        "precision": metrics.precision_score(y_targets, y_preds, zero_division=0),
        "recall": metrics.recall_score(y_targets, y_preds, zero_division=0),
        "f1": metrics.f1_score(y_targets, y_preds, zero_division=0),
        "mcc": metrics.matthews_corrcoef(y_targets, y_preds),
        "support": y_targets.sum(),
        "threshold": threshold,
    }


threshold = 0.5
y_probas_all = []
y_targets_all = []
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        target = batch['labels']
        input_ids, attention_mask, target = input_ids.to(device), attention_mask.to(device), target.to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=target)
        y_probas_all.extend(outputs.logits.sigmoid().cpu().numpy())
        y_targets_all.extend(target.cpu().numpy())

y_preds_all = (np.array(y_probas_all) > threshold).astype(int)
y_targets_all = np.array(y_targets_all)

sum_precision = 0
sum_recall = 0
sum_f1 = 0
sum_mcc = 0

results = []
for label_index, label in enumerate(labels):
    y_targets, y_preds = y_targets_all[:, label_index], y_preds_all[:, label_index]
    label_metrics = calc_label_metrics(label, y_targets, y_preds, threshold)
    results.append(label_metrics)

    # Sum up metrics for macro-average
    sum_precision += label_metrics["precision"]
    sum_recall += label_metrics["recall"]
    sum_f1 += label_metrics["f1"]
    sum_mcc += label_metrics["mcc"]

# Calculate macro-average metrics
num_labels = len(labels)
macro_avg_precision = sum_precision / num_labels
macro_avg_recall = sum_recall / num_labels
macro_avg_f1 = sum_f1 / num_labels
macro_avg_mcc = sum_mcc / num_labels

# Append macro-average metrics to results
macro_avg_results = {
    "label": "macro_avg",
    "accuracy": None,  # Macro-average accuracy is not typically used
    "precision": macro_avg_precision,
    "recall": macro_avg_recall,
    "f1": macro_avg_f1,
    "mcc": macro_avg_mcc,
    "support": None,  # Support doesn't make sense for macro-average
    "threshold": threshold
}
results.append(macro_avg_results)

per_label_results = pd.DataFrame(results, index=[label["label"] for label in results])
display(per_label_results.drop(columns=["label"]).round(3))


Unnamed: 0,accuracy,precision,recall,f1,mcc,support,threshold
admiration,0.943,0.699,0.681,0.689,0.658,504.0,0.5
amusement,0.982,0.799,0.83,0.814,0.805,264.0,0.5
anger,0.969,0.623,0.384,0.475,0.474,198.0,0.5
annoyance,0.942,0.549,0.088,0.151,0.203,320.0,0.5
approval,0.943,0.65,0.259,0.371,0.387,351.0,0.5
caring,0.975,0.521,0.185,0.273,0.301,135.0,0.5
confusion,0.975,0.614,0.281,0.386,0.405,153.0,0.5
curiosity,0.951,0.547,0.412,0.47,0.45,284.0,0.5
desire,0.987,0.719,0.277,0.4,0.441,83.0,0.5
disappointment,0.973,0.667,0.053,0.098,0.183,151.0,0.5


In [14]:
import random

threshold = 0.1

# Randomly select a sample from the test dataset
sample_idx = random.randint(0, len(test_loader.dataset) - 1)

batch = test_loader.dataset[sample_idx]
sample_input_ids = batch['input_ids']
sample_attention_mask = batch['attention_mask']
sample_target = batch['labels']
# Move the sample to the same device as your model
sample_input_ids, sample_attention_mask = sample_input_ids.to(device), sample_attention_mask.to(device)

# Get the model's prediction for this sample
with torch.no_grad():
    sample_output = model(sample_input_ids.unsqueeze(0), attention_mask=sample_attention_mask.unsqueeze(0))
    sample_prediction = sample_output.logits.sigmoid().squeeze().cpu().numpy() > threshold
# Convert sample input ids to text if your dataset is text-based
# This conversion depends on the tokenizer you used for your model
sample_text = tokenizer.decode(sample_input_ids, skip_special_tokens=True)

# Print the sample text, true labels, and predicted labels
print("Sample Text:\n", sample_text)
print("\nTrue Labels:", sample_target)
print("Predicted Labels:", sample_prediction.astype(int))

Sample Text:
 Been saying horford has been overrated since he got that fat contract

True Labels: [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0.]
Predicted Labels: [0 0 0 1 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]
