Skip to content

Commit

Permalink
add inference
Browse files Browse the repository at this point in the history
  • Loading branch information
wufan1991 committed Sep 15, 2021
1 parent 7ddf215 commit c170859
Showing 1 changed file with 56 additions and 56 deletions.
112 changes: 56 additions & 56 deletions fate-serving-cli/fate-serving-client.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,66 +597,66 @@ func init() {
},
}

//batchInferenceCommand := &grumble.Command{
// Name: "batchInference",
// Help: "batchInference",
// LongHelp: "batchInference",
// Flags: func(f *grumble.Flags) {
// f.String("a", "address", "localhost:8000", "serving-server的端口以及地址,如果不设置则默认使用localhost:8000")
// f.String("f", "filepath", "", "用于预测的json文件路径")
// f.String("s", "serviceId", "", "the serviceId of model")
// //f.String("h", "host", "localhost", "the ip of serving-server")
//
// // flag.StringVar(&host, "h", "localhost", "host ip")
// // flag.IntVar(&port, "p", 8000, "port")
// },
// Run: func(c *grumble.Context) error {
// c.App.Config().PromptColor.Add()
// var err error
// filePath := c.Flags.String("filepath")
//
// var content []byte
// address = c.Flags.String("address")
// if len(filePath) == 0 {
// servieId := c.Flags.String("serviceId")
// if len(servieId) == 0 {
// fmt.Println("serviceId for default testing data is expected")
// return nil
// } else {
// position := strings.IndexAny(BATCHINFERENCETEST, ":")
// position += 3
// content = []byte(BATCHINFERENCETEST[0:position] + servieId + BATCHINFERENCETEST[position:])
// }
// } else {
// content, err = cmd.ReadParams(filePath)
// if err != nil {
// fmt.Println(err)
// return nil
// }
// }
//
// inferenceMsg := pb.InferenceMessage{
// Body: content,
// }
//
// inferenceResp, err := rpc.BatchInference(address, &inferenceMsg)
// if err != nil {
// fmt.Println("Connection error")
// return nil
// }
// if inferenceResp != nil {
// resp := string(inferenceResp.GetBody())
// fmt.Println(resp)
// }
// //c.App.Println(c.Flags.String("directory"))
// return nil
// },
//}
batchInferenceCommand := &grumble.Command{
Name: "batchInference",
Help: "batchInference",
LongHelp: "batchInference",
Flags: func(f *grumble.Flags) {
f.String("a", "address", "localhost:8000", "serving-server的端口以及地址,如果不设置则默认使用localhost:8000")
f.String("f", "filepath", "", "用于预测的json文件路径")
f.String("s", "serviceId", "", "the serviceId of model")
//f.String("h", "host", "localhost", "the ip of serving-server")

// flag.StringVar(&host, "h", "localhost", "host ip")
// flag.IntVar(&port, "p", 8000, "port")
},
Run: func(c *grumble.Context) error {
c.App.Config().PromptColor.Add()
var err error
filePath := c.Flags.String("filepath")

var content []byte
address = c.Flags.String("address")
if len(filePath) == 0 {
servieId := c.Flags.String("serviceId")
if len(servieId) == 0 {
fmt.Println("serviceId for default testing data is expected")
return nil
} else {
position := strings.IndexAny(BATCHINFERENCETEST, ":")
position += 3
content = []byte(BATCHINFERENCETEST[0:position] + servieId + BATCHINFERENCETEST[position:])
}
} else {
content, err = cmd.ReadParams(filePath)
if err != nil {
fmt.Println(err)
return nil
}
}

inferenceMsg := pb.InferenceMessage{
Body: content,
}

inferenceResp, err := rpc.BatchInference(address, &inferenceMsg)
if err != nil {
fmt.Println("Connection error")
return nil
}
if inferenceResp != nil {
resp := string(inferenceResp.GetBody())
fmt.Println(resp)
}
//c.App.Println(c.Flags.String("directory"))
return nil
},
}

App.AddCommand(showModelCommand)
App.AddCommand(showConfigCommand)
App.AddCommand(inferenceCommand)
//App.AddCommand(batchInferenceCommand)
App.AddCommand(batchInferenceCommand)
App.AddCommand(showflowCommand)
App.AddCommand(cluster)
App.AddCommand(showJvmCommand)
Expand Down

0 comments on commit c170859

Please sign in to comment.