@@ -55,8 +55,8 @@ def _register_policy(self, policy: TFPolicy) -> None:
5555 with self .policy .graph .as_default ():
5656 self .tf_saver = tf .train .Saver (max_to_keep = self ._keep_checkpoints )
5757
58- def save_checkpoint (self , behavior_name : str , step : int ) -> str :
59- checkpoint_path = os .path .join (self .model_path , f"{ behavior_name } -{ step } " )
58+ def save_checkpoint (self , brain_name : str , step : int ) -> str :
59+ checkpoint_path = os .path .join (self .model_path , f"{ brain_name } -{ step } " )
6060 # Save the TF checkpoint and graph definition
6161 if self .graph :
6262 with self .graph .as_default ():
@@ -66,16 +66,16 @@ def save_checkpoint(self, behavior_name: str, step: int) -> str:
6666 self .graph , self .model_path , "raw_graph_def.pb" , as_text = False
6767 )
6868 # also save the policy so we have optimized model files for each checkpoint
69- self .export (checkpoint_path , behavior_name )
69+ self .export (checkpoint_path , brain_name )
7070 return checkpoint_path
7171
72- def export (self , output_filepath : str , behavior_name : str ) -> None :
72+ def export (self , output_filepath : str , brain_name : str ) -> None :
7373 # save model if there is only one worker or
7474 # only on worker-0 if there are multiple workers
7575 if self .policy and self .policy .rank is not None and self .policy .rank != 0 :
7676 return
7777 export_policy_model (
78- self .model_path , output_filepath , behavior_name , self .graph , self .sess
78+ self .model_path , output_filepath , brain_name , self .graph , self .sess
7979 )
8080
8181 def initialize_or_load (self , policy : Optional [TFPolicy ] = None ) -> None :
@@ -94,7 +94,6 @@ def initialize_or_load(self, policy: Optional[TFPolicy] = None) -> None:
9494 self ._load_graph (policy , self .model_path , reset_global_steps = reset_steps )
9595 else :
9696 policy .initialize ()
97-
9897 TFPolicy .broadcast_global_variables (0 )
9998
10099 def _load_graph (
0 commit comments