From 459ba6fde5cf6f846bf166bac39d5cad75bab4db Mon Sep 17 00:00:00 2001 From: "chenbangduo.cbd" Date: Thu, 16 Mar 2023 11:24:24 +0800 Subject: [PATCH] [CheckpointSaver] Optimize the time consumption of CheckpointSaver when creating a MonitoredSession. --- .../training/basic_session_run_hooks.py | 47 ++++++++++++++----- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py index eff7c25b390..98f570504a1 100644 --- a/tensorflow/python/training/basic_session_run_hooks.py +++ b/tensorflow/python/training/basic_session_run_hooks.py @@ -25,6 +25,7 @@ import os import time +import threading import numpy as np import six @@ -555,6 +556,15 @@ def __init__(self, self._steps_per_run = 1 self._incremental_save_secs = incremental_save_secs self._incremental_save = self._incremental_save_secs is not None + + self._optimize_start_time = False + env_key = 'OPTIMIZE_START_TIME' + optimize_start_time = os.getenv(env_key, default='false') + if optimize_start_time.lower() == 'true': + logging.info("CheckpointSaver: Optimize start time") + self._optimize_start_time = True + self._thread = None + logging.info("Init incremental saver , incremental_save:%s, incremental_path:%s", str(self._incremental_save), str(self._incremental_save_path)) def _set_steps_per_run(self, steps_per_run): @@ -574,17 +584,28 @@ def after_create_session(self, session, coord): # We do write graph and saver_def at the first call of before_run. # We cannot do this in begin, since we let other hooks to change graph and # add variables in begin. Graph is finalized after all begin calls. - training_util.write_graph( - ops.get_default_graph().as_graph_def(add_shapes=True), - self._checkpoint_dir, "graph.pbtxt") - saver_def = self._get_saver().saver_def if self._get_saver() else None - graph = ops.get_default_graph() - meta_graph_def = meta_graph.create_meta_graph_def( - graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def) - self._summary_writer.add_graph(graph) - self._summary_writer.add_meta_graph(meta_graph_def) - # The checkpoint saved here is the state at step "global_step". - self._save(session, global_step) + def write_graph_pbtxt(graph, checkpoint_dir): + training_util.write_graph(graph, checkpoint_dir, "graph.pbtxt") + logging.info("CheckpointSaver: Finish write graph") + + if not self._optimize_start_time: + write_graph_pbtxt(ops.get_default_graph().as_graph_def(add_shapes=True), + self._checkpoint_dir) + saver_def = self._get_saver().saver_def if self._get_saver() else None + graph = ops.get_default_graph() + meta_graph_def = meta_graph.create_meta_graph_def( + graph_def=graph.as_graph_def(add_shapes=True), + saver_def=saver_def) + self._summary_writer.add_graph(graph) + self._summary_writer.add_meta_graph(meta_graph_def) + # The checkpoint saved here is the state at step "global_step". + self._save(session, global_step) + else: + self._thread = threading.Thread(target=write_graph_pbtxt, + args=(ops.get_default_graph().as_graph_def(add_shapes=True), self._checkpoint_dir,)) + self._thread.setDaemon(True) + self._thread.start() + self._timer.update_last_triggered_step(global_step) self._incremental_save = self._incremental_save and self._get_incr_saver() is not None logging.info("Create incremental timer, incremental_save:%s, incremental_save_secs:%s", str(self._incremental_save), str(self._incremental_save_secs)) @@ -622,6 +643,10 @@ def end(self, session): for l in self._listeners: l.end(session, last_step) + if self._thread: + self._thread.join() + self._thread = None + def _save(self, session, step): """Saves the latest checkpoint, returns should_stop.""" logging.info("Saving checkpoints for %d into %s.", step, self._save_path)