Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Offline-ER to work with collate_fn #390

Merged
merged 14 commits into from
Sep 22, 2023
Merged

Refactor Offline-ER to work with collate_fn #390

merged 14 commits into from
Sep 22, 2023

Conversation

wistuba
Copy link
Contributor

@wistuba wistuba commented Aug 23, 2023

Offline-ER applies collate_fn individually on new and memory data. This change will apply the collate function on the entire batch instead.

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@github-actions
Copy link

github-actions bot commented Aug 23, 2023

Coverage report

Note

Coverage evolution disabled because this PR targets a different branch
than the default branch, for which coverage data is not available.

The coverage rate is 85.56%.

89.47% of new lines are covered.

Diff Coverage details (click to unfold)

src/renate/memory/buffer.py

100% of new lines are covered (94.02% of the complete file).

src/renate/utils/pytorch.py

100% of new lines are covered (96.11% of the complete file).

src/renate/updaters/learner.py

100% of new lines are covered (95.97% of the complete file).

src/renate/updaters/experimental/offline_er.py

62.5% of new lines are covered (82.19% of the complete file).
Missing lines: 75, 81, 82, 116, 126, 127

Args:
dataset_lengths: The length for the different datasets.
batch_sizes: Batch sizes used for specific datasets.
complete_dataset_iteration: Provide an index to indicate over which dataset to fully
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly rename?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestions?

else num_batches[self.complete_dataset_iteration]
)

def __iter__(self) -> Iterator[List[int]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add comments about the exact logic?

yield [j for i in samples for j in i]
else:
iterators = [iter(sampler) for sampler in self.subset_samplers]
for s in iterators[self.complete_dataset_iteration]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this optimized? Nested for-loops for each batch seems like a lot.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no nested loop for each batch. it is a single loop over each iterator. in case 1 this is hidden within zip but it also has a loop over each iterator and calls next.

Copy link
Contributor

@prabhuteja12 prabhuteja12 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check this works with distributed training? That uses something like a DistributedSampler which also modifies the data to sample from.

data_start_idx = data_end_idx
self.length = (
min(num_batches)
if complete_dataset_iteration is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not self.complete_dataset_iteration here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -156,3 +156,76 @@ def complementary_indices(num_outputs: int, valid_classes: Set[int]) -> List[int
valid_classes: A set of integers of valid classes.
"""
return [class_idx for class_idx in range(num_outputs) if class_idx not in valid_classes]


class ConcatRandomSampler(BatchSampler):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why inherit from BatchSampler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to Sampler

start_idx = data_start_idx + round(dataset_length / num_replicas * rank)
end_idx = data_start_idx + round(dataset_length / num_replicas * (rank + 1))
subset_sampler = BatchSampler(
SubsetRandomSampler(list(range(start_idx, end_idx)), generator),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why BatchSampler of SubsetRandomSampler?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BatchSampler creates batches, SubsetRandomSampler creates random ints from the provided list (List[int] vs int)



@pytest.mark.parametrize(
"complete_dataset_iteration,expected_batches", [[None, 2], [0, 7], [1, 5], [2, 2]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For None batches is 2 because 20//8 = 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. it is identical to the [2, 2] case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So a drop_last is implicit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. improved doc

@@ -11,6 +11,7 @@
from renate.memory.buffer import ReservoirBuffer
from renate.utils import pytorch
from renate.utils.pytorch import (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to add a DistributedSampler to a test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a unit test for the distributed case instead

@wistuba wistuba merged commit f99cf17 into dev Sep 22, 2023
18 checks passed
@wistuba wistuba deleted the mw-offline-er-fix branch September 22, 2023 09:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants