Skip to content

Commit

Permalink
Fixes post-Evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Jun 16, 2024
1 parent cd94357 commit 23aef3d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
2 changes: 2 additions & 0 deletions gflownet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ def load_gflow_net_from_run_path(
if forward_final is None:
print("Warning: no forward policy checkpoint found")
else:
print(f"\nLoading forward policy checkpoint: {str(forward_final)}")
gflownet.forward_policy.model.load_state_dict(
torch.load(forward_final, map_location=set_device(device))
)
Expand All @@ -393,6 +394,7 @@ def load_gflow_net_from_run_path(
if backward_final is None:
print("Warning: no backward policy checkpoint found")
else:
print(f"Loading backward policy checkpoint: {str(backward_final)}\n")
gflownet.backward_policy.model.load_state_dict(
torch.load(backward_final, map_location=set_device(device))
)
Expand Down
23 changes: 12 additions & 11 deletions scripts/eval_gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def main(args):
# ---------------------------------

if not args.samples_only:
gflownet.logger.test.n = args.n_samples
gflownet.evaluator.n = args.n_samples
eval_results = gflownet.evaluator.eval()

# TODO-V: legacy -> ok to remove?
Expand All @@ -184,16 +184,18 @@ def main(args):
print("output_dir: ", str(output_dir))
output_dir.mkdir(parents=True, exist_ok=True)

for figname, fig in eval_results["figs"].items():
output_fig = output_dir / (path_compatible(figname) + ".pdf")
if fig is not None:
fig.savefig(output_fig, bbox_inches="tight")
print(f"Saved figures to {output_dir}")
if "figs" in eval_results:
for figname, fig in eval_results["figs"].items():
output_fig = output_dir / (path_compatible(figname) + ".pdf")
if fig is not None:
fig.savefig(output_fig, bbox_inches="tight")
print(f"Saved figures to {output_dir}")

# Print metrics
print("Metrics:")
for k, v in eval_results["metrics"].items():
print(f"\t{k}: {v:.4f}")
if "metrics" in eval_results:
print("Metrics:")
for k, v in eval_results["metrics"].items():
print(f"\t{k}: {v:.4f}")

# ------------------------------------------
# ----- Sample GFlowNet -----
Expand All @@ -209,7 +211,6 @@ def main(args):
config_cond_env = config_cond_env.env
env_cond = instantiate(
config_cond_env,
proxy=env.proxy,
device=config.device,
float_precision=config.float_precision,
)
Expand All @@ -234,7 +235,7 @@ def main(args):
n_forward=bs, env_cond=env_cond, train=False
)
x_sampled = batch.get_terminating_states(proxy=True)
energies = env.proxy(x_sampled)
energies = gflownet.proxy(x_sampled)
x_sampled = batch.get_terminating_states()
df = pd.DataFrame(
{
Expand Down

0 comments on commit 23aef3d

Please sign in to comment.