-
Notifications
You must be signed in to change notification settings - Fork 21
/
hand.py
87 lines (71 loc) · 2.94 KB
/
hand.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
from PIL import Image
import numpy as np
import pandas as pd
import torch
import torch.utils.data as data
from ..utils import gaussianHeatmap, transformer
class Hand(data.Dataset):
def __init__(self, prefix, phase, transform_params=dict(), sigma=10, num_landmark=19, size=[1000, 1400],use_background_channel=False):
self.transform = transformer(transform_params)
self.size = tuple(size)
self.num_landmark = num_landmark
self.pth_Image = os.path.join(prefix, 'jpg')
self.use_background_channel = use_background_channel
self.labels = pd.read_csv(os.path.join(
prefix, 'all.csv'), header=None, index_col=0)
# file index
index_set = set(self.labels.index)
files = [i[:-4] for i in sorted(os.listdir(self.pth_Image))]
files = [i for i in files if int(i) in index_set]
n = len(files)
train_num = 550 # round(n*0.7)
val_num = 59 # round(n*0.1)
test_num = n - train_num - val_num
if phase == 'train':
self.indexes = files[:train_num]
elif phase == 'validate':
self.indexes = files[train_num:-test_num]
elif phase == 'test':
self.indexes = files[-test_num:]
else:
raise Exception("Unknown phase: {phase}".fomrat(phase=phase))
self.genHeatmap = gaussianHeatmap(sigma, dim=len(size))
def __getitem__(self, index):
name = self.indexes[index]
ret = {'name': name}
img, origin_size = self.readImage(
os.path.join(self.pth_Image, name+'.jpg'))
points = self.readLandmark(name, origin_size)
li = [self.genHeatmap(point, self.size) for point in points]
if self.use_background_channel:
sm = sum(li)
sm[sm>1]=1
li.append(1-sm)
gt = np.array(li)
img, gt = self.transform(img, gt)
ret['input'] = torch.FloatTensor(img)
ret['gt'] = torch.FloatTensor(gt)
return ret
def __len__(self):
return len(self.indexes)
def readLandmark(self, name, origin_size):
li = list(self.labels.loc[int(name), :])
r1, r2 = [i/j for i, j in zip(self.size, origin_size)]
points = [tuple([round(li[i]*r1), round(li[i+1]*r2)])
for i in range(0, len(li), 2)]
return points
def readImage(self, path):
'''Read image from path and return a numpy.ndarray in shape of cxwxh
'''
img = Image.open(path)
origin_size = img.size
# resize, width x height, channel=1
img = img.resize(self.size)
arr = np.array(img)
# channel x width x height: 1 x width x height
arr = np.expand_dims(np.transpose(arr, (1, 0)), 0).astype(np.float)
# conveting to float is important, otherwise big bug occurs
for i in range(arr.shape[0]):
arr[i] = (arr[i]-arr[i].mean())/(arr[i].std()+1e-20)
return arr, origin_size