# Metrics

> tu

Start with a brief description of the technical component, and an overview that links to the main symbols in the page (you might want to use doclinks)

In [None]:
#| default_exp metrics

In [None]:
#| hide
%load_ext autoreload
%autoreload 2
from nbdev.showdoc import *

In [None]:
import warnings
warnings.filterwarnings?

[0;31mSignature:[0m
[0;34m[0m    [0maction[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmessage[0m[0;34m=[0m[0;34m''[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmodule[0m[0;34m=[0m[0;34m''[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlineno[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mappend[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m

'action' -- one of "error", "ignore", "always", "default", "module",
            or "once"
'module' -- a regex that the module name must match
'append' -- if true, append to the list of filters
[0;31mType:[0m      function

In [2]:
#| export
from namable_classify.utils import default_on_exception, ensure_array
import torch.nn.functional as F
import numpy as np
import torch
from sklearn.metrics import roc_auc_score, top_k_accuracy_score, matthews_corrcoef, f1_score, precision_score, recall_score, log_loss, balanced_accuracy_score, cohen_kappa_score, hinge_loss, accuracy_score

from namable_classify.utils import MuteWarnings
import warnings
except_roc_auc_score = default_on_exception(default_value=-1)(roc_auc_score)

def compute_classification_metrics(
    y_true: np.ndarray,  # 1d array-like, or label indicator array / sparse matrix
    y_pred_logits: np.ndarray = None,  # label indicator array / sparse matrix
    logits_to_prob: bool = False,  # function to convert logits to probabilities
    y_pred: np.ndarray = None,  # predicted labels, if None, will be computed from logits
    labels:list[int|str]|None = None,  # list of labels
    supress_warnings: bool = True,  # whether to suppress warnings
    y_pred_metrics_only: bool = False,  # whether to compute only y_pred related metrics
):
        
    if supress_warnings:
        mute = MuteWarnings()
        mute.mute()
    if y_pred_logits is None:
        assert y_pred_metrics_only == True, "y_pred_logits is None, we can only compute y_pred related metrics! "
        assert y_pred is not None, "y_pred_logits is None, y_pred should be specified! "
        # warnings.warn("y_pred_logits is None, will compute y_pred related metrics only! ")
    y_true = ensure_array(y_true)
    if not y_pred_metrics_only:
        y_pred_logits = ensure_array(y_pred_logits)
        # print(type(y_pred_logits)) # <class 'numpy.ndarray'>
        # y_pred_probs = softmax(y_pred_logits)# label indicator array / sparse matrix
        y_pred_probs = (
            np.array(F.softmax(torch.Tensor(y_pred_logits), dim=1))
            if logits_to_prob
            else y_pred_logits
        )  # label indicator array / sparse matrix
    other_res = {}
    if y_pred is None:
        # 必然有 y_pred_logits
        assert y_pred_logits is not None, "y_pred_logits is None, cannot derive y_pred! "
        y_pred = np.argmax(y_pred_logits, axis=1)
    else:
        # 额外计算一个acc
        if not y_pred_metrics_only:
            warnings.warn("y_pred is specified since it may be different from argmax(y_pred_logits), this may happen to prob SVM. ")
        other_res["acc1_pred"] = accuracy_score(y_true, y_pred)
        
        
    # target_names = labels # dataset['train'].features[label_column_name].names
    # report_dict = classification_report(y_true, y_pred_probs, target_names=target_names, output_dict=True)
    
    if not y_pred_metrics_only:
        top_k_res = {
            f"acc{k}": top_k_accuracy_score(y_true, y_pred_probs, k=k, labels=labels)
            for k in [1, 2, 3, 5, 10, 20]
        }
        prob_res = dict(
            # roc_auc=roc_auc_score(
            roc_auc=except_roc_auc_score(
                y_true, y_pred_probs, average="macro", multi_class="ovr", labels=labels
            ),  # ovr更难一些，会不平衡
            hinge_loss=hinge_loss(y_true, y_pred_probs, labels=labels),
            log_loss=log_loss(
                y_true,
                y_pred_probs,
                labels=labels
                ),
            )
    else: 
        top_k_res = {}
        prob_res = {}

    pred_res = dict(
        matthews_corrcoef=matthews_corrcoef(y_true, y_pred),
        f1=f1_score(y_true, y_pred, average="macro", labels=labels),
        precision=precision_score(y_true, y_pred, average="macro", labels=labels),
        recall=recall_score(y_true, y_pred, average="macro", labels=labels),
        balanced_accuracy=balanced_accuracy_score(y_true, y_pred),
        cohen_kappa=cohen_kappa_score(y_true, y_pred, labels=labels),
    )
    if supress_warnings:
        mute.resume()
    
    # return top_k_res| balance_res| report_dict
    return top_k_res | pred_res | prob_res | other_res

In [3]:
compute_classification_metrics(torch.randint(0, 20, size=(100,)), 
                               torch.softmax(torch.randn(100, 20), dim=1), 
                               logits_to_prob=False, 
                               labels=list(range(20)), 
                               )


[1m{[0m
    [32m'acc1'[0m: [1;36m0.08[0m,
    [32m'acc2'[0m: [1;36m0.15[0m,
    [32m'acc3'[0m: [1;36m0.19[0m,
    [32m'acc5'[0m: [1;36m0.3[0m,
    [32m'acc10'[0m: [1;36m0.49[0m,
    [32m'acc20'[0m: [1;36m1.0[0m,
    [32m'matthews_corrcoef'[0m: [1;36m0.03259827883336652[0m,
    [32m'f1'[0m: [1;36m0.06344988344988345[0m,
    [32m'precision'[0m: [1;36m0.07928571428571429[0m,
    [32m'recall'[0m: [1;36m0.07494588744588744[0m,
    [32m'balanced_accuracy'[0m: [1;36m0.07494588744588744[0m,
    [32m'cohen_kappa'[0m: [1;36m0.032190195665895205[0m,
    [32m'roc_auc'[0m: [1;36m0.4935322122458186[0m,
    [32m'hinge_loss'[0m: [1;36m1.1451513[0m,
    [32m'log_loss'[0m: [1;36m3.3784022382628023[0m
[1m}[0m

In [None]:
[1,2,3].pop(0)
warnings.warn("This is a warning")



In [1]:
#| hide
import nbdev; nbdev.nbdev_export()