Skip to content

formath/tensorflow-predictor-cpp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

80 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tensorflow-predictor-cpp

TensorFlow prediction using its C++ API.

Having this repo, you will not need TensorFlow-Serving. This project has been tested on OSX.

Contains two examples:

  • simple model c = a * b
  • an industrial deep model for large scale click through rate prediction

Covered knowledge points:

  • save model and checkpoint
  • freeze model with checkpoint
  • replace part of nodes in the model for prediction
  • transform libfm data into tfrecord
  • load model in C++
  • construct SparseTensor in C++
  • prediction in C++

Build

Build TensorFlow

Follow the instruction build tensorflow from source

git clone --recursive https://github.com/tensorflow/tensorflow.git
cd tensorflow
sh tensorflow/contrib/makefile/build_all_linux.sh (works for linux and osx)
cd ..

Build this repo

Keep this repo in the same directory with tensorflow.

git clone https://github.com/formath/tensorflow-predictor-cpp.git
cd tensorflow-predictor-cpp
mkdir build && cd build
cmake ..
make

Simple Model

This demo used c = a * b to show how to save the model and load it using C++ for prediction.

  • Save model
  • Load model
  • Prediction

More detail in Chinese: tensorflow_c++_api_prediction

cd demo/simple_model
# train
sh train.sh
# predict
sh predict.sh

Deep CTR Model

This demo show a real-world deep model usage in click through rate prediction.

  • Transform LibFM data into TFRecord
  • Save model and checkpoint
  • Replace parts of model and freeze graph with checkpoint
  • Load model and checkpoint
  • Prediction

More detail in Chinese: tensorflow_c++_api_prediction

Transform LibFM data into TFRecord

  • LibFM format: label fieldId:featureId:value ...
cd demo/deep_model
sh trans_data_to_tfrecord.sh

Train model

sh train.sh

Freeze model

sh freeze_graph.sh

Predict using C++

sh predict.sh