Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Changes to avoid memory leak in Rollout worker #161

Merged
merged 13 commits into from Jan 4, 2019
16 changes: 14 additions & 2 deletions rl_coach/architectures/tensorflow_components/savers.py
Expand Up @@ -32,6 +32,16 @@ def __init__(self, name):
# target network is never saved or restored directly from checkpoint, so we are removing all its variables from the list
# the target network would be synched back from the online network in graph_manager.improve(...), at the beginning of the run flow.
self._variables = [v for v in self._variables if '/target' not in v.name]

# Using a placeholder to update the variable during restore to avoid memory leak.
# Ref: https://github.com/tensorflow/tensorflow/issues/4151
self._variable_placeholders = []
self._variable_update_ops = []
for v in self._variables:
variable_placeholder = tf.placeholder(v.dtype, shape=v.get_shape())
self._variable_placeholders.append(variable_placeholder)
self._variable_update_ops.append(v.assign(variable_placeholder))

x77a1 marked this conversation as resolved.
Show resolved Hide resolved
self._saver = tf.train.Saver(self._variables)

@property
Expand Down Expand Up @@ -66,8 +76,10 @@ def restore(self, sess: Any, restore_path: str):
# TODO: Can this be more generic so that `global/` and `online/` are not hardcoded here?
new_name = var_name.replace('global/', 'online/')
variables[new_name] = reader.get_tensor(var_name)
# Assign all variables
sess.run([v.assign(variables[v.name.split(':')[0]]) for v in self._variables])

# Assign all variables using placeholder
placeholder_dict = {ph: variables[v.name.split(':')[0]] for ph, v in zip(self._variable_placeholders, self._variables)}
sess.run(self._variable_update_ops, placeholder_dict)

def merge(self, other: 'Saver'):
"""
Expand Down
21 changes: 21 additions & 0 deletions rl_coach/tests/graph_managers/test_basic_rl_graph_manager.py
@@ -1,8 +1,10 @@
import gc
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
import tensorflow as tf
from rl_coach.base_parameters import TaskParameters, DistributedTaskParameters, Frameworks
from rl_coach.core_types import EnvironmentSteps
from rl_coach.utils import get_open_port
from multiprocessing import Process
from tensorflow import logging
Expand Down Expand Up @@ -41,12 +43,31 @@ def test_basic_rl_graph_manager_with_cartpole_dqn():
experiment_path="./experiments/test"))
# graph_manager.improve()

# Test for identifying memory leak in restore_checkpoint
@pytest.mark.unit_test
def test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore():
tf.reset_default_graph()
from rl_coach.presets.CartPole_DQN import graph_manager
assert graph_manager
graph_manager.create_graph(task_parameters=TaskParameters(framework_type=Frameworks.tensorflow,
experiment_path="./experiments/test",
apply_stop_condition=True))
# graph_manager.improve()
zach-nervana marked this conversation as resolved.
Show resolved Hide resolved
# graph_manager.evaluate(EnvironmentSteps(1000))
# graph_manager.save_checkpoint()
#
# graph_manager.task_parameters.checkpoint_restore_dir = "./experiments/test/checkpoint"
# while True:
# graph_manager.restore_checkpoint()
# graph_manager.evaluate(EnvironmentSteps(1000))
# gc.collect()

if __name__ == '__main__':
pass
# test_basic_rl_graph_manager_with_pong_a3c()
# test_basic_rl_graph_manager_with_ant_a3c()
# test_basic_rl_graph_manager_with_pong_nec()
# test_basic_rl_graph_manager_with_cartpole_dqn()
# test_basic_rl_graph_manager_with_cartpole_dqn_and_repeated_checkpoint_restore()
#test_basic_rl_graph_manager_multithreaded_with_pong_a3c()
#test_basic_rl_graph_manager_with_doom_basic_dqn()