diff --git a/ml-agents/mlagents/trainers/trainer_controller.py b/ml-agents/mlagents/trainers/trainer_controller.py index 870ce7a813..550a514a1d 100644 --- a/ml-agents/mlagents/trainers/trainer_controller.py +++ b/ml-agents/mlagents/trainers/trainer_controller.py @@ -30,6 +30,11 @@ from mlagents.trainers.agent_processor import AgentManager from mlagents.tf_utils.globals import get_rank +try: + import torch +except ModuleNotFoundError: + torch = None # type: ignore + class TrainerController: def __init__( @@ -66,6 +71,8 @@ def __init__( self.kill_trainers = False np.random.seed(training_seed) tf.set_random_seed(training_seed) + if torch is not None: + torch.manual_seed(training_seed) self.rank = get_rank() @timed