Skip to content

Commit

Permalink
MAINT: make ValDataLoaderIter able to be reset only when it is
Browse files Browse the repository at this point in the history
acquired by the syntax of normal `iterator`

In `LRFinder.range_test()`, `val_iter` won't be reset after it runs
out of values, and it makes `LRFinder._validate()` failed to work
correctly after the first iteration of `range_test()`.

To fix it, we add a counter to count the times a `ValDataLoaderIter`
has run (i.e. times of `__next__()` is called). And reset it only
when its `__iter__()` is called. So that it won't be reset
automatically like the way `TrainDataLoaderIter` works.

See also davidtvs#59 and the docstring of `ValDataLoaderIter` for further
details.
  • Loading branch information
NaleRaphael committed Aug 4, 2020
1 parent 2c40a11 commit 34ad3ba
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion torch_lr_finder/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,39 @@ def __next__(self):


class ValDataLoaderIter(DataLoaderIter):
pass
"""This iterator will reset itself **only** when it is acquired by
the syntax of normal `iterator`. That is, this iterator just works
like a `torch.data.DataLoader`. If you want to restart it, you
should use it like:
```
loader_iter = ValDataLoaderIter(data_loader)
for batch in loader_iter:
...
# `loader_iter` should run out of values now, you can restart it by:
# 1. the way we use a `torch.data.DataLoader`
for batch in loader_iter: # __iter__ is called implicitly
...
# 2. passing it into `iter()` manually
loader_iter = iter(loader_iter) # __iter__ is called by `iter()`
```
"""
def __init__(self, data_loader, auto_reset=True):
super().__init__(data_loader)
self.run_limit = len(self.data_loader)
self.run_counter = 0

def __iter__(self):
if self.run_counter >= self.run_limit:
self._iterator = iter(self.data_loader)
self.run_counter = 0
return self

def __next__(self):
self.run_counter += 1
return super(ValDataLoaderIter, self).__next__()


class LRFinder(object):
Expand Down

0 comments on commit 34ad3ba

Please sign in to comment.