Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CheckpointSaver] Optimize the time consumption of CheckpointSaver wh… #756

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions tensorflow/python/training/basic_session_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import os
import time
import threading

import numpy as np
import six
Expand Down Expand Up @@ -555,6 +556,15 @@ def __init__(self,
self._steps_per_run = 1
self._incremental_save_secs = incremental_save_secs
self._incremental_save = self._incremental_save_secs is not None

self._optimize_start_time = False
env_key = 'OPTIMIZE_START_TIME'
optimize_start_time = os.getenv(env_key, default='false')
if optimize_start_time.lower() == 'true':
logging.info("CheckpointSaver: Optimize start time")
self._optimize_start_time = True
self._thread = None

logging.info("Init incremental saver , incremental_save:%s, incremental_path:%s", str(self._incremental_save), str(self._incremental_save_path))

def _set_steps_per_run(self, steps_per_run):
Expand All @@ -574,17 +584,28 @@ def after_create_session(self, session, coord):
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
# add variables in begin. Graph is finalized after all begin calls.
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir, "graph.pbtxt")
saver_def = self._get_saver().saver_def if self._get_saver() else None
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
# The checkpoint saved here is the state at step "global_step".
self._save(session, global_step)
def write_graph_pbtxt(graph, checkpoint_dir):
training_util.write_graph(graph, checkpoint_dir, "graph.pbtxt")
logging.info("CheckpointSaver: Finish write graph")

if not self._optimize_start_time:
write_graph_pbtxt(ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir)
saver_def = self._get_saver().saver_def if self._get_saver() else None
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True),
saver_def=saver_def)
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
# The checkpoint saved here is the state at step "global_step".
self._save(session, global_step)
else:
self._thread = threading.Thread(target=write_graph_pbtxt,
args=(ops.get_default_graph().as_graph_def(add_shapes=True), self._checkpoint_dir,))
self._thread.setDaemon(True)
self._thread.start()

self._timer.update_last_triggered_step(global_step)
self._incremental_save = self._incremental_save and self._get_incr_saver() is not None
logging.info("Create incremental timer, incremental_save:%s, incremental_save_secs:%s", str(self._incremental_save), str(self._incremental_save_secs))
Expand Down Expand Up @@ -622,6 +643,10 @@ def end(self, session):
for l in self._listeners:
l.end(session, last_step)

if self._thread:
self._thread.join()
self._thread = None

def _save(self, session, step):
"""Saves the latest checkpoint, returns should_stop."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
Expand Down