-
Notifications
You must be signed in to change notification settings - Fork 13
/
fingertip.py
27 lines (24 loc) · 883 Bytes
/
fingertip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
from net.vgg16 import model as vgg_model
from net.inception import model as inception_model
from net.xception import model as xception_model
from net.mobilenet import model as mobilenet_model
class Fingertips:
def __init__(self, model, weights):
if model is 'vgg':
self.model = vgg_model()
elif model is 'inception':
self.model = inception_model()
elif model is 'xception':
self.model = xception_model()
elif model is 'mobilenet':
self.model = mobilenet_model()
else:
assert False, model + ' does not exist.'
self.model.load_weights(weights)
def classify(self, image):
image = image / 255.0
image = np.expand_dims(image, axis=0)
position = self.model.predict(image)
position = position[0]
return position