# Metrics

> tu

Start with a brief description of the technical component, and an overview that links to the main symbols in the page (you might want to use doclinks)

In [None]:
#| default_exp metrics

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [None]:
import warnings
warnings.filterwarnings?

[0;31mSignature:[0m
[0;34m[0m    [0maction[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmessage[0m[0;34m=[0m[0;34m''[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmodule[0m[0;34m=[0m[0;34m''[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlineno[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mappend[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m

'action' -- one of "error", "ignore", "always", "default", "module",
            or "once"
'module' -- a regex that the module name must match
'append' -- if true, append to the list of filters
[0;31mType:[0m      function

In [None]:
#| export
from namable_classify.utils import default_on_exception, ensure_array
import torch.nn.functional as F
import numpy as np
import torch
from sklearn.metrics import roc_auc_score, top_k_accuracy_score, matthews_corrcoef, f1_score, precision_score, recall_score, log_loss, balanced_accuracy_score, cohen_kappa_score, hinge_loss

from namable_classify.utils import MuteWarnings
except_roc_auc_score = default_on_exception(default_value=-1)(roc_auc_score)

def compute_classification_metrics(
    y_true: np.ndarray,  # 1d array-like, or label indicator array / sparse matrix
    y_pred_logits: np.ndarray,  # label indicator array / sparse matrix
    logits_to_prob: bool = False,  # function to convert logits to probabilities
    labels:list[int|str]|None = None,  # list of labels
    supress_warnings: bool = True,  # whether to suppress warnings
):
    if supress_warnings:
        mute = MuteWarnings()
        mute.mute()
    y_true = ensure_array(y_true)
    y_pred_logits = ensure_array(y_pred_logits)
    # print(type(y_pred_logits)) # <class 'numpy.ndarray'>
    # y_pred_probs = softmax(y_pred_logits)# label indicator array / sparse matrix
    y_pred_probs = (
        np.array(F.softmax(torch.Tensor(y_pred_logits), dim=1))
        if logits_to_prob
        else y_pred_logits
    )  # label indicator array / sparse matrix
    y_pred = np.argmax(y_pred_logits, axis=1)
    # target_names = labels # dataset['train'].features[label_column_name].names
    # report_dict = classification_report(y_true, y_pred_probs, target_names=target_names, output_dict=True)
    top_k_res = {
        f"acc{k}": top_k_accuracy_score(y_true, y_pred_probs, k=k, labels=labels)
        for k in [1, 2, 3, 5, 10, 20]
    }
    balance_res = dict(
        # roc_auc=roc_auc_score(
        roc_auc=except_roc_auc_score(
            y_true, y_pred_probs, average="macro", multi_class="ovr", labels=labels
        ),  # ovr更难一些，会不平衡
        matthews_corrcoef=matthews_corrcoef(y_true, y_pred),
        f1=f1_score(y_true, y_pred, average="macro", labels=labels),
        precision=precision_score(y_true, y_pred, average="macro", labels=labels),
        recall=recall_score(y_true, y_pred, average="macro", labels=labels),
        log_loss=log_loss(
            y_true,
            y_pred_probs,
            labels=labels
        ),
        balanced_accuracy=balanced_accuracy_score(y_true, y_pred),
        cohen_kappa=cohen_kappa_score(y_true, y_pred, labels=labels),
        hinge_loss=hinge_loss(y_true, y_pred_probs, labels=labels),
    )
    if supress_warnings:
        mute.resume()
    
    # return top_k_res| balance_res| report_dict
    return top_k_res | balance_res

In [None]:
compute_classification_metrics(torch.randint(0, 20, size=(100,)), 
                               torch.softmax(torch.randn(100, 20), dim=1), 
                               logits_to_prob=False, 
                               labels=list(range(20)), 
                               )

{'acc1': 0.04,
 'acc2': 0.07,
 'acc3': 0.13,
 'acc5': 0.2,
 'acc10': 0.49,
 'acc20': 1.0,
 'roc_auc': 0.4840177779299868,
 'matthews_corrcoef': -0.010926063988074865,
 'f1': 0.030833333333333334,
 'precision': 0.03166666666666666,
 'recall': 0.03297619047619048,
 'log_loss': 3.534164123045798,
 'balanced_accuracy': 0.03297619047619048,
 'cohen_kappa': -0.010845530167421291,
 'hinge_loss': 1.1727566}

In [None]:
[1,2,3].pop(0)
warnings.warn("This is a warning")



In [None]:
#| hide
import nbdev; nbdev.nbdev_export()