In [1]:
import torch
import numpy as np
from torchmetrics.classification import MultilabelAccuracy
from torchmetrics.functional.classification import multilabel_exact_match
from torchmetrics.functional.classification import multilabel_accuracy, multilabel_f1_score
from torchmetrics.functional.classification import multilabel_recall, multilabel_precision
from torchmetrics.functional.classification import multiclass_accuracy, multiclass_f1_score
from torchmetrics.functional.classification import multiclass_recall, multiclass_precision
from torchmetrics.functional.classification import multiclass_auroc, multilabel_auroc
from sklearn.model_selection import KFold
from datasets import Dataset
from atel.data import BookCollection
from data_clean import *
from transformers import AutoTokenizer

In [2]:
problem_type = 'multilabel'
logit_func = torch.nn.Sigmoid()
NUM_LABELS = 4
SEED = 42
NUM_SPLITS = 10

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logit_func(torch.tensor(logits))
    labels = torch.tensor(labels).int()
    
    if problem_type == 'multilabel':
        acc_exact = multilabel_exact_match(preds, labels, num_labels=NUM_LABELS)
        acc_macro = multilabel_accuracy(preds, labels, num_labels=NUM_LABELS)
        
        # How are they calculated?:
        # The metrics are calculated for each label. 
        # So if there is 4 labels, then 4 recalls are calculated.
        # These 4 values are then averaged, which is the end score that is logged.
        # The default average applied is 'macro' 
        # precision_macro = multilabel_precision(preds, labels, num_labels=NUM_LABELS)
        # recall_macro = multilabel_recall(preds, labels, num_labels=NUM_LABELS)
        f1_macro = multilabel_f1_score(preds, labels, num_labels=NUM_LABELS)
        
        # AUROC score of 1 is a perfect score
        # AUROC score of 0.5 corresponds to random guessing.
        auroc_macro = multilabel_auroc(preds, labels, num_labels=NUM_LABELS, average="macro", thresholds=None)
        
        metrics = {
            'accuracy_exact':  acc_exact,
            'accuracy_macro':  acc_macro,
            # 'precision_macro': precision_macro,
            # 'recall_macro':    recall_macro,
            'f1_macro':        f1_macro,
            'AUROC_macro':     auroc_macro
        }
    else:
        acc_micro = multiclass_accuracy(preds, labels, num_classes=NUM_LABELS, average='micro')
        acc_macro = multiclass_accuracy(preds, labels, num_classes=NUM_LABELS, average='macro')
        # precision_macro = multiclass_precision(preds, labels, num_classes=NUM_LABELS)
        # recall_macro = multiclass_recall(preds, labels, num_classes=NUM_LABELS)
        f1_macro = multiclass_f1_score(preds, labels, num_classes=NUM_LABELS)
        auroc_macro = multiclass_auroc(preds, labels, num_classes=NUM_LABELS, average="macro", thresholds=None)
        
        metrics = {
            'accuracy_micro':  acc_micro,
            'accuracy_macro':  acc_macro,
            # 'precision_macro': precision_macro,
            # 'recall_macro':    recall_macro,
            'f1_macro':        f1_macro,
            'AUROC_macro':     auroc_macro
        }
        
    return metrics

In [3]:
TARGET = 'Tekstbånd'
book_col = BookCollection(data_file="./data/book_col_271120.pkl")
df, labels = get_pandas_dataframe(book_col, TARGET)

Loaded from disk: ./data/book_col_271120.pkl


In [4]:
tokenizer = AutoTokenizer.from_pretrained("Maltehb/danish-bert-botxo")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

In [5]:
dataset = Dataset.from_pandas(df)
token_dataset = dataset.map(tokenize_function, batched=True)

kf = KFold(n_splits=NUM_SPLITS, shuffle=True, random_state=SEED)
all_splits = [k for k in kf.split(token_dataset)]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [6]:
train_idx, val_idx = all_splits[0]
train_dataset = token_dataset.select(train_idx)
val_dataset   = token_dataset.select(val_idx)

In [6]:
path = 'huggingface_logs/Tekstbånd/BERT-BS16-BA4-ep100-seed42-WD0.01-LR2e-05/CV_1/Tekstbånd_CV1_best_model_logits.pt'
test1 = torch.load(path)

