### Use subclassing metrics

In [2]:
import numpy as np
from huggingface_utils.metrics.ner import NerMetric

def fake_labels_and_logits(N, T, K):
    labels = np.random.randint(0,K, size = (N,T))
    labels[:,0] = -100
    labels[:,-1] = -100
    labels[-1,-2:] = -100
    labels[-2,-4:] = -100
    logits = np.random.random((N,T,K))
    return labels, logits

label_names = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
labels, logits = fake_labels_and_logits(100, 20, len(label_names))


compute_metrics = NerMetric(label_names)
print(compute_metrics((logits, labels)))
    

{'precision': 0.16445993031358885, 'recall': 0.16043507817811012, 'f1': 0.16242257398485893, 'accuracy': 0.111358574610245}


### Why not function?

In [3]:
import evaluate

# metric and label_names should be defined outside of compute_metrics function
# they are going to be used as global variables.
# And It's bad.
metric = evaluate.load('seqeval')
label_names = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']

def compute_metrics(logits_and_labels):
    logits, labels = logits_and_labels

    # compute integer labels for (N,T,K) data.
    pred_labels = np.argmax(logits, axis = -1)

    # fill -100 on unwanted positions
    # filter out -100 labels on both
    # pred_labels and labels array
    fill_with_null = np.where(labels == -100, -100, pred_labels)
    pred_labels = [[label_names[l] for l in ele if l!= -100] for ele in fill_with_null]
    labels = [[label_names[l] for l in ele if l!= -100] for ele in labels]

    # compute metrics
    result = metric.compute(predictions = pred_labels, references = labels)

    # filter out unwanted information
    return {
        "precision":result['overall_precision'],
        "recall":result['overall_recall'],
        "f1":result['overall_f1'],
        "accuracy":result['overall_accuracy']
    }

labels, logits = fake_labels_and_logits(100, 20, len(label_names))
compute_metrics((logits, labels))

{'precision': 0.18305814788226848,
 'recall': 0.17696044413601666,
 'f1': 0.17995765702187724,
 'accuracy': 0.11525612472160357}