Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Changes to avoid memory leak in Rollout worker #161

Merged
merged 13 commits into from Jan 4, 2019
3 changes: 2 additions & 1 deletion rl_coach/graph_managers/graph_manager.py
Expand Up @@ -562,7 +562,8 @@ def restore_checkpoint(self):
screen.warning("No checkpoint to restore in: {}".format(self.task_parameters.checkpoint_restore_dir))
else:
screen.log_title("Loading checkpoint: {}".format(checkpoint.model_checkpoint_path))
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
if not hasattr(self.agent_params.memory, 'memory_backend_params') or self.agent_params.memory.memory_backend_params.run_type != str(RunType.ROLLOUT_WORKER):
self.checkpoint_saver.restore(self.sess, checkpoint.model_checkpoint_path)
x77a1 marked this conversation as resolved.
Show resolved Hide resolved

[manager.restore_checkpoint(self.task_parameters.checkpoint_restore_dir) for manager in self.level_managers]

Expand Down
26 changes: 25 additions & 1 deletion rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py
@@ -1,8 +1,10 @@
import os
import sys
import gc
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import tensorflow as tf
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks, RunType
from rl_coach.memories.backend.memory import MemoryBackendParameters
from rl_coach.utils import get_open_port
from multiprocessing import Process
from tensorflow import logging
Expand Down Expand Up @@ -41,12 +43,34 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
experiment_path="./experiments/test"))
# graph_manager.improve()

# Test for identifying memory leak in restore_checkpoint
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
tf.reset_default_graph()
from rl_coach.presets.CartPole_DQN import graph_manager
assert graph_manager
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test",
apply_stop_condition=True))
# graph_manager.improve()
zach-nervana marked this conversation as resolved.
Show resolved Hide resolved
# graph_manager.save_checkpoint()
#
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
# graph_manager.agent_params.memory.register_var('memory_backend_params',
# MemoryBackendParameters(store_type=None,
# orchestrator_type=None,
# run_type=str(RunType.ROLLOUT_WORKER)))
# while True:
# graph_manager.restore_checkpoint()
# gc.collect()


if __name__ == '__main__':
pass
# test_basic_rl_graph_manager_with_pong_a3c()
# test_basic_rl_graph_manager_with_ant_a3c()
# test_basic_rl_graph_manager_with_pong_nec()
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
# test_basic_rl_graph_manager_with_cartpole_dqn()
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
#test_basic_rl_graph_manager_with_doom_basic_dqn()