-
Notifications
You must be signed in to change notification settings - Fork 0
/
pokemon_classifier_tflite.py
55 lines (45 loc) · 1.93 KB
/
pokemon_classifier_tflite.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import tflite_runtime.interpreter as tflite
# from matplotlib import pyplot as plt
from sklearn.preprocessing import LabelEncoder
# from PIL import Image
import numpy as np
from PIL import Image
# convert scientific notation to decimal
np.set_printoptions(suppress=True,
formatter={'float_kind':'{:f}'.format})
def get_label_encoder():
encoder = LabelEncoder()
encoder.classes_ = np.load('./resources/classifier_model/best_classes.npy')
return encoder
def predict_top_n_pokemon(image_filename, num_top_pokemon):
# Predicts num_top_pokemon from image_file, using a tflite model
TFLITE_MODEL="./resources/classifier_model/vecchio_modello_nuovo_dataset_55fotoclasse_hue.tflite"
interpreter = tflite.Interpreter(TFLITE_MODEL)
interpreter.allocate_tensors()
# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
# Load image
img = Image.open(image_filename)
img = img.resize((224, 224), Image.ANTIALIAS)
img = np.asarray(img, dtype=np.float32)
img /= 255
img = np.expand_dims(img, axis=0)
# remove comment only if experimening with some strange models (pytorch to tflite case)
#img = np.transpose(img, [0, 3, 1, 2])
print(img.shape)
input_tensor = np.array(img, dtype=np.float32)
# Load the TFLite model and allocate tensors.
interpreter.set_tensor(input_details[0]['index'], input_tensor)
interpreter.invoke()
# Get output
output_data = interpreter.get_tensor(output_details[0]['index'])
# Get label encoder
label_encoder = get_label_encoder()
# Get best num_top_pokemon
# tflite_runtime.interpreter as tflite method
results = np.squeeze(output_data, axis=0)
top_k_idx = np.argsort(results)[-num_top_pokemon:][::-1]
top_k_values = results[top_k_idx]
top_k_labels = label_encoder.inverse_transform(top_k_idx)
return top_k_labels, top_k_values