In [56]:
compute_metrics((test1, val_dataset['labels']))

  preds = logit_func(torch.tensor(logits))


{'accuracy_exact': tensor(0.4231),
 'accuracy_macro': tensor(0.8045),
 'f1_macro': tensor(0.6889),
 'AUROC_macro': tensor(0.8303)}

In [57]:
compute_metrics((test2, val_dataset['labels']))

  preds = logit_func(torch.tensor(logits))


{'accuracy_exact': tensor(0.4231),
 'accuracy_macro': tensor(0.8045),
 'f1_macro': tensor(0.6889),
 'AUROC_macro': tensor(0.8303)}

In [8]:
test2 = torch.Tensor([[ 5.16554260e+00, -3.19429779e+00,  3.20303631e+00,
        -2.51977801e+00],
       [ 2.22733569e+00, -1.60944903e+00, -2.73981905e+00,
        -5.95094109e+00],
       [ 5.68270922e+00, -2.50618291e+00, -3.83019865e-01,
        -2.50338840e+00],
       [ 4.87027931e+00,  5.51717103e-01,  5.10679674e+00,
         4.34238374e-01],
       [ 4.10413647e+00, -3.37481213e+00,  4.63073397e+00,
        -2.90059710e+00],
       [-1.60844517e+00,  4.57576084e+00,  4.24985456e+00,
         3.79182220e+00],
       [ 1.73678637e+00, -3.20532775e+00,  4.23913574e+00,
        -3.21407962e+00],
       [ 2.55214661e-01, -3.08052242e-01,  5.57446432e+00,
        -3.87515616e+00],
       [-1.36731982e+00,  7.30290234e-01,  2.95798612e+00,
        -2.31041527e+00],
       [ 2.83534074e+00,  2.48327827e+00, -2.88640904e+00,
        -4.68701220e+00],
       [-4.85026932e+00, -2.38049936e+00,  2.36961079e+00,
         1.99056661e+00],
       [-3.11256814e+00, -4.56319809e+00, -4.53811693e+00,
        -3.38206840e+00],
       [ 2.13222909e+00, -2.03173184e+00, -4.78000069e+00,
        -5.31168985e+00],
       [ 2.66329437e-01, -5.67291069e+00, -4.31413278e-02,
        -4.33494711e+00],
       [-4.80130005e+00, -4.29580212e+00, -5.09505844e+00,
        -3.47965002e+00],
       [-1.77052450e+00,  2.54218674e+00, -4.66702795e+00,
        -4.11903477e+00],
       [-1.81171334e+00,  1.88290441e+00, -4.56986570e+00,
        -4.82688379e+00],
       [-1.63080490e+00,  4.96254778e+00, -3.20065689e+00,
        -2.90217423e+00],
       [-2.48930979e+00, -3.98357004e-01, -4.04225588e+00,
        -5.60342264e+00],
       [ 4.77345133e+00, -1.12253988e+00, -2.20296407e+00,
        -4.55760527e+00],
       [-4.19607496e+00, -3.73130941e+00, -3.77044272e+00,
        -4.46575117e+00],
       [-2.83136463e+00,  1.59740949e+00, -5.05442905e+00,
        -3.77460551e+00],
       [-3.94084573e+00,  3.81104970e+00, -2.58745241e+00,
        -1.79828942e+00],
       [-1.68688238e+00,  5.24768925e+00, -6.76316619e-01,
        -2.72801399e+00],
       [-2.94148111e+00, -4.06552696e+00, -4.52295446e+00,
        -4.57724571e+00],
       [-6.98032379e-01, -6.07388973e+00, -1.87265766e+00,
        -5.41809607e+00],
       [-3.85732841e+00, -2.74241471e+00, -4.49357462e+00,
        -4.75534678e+00],
       [-1.43901455e+00, -4.18212509e+00, -3.29569268e+00,
        -2.98177361e+00],
       [ 5.84422874e+00,  9.62716222e-01,  2.20568323e+00,
         1.52428284e-01],
       [-2.54855061e+00, -4.19054174e+00, -3.61895919e+00,
        -4.66419744e+00],
       [-4.85175514e+00, -4.87695158e-01,  2.94838935e-01,
        -3.89464355e+00],
       [-3.88984323e+00,  4.29880428e+00, -2.03587532e+00,
         3.40583444e-01],
       [-2.80201197e+00, -3.30086350e+00, -4.22864866e+00,
        -5.15278006e+00],
       [ 5.48697090e+00,  2.51765156e+00,  3.83449435e+00,
        -1.39336526e+00],
       [-2.81599951e+00, -4.28130341e+00, -5.46455765e+00,
        -4.23700762e+00],
       [ 1.84177220e-01, -4.70478201e+00,  1.05635965e+00,
        -6.05461454e+00],
       [ 4.89777613e+00,  2.43476987e+00,  5.18893957e+00,
        -1.18302945e-02],
       [ 2.39241505e+00, -4.08998299e+00,  4.99989653e+00,
        -3.04086637e+00],
       [ 4.21136618e+00,  1.80318415e+00,  5.35461092e+00,
         1.50881791e+00],
       [-4.48399991e-01,  2.65689039e+00,  3.62911463e+00,
         4.99298763e+00],
       [-4.15511894e+00,  7.15964615e-01,  5.55582285e+00,
        -1.40282309e+00],
       [ 1.41760576e+00,  2.28500581e+00, -2.95246077e+00,
        -4.28444052e+00],
       [ 4.84543037e+00, -2.74104863e-01,  4.63908148e+00,
         2.38733602e+00],
       [ 4.03057051e+00, -4.90907478e+00,  6.77943006e-02,
        -3.18245053e+00],
       [-4.74699926e+00, -2.12090731e+00,  3.68814349e+00,
         7.15106070e-01],
       [ 4.45185661e+00, -4.40527081e-01, -7.19632804e-01,
        -4.85729599e+00],
       [-4.13652515e+00,  1.68945098e+00,  1.95994675e+00,
        -3.64911199e+00],
       [-3.36652875e+00, -1.46753180e+00,  4.99944830e+00,
        -3.64281917e+00],
       [-3.32274699e+00,  3.32543707e+00,  4.88996696e+00,
         5.58458984e-01],
       [-4.33163452e+00, -1.09995711e+00,  4.56538486e+00,
         3.43731046e-01],
       [-3.91006303e+00, -2.28928423e+00,  2.81870365e+00,
        -3.53220606e+00],
       [-4.30100632e+00, -1.65304947e+00, -2.92177296e+00,
        -3.96378994e+00],
       [-4.44125938e+00,  3.53695202e+00,  1.99212111e-03,
        -2.96911550e+00],
       [-2.97775483e+00, -5.35744667e+00,  3.16697091e-01,
        -5.06005430e+00],
       [-4.36757517e+00,  2.84198523e+00, -4.42518806e+00,
        -1.57003200e+00],
       [-3.73018217e+00, -4.76512098e+00, -4.51156092e+00,
        -4.60586834e+00],
       [-3.48317814e+00, -4.26795053e+00, -5.03264761e+00,
        -4.25665522e+00],
       [-4.41503286e+00,  7.93811262e-01, -3.87158227e+00,
        -4.64890909e+00],
       [-4.19175625e+00,  3.64704823e+00, -2.57251787e+00,
         8.08576345e-01],
       [-2.66953802e+00,  3.82662249e+00, -2.20182109e+00,
        -4.05804777e+00],
       [-3.96176314e+00,  3.07162213e+00,  4.30738115e+00,
        -8.40573192e-01],
       [-4.15608692e+00, -3.13124633e+00, -1.08846200e+00,
        -5.02937174e+00],
       [-3.83839321e+00,  4.76321030e+00,  2.86117220e+00,
         1.94900250e+00],
       [-5.80294561e+00, -1.27957910e-01,  2.44033027e+00,
         5.11822402e-01],
       [-9.09365341e-02,  8.21714699e-01, -4.70790720e+00,
        -5.34248495e+00],
       [-2.60192442e+00, -3.92364383e+00,  4.89422178e+00,
        -1.85849464e+00],
       [-3.87045693e+00,  7.57515788e-01,  4.66726589e+00,
        -1.27215052e+00],
       [-3.93891549e+00, -3.68174374e-01, -5.35932970e+00,
        -4.87996531e+00],
       [-3.70994210e+00,  3.00657535e+00, -2.37425351e+00,
        -4.43487215e+00],
       [-4.00082970e+00,  3.97570682e+00,  2.20934510e+00,
        -2.24306956e-01],
       [-5.28425741e+00, -2.90008521e+00, -4.83394194e+00,
        -3.26275826e+00],
       [-4.13419294e+00, -2.27572227e+00,  4.59969950e+00,
        -3.99268889e+00],
       [-5.99212456e+00, -1.59950304e+00, -1.84705400e+00,
        -2.98176527e+00],
       [ 4.30202723e+00, -3.54521036e+00,  3.09559155e+00,
        -3.23310804e+00],
       [-2.78096581e+00, -2.37547517e+00,  1.36402965e+00,
        -5.49689198e+00],
       [ 3.30583858e+00, -3.34388161e+00,  4.88955355e+00,
        -2.51294398e+00],
       [-4.19996738e+00, -7.75937021e-01,  7.50196695e-01,
        -4.17683506e+00],
       [-4.70574427e+00,  1.78250635e+00,  5.08841562e+00,
        -1.31903604e-01]])

