Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
zeynepakkalyoncu committed Dec 17, 2018
1 parent 59f0ac4 commit f1dbcf8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
23 changes: 12 additions & 11 deletions lib/model/__main__.py
Expand Up @@ -21,12 +21,13 @@
root_dir = os.getcwd()

# Set GPU usage
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.log_device_placement = True
sess = tf.Session(config=config)
if not args.cpu:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.log_device_placement = True
sess = tf.Session(config=config)

set_session(sess)
set_session(sess)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.devices
Expand Down Expand Up @@ -82,9 +83,8 @@
else:
raise Exception("Unsupported dataset")

model = None
metrics.DATASET = args.dataset
metrics.TARGET_VOCAB = target_vocab
training_generator = WMTSequence(encoder_train_input, decoder_train_input, decoder_train_target, model_config)
validation_generator = WMTSequence(encoder_dev_input, decoder_dev_input, decoder_dev_target, model_config)

model_config = deepcopy(args)
source_vocab_size = len(source_vocab)
Expand All @@ -100,11 +100,12 @@
model_config.source_embedding_map = source_embedding_map
model_config.target_embedding_map = target_embedding_map

training_generator = WMTSequence(encoder_train_input, decoder_train_input, decoder_train_target, model_config)
validation_generator = WMTSequence(encoder_dev_input, decoder_dev_input, decoder_dev_target, model_config)
model = None
metrics.DATASET = args.dataset
metrics.TARGET_VOCAB = target_vocab

if args.cpu:
model = TinySeq2Seq(args)
model = TinySeq2Seq(model_config)
else:
model = Seq2Seq(model_config)

Expand Down
2 changes: 1 addition & 1 deletion lib/model/seq2seq.py
Expand Up @@ -14,7 +14,7 @@
from lib.model.util import lr_scheduler

def encode(config, recurrent_unit='lstm'):
initial_weights = RandomUniform(minval=-0.08, maxval=0.08, seed=self.config.seed)
initial_weights = RandomUniform(minval=-0.08, maxval=0.08, seed=config.seed)
encoder_inputs = Input(shape=(None, ))
encoder_embedding = Embedding(config.source_vocab_size, config.embedding_dim,
weights=[config.source_embedding_map], trainable=False)
Expand Down

0 comments on commit f1dbcf8

Please sign in to comment.