In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [2]:
%%capture
!pip install fsspec==2023.6.0
!pip install git+https://github.com/huggingface/transformers@4.54.0.dev0
!python -m pip install matplotlib==3.10.0

In [3]:
from transformers import AutoTokenizer
from sklearn.preprocessing import MultiLabelBinarizer
from datasets import Dataset, DatasetDict
from matplotlib.lines import Line2D
import pickle, os
import numpy as np
from collections import defaultdict
from transformers import TrainingArguments, Trainer, AutoModelForSequenceClassification
from sklearn.metrics import f1_score, precision_score, recall_score

In [4]:
from transformers import set_seed
set_seed(0)

In [5]:
BASE_DIR = "<Project Folder>"

In [None]:
# https://huggingface.co/blog/modernbert

model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [7]:
hf_dataset = Dataset.from_json(os.path.join(BASE_DIR, "data/dataset.jsonl"))

In [10]:
labels = []
for i in range(len(hf_dataset)):
  labels.extend(hf_dataset[i]["str_label"])

labels = list(set(labels))

In [None]:
mlb = MultiLabelBinarizer()
mlb.fit([labels])

In [12]:
id2label = {idx:label for idx, label in enumerate(mlb.classes_)}
label2id = {label:idx for idx, label in enumerate(mlb.classes_)}

In [8]:
kfold_datasets = []
for i in range(5):
  kfold_datasets.append(DatasetDict.load_from_disk(os.path.join(BASE_DIR, "data/k_fold_ds", f"{i}-fold")))

# ITS Dataset Evaluation

## Operating Point Selection Algorithm for Multi-Label Binary Classifiers

In [13]:
from sklearn.metrics import roc_curve, auc
def sigmoid(x):
    """
    Computes the sigmoid function for a given input x.
    Can handle scalars, arrays, or matrices due to NumPy's capabilities.
    """
    return 1 / (1 + np.exp(-x))

In [14]:
args = TrainingArguments(
    output_dir=None, # This output_dir is where the predictions will be saved if you use predict()
    log_level="error",
    per_device_eval_batch_size=8, # Use an appropriate batch size for inference
    disable_tqdm=False,
    report_to="none",
)

In [None]:
output_folder = os.path.join(BASE_DIR, "output")
kfold_checkpoint_folder = [folder for folder in os.listdir(output_folder) if folder.startswith("kfold_checkpoints")][0]
kfold_checkpoint_full_path = os.path.join(output_folder, kfold_checkpoint_folder)
five_model_checkpoints = [os.path.join(kfold_checkpoint_full_path, fold) for fold in os.listdir(kfold_checkpoint_full_path) if os.path.isdir(os.path.join(kfold_checkpoint_full_path, fold))]
five_model_checkpoints

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
valid_set_outputs = []
for idx in range(5):
  loaded_model = AutoModelForSequenceClassification.from_pretrained(five_model_checkpoints[idx], problem_type="multi_label_classification", num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)
  predictor = Trainer(model=loaded_model, args=args, tokenizer=tokenizer)

  valid_split = kfold_datasets[idx]["valid"]

  # Element of the `valid_set_outputs`contains:
  #   - predictions: numpy array of logits
  #   - label_ids: numpy array of original labels (if available in the dataset)
  #   - metrics: dictionary of evaluation metrics (if compute_metrics was defined and run)
  valid_set_outputs.append(predictor.predict(valid_split))

In [17]:
valid_set_probs = [sigmoid(valid_pred.predictions) for valid_pred in valid_set_outputs]
valid_true_labels = [valid_pred.label_ids for valid_pred in valid_set_outputs]

In [18]:
def get_binary_predictions(probabilities, threshold):
    probabilities = np.array(probabilities)
    binary_predictions = (probabilities >= threshold).astype(int)

    return binary_predictions

**Operating Point Selection Algorithm for Multi-Label Binary Classifiers**

We propose an operating points selection algorithm tailored for binary classifiers within a multi-label classification framework, employing a **5-fold cross-validation strategy**.

For each fold and each label, the algorithm performs the following steps:

