Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #30 from PaddlePaddle/text_classification
Browse files Browse the repository at this point in the history
add multi card for text_classification
  • Loading branch information
guochaorong committed May 29, 2018
2 parents 6b8c122 + 0e2ba06 commit 2f701dc
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 28 deletions.
8 changes: 7 additions & 1 deletion text_classification/continuous_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,10 @@
lstm_train_cost_kpi = CostKpi('lstm_train_cost', 5, 0)
lstm_pass_duration_kpi = DurationKpi('lstm_pass_duration', 0.02, 0, actived=True)

tracking_kpis = [lstm_train_cost_kpi, lstm_pass_duration_kpi]
lstm_train_cost_kpi_card4 = CostKpi('lstm_train_cost_card4', 0.2, 0)
lstm_pass_duration_kpi_card4 = DurationKpi('lstm_pass_duration_card4', 0.02, 0, actived=True)

tracking_kpis = [
lstm_train_cost_kpi, lstm_pass_duration_kpi,
lstm_train_cost_kpi_card4, lstm_pass_duration_kpi_card4,
]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[17.750867716471355]
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[0.0030332264248281717]
9 changes: 7 additions & 2 deletions text_classification/run.xsh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@

export MKL_NUM_THREADS=1
export OMP_NUM_THREADS=1
cudaid=${text_classification:=0} # use 0-th card as default

cudaid=${text_classification:=0}
export CUDA_VISIBLE_DEVICES=$cudaid
FLAGS_benchmark=true python train.py --model lstm

cudaid=${text_classification_m:=0,1,2,3} # use 0-th card as default
export CUDA_VISIBLE_DEVICES=$cudaid

#LSTM pass_num 15
FLAGS_benchmark=true python train.py lstm
FLAGS_benchmark=true python train.py --model lstm --gpu_card_num 4
46 changes: 33 additions & 13 deletions text_classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,24 @@

import paddle.fluid as fluid
import paddle

import argparse
import utils
from nets import bow_net
from nets import cnn_net
from nets import lstm_net
from nets import gru_net
from continuous_evaluation import lstm_train_cost_kpi, lstm_pass_duration_kpi
from continuous_evaluation import *
fluid.default_startup_program().random_seed = 99

def parse_args():
parser = argparse.ArgumentParser("text_classification model benchmark.")
parser.add_argument(
'--model', type=str, default="lstm", help='model to run.')
parser.add_argument(
'--gpu_card_num', type=int, default=1, help='gpu card num used.')

args = parser.parse_args()
return args

def train(train_reader,
word_dict,
Expand All @@ -26,6 +36,7 @@ def train(train_reader,
"""
train network
"""
args = parse_args()
data = fluid.layers.data(
name="words", shape=[1], dtype="int64", lod_level=1)

Expand All @@ -34,7 +45,7 @@ def train(train_reader,
if not parallel:
cost, acc, prediction = network(data, label, len(word_dict))
else:
places = fluid.layers.get_places(device_count=2)
places = fluid.layers.get_places()
pd = fluid.layers.ParallelDo(places)
with pd.do():
cost, acc, prediction = network(
Expand Down Expand Up @@ -76,20 +87,29 @@ def train(train_reader,
print("pass_id: %d, avg_acc: %f, avg_cost: %f" %
(pass_id, avg_acc, avg_cost))
if pass_id == pass_num - 1:
lstm_train_cost_kpi.add_record(newest_avg_cost)
lstm_pass_duration_kpi.add_record(total_time / pass_num)
if args.gpu_card_num == 1:
lstm_train_cost_kpi.add_record(newest_avg_cost)
lstm_pass_duration_kpi.add_record(total_time / pass_num)
else:
lstm_train_cost_kpi_card4.add_record(newest_avg_cost)
lstm_pass_duration_kpi_card4.add_record(total_time / pass_num)

epoch_model = save_dirname + "/" + "epoch" + str(pass_id)
fluid.io.save_inference_model(epoch_model, ["words", "label"], acc,
exe)
lstm_train_cost_kpi.persist()
lstm_pass_duration_kpi.persist()

if args.gpu_card_num == 1:
lstm_train_cost_kpi.persist()
lstm_pass_duration_kpi.persist()
else:
lstm_train_cost_kpi_card4.persist()
lstm_pass_duration_kpi_card4.persist()

def train_net():
args = parse_args()
word_dict, train_reader, test_reader = utils.prepare_data(
"imdb", self_dict=False, batch_size=128, buf_size=50000)

if sys.argv[1] == "bow":
if args.model == "bow":
train(
train_reader,
word_dict,
Expand All @@ -100,7 +120,7 @@ def train_net():
lr=0.002,
pass_num=30,
batch_size=128)
elif sys.argv[1] == "cnn":
elif args.model == "cnn":
train(
train_reader,
word_dict,
Expand All @@ -111,18 +131,18 @@ def train_net():
lr=0.01,
pass_num=30,
batch_size=4)
elif sys.argv[1] == "lstm":
elif args.model == "lstm":
train(
train_reader,
word_dict,
lstm_net,
use_cuda=True,
parallel=False,
parallel=True,
save_dirname="lstm_model",
lr=0.05,
pass_num=15,
batch_size=4)
elif sys.argv[1] == "gru":
elif args.model == "gru":
train(
train_reader,
word_dict,
Expand Down
18 changes: 6 additions & 12 deletions text_classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,35 +69,29 @@ def prepare_data(data_type="imdb",

if data_type == "imdb":
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=buf_size),
paddle.dataset.imdb.train(word_dict),
batch_size=batch_size)

test_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.test(word_dict), buf_size=buf_size),
paddle.dataset.imdb.test(word_dict),
batch_size=batch_size)

elif data_type == "light_imdb":
train_reader = paddle.batch(
paddle.reader.shuffle(
light_imdb.train(word_dict), buf_size=buf_size),
light_imdb.train(word_dict),
batch_size=batch_size)

test_reader = paddle.batch(
paddle.reader.shuffle(
light_imdb.test(word_dict), buf_size=buf_size),
light_imdb.test(word_dict),
batch_size=batch_size)

elif data_type == "tiny_imdb":
train_reader = paddle.batch(
paddle.reader.shuffle(
tiny_imdb.train(word_dict), buf_size=buf_size),
tiny_imdb.train(word_dict),
batch_size=batch_size)

test_reader = paddle.batch(
paddle.reader.shuffle(
tiny_imdb.test(word_dict), buf_size=buf_size),
tiny_imdb.test(word_dict),
batch_size=batch_size)
else:
raise RuntimeError("no such dataset")
Expand Down

0 comments on commit 2f701dc

Please sign in to comment.