Skip to content

Commit

Permalink
Addressed comments from francisco
Browse files Browse the repository at this point in the history
  • Loading branch information
ravinkohli committed Feb 18, 2021
1 parent 0e20577 commit 26d5ac0
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 8 deletions.
8 changes: 4 additions & 4 deletions autoPyTorch/pipeline/components/training/losses.py
Expand Up @@ -14,14 +14,14 @@

losses = dict(classification=dict(
CrossEntropyLoss=dict(
module=CrossEntropyLoss, supported_output_types=(MULTICLASS, BINARY)),
module=CrossEntropyLoss, supported_output_types=[MULTICLASS, BINARY]),
BCEWithLogitsLoss=dict(
module=BCEWithLogitsLoss, supported_output_types=(BINARY,))),
module=BCEWithLogitsLoss, supported_output_types=[BINARY])),
regression=dict(
MSELoss=dict(
module=MSELoss, supported_output_types=(CONTINUOUS,)),
module=MSELoss, supported_output_types=[CONTINUOUS]),
L1Loss=dict(
module=L1Loss, supported_output_types=(CONTINUOUS,))))
module=L1Loss, supported_output_types=[CONTINUOUS])))

default_losses = dict(classification=CrossEntropyLoss, regression=MSELoss)

Expand Down
19 changes: 16 additions & 3 deletions test/conftest.py
Expand Up @@ -367,7 +367,12 @@ def loss_cross_entropy_multiclass():
predictions = torch.randn(4, 4, requires_grad=True)
name = 'CrossEntropyLoss'
targets = torch.empty(4, dtype=torch.long).random_(4)
labels = torch.empty(20, dtype=torch.long).random_(4)
# to ensure we have all classes in the labels
while True:
labels = torch.empty(20, dtype=torch.long).random_(4)
if len(torch.unique(labels)) == 4:
break

return dataset_properties, predictions, name, targets, labels


Expand All @@ -377,7 +382,11 @@ def loss_cross_entropy_binary():
predictions = torch.randn(4, 2, requires_grad=True)
name = 'CrossEntropyLoss'
targets = torch.empty(4, dtype=torch.long).random_(2)
labels = torch.empty(20, dtype=torch.long).random_(2)
# to ensure we have all classes in the labels
while True:
labels = torch.empty(20, dtype=torch.long).random_(2)
if len(torch.unique(labels)) == 2:
break
return dataset_properties, predictions, name, targets, labels


Expand All @@ -387,7 +396,11 @@ def loss_bce():
predictions = torch.empty(4).random_(2)
name = 'BCEWithLogitsLoss'
targets = torch.empty(4).random_(2)
labels = torch.empty(20, dtype=torch.long).random_(2)
# to ensure we have all classes in the labels
while True:
labels = torch.empty(20, dtype=torch.long).random_(2)
if len(torch.unique(labels)) == 2:
break
return dataset_properties, predictions, name, targets, labels


Expand Down
16 changes: 15 additions & 1 deletion test/test_pipeline/test_losses.py
Expand Up @@ -2,8 +2,9 @@

import torch
from torch import nn
from torch.nn.modules.loss import _Loss as Loss

from autoPyTorch.pipeline.components.training.losses import get_loss
from autoPyTorch.pipeline.components.training.losses import get_loss, losses
from autoPyTorch.utils.implementations import get_loss_weight_strategy


Expand Down Expand Up @@ -52,3 +53,16 @@ def test_losses(weighted, loss_details):
loss = loss() if weights is None else loss(**kwargs)
score = loss(predictions, targets)
assert isinstance(score, torch.Tensor)
# Ensure it is a one element tensor
assert len(score.size()) == 0


def test_loss_dict():
assert 'classification' in losses.keys()
assert 'regression' in losses.keys()
for task in losses.values():
for loss in task.values():
assert 'module' in loss.keys()
assert isinstance(loss['module'](), Loss)
assert 'supported_output_types' in loss.keys()
assert isinstance(loss['supported_output_types'], list)

0 comments on commit 26d5ac0

Please sign in to comment.