Skip to content

Commit

Permalink
Merge pull request #309 from Dana-Farber-AIOS/dev-pannuke-dataloader-…
Browse files Browse the repository at this point in the history
…augs

don't augment test or valid splits for PanNuke
  • Loading branch information
jacob-rosenthal committed Apr 21, 2022
2 parents eb122b6 + 716cd55 commit 4be71f4
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions pathml/datasets/pannuke.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,15 @@ def __init__(
self.batch_size = batch_size
self.hovernet_preprocess = hovernet_preprocess

def _get_dataset(self, fold_ix):
def _get_dataset(self, fold_ix, augment=True):
if augment:
transforms = self.transforms
else:
transforms = None
return PanNukeDataset(
data_dir=self.data_dir,
fold_ix=fold_ix,
transforms=self.transforms,
transforms=transforms,
nucleus_type_labels=self.nucleus_type_labels,
hovernet_preprocess=self.hovernet_preprocess,
)
Expand Down Expand Up @@ -363,7 +367,7 @@ def train_dataloader(self):
Yields (image, mask, tissue_type), or (image, mask, hv, tissue_type) for HoVer-Net
"""
return data.DataLoader(
dataset=self._get_dataset(fold_ix=self.split),
dataset=self._get_dataset(fold_ix=self.split, augment=True),
batch_size=self.batch_size,
shuffle=self.shuffle,
pin_memory=True,
Expand All @@ -380,7 +384,7 @@ def valid_dataloader(self):
else:
fold_ix = 1
return data.DataLoader(
self._get_dataset(fold_ix=fold_ix),
self._get_dataset(fold_ix=fold_ix, augment=False),
batch_size=self.batch_size,
shuffle=self.shuffle,
pin_memory=True,
Expand All @@ -397,7 +401,7 @@ def test_dataloader(self):
else:
fold_ix = 1
return data.DataLoader(
self._get_dataset(fold_ix=fold_ix),
self._get_dataset(fold_ix=fold_ix, augment=False),
batch_size=self.batch_size,
shuffle=self.shuffle,
pin_memory=True,
Expand Down

0 comments on commit 4be71f4

Please sign in to comment.