-
Notifications
You must be signed in to change notification settings - Fork 4
/
main.go
93 lines (74 loc) · 1.74 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
package main
import (
"context"
"fmt"
"io/ioutil"
"os"
"path"
"github.com/Applifier/go-tensorflow/predict"
"github.com/Applifier/go-tensorflow/savedmodel"
"github.com/Applifier/go-tensorflow/serving"
)
func getLocalPredictor() predict.Predictor {
pred, err := savedmodel.NewPredictor(getModelsDir(), "mobilenet", 1, "serving_default")
if err != nil {
panic(err)
}
return pred
}
func getServingPredictor() predict.Predictor {
servingModelClient, err := serving.NewModelPredictionClientFromAddr(
getServingAddr(),
"mobilenet",
"serving_default",
)
if err != nil {
panic(err)
}
return serving.NewPredictor(servingModelClient)
}
func readImage(url string) []byte {
file, err := os.Open(url)
if err != nil {
panic(err)
}
defer file.Close()
b, err := ioutil.ReadAll(file)
if err != nil {
panic(err)
}
return b
}
func main() {
imagePath := os.Args[1]
imageBytes := readImage(imagePath)
labels := loadLabels(getLabelsFile())
tensor, err := makeTensorFromImage(imageBytes, path.Ext(imagePath))
if err != nil {
panic(err)
}
// Run prediction both against a local model and tensorflow serving
for name, predictor := range map[string]predict.Predictor{
"serving": getServingPredictor(),
"embedded": getLocalPredictor(),
} {
fmt.Printf("\nRunning prediction on %s\n", name)
res, _, err := predictor.Predict(
context.Background(),
map[string]interface{}{
"inputs": tensor,
},
nil,
)
if err != nil {
panic(err)
}
classes := res["detection_classes"].Value().([][]float32)
scores := res["detection_scores"].Value().([][]float32)
fmt.Printf("Results from %s\n", name)
for i, val := range classes[0] {
label := labels[int(val)]
fmt.Printf("%s = %f\n", label, scores[0][i])
}
}
}