diff --git a/rl_coach/architectures/tensorflow_components/savers.py b/rl_coach/architectures/tensorflow_components/savers.py index 38a36ee94..67c0c8b67 100644 --- a/rl_coach/architectures/tensorflow_components/savers.py +++ b/rl_coach/architectures/tensorflow_components/savers.py @@ -32,6 +32,16 @@ def __init__(self, name): # target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list # the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow. self._variables = [v for v in self._variables if '/target' not in v.name] + + # Using a placeholder to update the variable during restore to avoid memory leak. + # Ref: https://github.com/tensorflow/tensorflow/issues/4151 + self._variable_placeholders = [] + self._variable_update_ops = [] + for v in self._variables: + variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape()) + self._variable_placeholders.append(variable_placeholder) + self._variable_update_ops.append(v.assign(variable_placeholder)) + self._saver = tf.train.Saver(self._variables) @property @@ -66,8 +76,10 @@ def restore(self, sess: Any, restore_path: str): # TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here? new_name = var_name.replace('global/', 'online/') variables[new_name] = reader.get_tensor(var_name) - # Assign all variables - sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables]) + + # Assign all variables using placeholder + placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)} + sess.run(self._variable_update_ops, placeholder_dict) def merge(self, other: 'Saver'): """ diff --git a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py index 214ef31e3..fbdc09803 100644 --- a/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py +++ b/rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py @@ -1,8 +1,10 @@ +import gc import os import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) import tensorflow as tf from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks +from rl_coach.core_types import EnvironmentSteps from rl_coach.utils import get_open_port from multiprocessing import Process from tensorflow import logging @@ -41,6 +43,24 @@ def test_basic_rl_graph_manager_with_cartpole_dqn(): experiment_path="./experiments/test")) # graph_manager.improve() +# Test for identifying memory leak in restore_checkpoint +@pytest.mark.unit_test +def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore(): + tf.reset_default_graph() + from rl_coach.presets.CartPole_DQN import graph_manager + assert graph_manager + graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow, + experiment_path="./experiments/test", + apply_stop_condition=True)) + # graph_manager.improve() + # graph_manager.evaluate(EnvironmentSteps(1000)) + # graph_manager.save_checkpoint() + # + # graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint" + # while True: + # graph_manager.restore_checkpoint() + # graph_manager.evaluate(EnvironmentSteps(1000)) + # gc.collect() if __name__ == '__main__': pass @@ -48,5 +68,6 @@ def test_basic_rl_graph_manager_with_cartpole_dqn(): # test_basic_rl_graph_manager_with_ant_a3c() # test_basic_rl_graph_manager_with_pong_nec() # test_basic_rl_graph_manager_with_cartpole_dqn() + # test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore() #test_basic_rl_graph_manager_multithreaded_with_pong_a3c() #test_basic_rl_graph_manager_with_doom_basic_dqn() \ No newline at end of file