From ce1ea5de2e8d3d462ae9e777a20c078b6abc67b8 Mon Sep 17 00:00:00 2001 From: ain-soph Date: Wed, 30 Dec 2020 17:18:47 -0500 Subject: [PATCH] fix ZipFolder --- trojanvision/datasets/imagefolder.py | 26 ++++++++++++--------- trojanvision/utils/data.py | 34 +++++++++++++++++++++++++--- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/trojanvision/datasets/imagefolder.py b/trojanvision/datasets/imagefolder.py index 18dd933d..19637177 100644 --- a/trojanvision/datasets/imagefolder.py +++ b/trojanvision/datasets/imagefolder.py @@ -92,19 +92,22 @@ def initialize_folder(self, verbose: bool = True, img_type: str = '.jpg', **kwar except FileNotFoundError: pass - def initialize_zip(self, **kwargs): - print('{yellow}initialize zip{reset}'.format(**ansi)) - mode_list: list[str] = ['train', 'valid'] if self.valid_set else ['train'] + def initialize_zip(self, mode_list: list[str] = ['train', 'valid'], **kwargs): + if not self.valid_set: + mode_list.remove('valid') for mode in mode_list: - src_path = os.path.normpath(os.path.join(self.folder_path, mode)) dst_path = os.path.join(self.folder_path, f'{self.name}_{mode}_store.zip') - with open(zipfile.ZipFile(dst_path, mode='w', compression=zipfile.ZIP_STORED)) as zf: - for root, dirs, files in os.walk(src_path): - _dir = root.removeprefix(os.path.normpath(self.folder_path, '')) - for _file in files: - org_path = os.path.join(root, _file) - zip_path = os.path.join(_dir, _file) - zf.write(org_path, zip_path) + if not os.path.exists(dst_path): + print('{yellow}initialize zip{reset}: '.format(**ansi), dst_path) + src_path = os.path.normpath(os.path.join(self.folder_path, mode)) + with zipfile.ZipFile(dst_path, mode='w', compression=zipfile.ZIP_STORED) as zf: + for root, dirs, files in os.walk(src_path): + _dir = root.removeprefix(os.path.join(self.folder_path, '')) + for _file in files: + org_path = os.path.join(root, _file) + zip_path = os.path.join(_dir, _file) + zf.write(org_path, zip_path) + print('{green}initialize zip finish{reset}'.format(**ansi)) def initialize_npz(self, mode_list: list[str] = ['train', 'valid'], transform: transforms.Lambda = transforms.Lambda(lambda x: np.array(x)), @@ -122,6 +125,7 @@ def initialize_npz(self, mode_list: list[str] = ['train', 'valid'], np.savez(npz_path, data=data, targets=targets) with open(json_path, 'w') as f: json.dump(dataset.class_to_idx, f) + print('{green}initialize npz finish{reset}: '.format(**ansi)) def get_org_dataset(self, mode: str, transform: Union[str, object] = 'default', data_format: str = None, **kwargs) -> Union[datasets.ImageFolder, MemoryDataset]: diff --git a/trojanvision/utils/data.py b/trojanvision/utils/data.py index 92262f4d..50ecebcf 100644 --- a/trojanvision/utils/data.py +++ b/trojanvision/utils/data.py @@ -54,12 +54,35 @@ class ZipFolder(DatasetFolder): def __init__(self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, is_valid_file: Optional[Callable[[str], bool]] = None) -> None: if not root.endswith('.zip'): - raise TypeError("Need to ZIP file for data source: ", self.root) - self.root_zip = ZipLookup(os.path.realpath(self.root)) + raise TypeError("Need to ZIP file for data source: ", root) + self.root_zip = ZipLookup(os.path.realpath(root)) super().__init__(root, self.zip_loader, IMG_EXTENSIONS if is_valid_file is None else None, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file) self.imgs = self.samples + def make_dataset( + self, + directory: str, + class_to_idx: dict[str, int], + extensions: Optional[tuple[str, ...]] = None, + is_valid_file: Optional[Callable[[str], bool]] = None, + ) -> list[tuple[str, int]]: + instances = [] + both_none = extensions is None and is_valid_file is None + both_something = extensions is not None and is_valid_file is not None + if both_none or both_something: + raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") + if extensions is not None: + def is_valid_file(x: str) -> bool: + return has_file_allowed_extension(x, cast(tuple[str, ...], extensions)) + is_valid_file = cast(Callable[[str], bool], is_valid_file) + for filepath in self.root_zip.keys(): + if is_valid_file(filepath): + _, target_class = os.path.split(os.path.dirname(filepath)) + item = filepath, class_to_idx[target_class] + instances.append(item) + return instances + def zip_loader(self, path) -> Any: f = self.root_zip[path] if get_image_backend() == 'accimage': @@ -80,7 +103,12 @@ def _find_classes(self, *args, **kwargs): Ensures: No class is a subdirectory of another. """ - classes = list({path.split('')[-2] for path in self.root_zip.keys() if '/' in path}) + classes = set() + for filepath in self.root_zip.keys(): + root, target_class = os.path.split(os.path.dirname(filepath)) + if root: + classes.add(target_class) + classes = list(classes) classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} return classes, class_to_idx