Skip to content

Commit

Permalink
Formatting and fixes (#187)
Browse files Browse the repository at this point in the history
Signed-off-by: TommyX12 <tommyx058@gmail.com>
  • Loading branch information
TommyX12 committed Mar 27, 2020
1 parent fa17f79 commit 6cff805
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 16 deletions.
43 changes: 31 additions & 12 deletions kaolin/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@

from kaolin import helpers


def _preprocess_task(args):
torch.set_num_threads(1)
with torch.no_grad():
idx, get_data, get_attributes, cache_transform = args
name = get_attributes(idx)['name']
if name not in cache_transform.cached_ids:
data = get_data(idx)
cache_transform(name, *data)
cache_transform(name, data)


class KaolinDatasetMeta(type):
def __new__(metacls, cls_name, base_cls, class_dict):
Expand All @@ -45,12 +47,13 @@ def __new__(metacls, cls_name, base_cls, class_dict):
no_progress (bool): disable tqdm progress bar for preprocessing."""
return type.__new__(metacls, cls_name, base_cls, class_dict)


class KaolinDataset(Dataset, metaclass=KaolinDatasetMeta):
"""
Abstract class for dataset with handling of multiprocess or cuda preprocessing.
A KaolinDataset children class will need the above implementation:
1) _initialize:
1) initialize:
Initialization function called at the beginning of the constructor.
2) _get_data:
Data getter that will be preprocessed => cached => transformed, take an index as input.
Expand All @@ -59,6 +62,7 @@ class KaolinDataset(Dataset, metaclass=KaolinDatasetMeta):
4) __len__:
Return the size of the dataset
"""

def __init__(self, *args, preprocessing_transform=None, preprocessing_params: dict = None,
transform=None, no_progress: bool = False, **kwargs):
"""
Expand All @@ -75,42 +79,55 @@ def __init__(self, *args, preprocessing_transform=None, preprocessing_params: di
"""
self.initialize(*args, **kwargs)
if preprocessing_transform is not None:
desc = 'applying preprocessing'
desc = 'Applying preprocessing'
if preprocessing_params is None:
preprocessing_params = {}
assert preprocessing_params.get('cache_dir') is not None

cache_dir = preprocessing_params.get('cache_dir')
assert cache_dir is not None, 'Cache directory is not given'

self.cache_convert = helpers.Cache(
preprocessing_transform, preprocessing_params['cache_dir'],
cache_key=helpers._get_hash(repr(preprocessing_transform)))
if preprocessing_params.get('use_cuda') is None:
preprocessing_params['use_cuda'] = False
preprocessing_transform,
cache_dir=cache_dir,
cache_key=helpers._get_hash(repr(preprocessing_transform))
)

use_cuda = preprocessing_params.get('use_cuda', False)

num_workers = preprocessing_params.get('num_workers')

if num_workers == 0:
with torch.no_grad():
for idx in tqdm(range(len(self)), desc=desc, disable=no_progress):
name = self._get_attributes(idx)['name']
if name not in self.cache_convert.cached_ids:
data = self._get_data(idx)
self.cache_convert(name, *data)
self.cache_convert(name, data)

else:
p = Pool(num_workers)
iterator = p.imap_unordered(
_preprocess_task,
[(idx, self._get_data, self._get_attributes, self.cache_convert)
for idx in range(len(self))])

for i in tqdm(range(len(self)), desc=desc, disable=no_progress):
next(iterator)

else:
self.cache_convert = None

self.transform = transform

def __getitem__(self, index):
"""Returns the item at index idx. """
attributes = self._get_attributes(index)
data = (self.cache_convert(attributes['name']) if self.cache_convert is not None else
self._get_data(index))
data = (self._get_data(index) if self.cache_convert is None else
self.cache_convert(attributes['name']))

if self.transform is not None:
data = self.transform(data)

return {'data': data, 'attributes': attributes}

@abstractmethod
Expand All @@ -129,6 +146,7 @@ def _get_data(self, index):
def __len__(self):
pass


class CombinationDataset(KaolinDataset):
"""Dataset combining a list of datasets into a unified dataset object.
Useful when multiple output representations are needed from a common base representation
Expand All @@ -139,7 +157,8 @@ class CombinationDataset(KaolinDataset):
Args:
datasets: list or tuple of KaolinDataset
"""
def _initialize(self, datasets):

def initialize(self, datasets):
self.len = len(datasets[0])
for i, d in enumerate(datasets):
assert len(d) == self.len, \
Expand Down
12 changes: 8 additions & 4 deletions kaolin/datasets/shapenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class ShapeNet_Meshes(data.Dataset):
>>> obj['data']['faces'].shape
torch.Size([1910, 3])
"""

def __init__(self, root: str, categories: list = ['chair'], train: bool = True,
split: float = .7, no_progress: bool = False):
self.root = Path(root)
Expand Down Expand Up @@ -227,7 +228,7 @@ def _get_data(self, index):
synset_idx = self.synset_idxs[index]
obj_location = self.paths[index] / 'model.obj'
mesh = TriangleMesh.from_obj(str(obj_location))
return (mesh,)
return mesh

def _get_attributes(self, index):
synset_idx = self.synset_idxs[index]
Expand All @@ -239,6 +240,7 @@ def _get_attributes(self, index):
}
return attributes


class ShapeNet_Images(data.Dataset):
r"""ShapeNet Dataset class for images.
Expand Down Expand Up @@ -278,8 +280,8 @@ class ShapeNet_Images(data.Dataset):
torch.Size([10, 4, 137, 137])
"""

def __init__(self, root: str, categories: list=['chair'], train: bool=True,
split: float=.7, views: int=24, transform=None):
def __init__(self, root: str, categories: list = ['chair'], train: bool = True,
split: float = .7, views: int = 24, transform=None):
self.root = Path(root)
self.synsets = _convert_categories(categories)
self.labels = [synset_to_label[s] for s in self.synsets]
Expand Down Expand Up @@ -379,6 +381,7 @@ class ShapeNet_Voxels(data.Dataset):
torch.Size([10, 128, 128, 128])
"""

def __init__(self, root: str, cache_dir: str, categories: list = ['chair'], train: bool = True,
split: float = .7, resolutions=[128, 32], no_progress: bool = False):
self.root = Path(root)
Expand Down Expand Up @@ -504,7 +507,7 @@ def __init__(self, root: str, cache_dir: str, categories: list = ['chair'], trai

def convert(og_mesh, voxel):
transforms = tfs.Compose([mesh_conversion,
tfs.MeshLaplacianSmoothing(smoothing_iterations)])
tfs.MeshLaplacianSmoothing(smoothing_iterations)])

new_mesh = transforms(voxel)
new_mesh.vertices = pcfunc.realign(new_mesh.vertices, og_mesh.vertices)
Expand Down Expand Up @@ -797,6 +800,7 @@ class ShapeNet_Tags(data.Dataset):
torch.Size([10, N])
"""

def __init__(self, dataset, tag_aug=True):
self.root = dataset.root
self.paths = dataset.paths
Expand Down

0 comments on commit 6cff805

Please sign in to comment.