From d94ca4758a3cfa30edde6b9b34fad9dd037ab365 Mon Sep 17 00:00:00 2001 From: zhaohanguang Date: Tue, 7 Aug 2018 20:42:57 +0800 Subject: [PATCH] Update demo --- demo/sentiment_analysis.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/demo/sentiment_analysis.py b/demo/sentiment_analysis.py index a1b3b0d..8f2786b 100644 --- a/demo/sentiment_analysis.py +++ b/demo/sentiment_analysis.py @@ -29,7 +29,7 @@ train_neg_files = train_neg_files[:train_num] epoch_num = 10 -batch_size = 16 +batch_size = 64 train_steps = train_num * 2 // batch_size val_steps = val_num * 2 // batch_size @@ -57,11 +57,11 @@ for file_name in train_pos_files: with codecs.open(os.path.join(TRAIN_ROOT, 'pos', file_name), 'r', 'utf8') as reader: text = reader.read().strip() - dicts_generator(sentence=get_word_list_eng(text)) + dicts_generator(sentence=get_word_list_eng(text)) for file_name in train_neg_files: with codecs.open(os.path.join(TRAIN_ROOT, 'neg', file_name), 'r', 'utf8') as reader: text = reader.read().strip() - dicts_generator(sentence=get_word_list_eng(text)) + dicts_generator(sentence=get_word_list_eng(text)) word_dict, char_dict, max_word_len = dicts_generator(return_dict=True) print('Word dict size: %d Char dict size: %d Max word len: %d' % (len(word_dict), len(char_dict), max_word_len)) @@ -177,7 +177,7 @@ def test_batch_generator(batch_size=32): predicts = numpy.argmax(predicts, axis=-1).tolist() correct = 0 for i in range(len(predicts)): - if i % batch_size < batch_size: + if i % batch_size < batch_size // 2: expect = 1 else: expect = 0