Skip to content

Commit

Permalink
Update metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
lixfz committed Dec 13, 2023
1 parent d4d195b commit 036b0fa
Showing 1 changed file with 13 additions and 2 deletions.
15 changes: 13 additions & 2 deletions hypernets/tabular/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""
"""
import inspect
import math
import os
import pickle
Expand All @@ -18,6 +19,10 @@

_MIN_BATCH_SIZE = 100000

_DEFAULT_RECALL_OPTIONS = {}
if 'zero_division' in inspect.signature(sk_metrics.recall_score).parameters.keys():
_DEFAULT_RECALL_OPTIONS['zero_division'] = 0.0


def _task_to_average(task):
if task == const.TASK_MULTICLASS:
Expand All @@ -37,10 +42,15 @@ def calc_score(y_true, y_preds, y_proba=None, metrics=('accuracy',), task=const.
if len(y_preds.shape) == 2 and y_preds.shape[-1] == 1:
y_preds = y_preds.reshape(-1)

recall_options = _DEFAULT_RECALL_OPTIONS.copy()

if average is None:
average = _task_to_average(task)
recall_options['average'] = average

if classes is not None:
recall_options['labels'] = classes

recall_options = dict(average=average, labels=classes)
if pos_label is not None:
recall_options['pos_label'] = pos_label

Expand Down Expand Up @@ -112,7 +122,8 @@ def metric_to_scoring(metric, task=const.TASK_BINARY, pos_label=None):
raise ValueError(f'Not found matching scoring for {metric}')

if metric_lower in metric2fn.keys():
options = dict(average=_task_to_average(task))
options = _DEFAULT_RECALL_OPTIONS.copy()
options['average'] = _task_to_average(task)
if pos_label is not None:
options['pos_label'] = pos_label
scoring = sk_metrics.make_scorer(metric2fn[metric_lower], **options)
Expand Down

0 comments on commit 036b0fa

Please sign in to comment.