Skip to content

Commit

Permalink
TransformedLoader fix
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewilyas committed Mar 31, 2020
1 parent 7c1e88e commit 1274b82
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions robustness/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ def __getattr__(self, attr):
return getattr(self.data_loader, attr)

def TransformedLoader(loader, func, transforms, workers=None,
batch_size=None, do_tqdm=False, augment=False, fraction=1.0):
batch_size=None, do_tqdm=False, augment=False, fraction=1.0,
shuffle=True):
'''
This is a function that allows one to apply any given (fixed)
transformation to the output from the loader *once*.
Expand Down Expand Up @@ -245,4 +246,4 @@ def TransformedLoader(loader, func, transforms, workers=None,

dataset = folder.TensorDataset(ch.cat(new_ims, 0), ch.cat(new_targs, 0), transform=transforms)
return ch.utils.data.DataLoader(dataset, num_workers=workers,
batch_size=batch_size)
batch_size=batch_size, shuffle=shuffle)

0 comments on commit 1274b82

Please sign in to comment.