Skip to content

Commit

Permalink
added mode counting
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Jun 18, 2024
1 parent a2721d3 commit c57a708
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
36 changes: 35 additions & 1 deletion tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def main(args): # noqa: C901
states_visited = 0
n_iterations = args.n_trajectories // args.batch_size
validation_info = {"l1_dist": float("inf")}
discovered_modes = set()

for iteration in trange(n_iterations):
trajectories = gflownet.sample_trajectories(
env,
Expand Down Expand Up @@ -254,11 +256,12 @@ def main(args): # noqa: C901
if use_wandb:
wandb.log(to_log, step=iteration)
if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1):
validation_info = validate(
validation_info, discovered_modes = validate_hypergrid(
env,
gflownet,
args.validation_samples,
visited_terminating_states,
discovered_modes,
)
if use_wandb:
wandb.log(validation_info, step=iteration)
Expand All @@ -268,6 +271,37 @@ def main(args): # noqa: C901
return validation_info["l1_dist"]


def validate_hypergrid(
env,
gflownet,
n_validation_samples,
visited_terminating_states,
discovered_modes,
):
validation_info = validate( # Standard validation shared across envs.
env,
gflownet,
n_validation_samples,
visited_terminating_states,
)

# Add the mode counting metric.
states, scale = visited_terminating_states.tensor, env.scale_factor

normalized_states = ((states * scale) - (scale / 2) * (env.height - 1)).abs()

modes = torch.all(
(normalized_states > (0.3 * scale) * (env.height - 1))
& (normalized_states <= (0.4 * scale) * (env.height - 1)),
dim=-1,
)
modes_found = set([tuple(s.tolist()) for s in states[modes.bool()]])
discovered_modes.update(modes_found)
validation_info["n_modes_found"] = len(discovered_modes)

return validation_info, discovered_modes


if __name__ == "__main__":
parser = ArgumentParser()

Expand Down
7 changes: 5 additions & 2 deletions tutorials/examples/train_hypergrid_multinode.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from gfn.modules import DiscretePolicyEstimator, ScalarEstimator
from gfn.utils.common import set_seed
from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular
from gfn.utils.training import validate
from .train_hypergrid import validate_hypergrid

DEFAULT_SEED = 4444

Expand Down Expand Up @@ -292,6 +292,8 @@ def main(args): # noqa: C901
print ("n_iterations = ", n_iterations)
print ("my_batch_size = ", my_batch_size)
time_start = time.time()
discovered_modes = set()

for iteration in trange(n_iterations):
sample_start = time.time()
trajectories = gflownet.sample_trajectories(
Expand Down Expand Up @@ -333,11 +335,12 @@ def main(args): # noqa: C901
if use_wandb:
wandb.log(to_log, step=iteration)
if (iteration % args.validation_interval == 0) or (iteration == n_iterations - 1):
validation_info = validate(
validation_info, discovered_modes = validate_hypergrid(
env,
gflownet,
args.validation_samples,
visited_terminating_states,
discovered_modes,
)
if use_wandb:
wandb.log(validation_info, step=iteration)
Expand Down

0 comments on commit c57a708

Please sign in to comment.