Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored and Borda committed Jun 16, 2020
1 parent 339814f commit 8d9f53a
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 57 deletions.
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
'sphinx.ext.linkcode',
'sphinx.ext.autosummary',
'sphinx.ext.napoleon',
# 'sphinx.ext.imgmath',
'sphinx.ext.imgmath',
'recommonmark',
'sphinx.ext.autosectionlabel',
# 'm2r',
Expand Down
86 changes: 36 additions & 50 deletions pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = Accuracy()
>>> metric(pred, target)
tensor([0.7500])
tensor(0.7500)
"""
super().__init__(name='accuracy',
Expand Down Expand Up @@ -111,14 +111,13 @@ def __init__(
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> pred = torch.tensor([0, 1, 2, 2])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = ConfusionMatrix()
>>> metric(pred, target)
tensor([[1., 0., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 0., 0.]])
tensor([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 2.]])
"""
super().__init__(name='confusion_matrix',
Expand Down Expand Up @@ -163,8 +162,11 @@ def __init__(
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = PrecisionRecall()
>>> metric(pred, target)
(tensor([0.3333, 0.0000, 0.0000, 1.0000]), tensor([1., 0., 0., 0.]), tensor([1., 2., 3.]))
>>> pr, rc, th = metric(pred, target)
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
(tensor([0.3333, 0.0000, 0.0000, 1.0000]),
tensor([1., 0., 0., 0.]),
tensor([1., 2., 3.]))
"""
super().__init__(name='precision_recall_curve',
Expand Down Expand Up @@ -226,7 +228,7 @@ def __init__(
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = PrecisionRecall()
>>> metric(pred, target)
tensor(1.)
(tensor([0.3333, 0.0000, 0.0000, 1.0000]), tensor([1., 0., 0., 0.]), tensor([1., 2., 3.]))
"""
super().__init__(name='precision',
Expand Down Expand Up @@ -548,6 +550,7 @@ def __init__(
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = ROC()
>>> fp, tp, thresholds = metric(pred, target)
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
(tensor([0.0000, 0.3333, 0.6667, 0.6667, 1.0000]),
tensor([0., 0., 0., 1., 1.]),
tensor([4., 3., 2., 1., 0.]))
Expand Down Expand Up @@ -607,25 +610,18 @@ def __init__(
Example:
.. testcode::
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
metric = MulticlassROC()
classes_roc = metric(pred, target)
Out:
.. testoutput::
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = MulticlassROC()
>>> classes_roc = metric(pred, target)
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
"""
super().__init__(name='multiclass_roc',
reduce_group=reduce_group,
Expand Down Expand Up @@ -678,20 +674,14 @@ def __init__(
Example:
.. testcode::
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
metric = MulticlassPrecisionRecall()
classes_pr = metric(pred, target)
Out:
.. testoutput::
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = MulticlassPrecisionRecall()
>>> classes_pr = metric(pred, target)
>>> metric(pred, target) # doctest: +NORMALIZE_WHITESPACE
((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])),
Expand Down Expand Up @@ -756,18 +746,14 @@ def __init__(
.. testcode:
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
metric = DiceCoefficient()
classes_pr = metric(pred, target)
Out:
.. testoutput:
>>> pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
... [0.05, 0.85, 0.05, 0.05],
... [0.05, 0.05, 0.85, 0.05],
... [0.05, 0.05, 0.05, 0.85]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> metric = DiceCoefficient()
>>> classes_pr = metric(pred, target)
>>> metric(pred, target)
tensor(0.3333)
"""
super().__init__(name='dice',
Expand Down
Loading

0 comments on commit 8d9f53a

Please sign in to comment.