Skip to content

Commit

Permalink
Metric (#37)
Browse files Browse the repository at this point in the history
* switch tensorboardX to tensorboard

* add more config for sgd

* cosine annealing lr scheduler

* handle prob gold label in accuracy

* support accuracy@k metric
  • Loading branch information
senwu authored Nov 23, 2019
1 parent 2f2224f commit 25933d1
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Unreleased_

Added
^^^^^
* `@senwu`_: Support accuracy@k metric.
* `@senwu`_: Support cosine annealing lr scheduler.

0.0.4_ - 2019-11-11
Expand Down
20 changes: 16 additions & 4 deletions src/emmental/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,20 @@
def accuracy_scorer(
golds: ndarray,
probs: Optional[ndarray],
preds: ndarray,
preds: Optional[ndarray],
uids: Optional[List[str]] = None,
normalize: bool = True,
topk: int = 1,
) -> Dict[str, Union[float, int]]:
r"""Accuracy classification score.
Args:
golds(ndarray): Ground truth values.
probs(ndarray or None): Predicted probabilities.
preds(ndarray): Predicted values.
preds(ndarray or None): Predicted values.
uids(list, optional): Unique ids, defaults to None.
normalize(bool, optional): Normalize the results or not, defaults to True.
topk(int, optional): Top K accuracy, defaults to 1.
Returns:
dict: Accuracy, if normalize is True, return the fraction of correctly
Expand All @@ -32,7 +34,17 @@ def accuracy_scorer(
if len(golds.shape) == 2:
golds = prob_to_pred(golds)

if topk == 1 and preds is not None:
n_matches = np.where(golds == preds)[0].shape[0]
else:
topk_preds = probs.argsort(axis=1)[:, -topk:][:, ::-1]
n_matches = np.logical_or.reduce(
topk_preds == golds.reshape(-1, 1), axis=1
).sum()

if normalize:
return {"accuracy": np.where(golds == preds)[0].shape[0] / golds.shape[0]}
return {
"accuracy" if topk == 1 else f"accuracy@{topk}": n_matches / golds.shape[0]
}
else:
return {"accuracy": np.where(golds == preds)[0].shape[0]}
return {"accuracy" if topk == 1 else f"accuracy@{topk}": n_matches}
10 changes: 8 additions & 2 deletions src/emmental/scorer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from functools import partial
from typing import Callable, Dict, List

from numpy import ndarray
Expand Down Expand Up @@ -26,9 +27,14 @@ def __init__(
) -> None:
self.metrics: Dict[str, Callable] = dict()
for metric in metrics:
if metric not in METRICS:
if metric in METRICS:
self.metrics[metric] = METRICS[metric]
elif metric.startswith("accuracy@"):
self.metrics[metric] = partial(
METRICS["accuracy"], topk=int(metric.split("@")[1])
)
else:
raise ValueError(f"Unrecognized metric: {metric}")
self.metrics[metric] = METRICS[metric]

self.metrics.update(customize_metric_funcs)

Expand Down
11 changes: 11 additions & 0 deletions tests/metrics/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,23 @@ def test_accuracy(caplog):
caplog.set_level(logging.INFO)

golds = np.array([0, 1, 0, 1, 0, 1])
probs = np.array(
[[0.9, 0.1], [0.6, 0.4], [1.0, 0.0], [0.8, 0.2], [0.6, 0.4], [0.05, 0.95]]
)
preds = np.array([0, 0, 0, 0, 0, 1])

metric_dict = accuracy_scorer(golds, None, preds)

assert isequal(metric_dict, {"accuracy": 0.6666666666666666})

metric_dict = accuracy_scorer(golds, probs, None)

assert isequal(metric_dict, {"accuracy": 0.6666666666666666})

metric_dict = accuracy_scorer(golds, probs, preds, topk=2)

assert isequal(metric_dict, {"accuracy@2": 1.0})


def test_precision(caplog):
"""Unit test of precision_scorer"""
Expand Down
9 changes: 7 additions & 2 deletions tests/test_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@ def test_scorer(caplog):

golds = np.array([1, 0, 1, 0, 1, 0])
preds = np.array([1, 1, 1, 1, 1, 0])
probs = np.array([0.8, 0.6, 0.9, 0.7, 0.7, 0.2])
probs = np.array(
[[0.2, 0.8], [0.4, 0.6], [0.1, 0.9], [0.3, 0.7], [0.3, 0.7], [0.8, 0.2]]
)

def sum(gold, probs, preds, uids):
return np.sum(preds)

scorer = Scorer(metrics=["accuracy", "f1"], customize_metric_funcs={"sum": sum})
scorer = Scorer(
metrics=["accuracy", "accuracy@2", "f1"], customize_metric_funcs={"sum": sum}
)

assert scorer.score(golds, probs, preds) == {
"accuracy": 0.6666666666666666,
"accuracy@2": 1.0,
"f1": 0.7499999999999999,
"sum": 5,
}

0 comments on commit 25933d1

Please sign in to comment.