New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add multi-label metric cladd and func #269
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Amir, looks good.
Please see my comments inline.
Please also list it here: https://github.com/BiomedSciAI/fuse-med-ml/tree/master/fuse/eval#implemented-low-level-metrics
And please add example and a test: here (https://github.com/BiomedSciAI/fuse-med-ml/blob/master/fuse/eval/examples/examples.py) and here (https://github.com/BiomedSciAI/fuse-med-ml/blob/master/fuse/eval/tests/test_eval.py)
y_score=y_score[:, i], y_true=y_true[:, i], sample_weight=sample_weight, max_fpr=max_fpr | ||
) | ||
all_auc.append(auc) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing return of average of all_auc
n_class = y_score.shape[1] | ||
all_auc = [] | ||
for i in range(n_class): | ||
auc = metrics.roc_auc_score( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be careful here and ignore labels that don't have both positive class and negative class
@@ -61,6 +61,44 @@ def auc_roc( | |||
y_score=y_score, y_true=np.asarray(target) == pos_class_index, sample_weight=sample_weight, max_fpr=max_fpr | |||
) | |||
|
|||
@staticmethod | |||
def auc_roc_mult_label( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multi_binary_label
Let's add the work binary next to multi_label to be more precise.
) -> float: | ||
""" | ||
Compute multi label auc roc (Receiver operating characteristic) score using sklearn | ||
:param pred: prediction array per sample. Each element shape [num_classes] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets change num_classes -> num_labels
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is still the number of classes. (num labels is the number of 1's In the vector)
:param pred: prediction array per sample. Each element shape [num_classes] | ||
:param target: target per sample. Each element shape [num_classes] with 0 and 1 only. | ||
:param sample_weight: Optional - weight per sample for a weighted auc. Each element is float in range [0-1] | ||
:param pos_class_index: the class to compute the metrics in one vs rest manner - set to 1 in binary classification |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not relevant here
@@ -145,6 +145,28 @@ def __init__( | |||
super().__init__(pred, target, metric_func=auc_roc, class_names=class_names, **kwargs) | |||
|
|||
|
|||
class MetricAUCROCmultLabel(MetricMultiClassDefault): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MetricDefault instead of MetricMultiClassDefault
@@ -145,6 +145,28 @@ def __init__( | |||
super().__init__(pred, target, metric_func=auc_roc, class_names=class_names, **kwargs) | |||
|
|||
|
|||
class MetricAUCROCmultLabel(MetricMultiClassDefault): | |||
""" | |||
Compute auc roc (Receiver operating characteristic) score using sklearn (one vs rest) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update the comment
**kwargs, | ||
): | ||
""" | ||
See MetricMultiClassDefault for the missing params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets specify here explicitly what is the expected format of pred and target.
self, | ||
pred: str, | ||
target: str, | ||
class_names: Optional[Sequence[str]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are not using class_names, right?
) | ||
all_auc.append(auc) | ||
|
||
elif average == "micro": # micro average |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a link that suggest this method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is better in case there is no positive labels in one column.
here the ref -
In a multi-class classification setup with highly imbalanced classes, micro-averaging is preferable over macro-averaging. In such cases, one can alternatively use a weighted macro-averaging, not demoed here.
https://scikit-learn.org/stable/auto_examples/model_selection/plot_roc.html
If not ``None``, the standardized partial AUC over the range [0, max_fpr] is returned. | ||
""" | ||
auc_roc = partial(MetricsLibClass.auc_roc_mult_binary_label, average=average, max_fpr=max_fpr) | ||
super().__init__(pred, target, metric_func=auc_roc, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix your problem, you specify the name of the arguments explicitly cause the order is different.
i.e. (pred=pred, target=target, ...)
Please also change the name to MetricAUCROCMultLabel (capital M)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
No description provided.