Skip to content

Commit

Permalink
TransformedLoader fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewilyas committed Mar 31, 2020
1 parent 7c1e88e commit 6347646
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
3 changes: 3 additions & 0 deletions docs/example_usage/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
CHANGELOG
=========

robustness 1.1.post2
- Critical fix in :meth:`robustness.loaders.TransformedLoader`, allow for data shuffling

robustness 1.1
''''''''''''''
- Added ability to superclass ImageNet to make
Expand Down
7 changes: 5 additions & 2 deletions robustness/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def __getattr__(self, attr):
return getattr(self.data_loader, attr)

def TransformedLoader(loader, func, transforms, workers=None,
batch_size=None, do_tqdm=False, augment=False, fraction=1.0):
batch_size=None, do_tqdm=False, augment=False, fraction=1.0,
shuffle=True):
'''
This is a function that allows one to apply any given (fixed)
transformation to the output from the loader *once*.
Expand Down Expand Up @@ -214,6 +215,8 @@ def TransformedLoader(loader, func, transforms, workers=None,
fraction (float): fraction of image-label pairs in the output loader
which are transformed. The remainder is just original image-label
pairs from loader.
shuffle (bool) : whether or not the resulting loader should shuffle every
epoch (defaults to True)
Returns:
A loader and validation loader according to the
Expand Down Expand Up @@ -245,4 +248,4 @@ def TransformedLoader(loader, func, transforms, workers=None,

dataset = folder.TensorDataset(ch.cat(new_ims, 0), ch.cat(new_targs, 0), transform=transforms)
return ch.utils.data.DataLoader(dataset, num_workers=workers,
batch_size=batch_size)
batch_size=batch_size, shuffle=shuffle)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Versions should comply with PEP440. For a discussion on single-sourcing
# the version across setup.py and the project code, see
# https://packaging.python.org/en/latest/single_source_version.html
version='1.1.post1',
version='1.1.post2',

description='Tools for Robustness',
long_description=long_description,
Expand Down

0 comments on commit 6347646

Please sign in to comment.