Skip to content

Commit

Permalink
fix docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicki Skafte authored and Borda committed Jun 16, 2020
1 parent 5fbda6e commit 339814f
Showing 1 changed file with 51 additions and 18 deletions.
69 changes: 51 additions & 18 deletions pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,12 +607,24 @@ def __init__(
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = MulticlassROC()
>>> classes_roc = metric(pred, target)
.. testcode::
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
metric = MulticlassROC()
classes_roc = metric(pred, target)
Out:
.. testoutput::
# TODO: fix bug - @nicki skafte
((tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0., 0., 1.]), tensor([0., 1., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])),
(tensor([0.0000, 0.3333, 1.0000]), tensor([0., 0., 1.]), tensor([1.8500, 0.8500, 0.0500])))
"""
super().__init__(name='multiclass_roc',
Expand All @@ -630,7 +642,7 @@ def forward(
Actual metric computation
Args:
pred: predicted labels
pred: predicted probability for each label
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
Expand Down Expand Up @@ -666,12 +678,24 @@ def __init__(
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = MulticlassPrecisionRecall()
>>> classes_pr = metric(pred, target)
.. testcode::
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
metric = MulticlassPrecisionRecall()
classes_pr = metric(pred, target)
Out:
# TODO: fix bug - @nicki skafte
.. testoutput::
((tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([1., 1.]), tensor([1., 0.]), tensor([0.8500])),
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])),
(tensor([0.2500, 0.0000, 1.0000]), tensor([1., 0., 0.]), tensor([0.0500, 0.8500])))
"""
super().__init__(name='multiclass_precision_recall_curve',
Expand All @@ -690,7 +714,7 @@ def forward(
Actual metric computation
Args:
pred: predicted labels
pred: predicted probability for each label
target: groundtruth labels
sample_weight: Weights for each sample defining the sample's impact on the score
Expand Down Expand Up @@ -730,12 +754,21 @@ def __init__(
Example:
>>> pred = torch.tensor([0, 1, 2, 3])
>>> target = torch.tensor([0, 1, 2, 2])
>>> metric = DiceCoefficient()
>>> classes_pr = metric(pred, target)
.. testcode:
pred = torch.tensor([[0.85, 0.05, 0.05, 0.05],
[0.05, 0.85, 0.05, 0.05],
[0.05, 0.05, 0.85, 0.05],
[0.05, 0.05, 0.05, 0.85]])
target = torch.tensor([0, 1, 3, 2])
metric = DiceCoefficient()
classes_pr = metric(pred, target)
# TODO: fix bug - @nicki skafte
Out:
.. testoutput:
tensor(0.3333)
"""
super().__init__(name='dice',
reduce_group=reduce_group,
Expand All @@ -751,7 +784,7 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Actual metric computation
Args:
pred: predicted labels
pred: predicted probability for each label
target: groundtruth labels
Return:
Expand Down

0 comments on commit 339814f

Please sign in to comment.