## Multi-label

In [131]:
preds = torch.tensor(
    [
        [1, 0, 1, 0],
        [1, 1, 0, 0]
    ]
)

target = torch.tensor(
    [
        [1, 0, 1, 0],
        [0, 1, 1, 1]S
    ]
)

# for individual samples
TP, TN = 2, 2
FP, FN = 2, 2

multilabel_exact_match(preds, target, num_labels=4)

tensor(0.5000)

In [132]:
multilabel_accuracy(preds, target, num_labels=4, average='micro')

tensor(0.6250)

In [96]:
r = multilabel_recall(preds, target, num_labels=4)
r, TP/(TP+FN)

(tensor(0.3750), 0.5)

In [95]:
p = multilabel_precision(preds, target, num_labels=4,)
p, TP/(TP+FP)

(tensor(0.3750), 0.5)

In [97]:
f1 = multilabel_f1_score(preds, target, num_labels=4)
f1, 2*(p*r)/(p+r)

(tensor(0.3333), tensor(0.3750))

In [63]:
2*(p*r)/(p+r)

tensor(0.3750)

In [86]:
from sklearn.metrics import f1_score, precision_score, recall_score
f1_score(target.detach(), preds.detach(), average=None)

array([0.66666667, 0.        , 0.66666667, 0.        ])

