Skip to content

Commit 5439edf

Browse files
committed
demo
1 parent 2eef4e2 commit 5439edf

16 files changed

+63
-58
lines changed

README.md

+16-8
Original file line numberDiff line numberDiff line change
@@ -21,27 +21,35 @@ cmake ..
2121
make
2222
```
2323

24-
# Demo
24+
# Simple Demo
25+
This demo used `c=a+b` to show how to save the model and load it using C++ for prediction. [tensorflow_c++_api_prediction_basic](http://mathmach.com/2017/10/09/tensorflow_c++_api_prediction_basic/)
26+
```bash
27+
cd demo/simple_model
28+
# train
29+
sh train.sh
30+
# predict
31+
sh predict.sh
32+
```
2533

26-
## Transform text file into TFRecord
34+
# Deep CTR Model Demo
35+
This demo show a real-wrold deep model usage in click through rate prediction. [tensorflow_c++_api_prediction_advance](http://mathmach.com/2017/10/11/tensorflow_c++_api_prediction_advance/)
36+
37+
## Transform LibFM data into TFRecord
38+
* LibFM format: `label fieldId:featureId:value ...`
2739
```bash
28-
cd demo
40+
cd demo/deep_model
2941
sh trans_data_to_tfrecord.sh
3042
cd ..
3143
```
3244

3345
## Train model
3446
```bash
35-
cd demo
3647
sh train.sh
37-
cd ..
3848
```
3949

4050
## Predict using C++
4151
```bash
42-
cd demo
43-
sh test.sh
44-
cd ..
52+
sh predict.sh
4553
```
4654

4755
# Reference
File renamed without changes.

demo/freeze_graph.sh demo/deep_model/freeze_graph.sh

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111
# --output_node_names=cross_entropy
1212
#cd -
1313

14-
python ../python/freeze_graph.py \
15-
--model_dir=./model \
14+
python ../../python/freeze_graph.py \
15+
--model_dir=./saved_model \
1616
--output_node_names=Softmax

demo/deep_model/predict.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/usr/bin/env bash
2+
3+
# TODO

demo/deep_model/train.sh

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#!/usr/bin/env bash
2+
3+
python ../../python/train.py \
4+
--dict "./data/dict.data" \
5+
--continuous_fields "" \
6+
--sparse_fields "9,6,116" \
7+
--linear_fields "152,179" \
8+
--train_file "./data/libfm.tfrecord" \
9+
--valid_file "./data/libfm.tfrecord"
+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/usr/bin/env bash
2+
3+
# generate field dict
4+
python ../../python/dict.py \
5+
'0' \
6+
'9,6,116' \
7+
'152,179' \
8+
./data/libfm.data \
9+
./data/dict.data
10+
11+
# transform libfm data into tfrecord
12+
python ../../python/data.py \
13+
./data/dict.data \
14+
'0' \
15+
'9,6,116' \
16+
'152,179' \
17+
./data/libfm.data \
18+
./data/libfm.tfrecord

demo/simple_model/predict.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/usr/bin/env bash
2+
3+
../../bin/simple_model.bin "./saved_model/graph.pb"

demo/simple_model/train.sh

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/usr/bin/env bash
2+
3+
python ../../python/simple_model.py

demo/train.sh

-9
This file was deleted.

demo/trans_data_to_tfrecord.sh

-18
This file was deleted.

python/model.py python/deep_model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@ def forward(self, sparse_id, sparse_val, linear_id, linear_val, continuous_val):
4747
self.hiddenW = []
4848
self.hiddenB = []
4949

50+
# sparse field embedding
5051
net = self.concat(self.sparse_field, sparse_id, sparse_val)
52+
53+
# concat sparse field embedding and continuous field
5154
if len(self.continuous_field) > 0:
5255
net = tf.concat([net, continuous_val], 1, name='concat_sparse_continuous')
5356

@@ -63,7 +66,7 @@ def forward(self, sparse_id, sparse_val, linear_id, linear_val, continuous_val):
6366
#net = tf.nn.dropout(net, self.drop_out, name='dropout_'+str(i))
6467
#tf.summary.histogram('hidden_w' + str(i), weight)
6568

66-
# merge linear sparse
69+
# merge linear sparse field embedding
6770
if len(self.linear_field) > 0:
6871
linear_embedding = self.concat(self.linear_field, linear_id, linear_val)
6972
net = tf.concat([net, linear_embedding], 1, name='concat_linear')

demo/simpile_model.py python/simple_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
print(b.eval()) # 6.0
1515
print(c.eval()) # 30.0
1616

17-
tf.train.write_graph(sess.graph_def, 'simple_model/', 'graph.pb', as_text=False)
17+
tf.train.write_graph(sess.graph_def, 'saved_model/', 'graph.pb', as_text=False)

python/test.py

-15
This file was deleted.

python/train.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import datetime
66
import tensorflow as tf
7-
from model import Model
7+
from deep_model import Model
88
from data import Data
99

1010
# config
@@ -16,9 +16,8 @@
1616
flags.DEFINE_integer("valid_batch_size", 100, "validate set batch size")
1717
flags.DEFINE_integer("thread_num", 1, "number of thread to read data")
1818
flags.DEFINE_integer("min_after_dequeue", 100, "min_after_dequeue for shuffle queue")
19-
flags.DEFINE_string("model_dir", "./model/", "model dirctory")
19+
flags.DEFINE_string("model_dir", "./saved_model/", "model dirctory")
2020
flags.DEFINE_string("tensorboard_dir", "./tensorboard/", "summary data saved for tensorboard")
21-
flags.DEFINE_string("model_type", "wide_and_deep", "model type, option: wide, deep, wide_and_deep")
2221
flags.DEFINE_string("optimizer", "adagrad", "optimization algorithm")
2322
flags.DEFINE_integer('steps_to_validate', 1, 'steps to validate and print')
2423
flags.DEFINE_bool("train_from_checkpoint", False, "reload model from checkpoint and go on training")

src/fnn.cc src/deep_model.cc

File renamed without changes.

src/simple_model.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ int main(int argc, char* argv[]) {
1919
// when using `bazel run` since the cwd isn't where you call
2020
// `bazel run` but from inside a temp folder.)
2121
GraphDef graph_def;
22-
status = ReadBinaryProto(Env::Default(), "../demo/simple_model/graph.pb", &graph_def);
22+
std::string model_path = argv[1];
23+
status = ReadBinaryProto(Env::Default(), model_path, &graph_def);
2324
if (!status.ok()) {
2425
std::cout << status.ToString() << std::endl;
2526
return 1;

0 commit comments

Comments
 (0)