1. **ROC Curve Derivation:**

  * The Receiver Operating Characteristic (ROC) curve is derived from the validation set probability scores and corresponding ground-truth labels.

  * This yields False Positive Rates (FPRs), True Positive Rates (TPRs), and associated thresholds.

2. **F1 Score Computation:**

 * Subsequently, the F1 score is computed for each threshold by applying binarization to the prediction scores and evaluating against the true labels.

3. **Optimal Operating Point Identification:**

 * The operating point achieving the **maximum F1 score** is identified.

 * **Tie-breaking Mechanism:**

    * In instances of a unique maximum, it is selected outright.

    * For ties—wherein multiple thresholds yield equivalent maximum F1 scores—the algorithm resolves ambiguity by selecting the operating point corresponding to the **highest TPR** among candidates. This emphasizes sensitivity while preserving an optimal balance between precision and recall.

This methodology yields fold- and label-specific thresholds optimized for validation performance to facilitate subsequent model evaluation processes.

In [None]:
operating_points_for_all_folds = []

for i in range(5):
  operating_points = []
  for j in range(len(labels)):
    f1_scores = []
    scores = valid_set_probs[i][:, j]
    y_true = valid_true_labels[i][:, j]

    fpr, tpr, thresholds = roc_curve(y_true, scores)
    roc_auc = auc(fpr, tpr)

    thresholds = thresholds.tolist()

    for threshold in thresholds:
      y_pred = get_binary_predictions(scores, threshold)
      f1_scores.append(f1_score(y_true, y_pred))

    roc_data = {
        'fpr': fpr,
        'tpr': tpr,
        'thresholds': thresholds,
        'f1_scores': f1_scores
    }

    max_element = max(f1_scores)
    indices = [index for index, value in enumerate(f1_scores) if value == max_element]
    if len(indices) == 1:
      idx = indices[0]
      operating_points.append(thresholds[idx])
    else:
      candidate_thresholds = []
      candidate_tpr = []
      for idx in indices:
        candidate_thresholds.append(thresholds[idx])
        candidate_tpr.append(tpr[idx])

      max_tpr_idx = candidate_tpr.index(max(candidate_tpr))
      operating_points.append(candidate_thresholds[max_tpr_idx])

  operating_points_for_all_folds.append(operating_points)

## Calculate Precision, Recall and F1 score

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
test_set_outputs = []
for idx in range(5):
  loaded_model = AutoModelForSequenceClassification.from_pretrained(five_model_checkpoints[idx], problem_type="multi_label_classification", num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)
  predictor = Trainer(model=loaded_model, args=args, tokenizer=tokenizer)

  test_split = kfold_datasets[idx]["test"]

  # Element of the `test_set_outputs`contains:
  #   - predictions: numpy array of logits
  #   - label_ids: numpy array of original labels (if available in the dataset)
  #   - metrics: dictionary of evaluation metrics (if compute_metrics was defined and run)
  test_set_outputs.append(predictor.predict(test_split))

In [21]:
precisions_micro = []
recalls_micro = []
f1_scores_micro = []

for idx, output in enumerate(test_set_outputs):
  probs = sigmoid(output.predictions)
  y_true = output.label_ids
  y_pred = []
  for j in range(len(labels)):
    y_pred.append(get_binary_predictions(probs[:,j], operating_points_for_all_folds[idx][j]))

  y_pred = np.array(y_pred).T

  overlap = (y_pred & y_true.astype(int)).sum(1)

  precisions_micro.append(precision_score(y_true, y_pred, average="micro", zero_division=0))
  recalls_micro.append(recall_score(y_true, y_pred, average="micro", zero_division=0))
  f1_scores_micro.append(f1_score(y_true, y_pred, average="micro"))


print(f"precision_<micro>: {np.mean(precisions_micro)}")
print(f"recall_<micro>:    {np.mean(recalls_micro)}")
print(f"f1_scores_<micro>:   {np.mean(f1_scores_micro)}")

precision_<micro>: 0.7428576222228833
recall_<micro>:    0.7629318805968361
f1_scores_<micro>:   0.752123596574215
