Skip to content

Commit

Permalink
Refactor SHREC16 to inherit from KaolinDataset (#190)
Browse files Browse the repository at this point in the history
* Refactor SHREC16 to inherit from KaolinDataset

Signed-off-by: Krishna Murthy <krrish94@gmail.com>

* Fix docstrings, default behavior

Signed-off-by: Krishna Murthy <krrish94@gmail.com>

* Address review comments

Signed-off-by: Krishna Murthy <krrish94@gmail.com>

* Add smoketest

Signed-off-by: Krishna Murthy <krrish94@gmail.com>

* Update SHREC16_ROOT

Signed-off-by: Krishna Murthy <krrish94@gmail.com>

* Update arguments

Signed-off-by: Krishna Murthy <krrish94@gmail.com>
  • Loading branch information
krrish94 committed Apr 19, 2020
1 parent e073a39 commit 4063413
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 71 deletions.
1 change: 0 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ exclude = .git, tests/, build/,
kaolin/cuda,
kaolin/datasets/scannet.py,
kaolin/datasets/shapenet.py,
kaolin/datasets/shrec.py,
kaolin/datasets/usdfile.py,
kaolin/engine,
kaolin/graphics/dib_renderer,
Expand Down
164 changes: 108 additions & 56 deletions kaolin/datasets/shrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from typing import Iterable

import os
import glob
from torch.utils.data import Dataset


from kaolin.rep import TriangleMesh
from ..rep import TriangleMesh

from .base import KaolinDataset

class SHREC16(Dataset):
r"""Class to help in loading the SHREC16 dataset.

SHREC16 is the dataset used for the "Large-scale 3D shape retrieval
class SHREC16(KaolinDataset):
r"""Dataset class for SHREC16, used for the "Large-scale 3D shape retrieval
from ShapeNet Core55" contest at Eurographics 2016.
More details about the challenge and the dataset are available
Expand All @@ -34,71 +33,124 @@ class SHREC16(Dataset):
root (str): Path to the root directory of the dataset.
categories (list): List of categories to load (each class is
specified as a string, and must be a valid `SHREC16`
category).
mode (str, choices=['train', 'test']): Whether to load the
'train' split or the 'test' split
category). If this argument is not specified, all categories
are loaded by default.
train (bool): If True, return the train split, else return the test
split (default: True).
Returns:
dict: Dictionary with keys: 'vertices' : vertices , 'faces' : faces
.. code-block::
dict: {
attributes: {path: str, category: str, label: int},
data: kaolin.rep.TriangleMesh
}
path: The filepath to the .obj file on disk.
category: A human-readable string describing the loaded sample.
label: An integer (in the range :math:`[0, \text{len(categories)}]`)
and can be used for training classifiers for example.
vertices: Vertices of the loaded mesh (:math:`(*, 3)`), where :math:`*`
indicates a positive integer.
faces: Faces of the loaded mesh (:math:`(*, 3)`), where :math:`*`
indicates a positive integer.
Example:
>>> dataset = SHREC16(root='/path/to/SHREC16/', categories=['alien', 'ants'], train=False)
>>> sample = dataset[0]
>>> sample["attributes"]["path"]
/path/to/SHREC16/alien/test/T411.obj
>>> sample["attributes"]["category"]
alien
>>> sample["attributes"]["label"]
0
>>> sample["data"].vertices.shape
torch.Size([252, 3])
>>> sample["data"].faces.shape
torch.Size([500, 3])
"""

def __init__(self, root: str, categories: list = ['alien'],
mode: list = 'train'):

super(SHREC16, self).__init__()

if mode not in ['train', 'test']:
raise ValueError('Argument \'mode\' must be one of \'train\''
'or \'test\'. Got {0} instead.'.format(mode))

VALID_CATEGORIES = [
'alien', 'ants', 'armadillo', 'bird1', 'bird2', 'camel',
'cat', 'centaur', 'dinosaur', 'dino_ske', 'dog1', 'dog2',
'flamingo', 'glasses', 'gorilla', 'hand', 'horse', 'lamp',
'laptop', 'man', 'myScissor', 'octopus', 'pliers', 'rabbit',
'santa', 'shark', 'snake', 'spiders', 'two_balls', 'woman'
]

_VALID_CATEGORIES = [
"alien",
"ants",
"armadillo",
"bird1",
"bird2",
"camel",
"cat",
"centaur",
"dinosaur",
"dino_ske",
"dog1",
"dog2",
"flamingo",
"glasses",
"gorilla",
"hand",
"horse",
"lamp",
"laptop",
"man",
"myScissor",
"octopus",
"pliers",
"rabbit",
"santa",
"shark",
"snake",
"spiders",
"two_balls",
"woman",
]

def initialize(
self,
root: str,
categories: Iterable = None,
train: bool = True,
):

if not categories:
categories = SHREC16._VALID_CATEGORIES
for category in categories:
if category not in VALID_CATEGORIES:
raise ValueError(f'Specified category {category} is not valid. '
'Valid categories are {VALID_CATEGORIES}')
if category not in SHREC16._VALID_CATEGORIES:
raise ValueError(
f"Specified category {category} is not valid. "
f"Valid categories are {SHREC16._VALID_CATEGORIES}"
)

self.mode = mode
self.root = root
self.categories = categories
self.categories_to_load = categories
self.train = train
self.num_samples = 0
self.paths = []
self.categories = []
for cl in self.categories:
clsdir = os.path.join(root, cl, self.mode)
cur = glob.glob(clsdir + '/*.obj')
self.category_names = []
self.labels = []
for i, cl in enumerate(self.categories_to_load):
clsdir = os.path.join(root, cl, "train" if self.train else "test")
cur = glob.glob(clsdir + "/*.obj")

self.paths = self.paths + cur
self.categories += [cl] * len(cur)
self.category_names += [cl] * len(cur)
self.labels += [i] * len(cur)
self.num_samples += len(cur)
if len(cur) == 0:
raise RuntimeWarning('No .obj files could be read '
f'for category \'{cl}\'. Skipping...')
raise RuntimeWarning(
"No .obj files could be read " f"for category '{cl}'. Skipping..."
)

def __len__(self):
"""Returns the length of the dataset. """
return self.num_samples

def __getitem__(self, idx):
"""Returns the sample at index idx. """

# Read in the list of vertices and faces
# from the obj file.
def _get_data(self, idx):
obj_location = self.paths[idx]
mesh = TriangleMesh.from_obj(obj_location)
category = self.categories[idx]
# Return these tensors as a dictionary.
data = dict()
attributes = dict()
data['vertices'] = mesh.vertices
data['faces'] = mesh.faces
attributes['rep'] = 'Mesh'
attributes['name'] = obj_location
attributes['class']: cagetory
return {'attributes': attributes, 'data': data}
return mesh

def _get_attributes(self, idx):
attributes = {
"path": self.paths[idx],
"category": self.category_names[idx],
"label": self.labels[idx],
}
return attributes
26 changes: 12 additions & 14 deletions tests/datasets/test_shrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest

import torch
import sys
import os
import shutil

import pytest

import kaolin as kal
from torch.utils.data import DataLoader

SHREC16_ROOT = "/data/SHREC16/"
CACHE_DIR = "tests/datasets/cache"

# Tests below can only be run is a ShapeNet dataset is available

# def test_SHREC16():

# shreck = kal.dataloader.SHREC16(root = 'tests/datasets_eval/shrec_16/', categories = ['alien', 'ants'], mode = 'train')
# for obj in shreck:
# assert obj['verts'].shape[0] > 0
# assert obj['faces'].shape[0] > 0
REASON = "SHREC16 not found at default location: {}".format(SHREC16_ROOT)


# test_SHREC16()
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.skipif(not os.path.exists(SHREC16_ROOT), reason=REASON)
def test_SHREC16(device):
models = kal.datasets.SHREC16(
root=SHREC16_ROOT, categories=["ants"], train=False
)
assert len(models) == 4

0 comments on commit 4063413

Please sign in to comment.