Skip to content

Commit

Permalink
fix ZipFolder
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph committed Dec 30, 2020
1 parent 89b7653 commit ce1ea5d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 14 deletions.
26 changes: 15 additions & 11 deletions trojanvision/datasets/imagefolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,19 +92,22 @@ def initialize_folder(self, verbose: bool = True, img_type: str = '.jpg', **kwar
except FileNotFoundError:
pass

def initialize_zip(self, **kwargs):
print('{yellow}initialize zip{reset}'.format(**ansi))
mode_list: list[str] = ['train', 'valid'] if self.valid_set else ['train']
def initialize_zip(self, mode_list: list[str] = ['train', 'valid'], **kwargs):
if not self.valid_set:
mode_list.remove('valid')
for mode in mode_list:
src_path = os.path.normpath(os.path.join(self.folder_path, mode))
dst_path = os.path.join(self.folder_path, f'{self.name}_{mode}_store.zip')
with open(zipfile.ZipFile(dst_path, mode='w', compression=zipfile.ZIP_STORED)) as zf:
for root, dirs, files in os.walk(src_path):
_dir = root.removeprefix(os.path.normpath(self.folder_path, ''))
for _file in files:
org_path = os.path.join(root, _file)
zip_path = os.path.join(_dir, _file)
zf.write(org_path, zip_path)
if not os.path.exists(dst_path):
print('{yellow}initialize zip{reset}: '.format(**ansi), dst_path)
src_path = os.path.normpath(os.path.join(self.folder_path, mode))
with zipfile.ZipFile(dst_path, mode='w', compression=zipfile.ZIP_STORED) as zf:
for root, dirs, files in os.walk(src_path):
_dir = root.removeprefix(os.path.join(self.folder_path, ''))
for _file in files:
org_path = os.path.join(root, _file)
zip_path = os.path.join(_dir, _file)
zf.write(org_path, zip_path)
print('{green}initialize zip finish{reset}'.format(**ansi))

def initialize_npz(self, mode_list: list[str] = ['train', 'valid'],
transform: transforms.Lambda = transforms.Lambda(lambda x: np.array(x)),
Expand All @@ -122,6 +125,7 @@ def initialize_npz(self, mode_list: list[str] = ['train', 'valid'],
np.savez(npz_path, data=data, targets=targets)
with open(json_path, 'w') as f:
json.dump(dataset.class_to_idx, f)
print('{green}initialize npz finish{reset}: '.format(**ansi))

def get_org_dataset(self, mode: str, transform: Union[str, object] = 'default',
data_format: str = None, **kwargs) -> Union[datasets.ImageFolder, MemoryDataset]:
Expand Down
34 changes: 31 additions & 3 deletions trojanvision/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,35 @@ class ZipFolder(DatasetFolder):
def __init__(self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None) -> None:
if not root.endswith('.zip'):
raise TypeError("Need to ZIP file for data source: ", self.root)
self.root_zip = ZipLookup(os.path.realpath(self.root))
raise TypeError("Need to ZIP file for data source: ", root)
self.root_zip = ZipLookup(os.path.realpath(root))
super().__init__(root, self.zip_loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform, target_transform=target_transform, is_valid_file=is_valid_file)
self.imgs = self.samples

def make_dataset(
self,
directory: str,
class_to_idx: dict[str, int],
extensions: Optional[tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> list[tuple[str, int]]:
instances = []
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(tuple[str, ...], extensions))
is_valid_file = cast(Callable[[str], bool], is_valid_file)
for filepath in self.root_zip.keys():
if is_valid_file(filepath):
_, target_class = os.path.split(os.path.dirname(filepath))
item = filepath, class_to_idx[target_class]
instances.append(item)
return instances

def zip_loader(self, path) -> Any:
f = self.root_zip[path]
if get_image_backend() == 'accimage':
Expand All @@ -80,7 +103,12 @@ def _find_classes(self, *args, **kwargs):
Ensures:
No class is a subdirectory of another.
"""
classes = list({path.split('')[-2] for path in self.root_zip.keys() if '/' in path})
classes = set()
for filepath in self.root_zip.keys():
root, target_class = os.path.split(os.path.dirname(filepath))
if root:
classes.add(target_class)
classes = list(classes)
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
return classes, class_to_idx
Expand Down

0 comments on commit ce1ea5d

Please sign in to comment.