Skip to content
Permalink
Browse files

Added gpu_options.allow_growth option.

  • Loading branch information...
Armour committed Jun 29, 2018
1 parent 7ec60cc commit 98b71724df77ff78996823e7027f89b15985298d
Showing with 8 additions and 4 deletions.
  1. +4 −2 test.py
  2. +4 −2 train.py
@@ -16,12 +16,14 @@
# Init model.
is_training, global_step, _, loss, predict_rgb, color_image_rgb, gray_image, file_paths = init_model(train=False)

# Init scaffold and hooks.
# Init scaffold, hooks and config.
scaffold = tf.train.Scaffold()
summary_hook = tf.train.SummarySaverHook(output_dir=testing_summary, save_steps=display_step, scaffold=scaffold)
checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=summary_path, save_steps=saving_step, scaffold=scaffold)
num_step_hook = tf.train.StopAtStepHook(num_steps=len(file_paths))
session_creator = tf.train.ChiefSessionCreator(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True))
config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
config.gpu_options.allow_growth = True # pylint: disable=E1101
session_creator = tf.train.ChiefSessionCreator(scaffold=scaffold, config=config, checkpoint_dir=summary_path)

# Create a session for running operations in the Graph.
with tf.train.MonitoredSession(session_creator=session_creator, hooks=[checkpoint_hook, summary_hook]) as sess:
@@ -16,17 +16,19 @@
# Init model.
is_training, global_step, optimizer, loss, predict_rgb, color_image_rgb, gray_image, _ = init_model(train=True)

# Init scaffold and hooks.
# Init scaffold, hooks and config.
scaffold = tf.train.Scaffold()
summary_hook = tf.train.SummarySaverHook(output_dir=training_summary, save_steps=display_step, scaffold=scaffold)
checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=summary_path, save_steps=saving_step, scaffold=scaffold)
num_step_hook = tf.train.StopAtStepHook(num_steps=training_iters)
config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)
config.gpu_options.allow_growth = True # pylint: disable=E1101

# Create a session for running operations in the Graph.
with tf.train.MonitoredTrainingSession(checkpoint_dir=summary_path,
hooks=[summary_hook, checkpoint_hook, num_step_hook],
scaffold=scaffold,
config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=True)) as sess:
config=config) as sess:
print("🤖 Start training...")

while not sess.should_stop():

0 comments on commit 98b7172

Please sign in to comment.
You can’t perform that action at this time.