diff --git a/gflownet/utils/common.py b/gflownet/utils/common.py index 739b6b38c..5c29bc238 100644 --- a/gflownet/utils/common.py +++ b/gflownet/utils/common.py @@ -385,6 +385,7 @@ def load_gflow_net_from_run_path( if forward_final is None: print("Warning: no forward policy checkpoint found") else: + print(f"\nLoading forward policy checkpoint: {str(forward_final)}") gflownet.forward_policy.model.load_state_dict( torch.load(forward_final, map_location=set_device(device)) ) @@ -393,6 +394,7 @@ def load_gflow_net_from_run_path( if backward_final is None: print("Warning: no backward policy checkpoint found") else: + print(f"Loading backward policy checkpoint: {str(backward_final)}\n") gflownet.backward_policy.model.load_state_dict( torch.load(backward_final, map_location=set_device(device)) ) diff --git a/scripts/eval_gflownet.py b/scripts/eval_gflownet.py index b2bcdc0fe..fcb509d4e 100644 --- a/scripts/eval_gflownet.py +++ b/scripts/eval_gflownet.py @@ -173,7 +173,7 @@ def main(args): # --------------------------------- if not args.samples_only: - gflownet.logger.test.n = args.n_samples + gflownet.evaluator.n = args.n_samples eval_results = gflownet.evaluator.eval() # TODO-V: legacy -> ok to remove? @@ -184,16 +184,18 @@ def main(args): print("output_dir: ", str(output_dir)) output_dir.mkdir(parents=True, exist_ok=True) - for figname, fig in eval_results["figs"].items(): - output_fig = output_dir / (path_compatible(figname) + ".pdf") - if fig is not None: - fig.savefig(output_fig, bbox_inches="tight") - print(f"Saved figures to {output_dir}") + if "figs" in eval_results: + for figname, fig in eval_results["figs"].items(): + output_fig = output_dir / (path_compatible(figname) + ".pdf") + if fig is not None: + fig.savefig(output_fig, bbox_inches="tight") + print(f"Saved figures to {output_dir}") # Print metrics - print("Metrics:") - for k, v in eval_results["metrics"].items(): - print(f"\t{k}: {v:.4f}") + if "metrics" in eval_results: + print("Metrics:") + for k, v in eval_results["metrics"].items(): + print(f"\t{k}: {v:.4f}") # ------------------------------------------ # ----- Sample GFlowNet ----- @@ -209,7 +211,6 @@ def main(args): config_cond_env = config_cond_env.env env_cond = instantiate( config_cond_env, - proxy=env.proxy, device=config.device, float_precision=config.float_precision, ) @@ -234,7 +235,7 @@ def main(args): n_forward=bs, env_cond=env_cond, train=False ) x_sampled = batch.get_terminating_states(proxy=True) - energies = env.proxy(x_sampled) + energies = gflownet.proxy(x_sampled) x_sampled = batch.get_terminating_states() df = pd.DataFrame( {