1
1
import os
2
- import pickle
3
- import argparse
4
2
import tensorflow as tf
5
3
import time
4
+ from model import Model
5
+ from utils import build_dict , build_dataset , batch_iter
6
6
start = time .perf_counter ()
7
- #from model import Model
8
- #from utils import build_dict, build_dataset, batch_iter
9
-
10
- # Uncomment next 2 lines to suppress error and Tensorflow info verbosity. Or change logging levels
11
- # tf.logging.set_verbosity(tf.logging.FATAL)
12
- # os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
13
-
14
- # def add_arguments(parser):
15
- # parser.add_argument("--num_hidden", type=int, default=150, help="Network size.")
16
- # parser.add_argument("--num_layers", type=int, default=2, help="Network depth.")
17
- # parser.add_argument("--beam_width", type=int, default=10, help="Beam width for beam search decoder.")
18
- # parser.add_argument("--glove", action="store_true", help="Use glove as initial word embedding.")
19
- # parser.add_argument("--embedding_size", type=int, default=300, help="Word embedding size.")
20
- #
21
- # parser.add_argument("--learning_rate", type=float, default=1e-3, help="Learning rate.")
22
- # parser.add_argument("--batch_size", type=int, default=64, help="Batch size.")
23
- # parser.add_argument("--num_epochs", type=int, default=10, help="Number of epochs.")
24
- # parser.add_argument("--keep_prob", type=float, default=0.8, help="Dropout keep prob.")
25
- #
26
- # parser.add_argument("--toy", action="store_true", help="Use only 50K samples of data")
27
- #
28
- # parser.add_argument("--with_model", action="store_true", help="Continue from previously saved model")
7
+ default_path = '.'
29
8
30
9
31
10
class args :
@@ -48,20 +27,16 @@ class args:
48
27
args .with_model = "store_true"
49
28
50
29
51
- #parser = argparse.ArgumentParser()
52
- # add_arguments(parser)
53
- #args = parser.parse_args()
54
- # with open("args.pickle", "wb") as f:
55
- # pickle.dump(args, f)
56
-
57
30
if not os .path .exists (default_path + "saved_model" ):
58
31
os .mkdir (default_path + "saved_model" )
59
32
else :
60
33
# if args.with_model:
61
34
old_model_checkpoint_path = open (
62
35
default_path + 'saved_model/checkpoint' , 'r' )
63
36
old_model_checkpoint_path = "" .join (
64
- [default_path + "saved_model/" , old_model_checkpoint_path .read ().splitlines ()[0 ].split ('"' )[1 ]])
37
+ [
38
+ default_path + "saved_model/" ,
39
+ old_model_checkpoint_path .read ().splitlines ()[0 ].split ('"' )[1 ]])
65
40
66
41
67
42
print ("Building dictionary..." )
@@ -98,9 +73,13 @@ class args:
98
73
map (lambda x : list (x ) + [word_dict ["</s>" ]], batch_y ))
99
74
100
75
batch_decoder_input = list (
101
- map (lambda d : d + (summary_max_len - len (d )) * [word_dict ["<padding>" ]], batch_decoder_input ))
76
+ map (
77
+ lambda d : d + (summary_max_len - len (d )) * [word_dict ["<padding>" ]],
78
+ batch_decoder_input ))
102
79
batch_decoder_output = list (
103
- map (lambda d : d + (summary_max_len - len (d )) * [word_dict ["<padding>" ]], batch_decoder_output ))
80
+ map (
81
+ lambda d : d + (summary_max_len - len (d )) * [word_dict ["<padding>" ]],
82
+ batch_decoder_output ))
104
83
105
84
train_feed_dict = {
106
85
model .batch_size : len (batch_x ),
@@ -112,7 +91,8 @@ class args:
112
91
}
113
92
114
93
_ , step , loss = sess .run (
115
- [model .update , model .global_step , model .loss ], feed_dict = train_feed_dict )
94
+ [model .update ,
95
+ model .global_step , model .loss ], feed_dict = train_feed_dict )
116
96
117
97
if step % 1000 == 0 :
118
98
print ("step {0}: loss = {1}" .format (step , loss ))
0 commit comments