-
Notifications
You must be signed in to change notification settings - Fork 1
/
datasets.py
executable file
·102 lines (86 loc) · 3.36 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import torch
import json
import codecs
import numpy as np
import sys
import torchvision.transforms as transforms
import argparse
import json
class PartDataset(data.Dataset):
def __init__(self, root, npoints = 2500, classification = False, class_choice = None, train = True):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.classification = classification
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
#print(self.cat)
if not class_choice is None:
self.cat = {k:v for k,v in self.cat.items() if k in class_choice}
self.meta = {}
for item in self.cat:
#print('category', item)
self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item], 'points')
dir_seg = os.path.join(self.root, self.cat[item], 'points_label')
#print(dir_point, dir_seg)
fns = sorted(os.listdir(dir_point))
if train:
fns = fns[:int(len(fns) * 0.9)]
else:
fns = fns[int(len(fns) * 0.9):]
#print(os.path.basename(fns))
for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0])
self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))
self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn[0], fn[1]))
self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
print(self.classes)
self.num_seg_classes = 0
if not self.classification:
for i in range(len(self.datapath)//50):
l = len(np.unique(np.loadtxt(self.datapath[i][-1]).astype(np.uint8)))
if l > self.num_seg_classes:
self.num_seg_classes = l
#print(self.num_seg_classes)
def __getitem__(self, index):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
point_set = np.loadtxt(fn[1]).astype(np.float32)
seg = np.loadtxt(fn[2]).astype(np.int64)
#print(point_set.shape, seg.shape)
choice = np.random.choice(len(seg), self.npoints, replace=True)
#resample
point_set = point_set[choice, :]
seg = seg[choice]
point_set = torch.from_numpy(point_set)
seg = torch.from_numpy(seg)
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
if self.classification:
return point_set, cls
else:
return point_set, seg
def __len__(self):
return len(self.datapath)
if __name__ == '__main__':
print('test')
d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', class_choice = ['Chair'])
print(len(d))
ps, seg = d[0]
print(ps.size(), ps.type(), seg.size(),seg.type())
d = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0', classification = True)
print(len(d))
ps, cls = d[0]
print(ps.size(), ps.type(), cls.size(),cls.type())