diff --git a/pywick/datasets/MultiFolderDataset.py b/pywick/datasets/MultiFolderDataset.py index f6ada4f..34578ab 100644 --- a/pywick/datasets/MultiFolderDataset.py +++ b/pywick/datasets/MultiFolderDataset.py @@ -68,7 +68,9 @@ class MultiFolderDataset(FolderDataset): For semantic segmentation this is required so the default is a binary mask. However, if you want to turn off this feature then specify target_index_map=None """ -def __init__(self, + + + def __init__(self, roots, class_mode='label', class_to_idx=None, @@ -86,45 +88,45 @@ def __init__(self, exclusion_file=None, target_index_map=None): - # call the super constructor first, then set our own parameters - # super().__init__() - self.num_inputs = 1 # these are hardcoded for the fit module to work - self.num_targets = 1 # these are hardcoded for the fit module to work - - if default_loader == 'npy': - default_loader = npy_loader - elif default_loader == 'pil': - default_loader = pil_loader - self.default_loader = default_loader - - # separate loading for targets (e.g. for black/white masks) - self.target_loader = target_loader - - if class_to_idx: - self.classes = class_to_idx.keys() - self.class_to_idx = class_to_idx - else: - self.classes, self.class_to_idx = _find_classes(roots) - - data_list = list() - for root in roots: - datai, _ = _finds_inputs_and_targets(root, class_mode=class_mode, class_to_idx=self.class_to_idx, input_regex=input_regex, - rel_target_root=rel_target_root, target_prefix=target_prefix, target_postfix=target_postfix, - target_extension=target_extension, exclusion_file=exclusion_file) - data_list.append(datai) - - self.data = list(itertools.chain.from_iterable(data_list)) - - if len(self.data) == 0: - raise (RuntimeError('Found 0 data items in subfolders of: {}'.format(roots))) - else: - print('Found %i data items' % len(self.data)) - - self.roots = [os.path.expanduser(x) for x in roots] - self.transform = transform - self.target_transform = target_transform - self.co_transform = co_transform - self.apply_co_transform_first = apply_co_transform_first - self.target_index_map = target_index_map - - self.class_mode = class_mode + # call the super constructor first, then set our own parameters + # super().__init__() + self.num_inputs = 1 # these are hardcoded for the fit module to work + self.num_targets = 1 # these are hardcoded for the fit module to work + + if default_loader == 'npy': + default_loader = npy_loader + elif default_loader == 'pil': + default_loader = pil_loader + self.default_loader = default_loader + + # separate loading for targets (e.g. for black/white masks) + self.target_loader = target_loader + + if class_to_idx: + self.classes = class_to_idx.keys() + self.class_to_idx = class_to_idx + else: + self.classes, self.class_to_idx = _find_classes(roots) + + data_list = list() + for root in roots: + datai, _ = _finds_inputs_and_targets(root, class_mode=class_mode, class_to_idx=self.class_to_idx, input_regex=input_regex, + rel_target_root=rel_target_root, target_prefix=target_prefix, target_postfix=target_postfix, + target_extension=target_extension, exclusion_file=exclusion_file) + data_list.append(datai) + + self.data = list(itertools.chain.from_iterable(data_list)) + + if len(self.data) == 0: + raise (RuntimeError('Found 0 data items in subfolders of: {}'.format(roots))) + else: + print('Found %i data items' % len(self.data)) + + self.roots = [os.path.expanduser(x) for x in roots] + self.transform = transform + self.target_transform = target_transform + self.co_transform = co_transform + self.apply_co_transform_first = apply_co_transform_first + self.target_index_map = target_index_map + + self.class_mode = class_mode