diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index 06f8736aac..be53dcee89 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -291,7 +291,10 @@ def start_learning(self): tf.reset_default_graph() - with tf.Session() as sess: + # Prevent a single session from taking all GPU memory. + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + with tf.Session(config=config) as sess: self._initialize_trainers(trainer_config, sess) for _, t in self.trainers.items(): self.logger.info(t)