Skip to content

Commit

Permalink
Merge pull request #1646 from vlomonaco/master
Browse files Browse the repository at this point in the history
task_labels experience attribute type fix: now it is a list, not a set
  • Loading branch information
AntonioCarta authored May 23, 2024
2 parents 0d6f715 + 12a6a0f commit 8f0e61f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions avalanche/benchmarks/scenarios/task_aware.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ def with_task_labels(obj):

def _add_task_labels(exp):
tls = exp.dataset.targets_task_labels.uniques
# tls is a set, we need to convert to list to call __getitem__
tls = list(tls)
if len(tls) == 1:
# tls is a set. we need to convert to list to call __getitem__
exp.task_label = list(tls)[0]
exp.task_label = tls[0]
exp.task_labels = tls
return exp

Expand Down
3 changes: 2 additions & 1 deletion tests/benchmarks/scenarios/test_task_aware.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class TestsTaskAware(unittest.TestCase):
def test_taskaware(self):
"""Common use case: add tas labels to class-incremental benchmark."""
"""Common use case: add task labels to class-incremental benchmark."""
n_classes, n_samples_per_class, n_features = 10, 3, 7

for _ in range(10000):
Expand Down Expand Up @@ -58,6 +58,7 @@ def test_taskaware(self):
ci_train = bm_ci.train_stream
for eid, exp in enumerate(bm_ti.train_stream):
assert exp.task_label == eid
assert isinstance(exp.task_labels, list)
assert len(ci_train[eid].dataset) == len(exp.dataset)


Expand Down

0 comments on commit 8f0e61f

Please sign in to comment.