Skip to content

Commit

Permalink
Merge pull request #302 from alexhernandezgarcia/fix_replay_bug
Browse files Browse the repository at this point in the history
Fix permutation sampling in buffer select
  • Loading branch information
alexhernandezgarcia committed Apr 4, 2024
2 parents 030f207 + 9274c37 commit ba8129a
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion gflownet/utils/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,12 @@ def select(
samples = list(samples.values())
if mode == "permutation":
assert rng is not None
samples = [samples[idx] for idx in rng.permutation(n)]
indices = rng.choice(
len(samples),
size=n,
replace=False,
)
samples = [samples[idx] for idx in indices]
elif mode == "weighted":
if "rewards" in data_dict:
score = "rewards"
Expand Down

0 comments on commit ba8129a

Please sign in to comment.