Skip to content

Commit

Permalink
add logging
Browse files Browse the repository at this point in the history
  • Loading branch information
Limmen committed Feb 4, 2024
1 parent d1e30cf commit d567850
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
csle_cyborg_env = CyborgScenarioTwoWrapper(config=simulation_env_config.simulation_env_input_config)
A = csle_cyborg_env.get_action_space()
initial_particles = csle_cyborg_env.initial_particles
rollout_policy = MetastoreFacade.get_ppo_policy(id=10)
rollout_policy = MetastoreFacade.get_ppo_policy(id=1)
# rollout_policy.save_path = ("/Users/kim/workspace/csle/examples/training/pomcp/cyborg_scenario_two_wrapper/"
# "ppo_test_1706439955.8221297/ppo_model2900_1706522984.6982665.zip")
# rollout_policy.save_path = ("/Users/kim/workspace/csle/examples/training/pomcp/cyborg_scenario_two_wrapper/"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def update_tree_with_new_samples(self, action_sequence: List[int], observation:
Logger.__call__().get_logger().info(f"Filling {particle_slots} particles")
particles = []
# fill particles by Monte-Carlo using reject sampling
count = 0
while len(particles) < particle_slots:
s = root.sample_state()
self.env.set_state(state=s)
Expand All @@ -356,6 +357,10 @@ def update_tree_with_new_samples(self, action_sequence: List[int], observation:
o = info[constants.COMMON.OBSERVATION]
if o == observation:
particles.append(s_prime)
else:
count += 1
if count >= 20000:
raise ValueError(f"Invalid observation: {o} given state: {root.sample_state()}")
new_root.particles += particles

# We now prune the old root from the tree
Expand Down

0 comments on commit d567850

Please sign in to comment.