In [76]:
precision_score(target.detach(), preds.detach(), average='macro')

  _warn_prf(average, modifier, msg_start, len(result))


0.375

In [77]:
recall_score(target.detach(), preds.detach(), average='macro')

  _warn_prf(average, modifier, msg_start, len(result))


0.375

In [126]:
preds = torch.tensor([[0.75, 0.05, 0.35],
                      [0.45, 0.75, 0.05],
                      [0.05, 0.55, 0.75],
                      [0.05, 0.65, 0.05]])
target = torch.tensor([[1, 0, 1],
                       [0, 0, 0],
                       [0, 1, 1],
                       [1, 1, 1]])
multilabel_auroc(preds, target, num_labels=3, average="macro", thresholds=None)

tensor(0.6528)

In [124]:
multilabel_auroc(preds, target, num_labels=3, average=None, thresholds=None).mean()

tensor(0.6528)

## Multi-class

In [111]:
from torchmetrics.functional.classification import multiclass_accuracy, multiclass_f1_score
from torchmetrics.functional.classification import multiclass_recall, multiclass_precision
from torchmetrics.functional.classification import multiclass_accuracy
from torchmetrics.functional import accuracy

In [100]:
target = torch.tensor([2, 1, 0, 0])
preds = torch.tensor([
  [0.16, 0.26, 0.58],
  [0.22, 0.61, 0.17],
  [0.71, 0.09, 0.20],
  [0.05, 0.82, 0.13],
])

In [112]:
accuracy(preds, target)

tensor(0.7500)

In [119]:
multiclass_accuracy(preds, target, num_classes=3, average='none').mean()

tensor(0.8333)

In [110]:
# 2*(p*r)/(p+r)
multiclass_f1_score(preds, target, num_classes=3, average='none')

tensor([0.6667, 0.6667, 1.0000])

In [109]:
# TP/(TP+FP)
multiclass_precision(preds, target, num_classes=3, average='none')

tensor([1.0000, 0.5000, 1.0000])

In [108]:
# TP/(TP+FN)
multiclass_recall(preds, target, num_classes=3, average='none')

tensor([0.5000, 1.0000, 1.0000])