-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rewrite the text classification demo. #83
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,3 @@ | ||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License | ||
|
||
import gzip | ||
|
||
import paddle.v2 as paddle | ||
|
@@ -51,10 +37,10 @@ def main(): | |
learning_rate_schedule="discexp", ) | ||
|
||
train_reader = paddle.batch( | ||
paddle.reader.shuffle(reader.test_reader("train.list"), buf_size=1000), | ||
paddle.reader.shuffle(reader.train_reader("train.list"), buf_size=1000), | ||
batch_size=BATCH_SIZE) | ||
test_reader = paddle.batch( | ||
reader.train_reader("test.list"), batch_size=BATCH_SIZE) | ||
reader.test_reader("test.list"), batch_size=BATCH_SIZE) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. test.list是文本数据集吗?没有在目录下找到 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这处修改用来修改图像分类例子的bug,目前每个例子读取数据的方式确实不统一。后续提PR修改图像分类的例子。 |
||
|
||
# End batch and end pass event handler | ||
def event_handler(event): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
data | ||
*.tar.gz | ||
*.log | ||
*.pyc |
Large diffs are not rendered by default.
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
import sys | ||
import os | ||
import gzip | ||
|
||
import paddle.v2 as paddle | ||
|
||
import network_conf | ||
import reader | ||
from utils import * | ||
|
||
|
||
def infer(topology, data_dir, model_path, word_dict_path, label_dict_path, | ||
batch_size): | ||
def _infer_a_batch(inferer, test_batch, ids_2_word, ids_2_label): | ||
probs = inferer.infer(input=test_batch, field=['value']) | ||
assert len(probs) == len(test_batch) | ||
for word_ids, prob in zip(test_batch, probs): | ||
word_text = " ".join([ids_2_word[id] for id in word_ids[0]]) | ||
print("%s\t%s\t%s" % (ids_2_label[prob.argmax()], | ||
" ".join(["{:0.4f}".format(p) | ||
for p in prob]), word_text)) | ||
|
||
logger.info('begin to predict...') | ||
use_default_data = (data_dir is None) | ||
|
||
if use_default_data: | ||
word_dict = paddle.dataset.imdb.word_dict() | ||
word_reverse_dict = dict((value, key) | ||
for key, value in word_dict.iteritems()) | ||
label_reverse_dict = {0: "positive", 1: "negative"} | ||
test_reader = paddle.dataset.imdb.test(word_dict) | ||
else: | ||
assert os.path.exists( | ||
word_dict_path), 'the word dictionary file does not exist' | ||
assert os.path.exists( | ||
label_dict_path), 'the label dictionary file does not exist' | ||
|
||
word_dict = load_dict(word_dict_path) | ||
word_reverse_dict = load_reverse_dict(word_dict_path) | ||
label_reverse_dict = load_reverse_dict(label_dict_path) | ||
|
||
test_reader = reader.test_reader(data_dir, word_dict)() | ||
|
||
dict_dim = len(word_dict) | ||
class_num = len(label_reverse_dict) | ||
prob_layer = topology(dict_dim, class_num, is_infer=True) | ||
|
||
# initialize PaddlePaddle | ||
paddle.init(use_gpu=False, trainer_count=1) | ||
|
||
# load the trained models | ||
parameters = paddle.parameters.Parameters.from_tar( | ||
gzip.open(model_path, 'r')) | ||
inferer = paddle.inference.Inference( | ||
output_layer=prob_layer, parameters=parameters) | ||
|
||
test_batch = [] | ||
for idx, item in enumerate(test_reader): | ||
test_batch.append([item[0]]) | ||
if len(test_batch) == batch_size: | ||
_infer_a_batch(inferer, test_batch, word_reverse_dict, | ||
label_reverse_dict) | ||
test_batch = [] | ||
|
||
if len(test_batch): | ||
_infer_a_batch(inferer, test_batch, word_reverse_dict, | ||
label_reverse_dict) | ||
test_batch = [] | ||
|
||
|
||
if __name__ == '__main__': | ||
model_path = 'dnn_params_pass_00000.tar.gz' | ||
assert os.path.exists(model_path), "the trained model does not exist." | ||
|
||
nn_type = 'dnn' | ||
test_dir = None | ||
word_dict = None | ||
label_dict = None | ||
|
||
if nn_type == 'dnn': | ||
topology = network_conf.fc_net | ||
elif nn_type == 'cnn': | ||
topology = network_conf.convolution_net | ||
|
||
infer( | ||
topology=topology, | ||
data_dir=test_dir, | ||
word_dict_path=word_dict, | ||
label_dict_path=label_dict, | ||
model_path=model_path, | ||
batch_size=10) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import sys | ||
import math | ||
import gzip | ||
|
||
from paddle.v2.layer import parse_network | ||
import paddle.v2 as paddle | ||
|
||
__all__ = ["fc_net", "convolution_net"] | ||
|
||
|
||
def fc_net(dict_dim, | ||
class_num, | ||
emb_dim=28, | ||
hidden_layer_sizes=[28, 8], | ||
is_infer=False): | ||
""" | ||
define the topology of the dnn network | ||
|
||
:param dict_dim: size of word dictionary | ||
:type input_dim: int | ||
:params class_num: number of instance class | ||
:type class_num: int | ||
:params emb_dim: embedding vector dimension | ||
:type emb_dim: int | ||
""" | ||
|
||
# define the input layers | ||
data = paddle.layer.data("word", | ||
paddle.data_type.integer_value_sequence(dict_dim)) | ||
if not is_infer: | ||
lbl = paddle.layer.data("label", | ||
paddle.data_type.integer_value(class_num)) | ||
|
||
# define the embedding layer | ||
emb = paddle.layer.embedding(input=data, size=emb_dim) | ||
# max pooling to reduce the input sequence into a vector (non-sequence) | ||
seq_pool = paddle.layer.pooling( | ||
input=emb, pooling_type=paddle.pooling.Max()) | ||
|
||
for idx, hidden_size in enumerate(hidden_layer_sizes): | ||
hidden_init_std = 1.0 / math.sqrt(hidden_size) | ||
hidden = paddle.layer.fc( | ||
input=hidden if idx else seq_pool, | ||
size=hidden_size, | ||
act=paddle.activation.Tanh(), | ||
param_attr=paddle.attr.Param(initial_std=hidden_init_std)) | ||
|
||
prob = paddle.layer.fc( | ||
input=hidden, | ||
size=class_num, | ||
act=paddle.activation.Softmax(), | ||
param_attr=paddle.attr.Param(initial_std=1.0 / math.sqrt(class_num))) | ||
|
||
if is_infer: | ||
return prob | ||
else: | ||
return paddle.layer.classification_cost( | ||
input=prob, label=lbl), prob, lbl | ||
|
||
|
||
def convolution_net(dict_dim, | ||
class_dim=2, | ||
emb_dim=28, | ||
hid_dim=128, | ||
is_infer=False): | ||
""" | ||
cnn network definition | ||
|
||
:param dict_dim: size of word dictionary | ||
:type input_dim: int | ||
:params class_dim: number of instance class | ||
:type class_dim: int | ||
:params emb_dim: embedding vector dimension | ||
:type emb_dim: int | ||
:params hid_dim: number of same size convolution kernels | ||
:type hid_dim: int | ||
""" | ||
|
||
# input layers | ||
data = paddle.layer.data("word", | ||
paddle.data_type.integer_value_sequence(dict_dim)) | ||
lbl = paddle.layer.data("label", paddle.data_type.integer_value(class_dim)) | ||
|
||
# embedding layer | ||
emb = paddle.layer.embedding(input=data, size=emb_dim) | ||
|
||
# convolution layers with max pooling | ||
conv_3 = paddle.networks.sequence_conv_pool( | ||
input=emb, context_len=3, hidden_size=hid_dim) | ||
conv_4 = paddle.networks.sequence_conv_pool( | ||
input=emb, context_len=4, hidden_size=hid_dim) | ||
|
||
# fc and output layer | ||
prob = paddle.layer.fc( | ||
input=[conv_3, conv_4], size=class_dim, act=paddle.activation.Softmax()) | ||
|
||
if is_infer: | ||
return prob | ||
else: | ||
cost = paddle.layer.classification_cost(input=prob, label=lbl) | ||
|
||
return cost, prob, lbl |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
import os | ||
|
||
|
||
def train_reader(data_dir, word_dict, label_dict): | ||
""" | ||
Reader interface for training data | ||
|
||
:param data_dir: data directory | ||
:type data_dir: str | ||
:param word_dict: path of word dictionary, | ||
the dictionary must has a "UNK" in it. | ||
:type word_dict: Python dict | ||
:param label_dict: path of label dictionary | ||
:type label_dict: Python dict | ||
""" | ||
|
||
def reader(): | ||
UNK_ID = word_dict["<UNK>"] | ||
word_col = 1 | ||
lbl_col = 0 | ||
|
||
for file_name in os.listdir(data_dir): | ||
with open(os.path.join(data_dir, file_name), "r") as f: | ||
for line in f: | ||
line_split = line.strip().split("\t") | ||
word_ids = [ | ||
word_dict.get(w, UNK_ID) | ||
for w in line_split[word_col].split() | ||
] | ||
yield word_ids, label_dict[line_split[lbl_col]] | ||
|
||
return reader | ||
|
||
|
||
def test_reader(data_dir, word_dict): | ||
""" | ||
Reader interface for testing data | ||
|
||
:param data_dir: data directory. | ||
:type data_dir: str | ||
:param word_dict: path of word dictionary, | ||
the dictionary must has a "UNK" in it. | ||
:type word_dict: Python dict | ||
""" | ||
|
||
def reader(): | ||
UNK_ID = word_dict["<UNK>"] | ||
word_col = 1 | ||
|
||
for file_name in os.listdir(data_dir): | ||
with open(os.path.join(data_dir, file_name), "r") as f: | ||
for line in f: | ||
line_split = line.strip().split("\t") | ||
if len(line_split) < word_col: continue | ||
word_ids = [ | ||
word_dict.get(w, UNK_ID) | ||
for w in line_split[word_col].split() | ||
] | ||
yield word_ids, line_split[word_col] | ||
|
||
return reader |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/sh | ||
|
||
python train.py \ | ||
--nn_type="dnn" \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的train的方式怎么又改成shell传参了,按照约定都应该写到train.py里? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shell 里面为 |
||
--batch_size=64 \ | ||
--num_passes=10 \ | ||
2>&1 | tee train.log |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
train.list是文本数据集吗?没有在目录下找到
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这处修改用来修改图像分类例子的bug,目前每个例子读取数据的方式确实不统一。后续提PR修改图像分类的例子。