diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 41d596a591..33ba291679 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -66,8 +66,11 @@ def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: task, nsrts, preds, timeout, seed) self._last_nsrt_plan = nsrt_plan self._last_atoms_seq = atoms_seq - policy = utils.nsrt_plan_to_greedy_policy(nsrt_plan, task.goal, - self._rng) + policy = utils.nsrt_plan_to_greedy_policy( + nsrt_plan, + task.goal, + self._rng, + abstract_function=lambda s: utils.abstract(s, preds)) logging.debug("Current Task Plan:") for act in nsrt_plan: logging.debug(act) @@ -110,7 +113,7 @@ def _run_sesame_plan( self._task_planning_heuristic, self._max_skeletons_optimized, max_horizon=CFG.horizon, - allow_noops=CFG.sesame_allow_noops, + allow_waits=CFG.sesame_allow_waits, use_visited_state_set=CFG.sesame_use_visited_state_set, **kwargs) except PlanningFailure as e: diff --git a/predicators/approaches/grammar_search_invention_approach.py b/predicators/approaches/grammar_search_invention_approach.py index 5197ce575f..390e4265c6 100644 --- a/predicators/approaches/grammar_search_invention_approach.py +++ b/predicators/approaches/grammar_search_invention_approach.py @@ -25,9 +25,9 @@ from predicators.predicate_search_score_functions import \ _PredicateSearchScoreFunction, create_score_function from predicators.settings import CFG -from predicators.structs import Dataset, GroundAtom, GroundAtomTrajectory, \ - Object, ParameterizedOption, Predicate, Segment, State, Task, Type, \ - VLMPredicate +from predicators.structs import Dataset, DerivedPredicate, GroundAtom, \ + GroundAtomTrajectory, Object, ParameterizedOption, Predicate, Segment, \ + State, Task, Type, VLMPredicate ################################################################################ # Programmatic classifiers # @@ -38,19 +38,23 @@ def _create_grammar(dataset: Dataset, given_predicates: Set[Predicate]) -> _PredicateGrammar: # We start with considering various ways to split either single or # two feature values across our dataset. - grammar: _PredicateGrammar = _SingleFeatureInequalitiesPredicateGrammar( - dataset) + grammar: Optional[_PredicateGrammar] = None + if CFG.grammar_search_grammar_use_single_feature: + grammar = _SingleFeatureInequalitiesPredicateGrammar(dataset) if CFG.grammar_search_grammar_use_diff_features: diff_grammar = _FeatureDiffInequalitiesPredicateGrammar(dataset) - grammar = _ChainPredicateGrammar([grammar, diff_grammar], - alternate=True) + grammar = _ChainPredicateGrammar( + ([grammar] if grammar is not None else []) + [diff_grammar], + alternate=True) if CFG.grammar_search_grammar_use_euclidean_dist: for (t1_f1, t1_f2, t2_f1, t2_f2) in CFG.grammar_search_euclidean_feature_names: euclidean_dist_grammar = _EuclideanDistancePredicateGrammar( dataset, t1_f1, t2_f1, t1_f2, t2_f2) - grammar = _ChainPredicateGrammar([grammar, euclidean_dist_grammar], - alternate=True) + grammar = _ChainPredicateGrammar( + ([grammar] if grammar is not None else []) + + [euclidean_dist_grammar], + alternate=True) # We next optionally add in the given predicates because we want to allow # negated and quantified versions of the given predicates, in # addition to negated and quantified versions of new predicates. @@ -58,14 +62,20 @@ def _create_grammar(dataset: Dataset, # given predicates, then the single feature inequality ones. if CFG.grammar_search_grammar_includes_givens: given_grammar = _GivenPredicateGrammar(given_predicates) - grammar = _ChainPredicateGrammar([given_grammar, grammar]) + if grammar is not None: + grammar = _ChainPredicateGrammar([given_grammar, grammar]) + else: + grammar = given_grammar # Now, the grammar will undergo a series of transformations. # For each predicate enumerated by the grammar, we also # enumerate the negation of that predicate. - grammar = _NegationPredicateGrammarWrapper(grammar) + if CFG.grammar_search_grammar_includes_negation: + assert grammar is not None + grammar = _NegationPredicateGrammarWrapper(grammar) # For each predicate enumerated, we also optionally enumerate foralls # for that predicate, along with appropriate negations. if CFG.grammar_search_grammar_includes_foralls: + assert grammar is not None grammar = _ForallPredicateGrammarWrapper(grammar) # Prune proposed predicates by checking if they are equivalent to # any already-generated predicates with respect to the dataset. @@ -77,17 +87,22 @@ def _create_grammar(dataset: Dataset, # predicates. if not CFG.grammar_search_use_handcoded_debug_grammar and \ CFG.grammar_search_prune_redundant_preds: + assert grammar is not None grammar = _PrunedGrammar(dataset, grammar) # We don't actually need to enumerate the given predicates # because we already have them in the initial predicate set, # so we just filter them out from actually being enumerated. # But remember that we do want to enumerate their negations # and foralls, which is why they're included originally. - grammar = _SkipGrammar(grammar, given_predicates) + if CFG.grammar_search_grammar_use_skip_grammar: + assert grammar is not None + grammar = _SkipGrammar(grammar, given_predicates) # If we're using the DebugGrammar, filter out all other predicates. if CFG.grammar_search_use_handcoded_debug_grammar: + assert grammar is not None grammar = _DebugGrammar(grammar) # We're done! Return the final grammar. + assert grammar is not None return grammar @@ -867,6 +882,9 @@ class _NegationPredicateGrammarWrapper(_PredicateGrammar): def enumerate(self) -> Iterator[Tuple[Predicate, float]]: for (predicate, cost) in self.base_grammar.enumerate(): yield (predicate, cost) + if isinstance(predicate, DerivedPredicate): + # Don't negate derived predicates. + continue classifier = _NegationClassifier(predicate) negated_predicate = Predicate(str(classifier), predicate.types, classifier) @@ -1104,7 +1122,7 @@ def rename(p: str) -> str: # pragma: no cover score_function = create_score_function( CFG.grammar_search_score_function, self._initial_predicates, atom_dataset, candidates, - self._train_tasks) + self._train_tasks, None) self._learned_predicates = \ self._select_predicates_by_score_hillclimbing( candidates, score_function, self._initial_predicates, diff --git a/predicators/approaches/maple_q_process_approach.py b/predicators/approaches/maple_q_process_approach.py new file mode 100644 index 0000000000..52bd059e44 --- /dev/null +++ b/predicators/approaches/maple_q_process_approach.py @@ -0,0 +1,202 @@ +"""A parameterized action reinforcement learning approach inspired by MAPLE, +(https://ut-austin-rpl.github.io/maple/) but where only a Q-function is +learned. + +Base samplers and applicable actions are used to perform the argmax. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, List, Optional, Set + +import dill as pkl +from gym.spaces import Box + +from predicators import utils +from predicators.approaches.pp_online_process_learning_approach import \ + OnlineProcessLearningAndPlanningApproach +from predicators.explorers import BaseExplorer, create_explorer +from predicators.ml_models import MapleQFunction +from predicators.nsrt_learning.segmentation import segment_trajectory +from predicators.settings import CFG +from predicators.structs import Action, GroundAtom, InteractionRequest, \ + LowLevelTrajectory, ParameterizedOption, Predicate, Segment, State, Task, \ + Type, _GroundCausalProcess, _Option + + +class MapleQProcessApproach(OnlineProcessLearningAndPlanningApproach): + """A parameterized action RL approach inspired by MAPLE.""" + + def __init__(self, initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], types: Set[Type], + action_space: Box, train_tasks: List[Task]) -> None: + super().__init__(initial_predicates, initial_options, types, + action_space, train_tasks) + + # The current implementation assumes that NSRTs are not changing. + assert CFG.strips_learner == "oracle" + # The base sampler should also be unchanging and from the oracle. + assert CFG.sampler_learner == "oracle" + + # Log all transition data. + self._interaction_goals: List[Set[GroundAtom]] = [] + self._last_seen_segment_traj_idx = -1 + # For Q-learning data updates (segments by option changes). + self._segmented_trajs: List[List[Segment]] = [] + self._offline_segmented_trajs: List[List[Segment]] = [] + + # Store the Q function. Note that this implicitly + # contains a replay buffer. + self._q_function = MapleQFunction( + seed=CFG.seed, + hid_sizes=CFG.mlp_regressor_hid_sizes, + max_train_iters=CFG.mlp_regressor_max_itr, + clip_gradients=CFG.mlp_regressor_clip_gradients, + clip_value=CFG.mlp_regressor_gradient_clip_value, + learning_rate=CFG.learning_rate, + weight_decay=CFG.weight_decay, + use_torch_gpu=CFG.use_torch_gpu, + train_print_every=CFG.pytorch_train_print_every, + n_iter_no_change=CFG.active_sampler_learning_n_iter_no_change, + num_lookahead_samples=CFG. + active_sampler_learning_num_lookahead_samples, + predicates=self._get_current_predicates()) + + @classmethod + def get_name(cls) -> str: + return "maple_q_with_process" + + # pylint: disable=arguments-differ + def _solve(self, + task: Task, + timeout: int, + train_or_test: str = "") -> Callable[[State], Action]: + + def _option_policy(state: State) -> _Option: + option = self._q_function.get_option( + state, + task.goal, + num_samples_per_ground_nsrt=CFG. + active_sampler_learning_num_samples, + train_or_test=train_or_test) + logging.debug(f"taking option: {option}") + return option + + return utils.option_policy_to_policy( + _option_policy, max_option_steps=CFG.max_num_steps_option_rollout) + + def _create_explorer(self) -> BaseExplorer: + """Create a new explorer at the beginning of each interaction cycle.""" + # Geometrically increase the length of exploration. + b = CFG.active_sampler_learning_explore_length_base + max_steps = b**(1 + self._online_learning_cycle) + preds = self._get_current_predicates() + assert CFG.explorer == "maple_q" + explorer = create_explorer( + CFG.explorer, + preds, + self._initial_options, + self._types, + self._action_space, + self._train_tasks, + # Endogenous processes are action-like + self._get_current_endogenous_processes(), # type: ignore[arg-type] + self._option_model, + max_steps_before_termination=max_steps, + maple_q_function=self._q_function) + return explorer + + def load(self, online_learning_cycle: Optional[int]) -> None: + super().load(online_learning_cycle) + save_path = utils.get_approach_load_path_str() + with open(f"{save_path}_{online_learning_cycle}.DATA", "rb") as f: + save_dict = pkl.load(f) + self._q_function = save_dict["q_function"] + self._last_seen_segment_traj_idx = save_dict[ + "last_seen_segment_traj_idx"] + self._interaction_goals = save_dict["interaction_goals"] + self._online_learning_cycle = CFG.skip_until_cycle + 1 + + def _learn_processes(self, + trajectories: List[LowLevelTrajectory], + online_learning_cycle: Optional[int], + annotations: Optional[List[Any]] = None) -> None: + # # Learn endogenous/exogenous processes via superclass. + # super()._learn_processes(trajectories, online_learning_cycle, + # annotations) + # Ground current endogenous processes for Q-learning. + all_ground_processes: Set[_GroundCausalProcess] = set() + all_objects = {o for t in self._train_tasks for o in t.init} + for process in self._get_current_endogenous_processes(): + all_ground_processes.update( + utils.all_ground_nsrts(process, + all_objects)) # type: ignore[arg-type] + goals = [t.goal for t in self._train_tasks] + self._q_function.set_grounding( + all_objects, goals, all_ground_processes) # type: ignore[arg-type] + # Refresh segmentation by option changes. + prev_segmenter = CFG.segmenter + try: + CFG.segmenter = "option_changes" + new_segments = [ + segment_trajectory(traj, self._get_current_predicates()) + for traj in trajectories + ] + finally: + CFG.segmenter = prev_segmenter + # if online_learning_cycle is None: + # # Offline phase: only offline trajectories are included. + # self._offline_segmented_trajs = new_segments + # self._segmented_trajs = list(self._offline_segmented_trajs) + # else: + # # Online phase: input trajectories are only the online ones so far. + # self._segmented_trajs = list(self._offline_segmented_trajs) + \ + # list(new_segments) + if online_learning_cycle is not None: + self._segmented_trajs = list(new_segments) + # Update the data using the updated self._segmented_trajs. + self._update_maple_data() + # Re-learn Q function. + self._q_function.train_q_function() + # Save the things we need other than the NSRTs, which were already + # saved in the above call to self._learn_processes() + save_path = utils.get_approach_save_path_str() + with open(f"{save_path}_{online_learning_cycle}.DATA", "wb") as f: + pkl.dump( + { + "q_function": self._q_function, + "last_seen_segment_traj_idx": + self._last_seen_segment_traj_idx, + "interaction_goals": self._interaction_goals, + }, f) + + def _update_maple_data(self) -> None: + start_idx = self._last_seen_segment_traj_idx + 1 + new_trajs = self._segmented_trajs[start_idx:] + + goal_offset = 0 + assert len(self._segmented_trajs) == goal_offset + \ + len(self._interaction_goals) + new_traj_goals = self._interaction_goals[goal_offset + start_idx:] + + for traj_i, segmented_traj in enumerate(new_trajs): + self._last_seen_segment_traj_idx += 1 + for seg_i, segment in enumerate(segmented_traj): + s = segment.states[0] + goal = new_traj_goals[traj_i] + o = segment.get_option() + ns = segment.states[-1] + reward = 1.0 if goal.issubset(segment.final_atoms) else 0.0 + terminal = reward > 0 or seg_i == len(segmented_traj) - 1 + self._q_function.add_datum_to_replay_buffer( + (s, goal, o, ns, reward, terminal)) + + def get_interaction_requests(self) -> List[InteractionRequest]: + # Save the goals for each interaction request so we can later associate + # states, actions, and goals. + requests = super().get_interaction_requests() + for request in requests: + goal = self._train_tasks[request.train_task_idx].goal + self._interaction_goals.append(goal) + return requests diff --git a/predicators/approaches/pp_online_predicate_invention_approach.py b/predicators/approaches/pp_online_predicate_invention_approach.py new file mode 100644 index 0000000000..143b62f665 --- /dev/null +++ b/predicators/approaches/pp_online_predicate_invention_approach.py @@ -0,0 +1,1239 @@ +import logging +import os +import re +import time +import traceback +from collections import defaultdict +from pprint import pformat +from typing import Any, Dict, FrozenSet, Iterator, List, Optional, Sequence, \ + Set, Tuple + +import dill as pkl +import PIL +import wandb +from gym.spaces import Box +from PIL import ImageDraw, ImageFont + +from predicators import utils +from predicators.approaches.grammar_search_invention_approach import \ + _create_grammar, _GivenPredicateGrammar +from predicators.approaches.pp_online_process_learning_approach import \ + OnlineProcessLearningAndPlanningApproach +from predicators.approaches.pp_predicate_invention_approach import \ + PredicateInventionProcessPlanningApproach +from predicators.envs import create_new_env +from predicators.nsrt_learning.process_learning_main import \ + filter_explained_segment +from predicators.nsrt_learning.segmentation import segment_trajectory +from predicators.option_model import _OptionModelBase +from predicators.planning_with_processes import process_task_plan_grounding +from predicators.predicate_search_score_functions import \ + _ExpectedNodesScoreFunction +from predicators.settings import CFG +from predicators.structs import CausalProcess, Dataset, DerivedPredicate, \ + EndogenousProcess, ExogenousProcess, GroundAtom, GroundAtomTrajectory, \ + Image, InteractionResult, LowLevelTrajectory, ParameterizedOption, \ + Predicate, Segment, State, Task, Type, _GroundExogenousProcess + + +class OnlinePredicateInventionProcessPlanningApproach( + PredicateInventionProcessPlanningApproach, + OnlineProcessLearningAndPlanningApproach): + """A bilevel planning approach that invent predicates.""" + + def __init__(self, + initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], + types: Set[Type], + action_space: Box, + train_tasks: List[Task], + task_planning_heuristic: str = "default", + max_skeletons_optimized: int = -1, + bilevel_plan_without_sim: Optional[bool] = None, + option_model: Optional[_OptionModelBase] = None): + # just used for oracle predicate proposal or learned predicate + self._oracle_predicates = create_new_env( + CFG.env, use_gui=False).target_predicates + self._candidate_predicates: Set[Predicate] = set() + self._llm = utils.create_llm_by_name(CFG.llm_model_name) + self._vlm = utils.create_vlm_by_name( + CFG.llm_model_name) # type: ignore[assignment] + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + task_planning_heuristic, + max_skeletons_optimized, + bilevel_plan_without_sim, + option_model=option_model) + + @classmethod + def get_name(cls) -> str: + return "online_predicate_invention_and_process_planning" + + def learn_from_offline_dataset(self, dataset: Dataset) -> None: + # Just store the dataset, don't learn from it yet. + self._offline_dataset = dataset + # proposed_predicates = self._get_predicate_proposals( + # "transition_modelling", + # self._offline_dataset.trajectories) + self.save() + + def learn_from_interaction_results( + self, results: Sequence[InteractionResult]) -> None: + # --- Process the interaction results --- + assert self._requests_train_task_idxs is not None, \ + "Missing request->task index mapping." + for i, result in enumerate(results): + task_idx = self._requests_train_task_idxs[i] + traj = LowLevelTrajectory(result.states, + result.actions, + _train_task_idx=task_idx) + self._online_dataset.append(traj) + + all_trajs = self._offline_dataset.trajectories + \ + self._online_dataset.trajectories + + # TODO: change to only propose when stop improving? + # Test to only generate proposals at cycle 0. + if self._online_learning_cycle == 0: + proposed_predicates = self._get_predicate_proposals( + "subgoals", all_trajs) + else: + proposed_predicates = set() + logging.info(f"Done: created {len(proposed_predicates)} predicates") + + # --- Select the predicates to keep --- + self._select_predicates_and_learn_processes( + ite=self._online_learning_cycle, + all_trajs=all_trajs, + proposed_predicates=proposed_predicates, + train_tasks=self._train_tasks) + logging.debug(f"Learned predicates: " + f"{self._learned_predicates-self._initial_predicates}") + + if CFG.learn_process_parameters: + self._learn_process_parameters(all_trajs) + self.save(self._online_learning_cycle) + + self._online_learning_cycle += 1 + + def save(self, online_learning_cycle: Optional[int] = None) -> None: + # Saving the learned processes, dataset, candidate predicates + save_path = utils.get_approach_save_path_str() + with open(f"{save_path}_{online_learning_cycle}.PROCes", "wb") as f: + save_dict = { + "processes": self._processes, + "learned_predicates": self._learned_predicates, + "candidate_predicates": self._candidate_predicates, + "offline_dataset": self._offline_dataset, + "online_dataset": self._online_dataset, + "online_learning_cycle": self._online_learning_cycle + } + pkl.dump(save_dict, f) + logging.info(f"Saved approach to {save_path}_" + f"{online_learning_cycle}.PROCes") + + def load(self, online_learning_cycle: Optional[int] = None) -> None: + save_path = utils.get_approach_load_path_str() + with open(f"{save_path}_{online_learning_cycle}.PROCes", "rb") as f: + save_dict = pkl.load(f) + # check save_dict has "processes", "candidate_predicate" values + assert "processes" in save_dict, "Processes not found in save_dict" + assert "candidate_predicates" in save_dict, \ + "Candidate predicates not found in save_dict" + assert "offline_dataset" in save_dict, \ + "Offline dataset not found in save_dict" + assert "online_dataset" in save_dict, \ + "Online dataset not found in save_dict" + self._processes = save_dict["processes"] + self._learned_predicates = save_dict["learned_predicates"] + self._candidate_predicates = save_dict["candidate_predicates"] + self._offline_dataset = save_dict["offline_dataset"] + self._online_dataset = save_dict["online_dataset"] + self._online_learning_cycle = save_dict["online_learning_cycle"] + 1 + logging.info(f"\n\nLoaded Processes:") + for process in sorted(self._processes): + logging.info(process) + logging.info( + f"Loaded {len(self._learned_predicates)} learned predicates") + logging.info(f"{sorted(self._learned_predicates)}") + logging.info( + f"Loaded {len(self._processes)} processes, " + f"{len(self._candidate_predicates)} candidate predicates, " + f"{len(self._offline_dataset.trajectories)} offline trajectories, " + f"{len(self._online_dataset.trajectories)} online trajectories\n") + + for proc in self._processes: + if isinstance(proc, EndogenousProcess): + proc.option.params_space.seed(CFG.seed) + pass + + def _get_predicate_proposals( + self, proposal_method: str, + trajectories: List[LowLevelTrajectory]) -> Set[Predicate]: + if CFG.vlm_predicator_oracle_base_predicates: + base_candidates = self._oracle_predicates - self._initial_predicates + else: + base_candidates: Set[Predicate] = set() # type: ignore[no-redef] + + # noisy_but_complete_proposal = True + # if noisy_but_complete_proposal: + # base_candidates |= set(p for p in self._oracle_predicates + # if p.name in [ + # # "NoWaterSpilled", + # "NoJugAtFaucetOrAtFaucetAndFilled" + # ]) + + for i in range(CFG.vlm_predicator_num_proposal_batches): + base_candidates |= self._get_predicate_proposals_from_fm( + proposal_method, trajectories, i, + invent_derived_predicates=\ + CFG.predicate_invent_invent_derived_predicates) + # TODO: filter semantically equivalent predicate by evaluation + return base_candidates + + def _get_predicate_proposals_from_fm( + self, proposal_method: str, trajectories: List[LowLevelTrajectory], + proposal_batch_id: int, + invent_derived_predicates: bool) -> Set[Predicate]: + """Get predicate proposals from the FM.""" + ###### Invent predicates in NL based on the dataset ###### + b_id = proposal_batch_id + seed = CFG.seed * 100 + self._online_learning_cycle * 10 + b_id + + assert proposal_method in [ + "transition_modeling", "discrimination", "unconditional", + "subgoals" + ] + + # transition modelling (2 fm calls): spec -> implementation + # discrimination (3 fm calls): nl -> spec -> implementation + # unconditional: (3 calls): spec -> primitive impl -> concept impl + if proposal_method in ["transition_modeling", "subgoals"]: + # 1. Get template + successful_trajectory = traj_is_successful(trajectories[0], + self._train_tasks) + if successful_trajectory: + if invent_derived_predicates: + prompt_template_f = f"prompts/invent_{proposal_method}"\ + "_solved_derived.outline" + else: + prompt_template_f =\ + f"prompts/invent_{proposal_method}_solved.outline" + else: + prompt_template_f = f"prompts/invent_{proposal_method}_failed"\ + f".outline" + with open(prompt_template_f, "r") as f: + prompt_template = f.read() + + # 2. Fill and save the template + pred_str = _get_predicates_str(self._get_current_predicates()) + types = set(o.type for o in set(trajectories[0].states[0])) + logging.info( + f"Inventing predicates from only the offline dataset.") + experience_str, state_str = _get_transition_str( + self._offline_dataset.trajectories, # +\ + # self._online_dataset.trajectories, + self._train_tasks, + self._get_current_predicates(), + ite=self._online_learning_cycle, + use_abstract_state_str=invent_derived_predicates, + ) + prompt = prompt_template.format( + PREDICATES_IN_ENV=pred_str, + TYPES_IN_ENV=_get_types_str(types), + EXPERIENCE_IN_ENV=experience_str, + GOAL_PREDICATE=self._train_tasks[0].goal) + with open( + f"{CFG.log_file}/ite{self._online_learning_cycle}_b{b_id}" + f"_s1.prompt", "w") as f: + f.write(prompt) + + # 3. Get spec proposals + temperature = 0.2 + if CFG.rgb_observation: + images = load_images_from_directory( + CFG.log_file + + f"ite{self._online_learning_cycle}_b{b_id}_obs/") + spec_response = self._vlm.sample_completions( # type: ignore[union-attr] + prompt, + images, + temperature=temperature, + num_completions=1, + seed=seed)[0] + else: + spec_response = self._llm.sample_completions( + prompt, + imgs=None, + temperature=temperature, + num_completions=1, + seed=seed)[0] + with open( + f"{CFG.log_file}/ite{self._online_learning_cycle}_b{b_id}" + f"_s1.response", "w") as f: + f.write(spec_response) + elif proposal_method == "discrimination": + # Method 1: Find each state, if it satisfies the condition of an + # exogenous process, check later that its effect did take place, save + # it if not. + # Then for each exogenous process, compare the above negative state + # with positive states where the effect took place (e.g. in the demo). + # Maybe this will mirror the planner. + # Remember to reset at the end + + # Step 1: Find the false positive examples + exogenous_processes = list(self._get_current_exogenous_processes()) + false_positive_process_state = get_false_positive_states( + self._online_dataset.trajectories, + self._get_current_predicates(), exogenous_processes) + + # Step 2: Find the true positive examples + # For each expected effect that did not take place, find in the demo + # the initial state where it did take place, and save it as a positive + # example. + true_positive_process_state = get_true_positive_process_states( + self._get_current_predicates(), exogenous_processes, + list(false_positive_process_state.keys()), + self._offline_dataset.trajectories) + + # Step 3: Prompt VLM to invent predicates + # TODO: prepare the prompt + # TODO: implement the prompt and parse logic + else: + raise NotImplementedError + + ###### Implement the predicates in python ###### + # Create the implementation prompt + if CFG.predicate_invent_neural_symbolic_predicates: + raise NotImplementedError + else: + template_f = "prompts/invent_sym_pred_implementation.outline" + state_api_f = "prompts/api_oo_state.py" + pred_api_f = "prompts/api_sym_predicate.py" + + with open(f"./{template_f}", "r") as f: + template = f.read() + with open(f"./{state_api_f}", "r") as f: + state_cls_str = f.read() + with open(f"./{pred_api_f}", "r") as f: + pred_cls_str = f.read() + + prompt = template.format( + STRUCT_DEFINITION=add_python_quote(state_cls_str + "\n\n" + + pred_cls_str), + TYPES_IN_ENV=add_python_quote( + _get_types_str(types, use_python_def_str=True)), + PREDICATES_IN_ENV=pred_str, + LISTED_STATES=state_str, + PREDICATE_SPECS=spec_response, + ) + with open( + f"{CFG.log_file}/ite{self._online_learning_cycle}_b{b_id}" + f"_s2_impl.prompt", "w") as f: + f.write(prompt) + + impl_response = self._llm.sample_completions(prompt, + imgs=None, + temperature=0, + num_completions=1, + seed=seed)[0] + with open( + f"{CFG.log_file}/ite{self._online_learning_cycle}_b{b_id}" + f"_s2_impl.response", "w") as f: + f.write(impl_response) + + prim_predicates, deri_predicates =\ + _parse_predicates_predictions(impl_response, + self._initial_predicates, + self._candidate_predicates, + types, + self._train_tasks[0].init + ) + base_candidates = set(prim_predicates) | set(deri_predicates) + return base_candidates + + def _select_predicates_and_learn_processes( + self, + ite: int, + all_trajs: List[LowLevelTrajectory], + proposed_predicates: Set[Predicate], + train_tasks: List[Task] = [], + enumerate_processes: bool = False, + ) -> None: + if CFG.vlm_predicator_oracle_learned_predicates: + if CFG.boil_goal_simple_human_happy: + selected_preds = { + p + for p in proposed_predicates if p.name in {"JugFilled"} + } + else: + selected_preds = proposed_predicates + self._learned_predicates |= selected_preds + # --- Learn processes & parameters --- + self._learn_processes( + all_trajs, online_learning_cycle=self._online_learning_cycle) + else: + self._candidate_predicates |= proposed_predicates + + all_candidates: Dict[Predicate, int] = { + p: p.arity + for p in self._initial_predicates + } + if CFG.vlm_predicator_use_grammar: + grammar = _create_grammar(dataset=Dataset(all_trajs), + given_predicates=\ + self._candidate_predicates) + else: + grammar = _GivenPredicateGrammar(self._candidate_predicates) + all_candidates.update( + grammar.generate(max_num=CFG.grammar_search_max_predicates) + ) # type: ignore[arg-type] + + atom_dataset: List[GroundAtomTrajectory] =\ + utils.create_ground_atom_dataset(all_trajs, + set(all_candidates)) + + new_preds = set(all_candidates) - self._initial_predicates + logging.info(f"Candidate predicates:\n{pformat(new_preds)}") + if CFG.use_wandb: + wandb.log({"candidate_predicates": pformat(new_preds)}) + + self._learned_predicates = set(all_candidates) # temp + # TODO: we need to save the top ranking conditions here so it can be + # used later in predicate selection. + self._learn_processes( + all_trajs, online_learning_cycle=self._online_learning_cycle) + + if CFG.learn_process_parameters: + self._learn_process_parameters(all_trajs) + # Whether to do predicate selection by scoring different predicate + # set or by scoring different process set. + start_time = time.perf_counter() + if enumerate_processes: + # Learn processes based on all the candidates. + + # Search by scoring different set of processes. + # When commented out: keeping all candidates. + selected_processes =\ + self._select_processes_by_score_optimization(train_tasks, + self._processes, atom_dataset) + self._processes = selected_processes + # TODO: remove duplicate predicates + self._learned_predicates = self._get_predicates_in_processes( + self._processes, set(all_candidates)) + else: + # select predicates + logging.info("[Start] Predicate search.") + self._learned_predicates =\ + self._select_predicates_by_score_optimization( + train_tasks, all_candidates, self._processes, # type: ignore[arg-type] + all_trajs, atom_dataset) + logging.info("[Finished] Predicate search.") + logging.info("Total search time " + f"{time.perf_counter() - start_time:.2f}s") + return None + + def _get_predicates_in_processes( + self, processes: Set[CausalProcess], + all_candidates: Set[Predicate]) -> Set[Predicate]: + """Get the predicates in the processes.""" + all_process_predicates = set() + for process in processes: + all_process_predicates |= { + atom.predicate + for atom in process.condition_at_start + } + all_process_predicates |= { + atom.predicate + for atom in process.add_effects + } + all_process_predicates |= { + atom.predicate + for atom in process.delete_effects + } + selected_predicates = set() + for pred in all_candidates: + if pred in all_process_predicates: + selected_predicates.add(pred) + return selected_predicates + + def _select_processes_by_score_optimization( + self, + train_tasks: List[Task], + all_processes: Set[CausalProcess], + atom_dataset: List[GroundAtomTrajectory], + ) -> Set[CausalProcess]: + """Perform a greedy search over process sets.""" + endogenous_processes = { + p + for p in all_processes if isinstance(p, EndogenousProcess) + } + exogenous_processes = { + p + for p in all_processes if isinstance(p, ExogenousProcess) + } + + # Precompute stuff for scoring. + segmented_trajs = [ + segment_trajectory(ll_traj, self._get_current_predicates(), + atom_seq) + for (ll_traj, atom_seq) in atom_dataset + ] + score_func = _ExpectedNodesScoreFunction( + _initial_predicates=set(), + _atom_dataset=[], + _candidates=dict(), + _train_tasks=train_tasks, + _current_processes=set(), + _use_processes=True, + metric_name="num_nodes_expanded") + + # Define the score function for a set of processes. + def _score_processes( + candidate_exogenous_processes: FrozenSet[ExogenousProcess] + ) -> float: + process_score = score_func.evaluate_with_operators( + candidate_predicates=self._get_current_predicates( + ), # type: ignore[arg-type] + low_level_trajs=self._offline_dataset.trajectories + + self._online_dataset.trajectories, + segmented_trajs=segmented_trajs, + strips_ops= + candidate_exogenous_processes # type: ignore[arg-type] + | endogenous_processes, + option_specs=[]) + process_penalty = _ExpectedNodesScoreFunction._get_operator_penalty( + candidate_exogenous_processes) # type: ignore[arg-type] + return process_score + process_penalty + + # Set up the search. + init_set: FrozenSet[ExogenousProcess] = frozenset() + + def _check_goal(s: FrozenSet[ExogenousProcess]) -> bool: + del s # unused + return False + + def _get_successors( + s: FrozenSet[ExogenousProcess] + ) -> Iterator[Tuple[None, FrozenSet[ExogenousProcess], float]]: + for process in sorted(exogenous_processes - s): + yield (None, frozenset(s | {process}), 1.0) + + # Run the search. + if CFG.grammar_search_search_algorithm == "hill_climbing": + path, _, heuristics = utils.run_hill_climbing( + init_set, + _check_goal, + _get_successors, + _score_processes, + enforced_depth=CFG.grammar_search_hill_climbing_depth, + parallelize=CFG.grammar_search_parallelize_hill_climbing) + logging.info("\nHill climbing summary:") + for i in range(1, len(path)): # pragma: no cover + new_additions = path[i] - path[i - 1] + assert len(new_additions) == 1 + new_addition = next(iter(new_additions)) + h = heuristics[i] + prev_h = heuristics[i - 1] + logging.info(f"\tOn step {i}, added {new_addition}, with " + f"heuristic {h:.3f} (an improvement of " + f"{prev_h - h:.3f} over the previous step)") + elif CFG.grammar_search_search_algorithm == "gbfs": + path, _ = utils.run_gbfs( + init_set, + _check_goal, + _get_successors, + _score_processes, + max_evals=CFG.grammar_search_gbfs_num_evals, + ) + else: + raise NotImplementedError( + "Unrecognized grammar_search_search_algorithm: " + f"{CFG.grammar_search_search_algorithm}.") + + selected_exogenous_processes = path[-1] + logging.debug(f"Selected processes: " + f"{pformat(selected_exogenous_processes)}") + + return endogenous_processes | selected_exogenous_processes + + def _select_predicates_by_score_optimization( + self, + train_tasks: List[Task], + candidates: Dict[Predicate, float], + all_processes: Set[CausalProcess], + all_trajs: List[LowLevelTrajectory], + atom_dataset: List[GroundAtomTrajectory], + ) -> Set[Predicate]: + """Perform a greedy search over predicate sets.""" + endogenous_processes = { + p + for p in all_processes if isinstance(p, EndogenousProcess) + } + + # Precompute stuff for scoring. + segmented_trajs = [ + segment_trajectory(ll_traj, self._get_current_predicates(), + atom_seq) + for (ll_traj, atom_seq) in atom_dataset + ] + score_func = _ExpectedNodesScoreFunction( + _initial_predicates=set(), + _atom_dataset=[], + _candidates=dict(), + _train_tasks=train_tasks, + _current_processes=set(), + _use_processes=True, + metric_name="num_nodes_expanded") + + def _filter_process( + process: CausalProcess, + candidate_predicates: FrozenSet[Predicate]) -> CausalProcess: + """Filter a process to only keep atoms with candidate + predicates.""" + proc_copy = process.copy() + proc_copy.condition_at_start = { + atom + for atom in proc_copy.condition_at_start + if atom.predicate in candidate_predicates + } + proc_copy.condition_overall = proc_copy.condition_at_start.copy() + proc_copy.add_effects = { + atom + for atom in proc_copy.add_effects + if atom.predicate in candidate_predicates + } + proc_copy.delete_effects = { + atom + for atom in proc_copy.delete_effects + if atom.predicate in candidate_predicates + } + # Make sure the parameter only include variables that appear in the + # conditions and effects + remaining_variables = set() + for atom in proc_copy.condition_at_start | proc_copy.add_effects |\ + proc_copy.delete_effects: + remaining_variables |= set(atom.variables) + proc_copy.parameters = [ + v for v in proc_copy.parameters if v in remaining_variables + ] + return proc_copy + + def _get_best_compatible_exo_processes( + candidate_predicates: FrozenSet[Predicate] + ) -> Set[ExogenousProcess]: + """Get the best compatible exogenous processes. + + # Get the processes compatible with the candidate + predicates. # Look at all the scored conditions, find the + top one that's a # subset of the candidate predicates; if + none, remove the none # candidates from the top conditions. + # Remove parts that are outside of candidates predicates + """ + new_predicates = candidate_predicates - self._initial_predicates + remaining_exogenous_processes = set() + for _, results in self._proc_name_to_results.items(): + best_compatible_process = results[0][3] + effect_pred = { + atom.predicate + for atom in best_compatible_process.add_effects + | best_compatible_process.delete_effects + } + if any(effect_p in candidate_predicates + for effect_p in effect_pred): + for _, (_, condition, _, proc) in enumerate(results): + condition_pred = {atom.predicate for atom in condition} + if new_predicates.issubset(condition_pred): + best_compatible_process = proc + break + if condition_pred.issubset(candidate_predicates): + # If the condition is a subset of the candidate + # predicates, then we can use this process. + best_compatible_process = proc + # logging.debug(f"Found compatible condition for " + # f"{proc.name}") + break + # else: + # logging.debug(f"No compatible condition found for " + # f"{best_compatible_process.name}, " + # f"filtering out non-candidate atoms.") + # Haven't found a condition that is a subset of the + # candidate predicates, so we filter out the non-candidate + # condition + proc_copy = _filter_process(best_compatible_process, + candidate_predicates) + if proc_copy.add_effects | proc_copy.delete_effects: + remaining_exogenous_processes.add(proc_copy) + logging.debug(f"Remaining exogenous processes:\n" + f"{pformat(remaining_exogenous_processes)}") + return remaining_exogenous_processes # type: ignore[return-value] + + def _score_predicates( + candidate_predicates: FrozenSet[Predicate]) -> float: + new_preds = candidate_predicates - self._initial_predicates + logging.debug(f"Evaluating predicates: {sorted(set(new_preds))}") + remaining_exogenous_processes = _get_best_compatible_exo_processes( + candidate_predicates) + # Score processes with the score function. + process_score = score_func.evaluate_with_operators( + candidate_predicates=candidate_predicates, + low_level_trajs=all_trajs, + segmented_trajs=segmented_trajs, + strips_ops= + remaining_exogenous_processes # type: ignore[arg-type] + | endogenous_processes, + option_specs=[]) + process_penalty = _ExpectedNodesScoreFunction._get_operator_penalty( + remaining_exogenous_processes) # type: ignore[arg-type] + final_score = process_score + process_penalty + logging.debug(f"Candidate scores: {final_score:.4f}") + return final_score + + def _check_goal(s: FrozenSet[Predicate]) -> bool: + del s # unused + return False + + # Successively consider larger predicate sets. + def _get_successors( + s: FrozenSet[Predicate] + ) -> Iterator[Tuple[None, FrozenSet[Predicate], float]]: + for predicate in sorted(set(candidates) - s): # determinism + # Actions not needed. Frozensets for hashing. The cost of + # 1.0 is irrelevant because we're doing GBFS / hill + # climbing and not A* (because we don't care about the + # path). + yield (None, frozenset(s | {predicate}), 1.0) + + # Start the search with no candidates. + # Don't need to include the initial predicates here because its + init: FrozenSet[Predicate] = frozenset(self._initial_predicates) + + # Greedy local hill climbing search. + if CFG.grammar_search_search_algorithm == "hill_climbing": + path, _, heuristics = utils.run_hill_climbing( + init, + _check_goal, + _get_successors, + _score_predicates, + enforced_depth=CFG.grammar_search_hill_climbing_depth, + parallelize=CFG.grammar_search_parallelize_hill_climbing, + exhaustive_lookahead=True) + logging.info("\nHill climbing summary:") + for i in range(1, len(path)): + new_additions = path[i] - path[i - 1] + assert len(new_additions) == 1 + new_addition = next(iter(new_additions)) + h = heuristics[i] + prev_h = heuristics[i - 1] + logging.info(f"\tOn step {i}, added {new_addition}, with " + f"heuristic {h:.3f} (an improvement of " + f"{prev_h - h:.3f} over the previous step)") + elif CFG.grammar_search_search_algorithm == "gbfs": + path, _ = utils.run_gbfs( + init, + _check_goal, + _get_successors, + _score_predicates, + max_evals=CFG.grammar_search_gbfs_num_evals, + ) + else: + raise NotImplementedError( + "Unrecognized grammar_search_search_algorithm: " + f"{CFG.grammar_search_search_algorithm}.") + kept_predicates = path[-1] + # The total number of predicate sets evaluated is just the + # ((number of candidates selected) + 1) * total number of candidates. + # However, since 'path' always has length one more than the + # number of selected candidates (since it evaluates the empty + # predicate set first), we can just compute it as below. + self._metrics["total_num_predicate_evaluations"] = len(path) * len( + candidates) + + # # Filter out predicates that don't appear in some operator + # # preconditions. + # logging.info("\nFiltering out predicates that don't appear in " + # "preconditions...") + # preds = kept_predicates | initial_predicates + # pruned_atom_data = utils.prune_ground_atom_dataset(atom_dataset, preds) + # segmented_trajs = [ + # segment_trajectory(ll_traj, set(preds), atom_seq=atom_seq) + # for (ll_traj, atom_seq) in pruned_atom_data + # ] + # low_level_trajs = [ll_traj for ll_traj, _ in pruned_atom_data] + # preds_in_preconds = set() + # for pnad in learn_strips_operators(low_level_trajs, + # train_tasks, + # set(kept_predicates + # | initial_predicates), + # segmented_trajs, + # verify_harmlessness=False, + # annotations=None, + # verbose=False): + # for atom in pnad.op.preconditions: + # preds_in_preconds.add(atom.predicate) + # kept_predicates &= preds_in_preconds + + newly_selected = kept_predicates - self._initial_predicates + new_candidates = set(candidates) - self._initial_predicates + logging.info(f"\n[ite {self._online_learning_cycle}] Selected " + f"{len(newly_selected)} predicates" + f" out of {len(new_candidates)} candidates:") + for pred in newly_selected: + logging.info(f"\t{pred}") + _score_predicates(kept_predicates) # log useful numbers + self._processes = endogenous_processes |\ + _get_best_compatible_exo_processes(kept_predicates) + + # Log processes and predicates to wandb if enabled + if CFG.use_wandb: + # Log each process as a separate entry + for i, process in enumerate(self._processes): + wandb.log({ + f"process_{i}_cycle_{self._online_learning_cycle}": + str(process), + "online_learning_cycle": + self._online_learning_cycle, + "process_index": + i, + "process_type": + type(process).__name__ + }) + + # Log each predicate as a separate entry + for i, pred in enumerate(kept_predicates): + wandb.log({ + f"predicate_{i}_cycle_{self._online_learning_cycle}": + str(pred), + "online_learning_cycle": + self._online_learning_cycle, + "predicate_index": + i, + "predicate_name": + pred.name, + }) + return set(kept_predicates) + + +def get_false_positive_states_from_seg_trajs( + segmented_trajs: List[List[Segment]], + exogenous_processes: List[ExogenousProcess], +) -> Dict[_GroundExogenousProcess, List[State]]: + + # Map from ground_exogenous_process to a list of init states where the + # condition is satisfied. + false_positive_process_state: Dict[_GroundExogenousProcess, List[State]] = \ + defaultdict(list) + + # Cache for ground_exogenous_processes to avoid recomputation + objects_to_ground_processes = {} + + for segmented_traj in segmented_trajs: + # Checking each segmented trajectory + objects = frozenset(segmented_traj[0].trajectory.states[0]) + # Only recompute if objects are different + if objects not in objects_to_ground_processes: + ground_exogenous_processes, _ = process_task_plan_grounding( + set(), + objects, # type: ignore[arg-type] + exogenous_processes, + allow_waits=True, + compute_reachable_atoms=False) + objects_to_ground_processes[objects] = ground_exogenous_processes + else: + ground_exogenous_processes = objects_to_ground_processes[objects] + + # Pre-compute segment init_atoms for efficiency + segment_init_atoms = [segment.init_atoms for segment in segmented_traj] + + for g_exo_process in ground_exogenous_processes: + condition = g_exo_process.condition_at_start # Cache reference + add_effects = g_exo_process.add_effects + delete_effects = g_exo_process.delete_effects + + for i, segment in enumerate(segmented_traj): + satisfy_condition = condition.issubset(segment_init_atoms[i]) + first_state_or_prev_state_doesnt_satisfy = i == 0 or \ + not condition.issubset(segment_init_atoms[i - 1]) + + if satisfy_condition and first_state_or_prev_state_doesnt_satisfy: + false_positive_process_state[ + g_exo_process].append( # type: ignore[index] + # segment.trajectory.states[0]) + segment.init_atoms) # type: ignore[arg-type] + + # Check for removal condition + if (add_effects.issubset(segment.add_effects) + and delete_effects.issubset(segment.delete_effects)): + if false_positive_process_state[ + g_exo_process]: # type: ignore[index] + # TODO: we don't really know which one to remove, pop + # the first one is a bias. + false_positive_process_state[g_exo_process].pop( + 0) # type: ignore[index] + return false_positive_process_state + + +def get_false_positive_states( + trajectories: List[LowLevelTrajectory], + predicates: Set[Predicate], + exogenous_processes: List[ExogenousProcess], +) -> Dict[_GroundExogenousProcess, List[State]]: + """Get the false positive states for each exogenous process. + + Return: + ground_exogenous_process -> + Tuple[List[State], List[GroundAtom], List[GroundAtom]] per + trajectory where List[State] is the list of states where the + process is activated in the trajectory. + """ + initial_segmenter_method = CFG.segmenter + # TODO: use option_changes allows for creating a segment for the noop option + # in the end, but would cause problem if the start and end of option + # execution doesn't satisfy the condition but somewhere in the middle does + # it. The same problem exists for the effects. + # + # The fix for the atom_changes segmenter would be to create a segment in + # the end if there is still sttes after the last atom change. + CFG.segmenter = "atom_changes" + segmented_trajs = [ + segment_trajectory(traj, predicates, verbose=False) + for traj in trajectories + ] + CFG.segmenter = initial_segmenter_method + + return get_false_positive_states_from_seg_trajs(segmented_trajs, + exogenous_processes) + + +def get_true_positive_process_states( + predicates: Set[Predicate], + exogenous_processes: List[ExogenousProcess], + ground_exogenous_processes: List[_GroundExogenousProcess], + trajectories: List[LowLevelTrajectory], +) -> Dict[_GroundExogenousProcess, List[State]]: + """Get the true positive states for each exogenous process.""" + initial_segmenter_method = CFG.segmenter + CFG.segmenter = "atom_changes" + segmented_trajs = [ + segment_trajectory(traj, predicates) for traj in trajectories + ] + CFG.segmenter = initial_segmenter_method + + # Filter out segments explained by endogenous processes. + filtered_segmented_trajs = filter_explained_segment( + segmented_trajs, + exogenous_processes, # type: ignore[arg-type] + remove_options=True) + true_positive_process_state: Dict[_GroundExogenousProcess, + List[State]] = defaultdict(list) + for g_exo_process in ground_exogenous_processes: + for segmented_traj in filtered_segmented_trajs: + # Checking each segmented trajectory + for segment in segmented_traj: + # Check if the segment is a positive example for any + # exogenous process + if g_exo_process.condition_at_start.issubset( + segment.init_atoms) and \ + g_exo_process.add_effects.issubset( + segment.add_effects) and \ + g_exo_process.delete_effects.issubset( + segment.delete_effects): + true_positive_process_state[g_exo_process].append( + segment.trajectory.states[0]) + return true_positive_process_state + + +def _get_predicates_str(predicates: Set[Predicate], + include_primitive_preds: bool = True, + include_derived_preds: bool = True) -> str: + + init_pred_str = [] + for p in predicates: + if include_primitive_preds and not isinstance(p, DerivedPredicate): + init_pred_str.append(p.pretty_str_with_assertion()) + elif include_derived_preds and isinstance(p, DerivedPredicate): + init_pred_str.append(p.pretty_str_with_assertion()) + logging.debug(f"Current predicate str: {init_pred_str}") + init_pred_str = sorted(init_pred_str) + return "\n".join(init_pred_str) + + +def _get_types_str(types: Set[Type], + include_features: bool = True, + use_python_def_str: bool = False) -> str: + """Get the types string.""" + excluded_types = [] + if CFG.excluded_objects_in_state_str: + excluded_types = CFG.excluded_objects_in_state_str.split(",") + + if use_python_def_str: + type_str = [ + t.python_definition_str() for t in types + if t.name not in excluded_types + ] + else: + type_str = [ + t.pretty_str() for t in types if t.name not in excluded_types + ] + type_str = sorted(type_str) + return "\n".join(type_str) + + +def _get_transition_str( + trajectories: List[LowLevelTrajectory], + train_tasks: List[Task], + predicates: Set[Predicate], + ite: int, + max_num_trajs: int = 1, + only_use_successful_trajs: bool = False, + use_abstract_state_str: bool = False, +) -> Tuple[str, str]: + """Get the state before and after some actions. + + Prioritize successful trajectories. + TODO: save images of the states. + """ + if CFG.rgb_observation: + obs_dir = CFG.log_file + f"ite{ite}_obs/" + os.makedirs(obs_dir, exist_ok=True) + + if only_use_successful_trajs: + successful_trajs = [ + traj for traj in trajectories + if traj_is_successful(traj, train_tasks) + ] + if successful_trajs: + trajectories = successful_trajs + trajectories = trajectories[:max_num_trajs] + + # Segment the trajectories and get states before and after the actions. + segmented_trajs = [ + segment_trajectory(ll_traj, predicates) for ll_traj in trajectories + ] + result_str, state_str_set = [], [] + state_hash_to_id: Dict[int, int] = {} + for seg_traj in segmented_trajs: + for i, segment in enumerate(seg_traj): + # Get state cache and observation name + init_state_hash = hash(segment.states[0]) + if init_state_hash not in state_hash_to_id: + state_hash_to_id[init_state_hash] = len(state_hash_to_id) + init_state_id = state_hash_to_id[init_state_hash] + obs_name = "state_" + str(init_state_id) + + # Append state + if i == 0: + result_str.append( + f"Starting at {obs_name} with additional info:") + state = segment.states[0] + assert isinstance(state, utils.PyBulletState) + if use_abstract_state_str: + state_str = sorted(utils.abstract(state, predicates)) + else: + state_str = state.dict_str( + indent=2, # type: ignore[assignment] + use_object_id=CFG.rgb_observation) + + result_str.append(f"{state_str}") + str_for_this_state = [f" {obs_name} with additional info:"] + str_for_this_state.append(f"{state_str}") + state_str_set.append("\n".join(str_for_this_state)) + if CFG.rgb_observation: + save_image_with_label( + state.labeled_image.copy(), + obs_name, # type: ignore[union-attr] + obs_dir) + + # Append action + action_str = segment.actions[0].get_option().simple_str( + use_object_id=CFG.rgb_observation) + result_str.append( + f"\nAction {action_str} was executed in {obs_name}") + + # Get state cache and observation name + end_state_hash = hash(segment.states[-1]) + if end_state_hash not in state_hash_to_id: + state_hash_to_id[end_state_hash] = len(state_hash_to_id) + end_state_id = state_hash_to_id[end_state_hash] + obs_name = "state_" + str(end_state_id) + result_str.append(f"\nThis action results in {obs_name} " + "with additional info:") + # Append final state + state = segment.states[-1] + if use_abstract_state_str: + state_str = sorted(utils.abstract(state, predicates)) + else: + state_str = state.dict_str( + indent=2, # type: ignore[assignment] + use_object_id=CFG.rgb_observation) + result_str.append(f"{state_str}") + str_for_this_state = [f" {obs_name} with additional info:"] + str_for_this_state.append(f"{state_str}") + state_str_set.append("\n".join(str_for_this_state)) + if CFG.rgb_observation: + save_image_with_label( + state.labeled_image.copy(), + obs_name, # type: ignore[attr-defined] + obs_dir) + + return "\n".join(result_str), "\n\n".join(state_str_set) + + +def save_image_with_label(img_copy: Image, + s_name: str, + obs_dir: str, + f_suffix: str = ".png") -> None: + draw = ImageDraw.Draw(img_copy) + font = ImageFont.load_default().font_variant( + size=50) # type: ignore[union-attr] + text_color = (0, 0, 0) # white + draw.text((0, 0), s_name, fill=text_color, font=font) + img_copy.save(os.path.join(obs_dir, s_name + + f_suffix)) # type: ignore[attr-defined] + logging.debug(f"Saved image {s_name}") + + +def load_images_from_directory(dir: str) -> List[PIL.Image.Image]: + images = [] + for filename in os.listdir(dir): + file_path = os.path.join(dir, filename) + if filename.lower().endswith(('.png', '.jpg')): + img = PIL.Image.open(file_path) + images.append(img) + return images + + +def traj_is_successful(traj: LowLevelTrajectory, + train_tasks: List[Task]) -> bool: + """Check if the trajectory is successful for any of the train tasks.""" + goal_atoms = train_tasks[traj._train_task_idx].goal # type: ignore[index] + goal_predicates = {atom.predicate for atom in goal_atoms} + abstract_state = utils.abstract(traj.states[-1], goal_predicates) + return goal_atoms.issubset(abstract_state) + + +def add_python_quote(text: str) -> str: + return f"```python\n{text}\n```" + + +def _parse_predicates_predictions( + response: str, + initial_predicates: Set[Predicate], + candidate_predicates: Set[Predicate], + # existing_primitive_candidates: Set[Predicate], + # existing_derived_candidates: Set[DerivedPredicate], + types: Set[Type], + example_state: State, +) -> Tuple[List[Predicate], List[DerivedPredicate]]: + # Regular expression to match Python code blocks + pattern = re.compile(r'```python(.*?)```', re.DOTALL) + python_blocks = [] + # Find all Python code blocks in the text + for match in pattern.finditer(response): + # Extract the Python code block and add it to the list + python_blocks.append(match.group(1).strip()) + + existing_primitive_candidates: Set[Predicate] = set( + p for p in candidate_predicates if not isinstance(p, DerivedPredicate)) + existing_derived_candidates: Set[DerivedPredicate] = set( + p for p in candidate_predicates if isinstance(p, DerivedPredicate)) + primitive_preds: Set[Predicate] = set() + context: Dict[str, Any] = {} + untranslated_derived_pred_str: List[str] = [] + # --- Existing predicates and their classifiers + for p in initial_predicates: + context[f"_{p.name}_NSP_holds"] = p._classifier + + for p in existing_derived_candidates: + context[f"_{p.name}_CP_holds"] = p._classifier + + for p in existing_primitive_candidates | existing_derived_candidates: + context[f"{p.name}"] = p + + # --- Types --- + for t in types: + context[f"_{t.name}_type"] = t + + # --- Imports --- + exec(import_str, context) + + # --- Interpret the Python blocks --- + for code_str in python_blocks: + # Extract name from code block + match = re.search(r'(\w+)\s*=\s*(NS)?Predicate', + code_str) # type: ignore[assignment] + if match is None: + logging.warning("No predicate name found in the code block" + ) # type: ignore[unreachable] + continue + pred_name = match.group(1) + logging.info(f"Found definition for predicate {pred_name}") + vlm_invention_use_concept_predicates = False + if vlm_invention_use_concept_predicates: + is_concept_predicate = check_is_derived_predicate(code_str) + logging.info(f"\t it's a derived predicate: " + f"{is_concept_predicate}") + else: + is_concept_predicate = False + # logging.info(f"\t derived predicate disabled") + + # Recognize that it's a derived predicate + if is_concept_predicate: + untranslated_derived_pred_str.append(add_python_quote(code_str)) + else: + # Type check the code + # passed = False + # while not passed: + # result, passed = self.type_check_proposed_predicates( + # pred_name, + # code_str) + # if not passed: + # # Ask the LLM or the User to fix the code + # pass + # else: + # break + + # Instantiate the primitive predicates + # check if it's roughly runable, and add it to list if it is. + try: + exec(code_str, context) + logging.debug(f"Testing predicate {pred_name}") + # Check1: Make sure it uses types present in the environment + proposed_pred = context[pred_name] + for t in proposed_pred.types: + if t not in types: + logging.warning(f"Type {t} not in the environment") + raise Exception(f"Type {t} not in the environment") + utils.abstract(example_state, [context[pred_name]]) + except Exception as e: + error_trace = traceback.format_exc() + logging.warning(f"Test failed: {e}\n{error_trace}") + continue + else: + logging.debug(f"Test passed!") + primitive_preds.add(context[pred_name]) + + # TODO: --- Convert the derived predicates to DerivedPredicate --- + derived_predicates: Set[DerivedPredicate] = set() + + return primitive_preds, derived_predicates # type: ignore[return-value] + + +import_str = """ +import numpy as np +from typing import Sequence, Set, List +from predicators.structs import State, Object, Type, GroundAtom, Predicate, \ + NSPredicate, DerivedPredicate +from predicators.settings import CFG +""" + + +def check_is_derived_predicate(code_str: str) -> bool: + """Check if the predicate is a derived predicate by looking for `get` or + `evaluate_simple` in the code block.""" + if "state.get(" in code_str or\ + "state.evaluate_simple_assertion" in code_str: + return False + return True diff --git a/predicators/approaches/pp_online_process_learning_approach.py b/predicators/approaches/pp_online_process_learning_approach.py new file mode 100644 index 0000000000..a9babe4bff --- /dev/null +++ b/predicators/approaches/pp_online_process_learning_approach.py @@ -0,0 +1,160 @@ +import logging +from typing import List, Optional, Sequence, Set + +from gym.spaces import Box +from scipy.optimize import minimize + +from predicators.approaches.pp_process_learning_approach import \ + ProcessLearningAndPlanningApproach +from predicators.explorers import BaseExplorer, create_explorer +from predicators.option_model import _OptionModelBase +from predicators.settings import CFG +from predicators.structs import Dataset, InteractionRequest, \ + InteractionResult, LowLevelTrajectory, ParameterizedOption, Predicate, \ + Task, Type + + +class OnlineProcessLearningAndPlanningApproach( + ProcessLearningAndPlanningApproach): + """A bilevel planning approach that uses hand-specified processes.""" + + def __init__(self, + initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], + types: Set[Type], + action_space: Box, + train_tasks: List[Task], + task_planning_heuristic: str = "default", + max_skeletons_optimized: int = -1, + bilevel_plan_without_sim: Optional[bool] = None, + option_model: Optional[_OptionModelBase] = None): + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + task_planning_heuristic, + max_skeletons_optimized, + bilevel_plan_without_sim, + option_model=option_model) + self._online_dataset = Dataset([]) + self._online_learning_cycle = 0 + self._requests_train_task_idxs: Optional[List[int]] = None + + @classmethod + def get_name(cls) -> str: + return "online_process_learning_and_planning" + + def learn_from_offline_dataset(self, dataset: Dataset) -> None: + """Learn models from the offline datasets.""" + if len(dataset.trajectories) > 0: + self._learn_processes( + dataset.trajectories, + online_learning_cycle=None, + annotations=(dataset.annotations + if dataset.has_annotations else None)) + else: + logging.info("Offline dataset is empty, skipping learning.") + + def get_interaction_requests(self) -> List[InteractionRequest]: + """Designing experiments to collect data. + TODO: This is currently the same as the one for OnlineNSRTLearning + We want to collect data to learn processes for solving, for now, the + planning tasks. + To achieve the goal, we want to learn the conditions and effects that + allows for efficient and effective sequencing of actions and processes. + + There are various exploration strategies: + 1. as in VisualPredicator, make plans for solving the tasks and learn + from the failure cases. + 2. try whether removing one of the conditions of the exogenous + process would allow the process to succeed. + """ + explorer = self._create_explorer() + + # As in OnlineNSRTLearningApproach, do some resets. + self._last_nsrt_plan = [] + self._last_atoms_seq = [] + self._last_plan = [] + + # Create the interaction requests. + requests = [] + # Can also just use CFG.online_nsrt_learning_requests_per_cycle + # for _ in range(CFG.online_nsrt_learning_number_of_tasks_to_try): + # # Select a random task (with replacement). + # task_idx = self._rng.choice(len(self._train_tasks)) + # for i in range(CFG.online_nsrt_learning_requests_per_task): + # logging.info(f"Getting strategy {i} for task {task_idx}") + # # Set up the explorer policy and termination function. + # policy, termination_function = explorer.get_exploration_strategy( + # task_idx, CFG.timeout) + # # Create the interaction request. + # req = InteractionRequest( + # train_task_idx=task_idx, + # act_policy=policy, + # query_policy=lambda s: None, + # termination_function=termination_function) + # requests.append(req) + self._requests_train_task_idxs = [] + for i in range(CFG.online_nsrt_learning_requests_per_cycle): + task_idx = self._rng.choice(len(self._train_tasks)) + logging.info(f"Getting strategy {i}; this is for task {task_idx}") + self._requests_train_task_idxs.append(task_idx) + policy, termination_function = explorer.get_exploration_strategy( + task_idx, CFG.timeout) + req = InteractionRequest(train_task_idx=task_idx, + act_policy=policy, + query_policy=lambda s: None, + termination_function=termination_function) + requests.append(req) + return requests + + def learn_from_interaction_results( + self, results: Sequence[InteractionResult]) -> None: + """Learn from interaction results. + + We will organize the interaction results as follows: + 1. interaction trajectories + 2. failed initial states for options? (might not work well with weak + option termination classifiers.) + Old: + For endogenous process, initial states where it succeeded and failed. + For exogenous process, suffixes of the trajectories where that atom + changed. + """ + # TODO: update _dataset based on the results + # Can potentially have a positive and negative dataset + for result in results: + traj = LowLevelTrajectory(result.states, result.actions) + self._online_dataset.append(traj) + + # Learn from the dataset + annotations = None + if self._online_dataset.has_annotations: + annotations = self._online_dataset.annotations # pragma: no cover + self._learn_processes( + self._online_dataset.trajectories, + online_learning_cycle=self._online_learning_cycle, + annotations=annotations) + + if CFG.learn_process_parameters: + self._learn_process_parameters(self._offline_dataset.trajectories+\ + self._online_dataset.trajectories) + + self._online_learning_cycle += 1 + + def _create_explorer(self) -> BaseExplorer: + """Create a new explorer at the beginning of each interaction cycle.""" + # Note that greedy lookahead is not yet supported. + preds = self._get_current_predicates() + explorer = create_explorer( + CFG.explorer, + preds, + self._initial_options, + self._types, + self._action_space, + self._train_tasks, + self._get_current_processes(), # type: ignore[arg-type] + self._option_model, + ) + return explorer diff --git a/predicators/approaches/pp_oracle_approach.py b/predicators/approaches/pp_oracle_approach.py new file mode 100644 index 0000000000..6bca8a8cf9 --- /dev/null +++ b/predicators/approaches/pp_oracle_approach.py @@ -0,0 +1,87 @@ +from typing import Callable, List, Optional, Set + +from gym.spaces import Box + +from predicators.approaches.process_planning_approach import \ + BilevelProcessPlanningApproach +from predicators.ground_truth_models import augment_task_with_helper_objects, \ + get_gt_helper_predicates, get_gt_helper_types, get_gt_processes +from predicators.option_model import _OptionModelBase +from predicators.settings import CFG +from predicators.structs import NSRT, Action, CausalProcess, \ + ParameterizedOption, Predicate, State, Task, Type + + +class OracleBilevelProcessPlanningApproach(BilevelProcessPlanningApproach): + """A bilevel planning approach that uses hand-specified processes.""" + + def __init__(self, + initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], + types: Set[Type], + action_space: Box, + train_tasks: List[Task], + task_planning_heuristic: str = "default", + max_skeletons_optimized: int = -1, + bilevel_plan_without_sim: Optional[bool] = None, + processes: Optional[Set[CausalProcess]] = None, + option_model: Optional[_OptionModelBase] = None) -> None: + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + task_planning_heuristic, + max_skeletons_optimized, + bilevel_plan_without_sim, + option_model=option_model) + # Add optional helpful types and predicates (such as in dominoes the + # ones about positions and directions) + helper_types = get_gt_helper_types(CFG.env) + helper_predicates = get_gt_helper_predicates(CFG.env) + self._types = types | helper_types + self._initial_predicates = initial_predicates | helper_predicates + + if processes is None: + # use only_endogenous for the no_invent baseline + processes = get_gt_processes( + CFG.env, + self._initial_predicates, + self._initial_options, + only_endogenous=CFG.running_no_invent_baseline) + + # Set all processes' strength parameters to 1 if flag is enabled + if CFG.process_planning_set_parameters_one: + import torch + modified_processes = set() + for process in processes: + # Create a copy with strength set to 1 + strength_params = torch.tensor([1.0]) + delay_params = torch.ones( + len(process.delay_distribution.get_parameters())) + process._set_parameters( + torch.cat([strength_params, delay_params]).tolist()) + modified_processes.add(process) + processes = modified_processes + + self._processes = processes + + @classmethod + def get_name(cls) -> str: + return "oracle_process_planning" + + @property + def is_learning_based(self) -> bool: + return False + + def _get_current_processes(self) -> Set[CausalProcess]: + return self._processes + + def _get_current_nsrts(self) -> Set[NSRT]: + """Get the current set of NSRTs.""" + return set() + + def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: + # Augment task with helper objects if needed + task = augment_task_with_helper_objects(task, CFG.env) + return super()._solve(task, timeout) \ No newline at end of file diff --git a/predicators/approaches/pp_param_learning_approach.py b/predicators/approaches/pp_param_learning_approach.py new file mode 100644 index 0000000000..816cf1c6a5 --- /dev/null +++ b/predicators/approaches/pp_param_learning_approach.py @@ -0,0 +1,1057 @@ +import logging +import os +import random +import time +from collections import defaultdict +from pprint import pformat +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple + +import torch +from gym.spaces import Box +from torch import Tensor +from torch.optim import LBFGS, Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau +from tqdm.auto import tqdm # type: ignore[import-untyped] + +from predicators import utils +from predicators.approaches.process_planning_approach import \ + BilevelProcessPlanningApproach +from predicators.ground_truth_models import get_gt_processes +from predicators.option_model import _OptionModelBase +from predicators.planning_with_processes import process_task_plan_grounding +from predicators.settings import CFG +from predicators.structs import NSRT, AtomOptionTrajectory, CausalProcess, \ + Dataset, EndogenousProcess, ExogenousProcess, GroundAtom, \ + LowLevelTrajectory, ParameterizedOption, Predicate, Task, Type, \ + _GroundCausalProcess + + +class ParamLearningBilevelProcessPlanningApproach( + BilevelProcessPlanningApproach): + """A bilevel planning approach that uses hand-specified processes.""" + + def __init__(self, + initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], + types: Set[Type], + action_space: Box, + train_tasks: List[Task], + task_planning_heuristic: str = "default", + max_skeletons_optimized: int = -1, + bilevel_plan_without_sim: Optional[bool] = None, + processes: Optional[Set[CausalProcess]] = None, + option_model: Optional[_OptionModelBase] = None): + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + task_planning_heuristic, + max_skeletons_optimized, + bilevel_plan_without_sim, + option_model=option_model) + if processes is None: + processes = get_gt_processes(CFG.env, self._initial_predicates, + self._initial_options) + self._processes: Set[CausalProcess] = processes + self._offline_dataset = Dataset([]) + + @classmethod + def get_name(cls) -> str: + return "param_learning_process_planning" + + @property + def is_learning_based(self) -> bool: + return True + + def _get_current_processes(self) -> Set[CausalProcess]: + return self._processes + + def _get_current_exogenous_processes(self) -> Set[ExogenousProcess]: + """Get the current set of exogenous processes.""" + return {p for p in self._processes if isinstance(p, ExogenousProcess)} + + def _get_current_endogenous_processes(self) -> Set[EndogenousProcess]: + """Get the current set of endogenous processes.""" + return {p for p in self._processes if isinstance(p, EndogenousProcess)} + + def _get_current_nsrts(self) -> Set[NSRT]: + """Get the current set of NSRTs.""" + return set() + + def learn_from_offline_dataset(self, dataset: Dataset) -> None: + """Learn parameters of processes from the offline datasets. + + This is currently achieved by optimizing the marginal data + likelihood. + """ + self._learn_process_parameters(dataset.trajectories) + + def _learn_process_parameters( + self, + trajectories: List[LowLevelTrajectory], + use_lbfgs: bool = False, + ) -> None: + """Stochastic (mini-batch) optimisation of process parameters.""" + processes = sorted(self._get_current_processes()) + _, scores = learn_process_parameters( + trajectories[:1], + self._get_current_predicates(), + processes, + use_lbfgs=use_lbfgs, + lbfgs_max_iter=CFG.process_param_learning_num_steps, + adam_num_steps=CFG.process_param_learning_num_steps, + early_stopping_patience=20, + use_empirical=CFG.process_param_learning_use_empirical, + ) + logging.debug(f"ELBO: {scores[0]}, exp_state: {scores[1]}, " + f"exp_delay: {scores[2]}, entropy: {scores[3]}") + logging.debug("Learned processes:") + for p in processes: + logging.debug(pformat(p)) + logging.debug(f"Log frame strength: {scores[4]}") + return + + +def learn_process_parameters( + trajectories: List[LowLevelTrajectory], + predicates: Set[Predicate], + processes: Sequence[CausalProcess], + use_lbfgs: bool = False, + plot_training_curve: bool = True, + lbfgs_max_iter: int = 200, + seed: int = 0, + display_progress: bool = True, + adam_num_steps: int = 200, + std_regularization: Optional[int] = None, + early_stopping_patience: Optional[int] = None, + early_stopping_tolerance: float = 1e-4, + check_condition_overall: bool = True, + batch_size: int = 16, + debug_log: bool = False, + use_empirical: bool = False, +) -> Tuple[Sequence[CausalProcess], Tuple[float, float, float, float, float]]: + """Learn process parameters using stochastic optimization or empirical + estimation. + + If use_empirical=True, bypasses variational inference and directly + estimates delay parameters from observed data. + """ + + # If using empirical estimation, bypass all the variational inference + if use_empirical: + processes, stats = learn_process_parameters_empirical( + trajectories, predicates, processes, use_empirical=True) + + # Even when using empirical estimation, we need to prepare data and evaluate properly + max_traj_len = max( + len(traj.states) + for traj in trajectories) if len(trajectories) > 0 else 0 + + per_traj_data, proc_and_guide_params_full, num_proc_params = \ + _prepare_training_data_and_model_params( + predicates, + processes, + trajectories, + check_condition_overall, + ) + + # Initialize guide parameters randomly since we don't learn them empirically + guide_params = proc_and_guide_params_full[num_proc_params:] + final_frame_param = torch.tensor(1.0) # Default frame strength + + # Evaluate the empirically set model on the dataset + mean_elbo, mean_exp_state, mean_exp_delay, mean_entropy = evaluate_model_on_dataset( + per_traj_data=per_traj_data, + frame_param=final_frame_param, + guide_params=guide_params, + debug_log=debug_log) + + return processes, (mean_elbo, mean_exp_state, mean_exp_delay, + mean_entropy, 1.0) + + if use_lbfgs: + num_steps = 1 + batch_size = 100 + inner_lbfgs_max_iter = lbfgs_max_iter + else: + num_steps = adam_num_steps + batch_size = batch_size + + torch.manual_seed(seed) + random.seed(seed) + + # -------------------------------------------------------------- # + # 0. Cache per-trajectory data & build a global param layout # + # -------------------------------------------------------------- # + max_traj_len = max(len(traj.states) for traj in trajectories)\ + if len(trajectories) > 0 else 0 + + per_traj_data, proc_and_guide_params_full, num_proc_params = \ + _prepare_training_data_and_model_params( + predicates, + processes, + trajectories, + check_condition_overall, + ) + + # --- Optionally initialize process parameters with empirical estimates --- + if CFG.use_empirical_init_for_vi_params: + _initialize_params_with_empirical_estimates( + trajectories, predicates, processes, proc_and_guide_params_full, + num_proc_params) + + # --- Separate parameter tensor into logical, learnable components --- + + # All process parameters (strength + delay) from the initial tensor + proc_params_full = proc_and_guide_params_full[:num_proc_params] + + learnable_params_for_optim = [] + + guide_params = torch.nn.Parameter( + proc_and_guide_params_full[num_proc_params:]) + learnable_params_for_optim.append(guide_params) + + learnable_proc_params = torch.nn.Parameter(proc_params_full) + learnable_params_for_optim.append(learnable_proc_params) + + frame_param = torch.nn.Parameter(torch.randn(1) * 0.01) + learnable_params_for_optim.append(frame_param) + + init_proc_param = proc_params_full.detach() + _set_process_parameters(processes, init_proc_param, + **{'max_k': max_traj_len}) + + # ------------------- progress bar -------------------------- # + if use_lbfgs: + pbar_total = num_steps * inner_lbfgs_max_iter + desc = "Training (mini‑batch LBFGS)" + else: + pbar_total = num_steps + desc = "Training (Adam)" + if display_progress: + pbar = tqdm(total=pbar_total, desc=desc) + else: + pbar = None + + best_elbo = -float("inf") + curve: Dict = { + "iterations": [], + "elbos": [], + "best_elbos": [], + "wall_time": [] + } + training_start_time = time.time() + + # --- Early stopping setup --- + patience_counter = early_stopping_patience + best_params_state = None + optim: Optional[torch.optim.Optimizer] = None + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None + if use_lbfgs: + # LBFGS is re-initialized per outer step or initialized once here. + # optim = LBFGS([params], max_iter=inner_lbfgs_max_iter, + # line_search_fn="strong_wolfe") + pass # Will be initialized in the loop + else: + # 1e-1; 500 steps: 983; + # 5k steps (no scheduler): 987.0547/987.0568 + # 5k steps (scheduler): 987.0547/987.0537 + # 10k (no schedule) : 987.0574 + # 5e-2; 500 steps: 980/981; 5k steps: 986/987 + # 1e-2; 500 steps: 953/962; 5k steps: 984/984 + lr = 1e-1 + if debug_log: + logging.debug(f"lr={lr}") + optim = Adam(learnable_params_for_optim, lr=lr) + # scheduler = ReduceLROnPlateau(optim, + # mode='min', + # factor=0.5, + # patience=20, + # verbose=True,) + if debug_log: + if scheduler: + logging.debug(f"Scheduler initialized: {scheduler}") + + # ------------------- training loop ----------------------------- # + iteration = 0 + for outer_step in range(num_steps): + current_optim: Optional[LBFGS] = None + if use_lbfgs: + current_optim = LBFGS(learnable_params_for_optim, + max_iter=inner_lbfgs_max_iter, + line_search_fn="strong_wolfe") + else: + current_optim = optim # type: ignore[assignment] + + assert current_optim is not None, "Optimizer not initialized" + + # remaining_ids = list(range(1, len(per_traj_data))) + # additional_samples = min(batch_size - 1, len(remaining_ids)) + # batch_ids = [0] + random.sample(remaining_ids, k=additional_samples) + num_trajs = len(per_traj_data) + batch_ids = random.sample(range(num_trajs), + k=min(batch_size, num_trajs)) + + def closure() -> float: + nonlocal best_elbo, iteration + nonlocal patience_counter #, best_params_state + + if current_optim: + current_optim.zero_grad(set_to_none=True) + + proc_param = learnable_proc_params + + _set_process_parameters(processes, proc_param) + + guide_flat = guide_params + frame = frame_param + + elbo = torch.tensor(0.0, dtype=frame.dtype, device=frame.device) + + for tidx in batch_ids: + td = per_traj_data[tidx] + guide_dict = _create_guide_dict_for_trajectory( + td, guide_flat, td["traj_len"]) + + data_elbo, _, _, _ = elbo_torch( + td["trajectory"], + td["sparse_trajectory"], + td["ground_causal_processes"], + td["start_times_per_gp"], + guide_dict, + frame, + set(td["all_atoms"]), + td["atom_to_val_to_gps"], + td["condition_cache"], + ) + elbo = elbo + data_elbo + + loss = -(elbo / len(batch_ids)) + if std_regularization and learnable_params_for_optim: + loss = loss + std_regularization * (proc_param[2::3].sum()) + + if learnable_params_for_optim: + loss.backward() # type: ignore + + detached_elbo_item = elbo.detach().item() + + # --- Early stopping check --- + if early_stopping_patience is not None: + if detached_elbo_item > best_elbo + early_stopping_tolerance: + best_elbo = detached_elbo_item + # best_params_state = [ + # p.clone().detach() for p in learnable_params_for_optim + # ] + patience_counter = early_stopping_patience + else: + if patience_counter is not None: + patience_counter -= 1 + elif detached_elbo_item > best_elbo: + best_elbo = detached_elbo_item + + curve["iterations"].append(iteration) + curve["elbos"].append(detached_elbo_item) + curve["best_elbos"].append(best_elbo) + curve["wall_time"].append(time.time() - training_start_time) + if pbar: + pbar.set_postfix(ELBO=detached_elbo_item, best=best_elbo) + pbar.update(1) + + iteration += 1 + return loss.item() + + if use_lbfgs: + current_optim.step(closure) + else: + loss = closure() + current_optim.step() + if scheduler: + if debug_log: + prev_lr = scheduler.get_last_lr() + scheduler.step(loss) # type: ignore[arg-type] + if debug_log: + curr_lr = scheduler.get_last_lr() + if curr_lr != prev_lr: + logging.debug( + f"decreasing lr from {prev_lr} to {curr_lr}") + + # --- Trigger early stop if patience has run out --- + if early_stopping_patience is not None and patience_counter is not None and patience_counter <= 0: + break + + if pbar: + pbar.close() + + # --- Restore best parameters before evaluation --- + if best_params_state is not None: + for param, best_state in zip( # type: ignore[unreachable] + learnable_params_for_optim, best_params_state): + param.data.copy_(best_state) + + # --- Persist Final Parameters and Evaluate --- + final_guide_params = guide_params.detach() + final_proc_params = learnable_proc_params.detach() + + _set_process_parameters(processes, final_proc_params) + final_frame_param = frame_param.detach() + + # Call the new independent evaluation function + mean_elbo, mean_exp_state, mean_exp_delay, mean_entropy = evaluate_model_on_dataset( + per_traj_data=per_traj_data, + frame_param=final_frame_param, + guide_params=final_guide_params, + debug_log=debug_log) + + if plot_training_curve: + _plot_training_curve(curve) + + return processes, (mean_elbo, mean_exp_state, mean_exp_delay, mean_entropy, + final_frame_param.item()) + + +def elbo_torch( + atom_option_trajectory: AtomOptionTrajectory, + sparse_trajectory: List[Tuple[Set[GroundAtom], int, int]], + ground_processes: List[ + _GroundCausalProcess], # All potential ground causal processes + start_times_per_gp: List[List[ + int]], # start_times_per_gp[gp_idx] is list of s_i for ground_processes[gp_idx] + guide: Dict[_GroundCausalProcess, + Dict[int, Tensor]], # Variational params q(z_t ; gp, s_i) + log_frame_strength: Tensor, + all_possible_atoms: Set[GroundAtom], + atom_to_val_to_gps: Dict[GroundAtom, Dict[bool, + Set[_GroundCausalProcess]]], + condition_cache: Dict[_GroundCausalProcess, Dict[int, Dict[int, bool]]], + use_sparse_trajectory: bool = True, + debug_log: bool = False, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """*Differentiable* ELBO computation with efficient, cached condition checks.""" + trajectory = atom_option_trajectory + num_time_steps = len(trajectory.states) + + ll = torch.tensor(0.0, dtype=log_frame_strength.dtype) + yt_prev = trajectory.states[0] + + # ----------------------------------------------------------------- + # 1. Expected log state probabilities + # ----------------------------------------------------------------- + exp_state_prob = torch.tensor(0.0, dtype=log_frame_strength.dtype) + if use_sparse_trajectory: + if debug_log: + logging.debug(f"Compute exp_state_prob from " + f"{len(sparse_trajectory)-1} segments") + for i, (yt, start_t, _) in enumerate(sparse_trajectory[1:]): + state_prob_t = torch.tensor(0.0, dtype=log_frame_strength.dtype) + E_log_Zt = torch.tensor(0.0, dtype=log_frame_strength.dtype) + for atom, val_to_gps in atom_to_val_to_gps.items(): + sum_ytj = torch.tensor(0.0, dtype=log_frame_strength.dtype) + for val in (True, False): + # Ground processes which have this atom in the add or delete + # effects. + gps = val_to_gps.get(val, set()) + + prod = torch.tensor(1.0, dtype=log_frame_strength.dtype) + for gp in gps: + for st, q in guide.get(gp, {}).items(): + if st < start_t: + # --- Efficient Cache Lookup --- + # Default to True if not in cache (e.g., no overall cond.) + condition_overall_holds = condition_cache.get( + gp, {}).get(st, {}).get(start_t - 1, True) + prev_val = atom in yt_prev + # --- Numerator Part --- + if val == (atom + in yt) and condition_overall_holds: + exp_state_prob = exp_state_prob + \ + q[start_t] * \ + gp.factored_effect_factor(val, atom, + prev_val) + state_prob_t = state_prob_t + \ + q[start_t] * \ + gp.factored_effect_factor(val, atom, + prev_val) + + # --- Denominator Part --- + if condition_overall_holds: + prod = prod * (q[start_t] * torch.exp( + gp.factored_effect_factor( + val, atom, prev_val)) + + (1 - q[start_t])) + sum_ytj = sum_ytj + prod * torch.exp(log_frame_strength * + (val == + (atom in yt_prev))) + E_log_Zt = E_log_Zt + torch.log(sum_ytj + 1e-12) + + # Atoms not referenced in any process law + add_atoms = yt - yt_prev + del_atoms = yt_prev - yt + atoms_unchanged = all_possible_atoms - del_atoms - add_atoms + exp_state_prob = exp_state_prob + log_frame_strength * len( + atoms_unchanged) + state_prob_t = state_prob_t + log_frame_strength * len( + atoms_unchanged) + + # Normalization contribution from atoms not described by the processes + atoms_in_law_effects = set(atom_to_val_to_gps) + atoms_not_in_law_effects = all_possible_atoms - atoms_in_law_effects + E_log_Zt = E_log_Zt + len(atoms_not_in_law_effects) * torch.log( + 1 + torch.exp(log_frame_strength)) + + exp_state_prob = exp_state_prob - E_log_Zt + state_prob_t = state_prob_t - E_log_Zt + yt_prev = yt + if debug_log: + logging.debug( + f"seg {i}: start_t={start_t}, " + f"exp_state_prob_t={state_prob_t.detach().item():.4f}, " + f"add atoms: {add_atoms}, del atoms: {del_atoms}") + else: + for t in range(1, num_time_steps): + yt = trajectory.states[t] + + E_log_Zt = torch.tensor(0.0, dtype=log_frame_strength.dtype) + for atom, val_to_gps in atom_to_val_to_gps.items(): + sum_ytj = torch.tensor(0.0, dtype=log_frame_strength.dtype) + for val in (True, False): + gps = val_to_gps.get(val, set()) + + prod = torch.tensor(1.0, dtype=log_frame_strength.dtype) + for gp in gps: + for st, q in guide.get(gp, {}).items(): + if st < t: + # --- Efficient Cache Lookup --- + # Default to True if not in cache (e.g., no overall cond.) + condition_overall_holds = condition_cache.get( + gp, {}).get(st, {}).get(t - 1, True) + + prev_val = atom in yt_prev + # --- Numerator Part --- + if val == (atom + in yt) and condition_overall_holds: + exp_state_prob = exp_state_prob + q[t] * \ + gp.factored_effect_factor(val, atom, + prev_val) + # --- Denominator Part --- + if condition_overall_holds: + prod = prod * (q[t] * torch.exp( + gp.factored_effect_factor( + val, atom, prev_val)) + (1 - q[t])) + + sum_ytj = sum_ytj + prod * torch.exp(log_frame_strength * + (val == + (atom in yt_prev))) + E_log_Zt = E_log_Zt + torch.log(sum_ytj + 1e-12) + + # Atoms not referenced in any process law + add_atoms = yt - yt_prev + del_atoms = yt_prev - yt + atoms_unchanged = all_possible_atoms - del_atoms - add_atoms + exp_state_prob = exp_state_prob + log_frame_strength * len( + atoms_unchanged) + + # Normalization contribution from atoms not described by the processes + atoms_in_law_effects = set(atom_to_val_to_gps) + atoms_not_in_law_effects = all_possible_atoms - atoms_in_law_effects + E_log_Zt = E_log_Zt + len(atoms_not_in_law_effects) * torch.log( + 1 + torch.exp(log_frame_strength)) + + exp_state_prob = exp_state_prob - E_log_Zt + yt_prev = yt + ll = ll + exp_state_prob + + # ----------------------------------------------------------------- + # 2. Expected Delay probabilities + # ----------------------------------------------------------------- + exp_delay_prob = torch.tensor(0.0, dtype=log_frame_strength.dtype) + for gp_idx, gp_obj in enumerate(ground_processes): + for s_i in start_times_per_gp[gp_idx]: + if s_i + 1 < num_time_steps: + delay_values = torch.arange(1, + num_time_steps - s_i, + dtype=torch.long, + device=log_frame_strength.device) + if delay_values.numel() == 0: + continue + t_indices_for_guide = s_i + delay_values + all_delay_log_probs = gp_obj.delay_distribution.log_prob( # type: ignore[attr-defined] + delay_values) + q_dist_for_instance = guide.get(gp_obj, {}).get(s_i, None) + if q_dist_for_instance is None: + raise Exception( + f"Guide distribution not found for {gp_obj} at s_i={s_i}" + ) + guide_slice_for_delays = q_dist_for_instance[ + t_indices_for_guide] + valid_mask = ~torch.isneginf(all_delay_log_probs) & ( + guide_slice_for_delays > 1e-9) + if valid_mask.any(): + single_exp_delay_prob = torch.sum( + guide_slice_for_delays[valid_mask] * + all_delay_log_probs[valid_mask]) + exp_delay_prob = exp_delay_prob + single_exp_delay_prob + if debug_log: + logging.debug( + f"exp_delay_prob={single_exp_delay_prob.detach().item():.4f} " + f"start_t={s_i}, " + f"max_guide_values: at t={torch.argmax(q_dist_for_instance)}: {torch.max(q_dist_for_instance)}" + ) + logging.debug( + f"guide_prob at arrival_t (94): {q_dist_for_instance[94]}" + ) + + ll = ll + exp_delay_prob + + # ----------------------------------------------------------------- + # 3. Entropy of the variational distributions + # ----------------------------------------------------------------- + num_started_delays = 0 + entropy = torch.tensor(0.0, dtype=log_frame_strength.dtype) + for start_time_q_map in guide.values(): + for q_dist_for_instance in start_time_q_map.values(): + mask = q_dist_for_instance > 1e-9 + if mask.any(): + entropy -= torch.sum(q_dist_for_instance[mask] * + torch.log(q_dist_for_instance[mask])) + num_started_delays += 1 + # Add entropy for guide for delay variables who were not activated + num_gp = len(ground_processes) + num_unstarted_delays = num_gp * num_time_steps - num_started_delays + unstarted_delay_entropy = num_unstarted_delays * torch.log( + torch.tensor(1 / num_time_steps, dtype=log_frame_strength.dtype)) + entropy -= unstarted_delay_entropy + + elbo = ll + entropy + return elbo, exp_state_prob, exp_delay_prob, entropy + + +def compute_empirical_delays( + trajectories: List[LowLevelTrajectory], + predicates: Set[Predicate], + processes: Sequence[CausalProcess], +) -> Dict[str, List[int]]: + """Compute empirical delays for each process type from trajectory data. + + Returns a dictionary mapping process names to lists of observed + delays. + """ + atom_option_dataset = utils.create_ground_atom_option_dataset( + trajectories, predicates) + + # Dictionary to store delays for each process type + process_delays: Dict[str, List[int]] = defaultdict(list) + + for traj in atom_option_dataset: + traj_len = len(traj.states) + objs = set(traj._low_level_states[0]) + + # Ground the processes for this trajectory + _ground_processes, _ = process_task_plan_grounding( + init_atoms=set(), + objects=objs, + cps=processes, + allow_waits=True, + compute_reachable_atoms=False, + ) + ground_processes = [ + gp for gp in _ground_processes + if isinstance(gp, _GroundCausalProcess) + ] + + # For each ground process, find when it was triggered and when effects appeared + for gp in ground_processes: + # Find all times when this process was triggered + trigger_times = [] + for t in range(traj_len): + if gp.cause_triggered(traj.states[:t + 1], + traj.actions[:t + 1]): + trigger_times.append(t) + + # For each trigger time, find when the effects appeared + for trigger_t in trigger_times: + # Check when the add effects appear + for effect_t in range(trigger_t + 1, traj_len): + # Check if all add effects are present and all delete effects are gone + add_satisfied = gp.add_effects.issubset( + traj.states[effect_t]) + delete_satisfied = not any(atom in traj.states[effect_t] + for atom in gp.delete_effects) + + if add_satisfied and delete_satisfied: + # Found the effect time - compute delay + delay = effect_t - trigger_t + process_delays[gp.parent.name].append(delay) + break + + return process_delays + + +def learn_process_parameters_empirical( + trajectories: List[LowLevelTrajectory], + predicates: Set[Predicate], + processes: Sequence[CausalProcess], + use_empirical: bool = False, +) -> Tuple[Sequence[CausalProcess], Dict[str, Tuple[Optional[float], + Optional[float]]]]: + """Learn process parameters using empirical estimation of delays. + + When use_empirical=True, directly computes mean and std from + observed delays. Returns the processes with updated parameters and a + dict of statistics. + """ + if not use_empirical: + raise ValueError("This function is only for empirical estimation") + + # Compute empirical delays for each process type + process_delays = compute_empirical_delays(trajectories, predicates, + processes) + + # Statistics dictionary to return + stats: Dict[str, Tuple[Optional[float], Optional[float]]] = {} + + # Update each process with empirical parameters + for process in processes: + if process.name in process_delays and len( + process_delays[process.name]) > 0: + delays = torch.tensor(process_delays[process.name], + dtype=torch.float32) + + # Compute mean and std + empirical_mean = delays.mean() + empirical_std = delays.std() if len(delays) > 1 else torch.tensor( + 0.1) + + # Ensure std is not too small + empirical_std = torch.clamp(empirical_std, min=0.1) + + # Create parameter tensor [log_strength, log_mu, log_sigma] + # We'll keep strength at 1.0 (log(1) = 0) since we're ignoring it + params = torch.tensor([ + 0.0, # log_strength = 0 (strength = 1) + torch.log(empirical_mean), # log_mu + torch.log(empirical_std) # log_sigma + ]) + + # Update the process parameters + process._set_parameters(params.tolist()) + + # Store statistics + stats[process.name] = (empirical_mean.item(), empirical_std.item()) + + print(f"Process {process.name}:") + print(f" Observed delays: {process_delays[process.name]}") + print(f" Empirical mean: {empirical_mean:.2f}") + print(f" Empirical std: {empirical_std:.2f}") + else: + # No observations for this process - use defaults + print( + f"Process {process.name}: No observations found, keeping defaults" + ) + stats[process.name] = (None, None) + + return processes, stats + + +@torch.no_grad() +def evaluate_model_on_dataset( + per_traj_data: List[Dict[str, Any]], + frame_param: torch.Tensor, + guide_params: torch.Tensor, + ignore_entropy: bool = False, + debug_log: bool = False) -> Tuple[float, float, float, float]: + """Evaluates a trained model on the full dataset. + + TODO: maybe normalize by number of segments? or total number of steps? + """ + total_elbo, total_exp_state, total_exp_delay, total_entropy = 0.0, 0.0, 0.0, 0.0 + + for td in per_traj_data: + guide_dict = _create_guide_dict_for_trajectory(td, guide_params, + td["traj_len"]) + + data_elbo, data_exp_state, data_exp_delay, data_entropy = elbo_torch( + td["trajectory"], + td["sparse_trajectory"], + td["ground_causal_processes"], + td["start_times_per_gp"], + guide_dict, + frame_param, + set(td["all_atoms"]), + td["atom_to_val_to_gps"], + td["condition_cache"], + debug_log=debug_log) + total_elbo += data_elbo.item() + if ignore_entropy: + total_elbo -= data_entropy.item() + total_exp_state += data_exp_state.item() + total_exp_delay += data_exp_delay.item() + total_entropy += data_entropy.item() + + num_trajectories = len(per_traj_data) + mean_elbo = total_elbo / num_trajectories + mean_exp_state = total_exp_state / num_trajectories + mean_exp_delay = total_exp_delay / num_trajectories + mean_entropy = total_entropy / num_trajectories + + return mean_elbo, mean_exp_state, mean_exp_delay, mean_entropy + + +def _set_process_parameters(processes: Sequence[CausalProcess], + parameters: Tensor, **kwargs: Any) -> None: + # Parameters are for the CausalProcess types, not ground instances. + # Assumes 3 parameters per CausalProcess type (e.g., for its delay distribution) + num_causal_process_types = len(processes) + expected_len = 3 * num_causal_process_types + assert len(parameters) == expected_len, \ + f"Expected {expected_len} params, got {len(parameters)}" + + # Loop through the CausalProcess types + for i in range(num_causal_process_types): + param_slice = parameters[i * 3:(i + 1) * 3] + processes[i]._set_parameters(param_slice.tolist(), **kwargs) + + +def _compute_condition_cache_for_traj( + ground_processes: List[_GroundCausalProcess], + start_times_per_gp: List[List[int]], history: List[Set[GroundAtom]], + num_time_steps: int +) -> Dict[_GroundCausalProcess, Dict[int, Dict[int, bool]]]: + """Pre-computes which `condition_overall` holds at each time step for a + single trajectory.""" + condition_cache: Dict[_GroundCausalProcess, Dict[int, Dict[int, + bool]]] = {} + for gp_idx, gp in enumerate(ground_processes): + # Only need to cache for processes that have an overall condition + if not gp.condition_overall: + continue + condition_cache[gp] = {} + for st in start_times_per_gp[gp_idx]: + condition_cache[gp][st] = {} + # Use dynamic programming: the result at `t` depends on the result at `t-1` + is_still_holding = True + for t_interval in range(st + 1, num_time_steps): + # Check only the new state at the end of the interval + if not gp.condition_overall.issubset(history[t_interval]): + is_still_holding = False + # The result for the interval [st+1, t_interval+1] is stored + condition_cache[gp][st][t_interval] = is_still_holding + return condition_cache + + +def _prepare_training_data_and_model_params( + predicates: Set[Predicate], processes: Sequence[CausalProcess], + trajectories: List[LowLevelTrajectory], check_condition_overall: bool +) -> Tuple[List[Dict[str, Any]], torch.nn.Parameter, int]: + """Cache per-trajectory data, build global param layout for process and + guide parameters, and initialize them.""" + atom_option_dataset = utils.create_ground_atom_option_dataset( + trajectories, predicates) + + per_traj_data: List[Dict[str, Any]] = [] + # num_proc_params is now just the number of process parameters + num_proc_params = 3 * len(processes) + q_offset = 0 + + for traj_id, traj in enumerate(atom_option_dataset): + traj_len = len(traj.states) + objs = set(traj._low_level_states[0]) + + _ground_processes, _ = process_task_plan_grounding( + init_atoms=set(), + objects=objs, + cps=processes, + allow_waits=True, + compute_reachable_atoms=False, + ) + ground_processes = [ + gp for gp in _ground_processes + if isinstance(gp, _GroundCausalProcess) + ] + + atom_to_val_to_gps: Dict[GroundAtom, Dict[ + bool, + Set[_GroundCausalProcess]]] = defaultdict(lambda: defaultdict(set)) + for gp in ground_processes: + for a in gp.add_effects: + atom_to_val_to_gps[a][True].add(gp) + for a in gp.delete_effects: + atom_to_val_to_gps[a][False].add(gp) + + start_times = [[ + t for t in range(traj_len) + if gp.cause_triggered(traj.states[:t + 1], traj.actions[:t + 1]) + ] for gp in ground_processes] + + # Pre-compute the condition cache for this trajectory + condition_cache: Dict[_GroundCausalProcess, + Dict[int, Dict[int, bool]]] = {} + if check_condition_overall: + condition_cache = _compute_condition_cache_for_traj( + ground_processes, start_times, traj.states, traj_len) + + gp_qparam_id_map: Dict[Tuple[_GroundCausalProcess, int], + Tuple[int, int]] = {} + for gp_idx, gp in enumerate(ground_processes): + for s_i in start_times[gp_idx]: + lo, hi = q_offset, q_offset + traj_len + gp_qparam_id_map[(gp, s_i)] = (lo, hi) + q_offset = hi + + # 1. Create sparse representation: [(state, start_time, end_time)] + sparse_trajectory = [] + if len(traj.states) > 1: + yt_prev = traj.states[0] + start_t = 0 + for t in range(1, len(traj.states)): + if traj.states[t] != yt_prev: + sparse_trajectory.append((yt_prev, start_t, t - 1)) + yt_prev = traj.states[t] + start_t = t + sparse_trajectory.append((yt_prev, start_t, len(traj.states) - 1)) + + per_traj_data.append({ + "trajectory": + traj, + "sparse_trajectory": + sparse_trajectory, + "traj_len": + traj_len, + "ground_causal_processes": + ground_processes, + "start_times_per_gp": + start_times, + "atom_to_val_to_gps": + atom_to_val_to_gps, + "all_atoms": + utils.all_possible_ground_atoms(traj._low_level_states[0], + predicates), + "gp_qparam_id_map": + gp_qparam_id_map, + "condition_cache": + condition_cache + }) + + # Total parameters for processes and the guide ONLY + total_params_len = num_proc_params + q_offset + model_params = torch.nn.Parameter(torch.randn(total_params_len) * 0.01) + + return per_traj_data, model_params, num_proc_params + + +def _initialize_params_with_empirical_estimates( + trajectories: List[LowLevelTrajectory], + predicates: Set[Predicate], + processes: Sequence[CausalProcess], + model_params: torch.nn.Parameter, + num_proc_params: int, +) -> None: + """Initialize process parameters using empirical estimates from trajectory + data. + + This function computes empirical delays from trajectory data and uses them to + initialize the process parameters in the model_params tensor. Only the process + parameters (first num_proc_params elements) are modified - guide parameters + remain randomly initialized. + """ + # Compute empirical delays for each process type + process_delays = compute_empirical_delays(trajectories, predicates, + processes) + + # Initialize process parameters with empirical estimates + with torch.no_grad(): + for i, process in enumerate(processes): + param_start_idx = i * 3 # 3 parameters per process + + if process.name in process_delays and len( + process_delays[process.name]) > 0: + delays = torch.tensor(process_delays[process.name], + dtype=torch.float32) + + # Compute mean and std + empirical_mean = delays.mean() + empirical_std = delays.std( + ) if len(delays) > 1 else torch.tensor(0.1) + empirical_std = torch.clamp(empirical_std, min=0.1) + + # Set parameters: [log_strength, log_mu, log_sigma] + model_params.data[ + param_start_idx] = 0.0 # log_strength = 0 (strength = 1) + model_params.data[param_start_idx + 1] = torch.log( + empirical_mean) # log_mu + model_params.data[param_start_idx + 2] = torch.log( + empirical_std) # log_sigma + + print(f"Empirically initialized process {process.name}:") + print(f" Mean delay: {empirical_mean:.2f}") + print(f" Std delay: {empirical_std:.2f}") + else: + # No observations - keep random initialization but log it + print( + f"Process {process.name}: No empirical data, keeping random initialization" + ) + + +def _create_guide_dict_for_trajectory( + td: Dict[str, Any], + guide_flat: Tensor, + traj_len: int, +) -> Dict[_GroundCausalProcess, Dict[int, Tensor]]: + """Helper to create the guide distribution dictionary for a single + trajectory.""" + guide_dict: Dict[_GroundCausalProcess, Dict[int, + Tensor]] = defaultdict(dict) + for (gp, s_i), (lo, hi) in td["gp_qparam_id_map"].items(): + # Create the causality mask to prevent effects from occurring at or before the cause + mask = torch.ones(traj_len, + dtype=torch.float32, + device=guide_flat.device) + mask[:s_i + 1] = 0 + + # Current behavior: softmax over learnable logits + raw = guide_flat[lo:hi] + probs = torch.softmax(raw + torch.log(mask + 1e-20), dim=0) + + guide_dict[gp][s_i] = probs + return guide_dict + + +def _plot_training_curve(training_curve: Dict, + image_dir: str = "images") -> None: + """Plot the training curve showing ELBO over iterations.""" + import matplotlib.pyplot as plt + + iterations = training_curve['iterations'] + elbos = training_curve['elbos'] + best_elbos = training_curve['best_elbos'] + wall_time = training_curve['wall_time'] + + plt.figure(figsize=(18, 6)) # Adjusted figure size for three plots + + # Plot current ELBO vs Iteration + plt.subplot(1, 2, 1) + plt.plot(iterations, elbos, 'b-', alpha=0.7, label='Current ELBO') + plt.plot(iterations, best_elbos, 'r-', linewidth=2, label='Best ELBO') + plt.xlabel('Iteration') + plt.ylabel('ELBO') + plt.title('ELBO vs Iteration') + plt.legend() + plt.grid(True, alpha=0.3) + + # Plot ELBO vs Wall Time + plt.subplot(1, 2, 2) + plt.plot(wall_time, elbos, 'b-', alpha=0.7, label='Current ELBO') + plt.plot(wall_time, best_elbos, 'r-', linewidth=2, label='Best ELBO') + plt.xlabel('Wall Time (s)') + plt.ylabel('ELBO') + plt.title('ELBO vs Wall Time') + plt.legend() + plt.grid(True, alpha=0.3) + + plt.tight_layout() + + # Save the plot + filename = f"training_curve.png" + plt.savefig(os.path.join(image_dir, filename)) + logging.info(f"Training curve saved to {filename}") + plt.close() diff --git a/predicators/approaches/pp_predicate_invention_approach.py b/predicators/approaches/pp_predicate_invention_approach.py new file mode 100644 index 0000000000..207e80f280 --- /dev/null +++ b/predicators/approaches/pp_predicate_invention_approach.py @@ -0,0 +1,62 @@ +import logging +from typing import Any, Dict, List, Optional, Sequence, Set + +from gym.spaces import Box + +from predicators.approaches.pp_process_learning_approach import \ + ProcessLearningAndPlanningApproach +from predicators.option_model import _OptionModelBase +from predicators.settings import CFG +from predicators.structs import Dataset, ParameterizedOption, Predicate, \ + Task, Type + + +class PredicateInventionProcessPlanningApproach( + ProcessLearningAndPlanningApproach): + """A bilevel planning approach that invent predicates.""" + + def __init__(self, + initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], + types: Set[Type], + action_space: Box, + train_tasks: List[Task], + task_planning_heuristic: str = "default", + max_skeletons_optimized: int = -1, + bilevel_plan_without_sim: Optional[bool] = None, + option_model: Optional[_OptionModelBase] = None): + self._learned_predicates: Set[Predicate] = set() + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + task_planning_heuristic, + max_skeletons_optimized, + bilevel_plan_without_sim, + option_model=option_model) + + @classmethod + def get_name(cls) -> str: + return "predicate_invention_and_process_planning" + + def _get_current_predicates(self) -> Set[Predicate]: + """Get the current predicates.""" + return self._initial_predicates | self._learned_predicates + + def learn_from_offline_dataset(self, dataset: Dataset) -> None: + self._offline_dataset = dataset + # --- Invent Predicates --- + # Check the atomic trajectory + + # ----- Predicate Proposal ----- + + # ----- Predicate Selection ----- + + # --- Learn Processes --- + self._learn_processes(dataset.trajectories, + online_learning_cycle=None, + annotations=(dataset.annotations if + dataset.has_annotations else None)) + if CFG.learn_process_parameters: + self._learn_process_parameters(dataset.trajectories) diff --git a/predicators/approaches/pp_process_learning_approach.py b/predicators/approaches/pp_process_learning_approach.py new file mode 100644 index 0000000000..19708f5a44 --- /dev/null +++ b/predicators/approaches/pp_process_learning_approach.py @@ -0,0 +1,99 @@ +import logging +from typing import Any, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple + +import dill as pkl +from gym.spaces import Box + +from predicators import utils +from predicators.approaches.pp_param_learning_approach import \ + ParamLearningBilevelProcessPlanningApproach +from predicators.ground_truth_models import get_gt_processes +from predicators.nsrt_learning.process_learning_main import \ + learn_processes_from_data +from predicators.option_model import _OptionModelBase +from predicators.settings import CFG +from predicators.structs import CausalProcess, Dataset, ExogenousProcess, \ + GroundAtomTrajectory, LiftedAtom, LowLevelTrajectory, \ + ParameterizedOption, Predicate, Task, Type + + +class ProcessLearningAndPlanningApproach( + ParamLearningBilevelProcessPlanningApproach): + """A bilevel planning approach that learns processes.""" + + def __init__(self, + initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], + types: Set[Type], + action_space: Box, + train_tasks: List[Task], + task_planning_heuristic: str = "default", + max_skeletons_optimized: int = -1, + bilevel_plan_without_sim: Optional[bool] = None, + option_model: Optional[_OptionModelBase] = None): + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + task_planning_heuristic, + max_skeletons_optimized, + bilevel_plan_without_sim, + option_model=option_model) + if CFG.only_learn_exogenous_processes: + self._processes = get_gt_processes(CFG.env, + self._initial_predicates, + self._initial_options, + only_endogenous=True) + else: + # Learn all + self._processes: Set[CausalProcess] = set() + self._proc_name_to_results: Dict[str, List[Tuple[float, + FrozenSet[LiftedAtom], + Tuple, + CausalProcess]]] = {} + + @classmethod + def get_name(cls) -> str: + return "process_learning_and_planning" + + def learn_from_offline_dataset(self, dataset: Dataset) -> None: + """Learn models from the offline datasets.""" + self._learn_processes(dataset.trajectories, + online_learning_cycle=None, + annotations=(dataset.annotations if + dataset.has_annotations else None)) + # Optional: learn process parameters + if CFG.learn_process_parameters: + self._learn_process_parameters(dataset.trajectories) + + def _learn_processes(self, + trajectories: List[LowLevelTrajectory], + online_learning_cycle: Optional[int], + annotations: Optional[List[Any]] = None) -> None: + """Learn processes from the offline datasets.""" + dataset_fname, _ = utils.create_dataset_filename_str( + saving_ground_atoms=True, + online_learning_cycle=online_learning_cycle) + ground_atom_dataset: Optional[List[GroundAtomTrajectory]] = None + if CFG.load_atoms: + ground_atom_dataset = utils.load_ground_atom_dataset( + dataset_fname, trajectories) + elif CFG.save_atoms: + ground_atom_dataset = utils.create_ground_atom_dataset( + trajectories, self._get_current_predicates()) + self._processes, self._proc_name_to_results = \ + learn_processes_from_data(trajectories, + self._train_tasks, + self._get_current_predicates(), + self._initial_options, + self._action_space, + ground_atom_dataset, + sampler_learner=CFG.sampler_learner, + annotations=annotations, + current_processes=self._get_current_processes(), + online_learning_cycle=online_learning_cycle,) + + save_path = utils.get_approach_save_path_str() + with open(f"{save_path}_{online_learning_cycle}.PROCes", "wb") as f: + pkl.dump(self._processes, f) diff --git a/predicators/approaches/process_planning_approach.py b/predicators/approaches/process_planning_approach.py new file mode 100644 index 0000000000..e86ecc85fc --- /dev/null +++ b/predicators/approaches/process_planning_approach.py @@ -0,0 +1,316 @@ +import abc +import logging +from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple + +from gym.spaces import Box + +from predicators import utils +from predicators.approaches import ApproachFailure, ApproachTimeout +from predicators.approaches.bilevel_planning_approach import \ + BilevelPlanningApproach +from predicators.option_model import _OptionModelBase +from predicators.planning import PlanningFailure, PlanningTimeout +from predicators.planning_with_processes import ProcessWorldModel, \ + process_task_plan_grounding, run_task_plan_with_processes_once, \ + sesame_plan_with_processes +from predicators.settings import CFG +from predicators.structs import AbstractProcessPolicy, Action, CausalProcess, \ + EndogenousProcess, GroundAtom, Metrics, Object, ParameterizedOption, \ + Predicate, State, Task, Type, _GroundEndogenousProcess, _Option + + +class BilevelProcessPlanningApproach(BilevelPlanningApproach): + """A bilevel planning approach that doesn't use the nsrt world model but + uses the process world model.""" + + def __init__(self, + initial_predicates: Set[Predicate], + initial_options: Set[ParameterizedOption], + types: Set[Type], + action_space: Box, + train_tasks: List[Task], + task_planning_heuristic: str = "default", + max_skeletons_optimized: int = -1, + bilevel_plan_without_sim: Optional[bool] = None, + option_model: Optional[_OptionModelBase] = None) -> None: + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + task_planning_heuristic, + max_skeletons_optimized, + bilevel_plan_without_sim, + option_model=option_model) + self._last_option_plan: List[_Option] = [] # used if plan WITH sim + + # Conditionally load VLM components if an abstract policy is used. + self._vlm = None + self.base_prompt = "" + if CFG.process_planning_use_abstract_policy: + # Set up the vlm and base prompt. + self._vlm = utils.create_llm_by_name(CFG.llm_model_name) + # Note: requires a new CFG setting, e.g., + # process_planning_vlm_prompt_suffix = "_process" + prompt_suffix = CFG.process_planning_vlm_prompt_suffix + filepath_to_vlm_prompt = utils.get_path_to_predicators_root() + \ + "/predicators/approaches/vlm_planning_prompts/no_few_shot_hla_plan" + \ + f"{prompt_suffix}.txt" + with open(filepath_to_vlm_prompt, "r", encoding="utf-8") as f: + self.base_prompt = f.read() + + @abc.abstractmethod + def _get_current_processes(self) -> Set[CausalProcess]: + """Get the current set of Processes.""" + raise NotImplementedError("Override me!") + + def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: + self._num_calls += 1 + # ensure random over successive + seed = self._seed + self._num_calls + processes = self._get_current_processes() + preds = self._get_current_predicates() + + abstract_policy = None + if CFG.process_planning_use_abstract_policy: + abstract_policy = self._build_abstract_policy(task) + + # Run task planning only and then greedily sample + # and execute in the policy. + if self._plan_without_sim: + process_plan, atoms_seq, metrics =\ + self._run_task_plan_with_processes( + task, + processes, + preds, + timeout, + seed, + abstract_policy=abstract_policy, + max_policy_guided_rollout=CFG. + process_planning_max_policy_guided_rollout) + self._last_process_plan = process_plan + self._last_atoms_seq = atoms_seq + policy = utils.process_plan_to_greedy_policy( + process_plan, + task.goal, + self._rng, + abstract_function=lambda s: utils.abstract(s, preds)) + logging.debug("Current Task Plan:") + for process in process_plan: + logging.debug(process.name) + else: + option_plan, process_plan, metrics = \ + self._run_sesame_plan_with_processes( + task, + processes, + preds, + timeout, + seed, + abstract_policy=abstract_policy, + max_policy_guided_rollout=CFG. + process_planning_max_policy_guided_rollout) + self._last_option_plan = option_plan + self._last_process_plan = process_plan + policy = utils.option_plan_to_policy(option_plan) + + self._save_metrics(metrics, processes, preds) + + def _policy(s: State) -> Action: + try: + return policy(s) + except utils.OptionExecutionFailure as e: + raise ApproachFailure(e.args[0], e.info) + + return _policy + + def _run_task_plan_with_processes( + self, task: Task, processes: Set[CausalProcess], preds: Set[Predicate], + timeout: int, seed: int, **kwargs: Any + ) -> Tuple[List[_GroundEndogenousProcess], List[Set[GroundAtom]], Metrics]: + try: + plan, atoms_seq, metrics = run_task_plan_with_processes_once( + task, + processes, + preds, + self._types, + timeout, + seed, + task_planning_heuristic=self._task_planning_heuristic, + max_horizon=float(CFG.horizon), + **kwargs) + except PlanningFailure as e: + raise ApproachFailure(e.args[0], e.info) + except PlanningTimeout as e: + raise ApproachTimeout(e.args[0], e.info) + + return plan, atoms_seq, metrics + + def _run_sesame_plan_with_processes( + self, task: Task, processes: Set[CausalProcess], preds: Set[Predicate], + timeout: float, seed: int, **kwargs: Any + ) -> Tuple[List[_Option], List[_GroundEndogenousProcess], Metrics]: + """Run full bilevel planning with processes. + + Subclasses may override, e.g. to insert an abstract policy. + """ + try: + option_plan, process_plan, metrics = sesame_plan_with_processes( + task, + self._option_model, + processes, + preds, + timeout, + seed, + max_skeletons_optimized=self._max_skeletons_optimized, + max_horizon=CFG.horizon, + **kwargs) + except PlanningFailure as e: + raise ApproachFailure(e.args[0], e.info) + except PlanningTimeout as e: + raise ApproachTimeout(e.args[0], e.info) + + return option_plan, process_plan, metrics + + def _save_metrics( # type: ignore[override] + self, metrics: Metrics, processes: Set[CausalProcess], + predicates: Set[Predicate]) -> None: + for metric in [ + "num_samples", "num_skeletons_optimized", + "num_failures_discovered", "num_nodes_expanded", + "num_nodes_created", "plan_length", "refinement_time" + ]: + self._metrics[f"total_{metric}"] += metrics[metric] + self._metrics["total_num_processes"] += len(processes) + self._metrics["total_num_preds"] += len(predicates) + for metric in [ + "num_samples", + "num_skeletons_optimized", + ]: + self._metrics[f"min_{metric}"] = min( + metrics[metric], self._metrics[f"min_{metric}"]) + self._metrics[f"max_{metric}"] = max( + metrics[metric], self._metrics[f"max_{metric}"]) + + def _build_abstract_policy(self, task: Task) -> AbstractProcessPolicy: + """Use a VLM to generate a plan and build a policy from it.""" + # 1. Set up for VLM query. + init_atoms = utils.abstract(task.init, self._get_current_predicates()) + objects = set(task.init) + all_processes = self._get_current_processes() + endogenous_processes = sorted( + [p for p in all_processes if isinstance(p, EndogenousProcess)]) + vlm_process_plan = self._get_vlm_plan(task, init_atoms, objects, + endogenous_processes) + + # 3. Build the partial policy dictionary by simulating the plan. + partial_policy_dict: Dict[FrozenSet[GroundAtom], + _GroundEndogenousProcess] = {} + current_atoms = init_atoms.copy() + all_ground_processes, _ = process_task_plan_grounding(init_atoms, + objects, + all_processes, + allow_waits=True) + all_predicates = utils.add_in_auxiliary_predicates( + self._get_current_predicates()) + derived_predicates = utils.get_derived_predicates(all_predicates) + + # Build indexes for efficient world model execution (do this once outside the loop) + from collections import defaultdict + + from predicators.planning_with_processes import \ + _build_exogenous_process_index + + precondition_to_exogenous_processes = None + if CFG.build_exogenous_process_index_for_planning: + precondition_to_exogenous_processes = _build_exogenous_process_index( + all_ground_processes) + + # Pre-compute dependencies for incremental derived predicates + dep_to_derived_preds = defaultdict(list) + for der_pred in derived_predicates: + for aux_pred in (der_pred.auxiliary_predicates or set()): + dep_to_derived_preds[aux_pred].append(der_pred) + + for ground_process in vlm_process_plan: + if not ground_process.condition_at_start.issubset(current_atoms): + logging.warning(f"VLM plan deviates, precondition not met for " + f"{ground_process.name_and_objects_str()}") + break + + frozen_atoms = frozenset(current_atoms) + partial_policy_dict[frozen_atoms] = ground_process + + # Simulate the step to get the next state with proper indexing. + world_model = ProcessWorldModel( + ground_processes=all_ground_processes, + state=current_atoms.copy(), + state_history=[], + action_history=[], + scheduled_events={}, + t=0, + derived_predicates=derived_predicates, + objects=objects, + precondition_to_exogenous_processes= + precondition_to_exogenous_processes, + dep_to_derived_preds=dep_to_derived_preds) + + world_model.big_step(ground_process) + current_atoms = world_model.state + + # 4. Create and return the abstract policy. + abstract_policy = lambda atoms, _1, _2: partial_policy_dict.get( + frozenset(atoms), None) + + return abstract_policy + + def _get_vlm_plan( + self, task: Task, init_atoms: Set[GroundAtom], objects: Set[Object], + endogenous_processes: List[EndogenousProcess] + ) -> List[_GroundEndogenousProcess]: + + # 2. Query VLM for a process plan. + processes_str = "\n".join(str(p) for p in endogenous_processes) + objects_list = sorted(list(objects)) + objects_str = "\n".join(str(obj) for obj in objects_list) + goal_str = "\n".join(str(g) for g in sorted(task.goal)) + type_hierarchy_str = utils.create_pddl_types_str(self._types) + init_state_str = "\n".join(map(str, sorted(init_atoms))) + + prompt = self.base_prompt.format(processes=processes_str, + typed_objects=objects_str, + type_hierarchy=type_hierarchy_str, + init_state_str=init_state_str, + goal_str=goal_str) + + try: + assert self._vlm is not None + vlm_output = self._vlm.sample_completions( + prompt, + imgs=None, # No images for process planning. + temperature=CFG.vlm_temperature, + seed=CFG.seed, + num_completions=1) + plan_prediction_txt = vlm_output[0] + start_index = plan_prediction_txt.index("Plan:\n") + len("Plan:\n") + parsable_plan_prediction = plan_prediction_txt[start_index:] + except (ValueError, IndexError, AssertionError) as e: + logging.warning(f"VLM output parsing failed, returning trivial " + f"policy. Reason: {e}") + # Return an empty plan on parsing failure + vlm_process_plan = [] + + # Note: this requires a new utility function, + # `parse_model_output_into_process_plan`, which should be analogous + # to `parse_model_output_into_option_plan`. + try: + parsed_process_plan = utils.parse_model_output_into_process_plan( # type: ignore[attr-defined] + parsable_plan_prediction, objects_list, self._types, + endogenous_processes) + vlm_process_plan = [ + p.ground(objs) for p, objs in parsed_process_plan + ] + except Exception as e: + logging.warning(f"Failed to parse/ground VLM process plan: {e}") + vlm_process_plan = [] + + return vlm_process_plan diff --git a/predicators/nsrt_learning/process_learning/__init__.py b/predicators/nsrt_learning/process_learning/__init__.py new file mode 100644 index 0000000000..5d1b9bb2f7 --- /dev/null +++ b/predicators/nsrt_learning/process_learning/__init__.py @@ -0,0 +1,28 @@ +from typing import Any, List, Optional, Set + +from predicators import utils +from predicators.nsrt_learning.process_learning.base_process_learner import \ + BaseProcessLearner +from predicators.settings import CFG +from predicators.structs import PAPAD, CausalProcess, ExogenousProcess, \ + LowLevelTrajectory, Predicate, Segment, Task + +__all__ = ["BaseProcessLearner"] + +# Import submodules to register them. +utils.import_submodules(__path__, __name__) + + +def learn_exogenous_processes(trajectories: List[LowLevelTrajectory], + train_tasks: List[Task], + predicates: Set[Predicate], + segmented_trajs: List[List[Segment]], + verify_harmlessness: bool, + annotations: Optional[List[Any]], + verbose: bool = True) -> List[ExogenousProcess]: + """Learn exogenous processes on the given data segments.""" + for cls in utils.get_all_subclasses(BaseProcessLearner): + if not cls.__abstractmethods__ and \ + cls.get_name() == CFG.exogenous_process_learner: + learner = cls(...) + raise NotImplementedError diff --git a/predicators/nsrt_learning/process_learning/base_process_learner.py b/predicators/nsrt_learning/process_learning/base_process_learner.py new file mode 100644 index 0000000000..b8a562ca04 --- /dev/null +++ b/predicators/nsrt_learning/process_learning/base_process_learner.py @@ -0,0 +1,6 @@ +"""Base class for process learning algorithms.""" +import abc + + +class BaseProcessLearner(abc.ABC): + """Base class definition.""" \ No newline at end of file diff --git a/predicators/nsrt_learning/process_learning_main.py b/predicators/nsrt_learning/process_learning_main.py new file mode 100644 index 0000000000..955f12155a --- /dev/null +++ b/predicators/nsrt_learning/process_learning_main.py @@ -0,0 +1,270 @@ +import logging +from pprint import pformat +from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple, cast + +from gym.spaces import Box + +from predicators import utils +from predicators.nsrt_learning.nsrt_learning_main import _learn_pnad_options, \ + _learn_pnad_samplers +from predicators.nsrt_learning.segmentation import segment_trajectory +from predicators.nsrt_learning.strips_learning import learn_strips_operators +from predicators.nsrt_learning.strips_learning.clustering_learner import \ + ClusterAndSearchProcessLearner +from predicators.settings import CFG +from predicators.structs import PNAD, CausalProcess, DerivedPredicate, \ + DummyOption, EndogenousProcess, ExogenousProcess, GroundAtom, \ + GroundAtomTrajectory, LiftedAtom, LowLevelTrajectory, \ + ParameterizedOption, Predicate, Segment, Task + + +def learn_processes_from_data( + trajectories: List[LowLevelTrajectory], + train_tasks: List[Task], + predicates: Set[Predicate], + known_options: Optional[Set[ParameterizedOption]] = None, + action_space: Optional[Box] = None, + ground_atom_dataset: Optional[List[GroundAtomTrajectory]] = None, + sampler_learner: Optional[str] = None, + annotations: Optional[List[Any]] = None, + current_processes: Optional[Set[CausalProcess]] = None, + log_all_processes: bool = True, + online_learning_cycle: Optional[int] = None, +) -> Tuple[Set[CausalProcess], Dict[str, List]]: + """Learn CausalProcesses from the given dataset of low-level transitions, + using the given set of predicates.""" + logging.info(f"\nLearning CausalProcesses on {len(trajectories)} " + "trajectories...") + # remember to reset at the end + initial_segmentation_method = CFG.segmenter + + # We will probably learn endogenous and exogenous processes separately. + if CFG.only_learn_exogenous_processes: + endogenous_processes = [ + p for p in (current_processes or []) + if isinstance(p, EndogenousProcess) + ] + else: + assert sampler_learner is not None, \ + "Sampler learner must be specified for action model learning." + # -- Learn the endogenous processes --- + CFG.segmenter = "option_changes" + # STEP 1: Segment the trajectory by options. (don't currently consider + # segmenting by predicates). + # Segment each trajectory in the dataset based on changes in + # either predicates or options. If we are doing option learning, + # then the data will not contain options, so this segmenting + # procedure only uses the predicates. + # If we know the option segmentations this is pretty similar to + # learning NSRTs. + if ground_atom_dataset is None: + segmented_trajs = [ + segment_trajectory(traj, predicates) for traj in trajectories + ] + else: + segmented_trajs = [ + segment_trajectory(traj, predicates, ground_atom_dataset[i][1]) + for i, traj in enumerate(trajectories) + ] + + # STEP 2: Learn STRIPS operators on the given data segments as for NSRTs. + pnads = learn_strips_operators( + trajectories, + train_tasks, + predicates, + segmented_trajs, + verify_harmlessness= + False, # these processes are in principal 'harmful' + # because they should leave some atoms to be explained by exogenous + # processes. + verbose=(CFG.option_learner != "no_learning"), + annotations=annotations) + + # STEP 3: Learn options and update PNADs + if CFG.strips_learner != "oracle" or CFG.sampler_learner != "oracle" or \ + CFG.option_learner != "no_learning": + assert action_space is not None, \ + "Action space must be provided for option learning." + assert known_options is not None, \ + "Known options must be provided for option learning." + # Updates the endo_papads in-place. + _learn_pnad_options(pnads, known_options, action_space) + + # STEP 4 (currently skipped): Learn samplers and update PNADs + _learn_pnad_samplers(pnads, sampler_learner) + + # STEP 5: Convert PNADs to endogenous processes. (Maybe also make rough + # parameter estimates.) + endogenous_processes = [ + pnad.make_endogenous_process() for pnad in pnads + ] + # for proc in endogenous_processes: + # logging.debug(f"{proc}") + # logging.debug("") + + # --- Learn the exogenous processes. --- + # STEP 1: Segment the trajectory by atom_changes, and filter out the ones + # that are explained by the endogenous processes. + CFG.segmenter = "atom_changes" + CFG.strips_learner = CFG.exogenous_process_learner + + segmented_trajs = [ + segment_trajectory(traj, predicates, verbose=False) + for traj in trajectories + ] + # Filter out segments explained by endogenous processes. + remaining_segmented_trajs = filter_explained_segment( + segmented_trajs, + cast(List[CausalProcess], endogenous_processes), + remove_options=False) + + # STEP 2: Learn the exogenous processes based on unexplained processes. + # This is different from STRIPS/endogenous processes, where these + # don't have options and samplers. + num_unexplaned_segments = sum( + len(sugments) for sugments in remaining_segmented_trajs) + if num_unexplaned_segments == 0: + new_exogenous_processes = [] + else: + process_learner = ClusterAndSearchProcessLearner( + trajectories, + train_tasks, + predicates, + remaining_segmented_trajs, + verify_harmlessness=False, + verbose=(CFG.option_learner != "no_learning"), + annotations=annotations, + endogenous_processes=set(endogenous_processes), + online_learning_cycle=online_learning_cycle, + ) + exogenous_processes_pnad = process_learner.learn() + new_exogenous_processes = [ + pnad.make_exogenous_process() for pnad in exogenous_processes_pnad + ] + # Get the other conditions' scores through class attributes. + proc_name_to_results: Dict[str, List[ + Tuple[float, FrozenSet[LiftedAtom], Tuple, ExogenousProcess]]] =\ + process_learner.proc_name_to_results + logging.info( + f"Learned {len(new_exogenous_processes)} exogenous processes:\n" + f"{pformat(new_exogenous_processes)}") + if CFG.pause_after_process_learning_for_inspection: + input("Press Enter to continue...") # pause for user inspection + + # STEP 3: Make, log, and return the endogenous and exogenous processes. + processes = endogenous_processes + new_exogenous_processes + if log_all_processes: + logging.info(f"\nLearned CausalProcesses:\n{pformat(processes)}") + + CFG.segmenter = initial_segmentation_method + return set(processes), proc_name_to_results + + +def is_endogenous_process_list(processes: List) -> bool: + """Check if all elements in the list are EndogenousProcess.""" + return all(isinstance(p, EndogenousProcess) for p in processes) + + +def is_exogenous_process_list(processes: List) -> bool: + """Check if all elements in the list are ExogenousProcess.""" + return all(isinstance(p, ExogenousProcess) for p in processes) + + +def filter_explained_segment( + segmented_trajs: List[List[Segment]], + processes: List[CausalProcess], + remove_options: bool = False, + log_remaining_trajs: bool = False, +) -> List[List[Segment]]: + """Filter out segments that are explained by the given PNADs.""" + num_segments = sum(len(traj) for traj in segmented_trajs) + if is_endogenous_process_list(processes): + processes_type_str = "endogenous" + elif is_exogenous_process_list(processes): + processes_type_str = "exogenous" + else: + raise NotImplementedError("Currently don't support " + "mixed process types.") + logging.debug(f"\nNum of segments before filtering the ones explained by " + f"{processes_type_str} procs: {num_segments}, from " + f"{len(segmented_trajs)} trajs.") + remaining_trajs = [] + for traj in segmented_trajs: + objects = set(traj[0].trajectory.states[0]) + remaining_segments = [] + for segment in traj: + # TODO: is this kind of like "cover"? + if processes_type_str == "endogenous": + relevant_procs = [ + p for p in processes + if segment.get_option().parent == cast( + EndogenousProcess, p).option + ] + else: + # all exogenous; mixed cases all handle at the top. + relevant_procs = processes + add_atoms = { + a + for a in segment.add_effects + if not isinstance(a.predicate, DerivedPredicate) + } + delete_atoms = { + a + for a in segment.delete_effects + if not isinstance(a.predicate, DerivedPredicate) + } + # if not explained by any; consider explained if the atom change is + # a subset of the add_effects and delete_effects of any compatible + # ground process. + not_explained_by_any = True + for proc in relevant_procs: + if processes_type_str == "endogenous": + endo_proc = cast(EndogenousProcess, proc) + option_vars = endo_proc.option_vars + ignore_effects = endo_proc.ignore_effects + else: + option_vars = [] + ignore_effects = set() + var_to_obj = { + v: o + for v, o in zip(option_vars, + segment.get_option().objects) + } + for g_proc in utils.all_ground_operators_given_partial( + proc, objects, var_to_obj): # type: ignore[arg-type] + _add_atoms = add_atoms.copy() + _delete_atoms = delete_atoms.copy() + if ignore_effects: + _add_atoms = { + a + for a in add_atoms + if a.predicate not in ignore_effects + } + _delete_atoms = { + a + for a in delete_atoms + if a.predicate not in ignore_effects + } + if _add_atoms.issubset(g_proc.add_effects) and \ + _delete_atoms.issubset(g_proc.delete_effects): + not_explained_by_any = False + break + if not_explained_by_any: + if remove_options: + segment.set_option(DummyOption) + remaining_segments.append(segment) + remaining_trajs.append(remaining_segments) + + num_remaining_segments = sum(len(traj) for traj in remaining_trajs) + logging.debug(f"Num of leftover segments: {num_remaining_segments}") + if log_remaining_trajs: + for j, seg_traj in enumerate(remaining_trajs): + logging.debug(f"Trajectory {j}:") + for i, segment in enumerate(seg_traj): + logging.debug(f"Segment {i}. Init atoms: " + f"{sorted(segment.init_atoms)}") + logging.debug(f"Add effects: {sorted(segment.add_effects)}") + logging.debug( + f"Delete effects: {sorted(segment.delete_effects)}") + logging.debug(f"Option: {segment.get_option()}\n") + return remaining_trajs diff --git a/predicators/nsrt_learning/segmentation.py b/predicators/nsrt_learning/segmentation.py index c33e7b5e62..f6a3326b98 100644 --- a/predicators/nsrt_learning/segmentation.py +++ b/predicators/nsrt_learning/segmentation.py @@ -1,5 +1,6 @@ """Methods for segmenting low-level trajectories into segments.""" +import logging from typing import Callable, List, Optional, Set from predicators import utils @@ -10,15 +11,18 @@ Predicate, Segment, State -def segment_trajectory( - ll_traj: LowLevelTrajectory, - predicates: Set[Predicate], - atom_seq: Optional[List[Set[GroundAtom]]] = None) -> List[Segment]: +def segment_trajectory(ll_traj: LowLevelTrajectory, + predicates: Set[Predicate], + atom_seq: Optional[List[Set[GroundAtom]]] = None, + verbose: bool = False) -> List[Segment]: """Segment a ground atom trajectory.""" # Start with the segmenters that don't need atom_seq. Still pass it in # because if it was provided, it can be used to avoid calling abstract. if CFG.segmenter == "option_changes": - return _segment_with_option_changes(ll_traj, predicates, atom_seq) + return _segment_with_option_changes(ll_traj, + predicates, + atom_seq, + verbose=verbose) if CFG.segmenter == "every_step": return _segment_with_switch_function(ll_traj, predicates, atom_seq, lambda _: True) @@ -26,7 +30,10 @@ def segment_trajectory( if atom_seq is None: atom_seq = [utils.abstract(s, predicates) for s in ll_traj.states] if CFG.segmenter == "atom_changes": - return _segment_with_atom_changes(ll_traj, predicates, atom_seq) + return _segment_with_atom_changes(ll_traj, + predicates, + atom_seq, + verbose=verbose) if CFG.segmenter == "oracle": return _segment_with_oracle(ll_traj, predicates, atom_seq) if CFG.segmenter == "contacts": @@ -35,15 +42,25 @@ def segment_trajectory( def _segment_with_atom_changes( - ll_traj: LowLevelTrajectory, predicates: Set[Predicate], - atom_seq: List[Set[GroundAtom]]) -> List[Segment]: + ll_traj: LowLevelTrajectory, + predicates: Set[Predicate], + atom_seq: List[Set[GroundAtom]], + count_last_unchanged_steps_as_segment: bool = True, + verbose: bool = False) -> List[Segment]: """Segment a trajectory whenever the abstract state changes.""" + if verbose: + logging.debug("Segmenting by atom changes.") def _switch_fn(t: int) -> bool: - return atom_seq[t] != atom_seq[t + 1] + return atom_seq[t] != atom_seq[t + 1] or ( + count_last_unchanged_steps_as_segment + and t == len(ll_traj.actions) - 1) - return _segment_with_switch_function(ll_traj, predicates, atom_seq, - _switch_fn) + return _segment_with_switch_function(ll_traj, + predicates, + atom_seq, + _switch_fn, + verbose=verbose) def _segment_with_contact_changes( @@ -93,10 +110,13 @@ def _switch_fn(t: int) -> bool: _switch_fn) -def _segment_with_option_changes( - ll_traj: LowLevelTrajectory, predicates: Set[Predicate], - atom_seq: Optional[List[Set[GroundAtom]]]) -> List[Segment]: +def _segment_with_option_changes(ll_traj: LowLevelTrajectory, + predicates: Set[Predicate], + atom_seq: Optional[List[Set[GroundAtom]]], + verbose: bool = False) -> List[Segment]: """Segment a trajectory whenever the (assumed known) option changes.""" + if verbose: + logging.debug("Segmenting by option changes.") def _switch_fn(t: int) -> bool: # Segment by checking whether the option changes on the next step. @@ -115,11 +135,15 @@ def _switch_fn(t: int) -> bool: option_duration = t - backward_t + 1 if option_duration >= CFG.max_num_steps_option_rollout: return True - return option_t.terminal(ll_traj.states[t + 1]) + return option_t.terminal(ll_traj.states[t + 1]) or \ + option_t.name.lower() == "wait" return option_t is not ll_traj.actions[t + 1].get_option() - return _segment_with_switch_function(ll_traj, predicates, atom_seq, - _switch_fn) + return _segment_with_switch_function(ll_traj, + predicates, + atom_seq, + _switch_fn, + verbose=verbose) def _segment_with_oracle(ll_traj: LowLevelTrajectory, @@ -147,7 +171,7 @@ def _segment_with_oracle(ll_traj: LowLevelTrajectory, } atoms = atom_seq[0] all_expected_next_atoms = [ - utils.apply_operator(n, atoms) + utils.apply_operator(n, atoms) # type: ignore[type-var] for n in utils.get_applicable_operators(ground_nsrts, atoms) ] @@ -163,7 +187,8 @@ def _switch_fn(t: int) -> bool: applicable_nsrts = utils.get_applicable_operators( ground_nsrts, next_atoms) all_expected_next_atoms = [ - utils.apply_operator(n, next_atoms) for n in applicable_nsrts + utils.apply_operator(n, next_atoms) # type: ignore[type-var] + for n in applicable_nsrts ] return True # Not yet time to segment. @@ -173,10 +198,11 @@ def _switch_fn(t: int) -> bool: _switch_fn) -def _segment_with_switch_function( - ll_traj: LowLevelTrajectory, predicates: Set[Predicate], - atom_seq: Optional[List[Set[GroundAtom]]], - switch_fn: Callable[[int], bool]) -> List[Segment]: +def _segment_with_switch_function(ll_traj: LowLevelTrajectory, + predicates: Set[Predicate], + atom_seq: Optional[List[Set[GroundAtom]]], + switch_fn: Callable[[int], bool], + verbose: bool = False) -> List[Segment]: """Helper for other segmentation methods. The switch_fn takes in a timestep and returns True if the trajectory @@ -196,6 +222,10 @@ def _segment_with_switch_function( current_segment_states.append(ll_traj.states[t]) current_segment_actions.append(ll_traj.actions[t]) if switch_fn(t): + if verbose: + logging.debug( + f"Segmenting at {t}, executing {ll_traj.actions[t].get_option().name}" + ) # Include the final state as the end of this segment. current_segment_states.append(ll_traj.states[t + 1]) current_segment_traj = LowLevelTrajectory(current_segment_states, @@ -206,6 +236,18 @@ def _segment_with_switch_function( st1 = ll_traj.states[t + 1] current_segment_final_atoms = utils.abstract(st1, predicates) if ll_traj.actions[t].has_option(): + + if len(ll_traj.states) > t + 1: + st = ll_traj.states[t] + delete_atoms = utils.abstract( + st, predicates) - current_segment_final_atoms + add_atoms = current_segment_final_atoms - utils.abstract( + st, predicates) + if verbose: + logging.debug( + f"State change: add {add_atoms}, delete {delete_atoms}" + ) + segment = Segment(current_segment_traj, current_segment_init_atoms, current_segment_final_atoms, diff --git a/predicators/nsrt_learning/strips_learning/__init__.py b/predicators/nsrt_learning/strips_learning/__init__.py index 989db5c41b..b4c1f9c5d3 100644 --- a/predicators/nsrt_learning/strips_learning/__init__.py +++ b/predicators/nsrt_learning/strips_learning/__init__.py @@ -6,8 +6,8 @@ from predicators.nsrt_learning.strips_learning.base_strips_learner import \ BaseSTRIPSLearner from predicators.settings import CFG -from predicators.structs import PNAD, LowLevelTrajectory, Predicate, Segment, \ - Task +from predicators.structs import PNAD, EndogenousProcess, LowLevelTrajectory, \ + Predicate, Segment, Task __all__ = ["BaseSTRIPSLearner"] @@ -21,7 +21,8 @@ def learn_strips_operators(trajectories: List[LowLevelTrajectory], segmented_trajs: List[List[Segment]], verify_harmlessness: bool, annotations: Optional[List[Any]], - verbose: bool = True) -> List[PNAD]: + verbose: bool = True, + **kwargs: Any) -> List[PNAD]: """Learn strips operators on the given data segments. Return a list of PNADs with op (STRIPSOperator), datastore, and @@ -32,7 +33,7 @@ def learn_strips_operators(trajectories: List[LowLevelTrajectory], cls.get_name() == CFG.strips_learner: learner = cls(trajectories, train_tasks, predicates, segmented_trajs, verify_harmlessness, annotations, - verbose) + verbose, **kwargs) break else: raise ValueError(f"Unrecognized STRIPS learner: {CFG.strips_learner}") diff --git a/predicators/nsrt_learning/strips_learning/base_strips_learner.py b/predicators/nsrt_learning/strips_learning/base_strips_learner.py index 5d3aa998ac..d82eafbd12 100644 --- a/predicators/nsrt_learning/strips_learning/base_strips_learner.py +++ b/predicators/nsrt_learning/strips_learning/base_strips_learner.py @@ -23,7 +23,8 @@ def __init__(self, segmented_trajs: List[List[Segment]], verify_harmlessness: bool, annotations: Optional[List[Any]], - verbose: bool = True) -> None: + verbose: bool = True, + **kwargs: Any) -> None: self._trajectories = trajectories self._train_tasks = train_tasks self._predicates = predicates @@ -250,7 +251,10 @@ def _find_best_matching_pnad_and_sub( if not check_only_preconditions: # If the atoms resulting from apply_operator() don't # all hold in the segment's final atoms, skip. - if not next_atoms.issubset(segment.final_atoms): + # Note: One might want to turn this off, e.g., with LLM + # learner, because it might not account for all the changes. + if not next_atoms.issubset(segment.final_atoms) and \ + CFG.find_best_matching_pnad_skip_if_effect_not_subset: continue # If the segment has a non-None necessary_add_effects, # and the ground operator's add effects don't fit this, diff --git a/predicators/nsrt_learning/strips_learning/clustering_learner.py b/predicators/nsrt_learning/strips_learning/clustering_learner.py index 7cbb9c6b6c..dd8d49583f 100644 --- a/predicators/nsrt_learning/strips_learning/clustering_learner.py +++ b/predicators/nsrt_learning/strips_learning/clustering_learner.py @@ -1,16 +1,90 @@ """Algorithms for STRIPS learning that rely on clustering to obtain effects.""" - import abc +import bisect +import copy import functools +import itertools import logging +import os +import re +import sys +import time from collections import defaultdict -from typing import Dict, FrozenSet, Iterator, List, Set, Tuple, cast +from pprint import pformat +from typing import Any, Dict, FrozenSet, Iterator, List, Optional, Set, \ + Tuple, cast + +import multiprocess as mp +import psutil +import wandb +from pathos.multiprocessing import ProcessingPool as Pool from predicators import utils +from predicators.nsrt_learning.segmentation import segment_trajectory from predicators.nsrt_learning.strips_learning import BaseSTRIPSLearner +from predicators.planning import PlanningFailure, PlanningTimeout +from predicators.planning_with_processes import \ + task_plan_from_task as task_plan_with_processes from predicators.settings import CFG -from predicators.structs import PNAD, Datastore, DummyOption, LiftedAtom, \ - ParameterizedOption, Predicate, STRIPSOperator, VarToObjSub +from predicators.structs import PNAD, CausalProcess, Datastore, \ + DerivedPredicate, DummyOption, EndogenousProcess, ExogenousProcess, \ + GroundAtom, LiftedAtom, Object, ParameterizedOption, Predicate, Segment, \ + STRIPSOperator, Variable, VarToObjSub, _TypedEntity + +if sys.platform == "darwin": + # Set this when using macOS, to avoid issues with forked processes. + mp.set_start_method("spawn", force=True) + + +def _flat_pnad_scoring_worker( + args: Tuple[int, int, ExogenousProcess, Set[LiftedAtom], List[Any], + Set[Predicate], int, int, float, Optional[str], Optional[str], + int] +) -> Tuple[int, int, float, Set[LiftedAtom], Tuple[float, ...], + ExogenousProcess]: + """Utility for flat multiprocessing: evaluates one condition candidate for + one PNAD under the data-likelihood scoring regime. + + Returns (pnad_idx, condition_idx, cost, condition_candidate, + scores_tuple, process). + """ + (pnad_idx, condition_idx, base_process, condition_candidate, trajectories, + predicates, seed, num_it, complexity_weight, load_dir, save_dir, + early_stopping_patience) = args + + # Set the conditions on the process object. + base_process.condition_at_start = condition_candidate + base_process.condition_overall = condition_candidate + + # Calculate complexity penalty. + complexity_penalty = complexity_weight * len(condition_candidate) + + # Local import avoids pickling issues with bound methods. + from predicators.approaches.pp_param_learning_approach import \ + learn_process_parameters + + # Perform the expensive part: learning and scoring. + process, scores = learn_process_parameters( + trajectories, + predicates, + [base_process], # The list now contains just the one process to score. + use_lbfgs=False, + plot_training_curve=False, + lbfgs_max_iter=num_it, + adam_num_steps=num_it, + seed=seed, + display_progress=False, + early_stopping_patience=early_stopping_patience, + batch_size=CFG.process_param_learning_batch_size, + use_empirical=CFG.process_learning_use_empirical, + ) + + # Cost is negative log-likelihood plus penalty. + cost = -scores[0] + complexity_penalty + + # Return the identifier, condition index, cost, candidate, and the full scores tuple for logging. + return pnad_idx, condition_idx, cost, condition_candidate, scores, process[ # type: ignore[return-value] + 0] class ClusteringSTRIPSLearner(BaseSTRIPSLearner): @@ -50,7 +124,10 @@ def _learn(self) -> List[PNAD]: if suc: # Add to this PNAD. assert set(sub.keys()) == set(pnad.op.parameters) - pnad.add_to_datastore((segment, sub)) + pnad.add_to_datastore( + (segment, sub), + check_effect_equality=CFG. + clustering_learner_check_effect_equality) break else: # Otherwise, create a new PNAD. @@ -157,6 +234,2000 @@ def _postprocessing_learn_ignore_effects(self, return ret_pnads +class ClusterAndLLMSelectSTRIPSLearner(ClusteringSTRIPSLearner): + """Learn preconditions via LLM selection. + + Note: The current prompt are tailored for exogenous processes. + """ + + def __init__(self, *args: List, + **kwargs: Dict) -> None: # type: ignore[type-arg] + """Initialize the LLM and load the prompt template.""" + super().__init__(*args, **kwargs) # type: ignore[arg-type] + self._llm = utils.create_llm_by_name(CFG.llm_model_name) + prompt_file = utils.get_path_to_predicators_root() + \ + "/predicators/nsrt_learning/strips_learning/" + \ + "llm_op_learning_prompts/condition_selection.prompt" + with open(prompt_file, "r") as f: + self.base_prompt = f.read() + from predicators.approaches.pp_online_predicate_invention_approach import \ + get_false_positive_states + self._get_false_positive_process_states = get_false_positive_states + + @classmethod + def get_name(cls) -> str: + return "cluster_and_llm_select" + + def _learn_pnad_preconditions(self, pnads: List[PNAD]) -> List[PNAD]: + """Assume there is one segment per PNAD We can either do lifting first + and selection second, or the other way around. + + If we have multiple segments per PNAD, lifting requires us to + find a subset of atoms that unifies the segments. We'd have to + do this if we want to learn a single condition. But we could + also learn more than one. + """ + # Add var_to_obj for objects in the init state of the segment + new_pnads = [] + for pnad in pnads: + # Removing this assumption because we're now making sure that + # all the init_atoms in the PNAD are the same up to unification. + # assert len(pnad.datastore) == 1 + seg, var_to_obj = pnad.datastore[0] + existing_objs = set(var_to_obj.values()) + # Get the init atoms of the segment + init_atoms = seg.init_atoms + # Get the objects in the init atoms + additional_objects = { + o + for atom in init_atoms for o in atom.objects + if o not in existing_objs + } + # Create a new var_to_obj mapping for the objects + objects_lst = sorted(additional_objects) + params = utils.create_new_variables([o.type for o in objects_lst], + existing_vars=list(var_to_obj)) + var_to_obj.update(dict(zip(params, objects_lst))) + new_pnads.append( + PNAD(pnad.op, [(seg, var_to_obj)], + pnad.option_spec)) # dummy option + + seperate_llm_query_per_pnad = True + effect_and_conditions = "" + proposed_conditions: List[str] = [] + for i, pnad in enumerate(new_pnads): + if seperate_llm_query_per_pnad: + effect_and_conditions += f"Process 0:\n" + else: + effect_and_conditions += f"Process {i}:\n" + add_effects = pnad.op.add_effects + delete_effects = pnad.op.delete_effects + effect_and_conditions += "Add effects: (" + if add_effects: + effect_and_conditions += "and " + " ".join(f"({str(atom)})" for\ + atom in add_effects) + effect_and_conditions += ")\n" + effect_and_conditions += "Delete effects: (" + if delete_effects: + effect_and_conditions += "and " + " ".join(f"({str(atom)})" \ + for atom in delete_effects) + effect_and_conditions += ")\n" + segment_init_atoms = pnad.datastore[0][0].init_atoms + segment_var_to_obj = pnad.datastore[0][1] + obj_to_var = {v: k for k, v in segment_var_to_obj.items()} + conditions_to_choose_from = pformat( + {a.lift(obj_to_var) + for a in segment_init_atoms}) + effect_and_conditions += "Conditions to choose from:\n" +\ + conditions_to_choose_from + "\n\n" + + if seperate_llm_query_per_pnad: + prompt = self.base_prompt.format( + EFFECTS_AND_CONDITIONS=effect_and_conditions) + proposals = self._llm.sample_completions( + prompt, None, 0.0, CFG.seed)[0] + pattern = r'```\n(.*?)\n```' + matches = re.findall(pattern, proposals, re.DOTALL) + proposed_conditions.append(matches[0]) + effect_and_conditions = "" + + if not seperate_llm_query_per_pnad: + prompt = self.base_prompt.format( + EFFECTS_AND_CONDITIONS=effect_and_conditions) + proposals = self._llm.sample_completions(prompt, None, 0.0, + CFG.seed)[0] + pattern = r'```\n(.*?)\n```' + matches = re.findall(pattern, proposals, re.DOTALL) + proposed_conditions = matches[0].split("\n\n") + + def atom_in_llm_selection( + atom: LiftedAtom, + conditions: List[Tuple[str, List[Tuple[str, str]]]]) -> bool: + for condition in conditions: + atom_name = condition[0] + atom_variables = condition[1] + if atom.predicate.name == atom_name and \ + all([var_type[0] == var.name for (var_type, var) in + zip(atom_variables, atom.variables)]): + return True + return False + + # Assumes the same number of PNADs and response chunks + assert len(new_pnads) == len(proposed_conditions) + final_pnads: List[PNAD] = [] + for proposed_condition, corresponding_pnad in zip( + proposed_conditions, new_pnads): + # Get the effect atoms + # Get the condition atoms + lines = proposed_condition.split("\n") + # add_effects = self.parse_effects_or_conditions(lines[0]) + # delete_effects = self.parse_effects_or_conditions(lines[1]) + conditions = self.parse_effects_or_conditions(lines[2]) + + segment_init_atoms = corresponding_pnad.datastore[0][0].init_atoms + segment_var_to_obj = corresponding_pnad.datastore[0][1] + obj_to_var = {v: k for k, v in segment_var_to_obj.items()} + conditions_to_choose_from = { # type: ignore[assignment] + a.lift(obj_to_var) + for a in segment_init_atoms + } + new_conditions = set( + atom for atom in + conditions_to_choose_from # type: ignore[union-attr] + if atom_in_llm_selection(atom, + conditions)) # type: ignore[arg-type] + add_eff = corresponding_pnad.op.add_effects + del_eff = corresponding_pnad.op.delete_effects + # the variable might also just in the effects + new_parameters = set( + var for atom in new_conditions | add_eff | del_eff + for var in atom.variables) # type: ignore[union-attr] + # Only append if it's unique + for final_pnad in final_pnads: + suc, _ = utils.unify_preconds_effects_options( + frozenset(new_conditions), + frozenset(final_pnad.op.preconditions), + frozenset(corresponding_pnad.op.add_effects), + frozenset(final_pnad.op.add_effects), + frozenset(corresponding_pnad.op.delete_effects), + frozenset(final_pnad.op.delete_effects), + corresponding_pnad.option_spec[0], + final_pnad.option_spec[0], + tuple(corresponding_pnad.option_spec[1]), + tuple(final_pnad.option_spec[1]), + ) + if suc: + break + else: + # We have a new process! + # Create a new PNAD with the new parameters and conditions + # and add it to the final list + pnad = PNAD( + corresponding_pnad.op.copy_with( + parameters=new_parameters, + preconditions=new_conditions), + corresponding_pnad.datastore, + corresponding_pnad.option_spec) + final_pnads.append(pnad) + + # if CFG.process_learner_check_false_positives: + # # Go through the trajectories and check if this process + # # leads to false positive effect predications. + # false_positive_process_state = \ + # self._get_false_positive_process_states( + # self._trajectories, + # self._predicates, + # [pnad.make_exogenous_process()]) + + # for _, states in false_positive_process_state.items(): + # if len(states) > 0: + # # initial_segmenter_method = CFG.segmenter + # # CFG.segmenter = "atom_changes" + # # segments = [segment_trajectory(traj, self._predicates) for traj in self._trajectories] + # # CFG.segmenter = initial_segmenter_method + return final_pnads + + def parse_effects_or_conditions( + self, line: str) -> List[Tuple[str, List[Tuple[str, str]]]]: + """Parse a line containing effects or conditions into a list of tuples. + For example, when given: 'Conditions: (and (FaucetOn(?x1:faucet)) + (JugUnderFaucet(?x2:jug, ?x1:faucet)))'. + + Each returned tuple has: + - An atom name (e.g., "JugFilled") + - A list of (variable_name, type_name) pairs + (e.g., [("?x0", "jug"), ("?x1", "faucet")]). + + Example Return: + [ + ("FaucetOn", [("?x1", "faucet")]), + ("JugUnderFaucet", [("?x2", "jug"), ("?x1", "faucet")]) + ] + """ + + # Remove the top-level (and ...) if present. + # This way, we won't accidentally capture "and" as an atom. + line = re.sub(r"\(\s*and\s+", "(", line) + + # Match an atom name and the entire content inside its parentheses. + pattern = r"\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\((.*?)\)\)" + atom_matches = re.findall(pattern, line) + + var_type_pattern = r"(\?[a-zA-Z0-9]+):([a-zA-Z0-9_]+)" + parsed_atoms: List[Tuple[str, List[Tuple[str, str]]]] = [] + + for atom_name, vars_str in atom_matches: + # Find all variable:type pairs in the string + var_type_pairs = re.findall(var_type_pattern, vars_str) + parsed_atoms.append((atom_name, var_type_pairs)) + + return parsed_atoms + + +class ClusteringProcessLearner(ClusteringSTRIPSLearner): + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.online_learning_cycle = kwargs.get("online_learning_cycle", None) + self._endogenous_processes = kwargs["endogenous_processes"] + from predicators.approaches.pp_online_predicate_invention_approach import \ + get_false_positive_states_from_seg_trajs + self._get_false_positive_states_from_seg_trajs = \ + get_false_positive_states_from_seg_trajs + + from predicators.approaches.pp_param_learning_approach import \ + learn_process_parameters + self._get_data_likelihood_and_learn_params = \ + learn_process_parameters + + self._atom_change_segmented_trajs: List[List[Segment]] = [] + + if CFG.cluster_and_search_process_learner_llm_propose_top_conditions or\ + CFG.cluster_and_search_process_learner_llm_rank_atoms: + self._llm = utils.create_llm_by_name(CFG.llm_model_name) + else: + self._llm = None # type: ignore[assignment] + + def _learn(self) -> List[PNAD]: + segments = [seg for segs in self._segmented_trajs for seg in segs] + # Cluster the segments according to common option and effects. + pnads: List[PNAD] = [] + for i, segment in enumerate(segments): + if segment.has_option(): + segment_option = segment.get_option() + segment_param_option = segment_option.parent + segment_option_objs = tuple(segment_option.objects) + else: + segment_param_option = DummyOption.parent + segment_option_objs = tuple() + if self.get_name() not in [ + "cluster_and_llm_select", + "cluster_and_search_process_learner", + "cluster_and_inverse_planning" + ] or CFG.exogenous_process_learner_do_intersect: + preconds1: FrozenSet = frozenset() # no preconditions + segment_param_option = DummyOption.parent + segment_option_objs = tuple() + else: + # Ground + preconds1 = frozenset(segment.init_atoms) + + # ent_to_ent_sub here is obj_to_var + seg_add_effects = frozenset( + a for a in segment.add_effects + if not isinstance(a.predicate, DerivedPredicate)) + seg_del_effects = frozenset( + a for a in segment.delete_effects + if not isinstance(a.predicate, DerivedPredicate)) + if self.get_name() in ["cluster_and_search_process_learner"]: + # Remove atoms explained by endogenous processes + seg_add_effects, seg_del_effects = \ + self.remove_atoms_explained_by_endogenous_processes( # type: ignore[assignment] + segment, self._endogenous_processes, + set(seg_add_effects), set(seg_del_effects)) # type: ignore[arg-type] + seg_add_effects = frozenset(seg_add_effects) + seg_del_effects = frozenset(seg_del_effects) + + suc, ent_to_ent_sub, pnad = self._unify_segment_with_pnads( # type: ignore[misc] + preconds1, seg_add_effects, seg_del_effects, + segment_param_option, segment_option_objs, pnads) + + if suc: + sub = cast(VarToObjSub, + {v: o + for o, v in ent_to_ent_sub.items()}) + # Add to this PNAD. + if CFG.exogenous_process_learner_do_intersect: + # Find the largest conditions that unifies the init + # atoms of the segment and another segment in the PNAD. + # and add that segment and sub to the datastore. + # Doing this sequentially ensures one of the + # substitutions has the objects we care about with + # intersection. Hence it can fall out later in + # `induce_preconditions_via_intersection`. + (pnad_param_option, pnad_option_vars) = pnad.option_spec + sub = self._find_best_segment_unification( + segment, + seg_add_effects, + seg_del_effects, + pnad, + ent_to_ent_sub, + segment_param_option, + pnad_param_option, + segment_option_objs, # type: ignore[arg-type] + tuple(pnad_option_vars), + self._endogenous_processes) + else: + assert set(sub.keys()) == set(pnad.op.parameters) + pnad.add_to_datastore( + (segment, sub), + check_effect_equality=not self.get_name() + in ["cluster_and_search_process_learner"], + check_option_equality=not self.get_name() + in ["cluster_and_search_process_learner"]) + else: + # Otherwise, create a new PNAD. + objects = {o for atom in segment.add_effects | + segment.delete_effects for o in atom.objects} | \ + set(segment_option_objs) + + if self.get_name() in [ + "cluster_and_llm_select", + "cluster_and_search_process_learner", + "cluster_and_inverse_planning" + ]: + # With cluster_and_llm_select, the param may include + # anything in the init atoms of the segment. + objects |= { + o + for atom in segment.init_atoms for o in atom.objects + } + + objects_lst = sorted(objects) + params = utils.create_new_variables( + [o.type for o in objects_lst]) + preconds: Set[LiftedAtom] = set() # will be learned later + obj_to_var = dict(zip(objects_lst, params)) + var_to_obj = dict(zip(params, objects_lst)) + grd_add_effects = { + atom + for atom in segment.add_effects + if not isinstance(atom.predicate, DerivedPredicate) + } + grd_delete_effects = { + atom + for atom in segment.delete_effects + if not isinstance(atom.predicate, DerivedPredicate) + } + lfd_add_effects = { + atom.lift(obj_to_var) + for atom in grd_add_effects + } + lfd_delete_effects = { + atom.lift(obj_to_var) + for atom in grd_delete_effects + } + ignore_effects: Set[Predicate] = set() # will be learned later + if self.get_name() in ["cluster_and_search_process_learner"]: + # Remove atoms explained by endogenous processes + lfd_add_effects, lfd_delete_effects = \ + self.remove_atoms_explained_by_endogenous_processes( + segment, self._endogenous_processes, lfd_add_effects, + lfd_delete_effects, obj_to_var) + grd_add_effects, grd_delete_effects = \ + self.remove_atoms_explained_by_endogenous_processes( # type: ignore[assignment] + segment, self._endogenous_processes, grd_add_effects, # type: ignore[arg-type] + grd_delete_effects) # type: ignore[arg-type] + + # ---- Single effect bias ---- + if CFG.cluster_learning_one_effect_per_process: + # If there are still processes with multiple effects, + # add multiple PNAD here; after checking such pnad don't + # already exists. + for atom in grd_add_effects | grd_delete_effects: + neg_atom = atom.get_negated_atom() + if atom in grd_add_effects: + add_effect_set = frozenset({atom}) + # Check if the negated atom is in the delete + # effects + if neg_atom in grd_delete_effects: + del_effect_set = frozenset({neg_atom}) + else: + del_effect_set = frozenset() + else: + del_effect_set = frozenset({atom}) + if neg_atom in grd_add_effects: + add_effect_set = frozenset({neg_atom}) + else: + add_effect_set = frozenset() + # Check if the pnad already exists + suc, ent_to_ent_sub, pnad =\ + self._unify_segment_with_pnads( # type: ignore[misc] + frozenset(), add_effect_set, del_effect_set, + segment_param_option, segment_option_objs, + pnads) + if suc: + sub = cast( + VarToObjSub, + {v: o + for o, v in ent_to_ent_sub.items()}) + # Add to this PNAD. + if CFG.exogenous_process_learner_do_intersect: + # Find the largest conditions that unifies the init + # atoms of the segment and another segment in the PNAD. + # and add that segment and sub to the datastore. + # Doing this sequentially ensures one of the + # substitutions has the objects we care about with + # intersection. Hence it can fall out later in + # `induce_preconditions_via_intersection`. + (pnad_param_option, + pnad_option_vars) = pnad.option_spec + sub = self._find_best_segment_unification( + segment, + add_effect_set, + del_effect_set, + pnad, + ent_to_ent_sub, + segment_param_option, + pnad_param_option, + segment_option_objs, # type: ignore[arg-type] + tuple(pnad_option_vars), + self._endogenous_processes) + else: + assert set(sub.keys()) == set( + pnad.op.parameters) + pnad.add_to_datastore( + (segment, sub), + check_effect_equality=False, + check_option_equality=False) + else: + add_effect_set = frozenset({ + atom.lift(obj_to_var) # type: ignore[misc] + for atom in add_effect_set + }) + del_effect_set = frozenset({ + atom.lift(obj_to_var) # type: ignore[misc] + for atom in del_effect_set + }) + # Create a new pnad with this atom + op = STRIPSOperator( + f"Op{len(pnads)}", + params, + preconds, + add_effect_set, # type: ignore[arg-type] + del_effect_set, # type: ignore[arg-type] + ignore_effects) + datastore = [(segment, var_to_obj)] + option_vars = [ + obj_to_var[o] for o in segment_option_objs + ] + option_spec = (segment_param_option, + option_vars) + pnads.append(PNAD(op, datastore, option_spec)) + continue + op = STRIPSOperator(f"Op{len(pnads)}", params, preconds, + lfd_add_effects, lfd_delete_effects, + ignore_effects) + datastore = [(segment, var_to_obj)] + option_vars = [obj_to_var[o] for o in segment_option_objs] + option_spec = (segment_param_option, option_vars) + pnads.append(PNAD(op, datastore, option_spec)) + + if self.get_name() in ["cluster_and_search_process_learner"]: + # Do this extra step for this learner + initial_segmenter_method = CFG.segmenter + CFG.segmenter = "atom_changes" + self._atom_change_segmented_trajs = [ + segment_trajectory(traj, self._predicates, verbose=False) + for traj in self._trajectories + ] + CFG.segmenter = initial_segmenter_method + # Learn the preconditions of the operators in the PNADs. This part + # is flexible; subclasses choose how to implement it. + pnads = self._learn_pnad_preconditions(pnads) + + # Handle optional postprocessing to learn ignore effects. + pnads = self._postprocessing_learn_ignore_effects(pnads) + + # Log and return the PNADs. + if self._verbose: + logging.info("Learned operators (before option learning):") + for pnad in pnads: + logging.info(pnad) + return pnads + + def _unify_segment_with_pnads(self, seg_preconds, seg_add_effects, # type: ignore[no-untyped-def] + seg_del_effects, seg_param_option, + seg_option_objs, pnads: List[PNAD]) -> \ + Tuple[bool, VarToObjSub]: + """Try to unify the segment with the PNADs.""" + for pnad in pnads: + # Try to unify this transition with existing effects. + # Note that both add and delete effects must unify, + # and also the objects that are arguments to the options. + (pnad_param_option, pnad_option_vars) = pnad.option_spec + if self.get_name() not in [ + "cluster_and_llm_select", + "cluster_and_search_process_learner", + "cluster_and_inverse_planning" + ] or CFG.exogenous_process_learner_do_intersect: + preconds2: FrozenSet = frozenset() # no preconditions + else: + # Lifted + obj_to_var = {v: k for k, v in pnad.datastore[-1][1].items()} + preconds2 = frozenset({ + atom.lift(obj_to_var) + for atom in pnad.datastore[-1][0].init_atoms + }) + suc, ent_to_ent_sub = utils.unify_preconds_effects_options( + seg_preconds, preconds2, seg_add_effects, + frozenset(pnad.op.add_effects), seg_del_effects, + frozenset(pnad.op.delete_effects), seg_param_option, + pnad_param_option, seg_option_objs, tuple(pnad_option_vars)) + if suc: + return True, ent_to_ent_sub, pnad # type: ignore[return-value] + return False, dict(), None # type: ignore[return-value] + + @staticmethod + def _find_best_segment_unification( + segment: Segment, seg_add_eff: FrozenSet[GroundAtom], + seg_del_eff: FrozenSet[GroundAtom], pnad: PNAD, + obj_to_var: Dict[Object, Variable], + segment_param_option: ParameterizedOption, + pnad_param_option: ParameterizedOption, + segment_option_objs: Tuple[Object], + pnad_option_vars: Tuple[Variable], + endogenous_processes: List[EndogenousProcess]) -> VarToObjSub: + """Try to unify and find the *best* set of matching init atoms between + the given segment and the *last* segment in the PNAD's datastore, then + return the resulting Var->Obj substitution. + + Prioritizes atoms involving effect variables to ensure critical + atoms like SideOf(dest, source, direction) are preserved. + """ + # ---------- 0) Gather init atoms (ground vs. lifted) ---------- + seg_init_atoms_full = set(segment.init_atoms) + + # The last segment in the PNAD's datastore and its variable mapping. + last_seg, last_var_to_obj = pnad.datastore[-1] + last_obj_to_var = {o: v for v, o in last_var_to_obj.items()} + objects_in_last = set(last_obj_to_var) + lifted_last_init_atoms = { + atom.lift(last_obj_to_var) + for atom in last_seg.init_atoms + if all(o in objects_in_last for o in atom.objects) + } + + # Identify effect variables for prioritization + effect_vars = set() + for atom in pnad.op.add_effects | pnad.op.delete_effects: + effect_vars.update(atom.variables) + + # Identify critical ground objects from segment effects + effect_objects = set() + for atom in seg_add_eff | seg_del_eff: # type: ignore[assignment] + effect_objects.update(atom.objects) # type: ignore[attr-defined] + + # Restrict to predicates shared between the two sides. + common_preds = {a.predicate for a in seg_init_atoms_full} & \ + {b.predicate for b in lifted_last_init_atoms} + remove_ignore_atoms = True + if remove_ignore_atoms: + relevant_procs = [ + p for p in endogenous_processes + if segment.get_option().parent == p.option + ] + for endo_proc in relevant_procs: + common_preds -= endo_proc.ignore_effects + + seg_pre_list: List[GroundAtom] = sorted( + [a for a in seg_init_atoms_full if a.predicate in common_preds], + key=str, + ) + pnad_pre_list: List[LiftedAtom] = sorted( + [b for b in lifted_last_init_atoms if b.predicate in common_preds], + key=str, + ) + + # Quick exits: nothing to match or no shared predicates. + if not seg_pre_list or not pnad_pre_list: + return cast(VarToObjSub, {v: o for o, v in obj_to_var.items()}) + + # ---------- 1) Start from the mapping returned by effects+options ---------- + current_map: Dict[_TypedEntity, Variable] = dict( + obj_to_var) # type: ignore[arg-type] + + # We'll try to extend current_map with as many precondition matches as possible. + # Use weighted scoring that prioritizes effect-related atoms + best_map: Dict[_TypedEntity, Variable] = dict(current_map) + best_score: float = 0.0 # Changed to float for weighted scoring + + # ---------- 2) Organize atoms by predicate for bounds & candidate search ---------- + from collections import Counter, defaultdict + + idx_pnad_by_pred: Dict[Predicate, List[int]] = defaultdict(list) + for j, b in enumerate(pnad_pre_list): + idx_pnad_by_pred[b.predicate].append(j) + + # Compute atom weights based on involvement with effects + def compute_atom_weight(ground_atom: GroundAtom, + lifted_atom: LiftedAtom) -> float: + """Compute weight for matching this atom pair.""" + weight = 1.0 # Base weight + + # High priority for atoms involving effect objects/variables + involves_effect_ground = any(obj in effect_objects + for obj in ground_atom.objects) + involves_effect_lifted = any(var in effect_vars + for var in lifted_atom.variables) + + if involves_effect_ground and involves_effect_lifted: + # Critical atoms like SideOf connecting source and dest + if ground_atom.predicate.name == "SideOf": + # Check if it connects effect locations + if len(effect_objects.intersection( + ground_atom.objects)) >= 2: + weight = 100.0 # Highest priority + else: + weight = 10.0 + else: + weight = 5.0 + elif involves_effect_ground or involves_effect_lifted: + weight = 2.0 + + return weight + + # Upper bound helper with weighted scoring + def weighted_upper_bound(seg_idxs: Set[int], + pnad_unused: Set[int]) -> float: + """Compute weighted upper bound on possible score.""" + bound = 0.0 + seg_by_pred = defaultdict(list) + for i in seg_idxs: + seg_by_pred[seg_pre_list[i].predicate].append(i) + + for pred, seg_indices in seg_by_pred.items(): + pnad_indices = [ + j for j in idx_pnad_by_pred[pred] if j in pnad_unused + ] + # For each predicate, we can match at most min(seg_count, pnad_count) + max_matches = min(len(seg_indices), len(pnad_indices)) + if max_matches > 0: + # Use maximum possible weight for this predicate + max_weight = max( + compute_atom_weight(seg_pre_list[si], + pnad_pre_list[pi]) + for si in seg_indices[:max_matches] + for pi in pnad_indices[:max_matches] + ) if seg_indices and pnad_indices else 1.0 + bound += max_matches * max_weight + return bound + + # Compatibility check for a single (ground, lifted) atom pair + def compatible_extension( + a: GroundAtom, b: LiftedAtom, mapping: Dict[_TypedEntity, Variable] + ) -> Optional[List[Tuple[_TypedEntity, Variable]]]: + if a.predicate != b.predicate: + return None + new_pairs: List[Tuple[_TypedEntity, Variable]] = [] + inv = {v: k for k, v in mapping.items()} + for obj_ent, var_ent in zip(a.entities, b.entities): + # Types must match + if obj_ent.type != var_ent.type: + return None + # b side should be a Variable (usually), but handle if lifted constant + if isinstance(var_ent, Variable): + # mapping consistency: obj -> var one-to-one + if obj_ent in mapping: + if mapping[obj_ent] != var_ent: + return None + elif var_ent in inv: + if inv[var_ent] != obj_ent: + return None + else: + new_pairs.append((obj_ent, var_ent)) + else: + # If b side is a constant-typed entity, require equality + if obj_ent != var_ent: + return None + return new_pairs + + def search(mapping: Dict[_TypedEntity, Variable], seg_left: Set[int], + pnad_unused: Set[int], score: float) -> None: + nonlocal best_score, best_map + + # Upper bound pruning with weighted scoring + ub = score + weighted_upper_bound(seg_left, pnad_unused) + if ub <= best_score: + return + + if not seg_left: + if score > best_score: + best_score = score + best_map = dict(mapping) + return + + # Choose next atom: prioritize high-weight atoms with few candidates + best_i = None + best_candidates: List[Tuple[int, List, float]] = [] + best_priority = -float('inf') + + for i in list(seg_left): + a = seg_pre_list[i] + candidates = [] + for j in idx_pnad_by_pred[a.predicate]: + if j not in pnad_unused: + continue + ext = compatible_extension(a, pnad_pre_list[j], mapping) + if ext is not None: + weight = compute_atom_weight(a, pnad_pre_list[j]) + candidates.append((j, ext, weight)) + + if not candidates: + # This atom cannot be matched; continue without it + seg_left_minus_i = set(seg_left) + seg_left_minus_i.remove(i) + search(mapping, seg_left_minus_i, pnad_unused, score) + return + + # Priority: high weight atoms with few candidates (more constrained) + max_weight = max(c[2] for c in candidates) + priority = max_weight / (len(candidates) + 1 + ) # Favor constrained, high-weight + + if priority > best_priority: + best_i = i + best_candidates = candidates + best_priority = priority + + assert best_i is not None + + # Try candidates, ordered by weight (highest first) + for j, ext_pairs, weight in sorted(best_candidates, + key=lambda x: (-x[2], x[0])): + # Apply extension + for k, v in ext_pairs: + mapping[k] = v + pnad_unused.remove(j) + seg_left.remove(best_i) + + search(mapping, seg_left, pnad_unused, score + weight) + + # Revert + seg_left.add(best_i) + pnad_unused.add(j) + for k, _ in ext_pairs: + try: + del mapping[k] + except KeyError: + pass + + # Run the weighted search + search(dict(current_map), set(range(len(seg_pre_list))), + set(range(len(pnad_pre_list))), 0.0) + + # Convert best map (Object->Variable) back to Var->Object for return + sub = cast(VarToObjSub, {v: o for o, v in best_map.items()}) + return sub + + @staticmethod + def remove_atoms_explained_by_endogenous_processes( + segment: Segment, + endogenous_processes: List[EndogenousProcess], + add_effects: Set[LiftedAtom], + delete_effects: Set[LiftedAtom], + obj_to_var: Optional[Dict[Object, Variable]] = None + ) -> Tuple[Set[LiftedAtom], Set[LiftedAtom]]: + """If obj_to_var is None, we are taking in a set of ground atoms. + + and will return a set of ground atoms. Otherwise they are + lifted. This is to account for some exogenous effect that may + happen in the same time as some endogenous effect. + """ + if obj_to_var: + process_lifted_atoms = True + else: + process_lifted_atoms = False + objects = set(segment.states[0]) + seg_add_eff = segment.add_effects + seg_del_eff = segment.delete_effects + + relevant_procs = [ + p for p in endogenous_processes + if segment.get_option().parent == p.option + ] + for endo_proc in relevant_procs: + if endo_proc.name == "Wait": + continue + add_effects = { + a + for a in add_effects + if a.predicate not in endo_proc.ignore_effects + } + delete_effects = { + a + for a in delete_effects + if a.predicate not in endo_proc.ignore_effects + } + var_to_obj = { + v: o + for v, o in zip(endo_proc.option_vars, + segment.get_option().objects) + } + for g_proc in utils.all_ground_operators_given_partial( + endo_proc, objects, var_to_obj): # type: ignore[arg-type] + if g_proc.add_effects.issubset(seg_add_eff) and\ + g_proc.delete_effects.issubset(seg_del_eff): + if process_lifted_atoms: + add_effects -= { + atom.lift(obj_to_var) # type: ignore[arg-type] + for atom in g_proc.add_effects + } + delete_effects -= { + atom.lift(obj_to_var) # type: ignore[arg-type] + for atom in g_proc.delete_effects + } + else: + add_effects -= g_proc.add_effects + delete_effects -= g_proc.delete_effects + # logging.debug( + # f"Processing lifted atoms: {process_lifted_atoms}, " + # f"Removed effects of {g_proc} \n from " + # f"segment with \n add effect {seg_add_eff} " + # f"and delete effect {seg_del_eff}\n" + # f"new add effects: {add_effects}, del effects: {delete_effects}") + return add_effects, delete_effects + + @staticmethod + def _get_top_candidates( + candidates_with_scores: List, percentage: float, + number: int) -> List[Tuple[float, Set[LiftedAtom]]]: + assert percentage > 0 or number > 0, \ + "At least one of percentage or number must be greater than 0." + n_candidates = len(candidates_with_scores) + if percentage > 0: + num_under_percentage = max(1, + int(n_candidates * percentage / 100.0)) + score_at_threshold = candidates_with_scores[:num_under_percentage][ + -1][0] + scores = [score for score, _ in candidates_with_scores] + # Include all candidates with score_at_threshold + position = bisect.bisect_right(scores, score_at_threshold) + logging.info( + f"Score threshold {score_at_threshold}; " + f"Candidates under threshold: {position}/{n_candidates}") + else: + position = n_candidates + + # include at most top_n_candidates + if number > 0: + position = min(position, number) + logging.debug(f"Returning {position}/{n_candidates} candidates:") + num_to_log = 100 + for i, candidate in enumerate(candidates_with_scores[:num_to_log]): + score, condition_candidate = candidate + logging.debug(f"{i}: {condition_candidate}, Score: {score:.4f}") + if CFG.use_wandb: + wandb.log({ + f"candidate_{i}_score": score, + f"candidate_{i}_condition": str(condition_candidate) + }) + return candidates_with_scores[:position] + + def _get_top_consistent_conditions(self, initial_atom: Set[LiftedAtom], + pnad: PNAD, method: str, + seed: int) -> Iterator[Set[LiftedAtom]]: + """Get the top consistent conditions for a PNAD.""" + exogenous_process = pnad.make_exogenous_process() + logging.debug(f"For Process sketch:\n{exogenous_process}") + candidates_with_scores = self.score_precondition_candidates( # type: ignore[attr-defined] + exogenous_process, initial_atom, seed) + + if method == "top_p_percent": + # Return top p% of candidates + top_candidates = self._get_top_candidates( + candidates_with_scores, + CFG.cluster_process_learner_top_p_percent, + CFG.cluster_process_learner_top_n_conditions) + num_top_candidates = len(top_candidates) + # Reocrd the total number of candidates + if self._total_num_candidates == 0: # type: ignore[attr-defined] + self._total_num_candidates += num_top_candidates # type: ignore[attr-defined] + else: + self._total_num_candidates *= num_top_candidates # type: ignore[attr-defined] + elif method == "top_n": + # Return top n candidates + n = CFG.cluster_process_learner_top_n_conditions + top_candidates = candidates_with_scores[:n] + else: + raise NotImplementedError( + f"Unknown top consistent method: {method}") + + # Yield the selected candidates + for candidate in top_candidates: + if len(candidate) == 2: + score, condition_candidate = candidate + else: + score, condition_candidate, _ = candidate # type: ignore[unreachable] + logging.info( + f"Selected condition: {condition_candidate}, Score: {score}") + yield condition_candidate + + +class ClusterAndSearchProcessLearner(ClusteringProcessLearner): + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the process learner.""" + super().__init__(*args, **kwargs) + self.proc_name_to_results: Dict[str, List[ + Tuple[float, FrozenSet[LiftedAtom], Tuple, ExogenousProcess]]] =\ + defaultdict(list) + + @classmethod + def get_name(cls) -> str: + return "cluster_and_search_process_learner" + + def _learn_pnad_preconditions(self, pnads: List[PNAD]) -> List[PNAD]: + """Learns preconditions for all PNADs. + + This implementation flattens the search for preconditions into a + single multiprocessing pool. It supports an optional preliminary + pruning step using a fast false-positive count metric to reduce + the number of candidates that need to be scored with the more + expensive data-likelihood metric. + """ + cpu_cnt = self._determine_worker_count() + use_parallel = (CFG.cluster_and_search_process_learner_parallel_pnad + and cpu_cnt > 1) + + logging.info( + f"Learning preconditions for {len(pnads)} PNADs using a flat parallel pool." + ) + + # Step 1: Generate candidate conditions + (possible_atoms_per_pnad, + condition_sets_per_pnad) = self._generate_candidate_conditions(pnads) + + # Step 2: Filter PNAD parameters + pnads = self._filter_pnad_parameters(pnads, possible_atoms_per_pnad, + condition_sets_per_pnad) + + # Step 2.5: Ablation - use top condition if flag is set + if CFG.process_learner_ablate_bayes: + logging.info( + "Using ablation: taking top condition from condition_sets_per_pnad" + ) + best_conditions: Dict[int, FrozenSet[LiftedAtom]] = {} + + # Set up proc_name_to_results with placeholder values + for i, pnad in enumerate(pnads): + if (condition_sets_per_pnad is not None + and i < len(condition_sets_per_pnad) + and condition_sets_per_pnad[i]): + # Take the first (top) condition from condition_sets + best_condition = condition_sets_per_pnad[i][0] + else: + # Fallback to empty condition if no condition sets available + best_condition = set() + best_conditions[i] = best_condition # type: ignore[assignment] + + # Create placeholder scored_conditions entry for proc_name_to_results + # Format: (cost, frozenset(condition), scores_tuple, process) + placeholder_process = pnad.make_exogenous_process() + placeholder_process.condition_at_start = best_condition.copy() + placeholder_process.condition_overall = best_condition.copy() + placeholder_scored_conditions = [ + (0.0, frozenset(best_condition), (0.0, ), + placeholder_process) + ] + self.proc_name_to_results[ + pnad.op.name] = placeholder_scored_conditions + + # Construct final PNADs with the top conditions + return self._construct_final_pnads(best_conditions, pnads) + + # Step 3: Calculate candidate limits for CPU utilization + min_candidates_to_keep = self._calculate_candidate_limits( + possible_atoms_per_pnad, condition_sets_per_pnad, cpu_cnt) + + # Step 4: Generate final candidates with pruning + final_candidates_for_pnad = self._generate_final_candidates_with_pruning( + pnads, possible_atoms_per_pnad, condition_sets_per_pnad, + min_candidates_to_keep) + + # Step 5: Create work items for parallel scoring + work_items = self._create_scoring_work_items( + pnads, final_candidates_for_pnad) + + if not work_items: + return [] + + # Step 6: Execute parallel scoring + start_time = time.time() + logging.info(f"Scoring {len(work_items)} total conditions for " + f"{len(pnads)} PNADs using up to {cpu_cnt} workers.") + logging.debug(f"Num vi steps: {CFG.cluster_and_search_vi_steps}, " + "Early stopping patience: " + f"{CFG.process_param_learning_patience}") + + if use_parallel: + with Pool(nodes=min(len(work_items), cpu_cnt)) as pool: + results = pool.map(_flat_pnad_scoring_worker, work_items) + else: + logging.info( + "Using sequential scoring as alternative to parallel processing." + ) + results = [] + for work_item in work_items: + result = _flat_pnad_scoring_worker(work_item) + results.append(result) + + logging.info(f"Finished scoring in {time.time() - start_time:.2f}s.") + + # Step 7: Process results and select best conditions + best_conditions = self._process_scoring_results( + results, final_candidates_for_pnad, pnads) + + # Step 8: Construct final PNADs + return self._construct_final_pnads(best_conditions, pnads) + + def _generate_candidate_conditions( + self, pnads: List[PNAD] + ) -> Tuple[List[Set[LiftedAtom]], Optional[List[List[Set[LiftedAtom]]]]]: + """Generate candidate conditions for PNADs using intersection or + LLM.""" + possible_atoms_per_pnad = [ + self._induce_preconditions_via_intersection(pnad) for pnad in pnads + ] + + if CFG.cluster_and_search_process_learner_llm_propose_top_conditions: + condition_sets_per_pnad = self._llm_propose_condition_sets( + possible_atoms_per_pnad, + pnads, + # batch_size=CFG.cluster_and_search_llm_propose_batch_size + ) + elif CFG.cluster_and_search_process_learner_llm_rank_atoms: + ranked_atoms_per_pnad = self._llm_rank_atoms( + possible_atoms_per_pnad, pnads) + possible_atoms_per_pnad = [ + set(atoms) for atoms in ranked_atoms_per_pnad + ] + condition_sets_per_pnad = None + else: + condition_sets_per_pnad = None + + return possible_atoms_per_pnad, condition_sets_per_pnad + + def _determine_worker_count(self) -> int: + """Return number of worker processes to use based on config.""" + if CFG.process_learning_process_per_physical_core: + return max(1, psutil.cpu_count(logical=False) - 1) + return max(1, mp.cpu_count() - 1) + + def _build_process_descriptions( + self, + possible_atoms_per_pnad: List[Set[LiftedAtom]], + pnads: Optional[List[PNAD]] = None + ) -> List[Tuple[str, List[LiftedAtom]]]: + """Build process descriptions for LLM prompts. + + Args: + possible_atoms_per_pnad: List of sets of possible precondition atoms + pnads: Optional list of PNADs to get effect information from + + Returns: + List of (process_description, sorted_atoms) tuples + """ + process_descriptions = [] + for i, poss_atoms in enumerate(possible_atoms_per_pnad): + process_desc = f"Process {i}:\n" + + # Add effects information if PNADs are available + if pnads and i < len(pnads): + pnad = pnads[i] + add_effects = pnad.op.add_effects + delete_effects = pnad.op.delete_effects + + process_desc += "Add effects: " + if add_effects: + process_desc += "(" + " ".join( + f"({str(atom)})" for atom in add_effects) + ")" + else: + process_desc += "()" + process_desc += "\n" + + process_desc += "Delete effects: " + if delete_effects: + process_desc += "(" + " ".join( + f"({str(atom)})" for atom in delete_effects) + ")" + else: + process_desc += "()" + process_desc += "\n" + + # Add candidate atoms + sorted_atoms = sorted(poss_atoms, key=str) + process_desc += "Candidate atoms:\n" + for j, atom in enumerate(sorted_atoms): + process_desc += f" {j}: {atom}\n" + process_desc += "\n" + + process_descriptions.append((process_desc, sorted_atoms)) + + return process_descriptions + + def _call_llm_with_template(self, template_path: str, + template_vars: Dict[str, Any], + debug_filename: str) -> str: + """Call LLM with a template and save debug info. + + Args: + template_path: Path to the prompt template file + template_vars: Variables to substitute in template + debug_filename: Name for debug output file + + Returns: + LLM response text + """ + if self._llm is None: + raise ValueError("LLM not available") + + # Load the prompt template + with open(template_path, "r") as f: + template = f.read() + + # Format the prompt + prompt = template.format(**template_vars) + + # Get LLM response - use online_learning_cycle as seed if available + seed = CFG.seed * 10 + self.online_learning_cycle if \ + self.online_learning_cycle is not None else CFG.seed + response = self._llm.sample_completions(prompt, + imgs=None, + temperature=0.1, + seed=seed)[0] + + # Save debug info + with open(f"{CFG.log_file}/{debug_filename}", "w") as f: + f.write(f"{prompt}\n=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*" + f"\n{response}") + + return response + + def _parse_llm_answer_block(self, response: str) -> Optional[str]: + """Extract answer content from LLM response. + + Args: + response: Raw LLM response + + Returns: + Answer text or None if not found + """ + answer_match = re.search(r'(.*?)', response, + re.DOTALL) + if not answer_match: + return None + return answer_match.group(1).strip() + + def _llm_rank_atoms( + self, + possible_atoms_per_pnad: List[Set[LiftedAtom]], + pnads: Optional[List[PNAD]] = None, + max_atoms: Optional[int] = None) -> List[List[LiftedAtom]]: + """Rank the possible atoms by their likelihood of being + relevant/necessary for the PNAD's effects. + + Args: + possible_atoms_per_pnad: List of sets of possible precondition atoms, one set per PNAD + pnads: Optional list of PNADs to get effect information from + + Returns: + List of lists of ranked atoms, keeping only the most relevant ones based on LLM assessment + """ + if not possible_atoms_per_pnad or self._llm is None: + return [list(atoms) for atoms in possible_atoms_per_pnad] + + try: + # Build process descriptions + process_descriptions = self._build_process_descriptions( + possible_atoms_per_pnad, pnads) + + # Call LLM with template + template_path = (utils.get_path_to_predicators_root() + + "/predicators/nsrt_learning/strips_learning/" + + "llm_op_learning_prompts/atom_ranking.prompt") + all_descriptions = "\n".join( + [desc for desc, _ in process_descriptions]) + template_vars = { + "PROCESS_EFFECTS_AND_CANDIDATES": all_descriptions + } + response = self._call_llm_with_template( + template_path, template_vars, "atom_ranking_response.txt") + + # Parse the response + answer_text = self._parse_llm_answer_block(response) + if not answer_text: + logging.warning("LLM failed to provide properly formatted " + "answer for atom ranking") + return [list(atoms) for atoms in possible_atoms_per_pnad] + lines = [ + line.strip() for line in answer_text.split('\n') + if line.strip() + ] + + # Parse rankings for each process + ranked_atoms_per_pnad = [] + for i, (_, sorted_atoms) in enumerate(process_descriptions): + # Find the line for this process + process_line = None + for line in lines: + if line.startswith(f"Process {i}:"): + process_line = line + break + + if process_line is None: + logging.warning( + f"No ranking found for process {i}, keeping all atoms") + ranked_atoms_per_pnad.append(list(sorted_atoms)) + continue + + # Extract indices after the colon + try: + indices_str = process_line.split(':', 1)[1].strip() + if indices_str: + indices = [ + int(idx.strip()) for idx in indices_str.split(',') + ] + # Filter valid indices and get corresponding atoms + valid_indices = [ + idx for idx in indices + if 0 <= idx < len(sorted_atoms) + ] + if valid_indices: + # Keep atoms in the order specified by LLM ranking + # But limit to top N atoms to avoid combinatorial explosion + if max_atoms is None: + max_atoms = len(valid_indices) + else: + max_atoms = min(max_atoms, len(valid_indices)) + selected_atoms = [ + sorted_atoms[idx] + for idx in valid_indices[:max_atoms] + ] + ranked_atoms_per_pnad.append(selected_atoms) + else: + # No valid indices, keep original list + ranked_atoms_per_pnad.append(list(sorted_atoms)) + else: + # Empty ranking, keep original list + ranked_atoms_per_pnad.append(list(sorted_atoms)) + except (ValueError, IndexError) as e: + logging.warning( + f"Failed to parse ranking for process {i}: {e}") + ranked_atoms_per_pnad.append(list(sorted_atoms)) + + # Log the results + for i, ranked in enumerate(ranked_atoms_per_pnad): + original = list(process_descriptions[i][1]) + logging.info( + f"Process {i}: Kept {len(ranked)}/{len(original)} atoms") + logging.debug(f" Kept atoms: {sorted(ranked, key=str)}") + logging.debug( + f" Removed atoms: {sorted(set(original) - set(ranked), key=str)}" + ) + + return ranked_atoms_per_pnad + + except Exception as e: + logging.warning( + f"LLM atom ranking failed: {e}, keeping original atoms") + return [list(atoms) for atoms in possible_atoms_per_pnad] + + def _llm_propose_condition_sets( + self, + possible_atoms_per_pnad: List[Set[LiftedAtom]], + pnads: Optional[List[PNAD]] = None, + k: Optional[int] = None, + batch_size: Optional[int] = None) -> List[List[Set[LiftedAtom]]]: + """Propose top k condition sets for each PNAD using LLM. + + Args: + possible_atoms_per_pnad: List of sets of possible precondition atoms, one set per PNAD + pnads: Optional list of PNADs to get effect information from + k: Number of condition sets to propose per PNAD + batch_size: Maximum number of PNADs to process in each LLM call + + Returns: + List of lists of condition sets, where each condition set is a set of atoms + """ + if not possible_atoms_per_pnad or self._llm is None: + return [[set(atoms)] for atoms in possible_atoms_per_pnad] + + if k is None: + k = CFG.process_learner_llm_propose_conditions_k + + # If batch_size is not specified or if we have fewer PNADs than the limit, + # process all at once (original behavior) + if (batch_size is None or len(possible_atoms_per_pnad) <= batch_size): + return self._llm_propose_condition_sets_batch( + possible_atoms_per_pnad, pnads, k) + + # Otherwise, process in batches + all_condition_sets = [] + num_pnads = len(possible_atoms_per_pnad) + + for start_idx in range(0, num_pnads, batch_size): + end_idx = min(start_idx + batch_size, num_pnads) + + # Extract batch data + batch_atoms = possible_atoms_per_pnad[start_idx:end_idx] + batch_pnads = pnads[start_idx:end_idx] if pnads else None + + # Process this batch + batch_condition_sets = self._llm_propose_condition_sets_batch( + batch_atoms, batch_pnads, k, batch_idx=start_idx // batch_size) + + all_condition_sets.extend(batch_condition_sets) + + return all_condition_sets + + def _llm_propose_condition_sets_batch( + self, + possible_atoms_per_pnad: List[Set[LiftedAtom]], + pnads: Optional[List[PNAD]] = None, + k: Optional[int] = None, + batch_idx: Optional[int] = None) -> List[List[Set[LiftedAtom]]]: + """Process a batch of PNADs for condition set proposal.""" + try: + # Build process descriptions + process_descriptions = self._build_process_descriptions( + possible_atoms_per_pnad, pnads) + + # Extract unique predicates from all candidate atoms + all_predicates = set() + for poss_atoms in possible_atoms_per_pnad: + for atom in poss_atoms: + all_predicates.add(atom.predicate) + + # Create predicate listing string + predicate_listing = "\n".join( + predicate.pretty_str_with_assertion() + for predicate in sorted(all_predicates, key=lambda p: p.name)) + + # Call LLM with template + template_path = ( + utils.get_path_to_predicators_root() + + "/predicators/nsrt_learning/strips_learning/" + + "llm_op_learning_prompts/condition_set_proposal.prompt") + all_descriptions = "\n".join( + [desc for desc, _ in process_descriptions]) + template_vars = { + "PROCESS_EFFECTS_AND_CANDIDATES": all_descriptions, + "PREDICATE_LISTING": predicate_listing, + # "K": k + } + response = self._call_llm_with_template( + template_path, template_vars, + "condition_set_proposal_response_"\ + f"{self.online_learning_cycle}_{batch_idx}.txt") + + # Parse the response + answer_text = self._parse_llm_answer_block(response) + if not answer_text: + logging.warning("LLM failed to provide properly formatted " + "answer for condition set proposal") + return [[set(atoms)] for atoms in possible_atoms_per_pnad] + lines = [ + line.strip() for line in answer_text.split('\n') + if line.strip() + ] + + # Parse condition sets for each process + condition_sets_per_pnad = [] + for i, (_, sorted_atoms) in enumerate(process_descriptions): + # Find lines for this process + process_sets = [] + process_found = False + + for line in lines: + if line.startswith(f"Process {i}:"): + process_found = True + continue + elif process_found and line.startswith("Process "): + # Start of next process, break + break + elif process_found and line.startswith("Set "): + # Parse set line: "Set 1: [2,0,4]" + try: + set_part = line.split(":", 1)[1].strip() + # Remove brackets and split by comma + set_part = set_part.strip("[]") + if set_part: + indices = [ + int(idx.strip()) + for idx in set_part.split(',') + ] + # Filter valid indices and get corresponding atoms + valid_indices = [ + idx for idx in indices + if 0 <= idx < len(sorted_atoms) + ] + if valid_indices: + condition_set = { + sorted_atoms[idx] + for idx in valid_indices + } + process_sets.append(condition_set) + except (ValueError, IndexError) as e: + logging.warning( + f"Failed to parse condition set for process {i}: {e}" + ) + + if not process_sets: + # No valid sets found, use original atoms as single set + process_sets.append(set(sorted_atoms)) + + condition_sets_per_pnad.append(process_sets) + + # Log the results + for i, sets in enumerate(condition_sets_per_pnad): + if pnads: + logging.debug(f"Process {i}: {pformat(pnads[i])}\n" + f"Proposed {len(sets)} condition sets") + else: + logging.debug( + f"Process {i}: Proposed {len(sets)} condition sets") + for j, condition_set in enumerate(sets): + logging.debug( + f" Set {j+1}: {sorted(condition_set, key=str)}") + + return condition_sets_per_pnad + + except Exception as e: + logging.warning( + f"LLM condition set proposal failed: {e}, using original atoms" + ) + return [[set(atoms)] for atoms in possible_atoms_per_pnad] + + def _filter_pnad_parameters( + self, pnads: List[PNAD], + possible_atoms_per_pnad: List[Set[LiftedAtom]], + condition_sets_per_pnad: Optional[List[List[Set[LiftedAtom]]]] + ) -> List[PNAD]: + """Filter PNAD parameters to only include variables used in + preconditions or effects.""" + filtered_pnads: List[PNAD] = [] + for i, (pnad, + poss_atoms) in enumerate(zip(pnads, possible_atoms_per_pnad)): + if condition_sets_per_pnad is not None: + poss_atoms = set.union(*condition_sets_per_pnad[i]) + eff_atoms = pnad.op.add_effects | pnad.op.delete_effects + used_vars = { + v + for atom in (poss_atoms | eff_atoms) for v in atom.variables + } + if not used_vars: + filtered_pnads.append(pnad) + continue + new_params = [p for p in pnad.op.parameters if p in used_vars] + if list(pnad.op.parameters) == new_params: + filtered_pnads.append(pnad) + continue + new_op = pnad.op.copy_with(parameters=new_params) + filtered_pnads.append( + PNAD(new_op, pnad.datastore, pnad.option_spec)) + return filtered_pnads + + def _calculate_candidate_limits( + self, possible_atoms_per_pnad: List[Set[LiftedAtom]], + condition_sets_per_pnad: Optional[List[List[Set[LiftedAtom]]]], + cpu_cnt: int) -> int: + """Calculate optimal candidate limits per PNAD to utilize available + CPUs.""" + max_candidates_per_pnad = [ + 2**len(possible_atoms) + for possible_atoms in possible_atoms_per_pnad + ] + if condition_sets_per_pnad is not None: + max_candidates_per_pnad = [ + len(condition_sets) + for condition_sets in condition_sets_per_pnad + ] + max_candidates_across_pnads = min(max(max_candidates_per_pnad), + cpu_cnt) + min_candidates_to_keep = 1 + + for i in range(max_candidates_across_pnads, 0, -1): + total_candidates = sum( + [min(num, i) for num in max_candidates_per_pnad]) + if total_candidates <= cpu_cnt: + logging.info( + f"Setting candidate cap per PNAD to {i} to utilize {cpu_cnt} CPUs " + f"(total candidates: {total_candidates}).") + min_candidates_to_keep = i + break + return min_candidates_to_keep + + def _generate_final_candidates_with_pruning( + self, pnads: List[PNAD], + possible_atoms_per_pnad: List[Set[LiftedAtom]], + condition_sets_per_pnad: Optional[List[List[Set[LiftedAtom]]]], + min_candidates_to_keep: int) -> Dict[int, List[Set[LiftedAtom]]]: + """Generate final candidates with optional false positive pruning.""" + final_candidates_for_pnad: Dict[int, List[Set[LiftedAtom]]] = {} + indexed_pnads = {i: p for i, p in enumerate(pnads)} + + fp_count_pruning = ( + CFG.process_scoring_method == 'data_likelihood' + and CFG.process_condition_search_prune_with_fp_count and not CFG. + cluster_and_search_process_learner_llm_propose_top_conditions) + + def _initial_lifted_atoms_for_index(idx: int, + p: PNAD) -> Set[LiftedAtom]: + if CFG.exogenous_process_learner_do_intersect: + return possible_atoms_per_pnad[idx] + init_ground_atoms = p.datastore[0][0].init_atoms + var_to_obj = p.datastore[0][1] + obj_to_var = {v: k for k, v in var_to_obj.items()} + return {atom.lift(obj_to_var) for atom in init_ground_atoms} + + for i, pnad in indexed_pnads.items(): + initial_lift_atoms = _initial_lifted_atoms_for_index(i, pnad) + + if (condition_sets_per_pnad is not None + and i < len(condition_sets_per_pnad)): + all_candidates = condition_sets_per_pnad[i] + else: + all_candidates = list(utils.all_subsets(initial_lift_atoms)) + + if not all_candidates: + final_candidates_for_pnad[i] = [] + continue + + if fp_count_pruning: + pruned_candidates = self._prune_candidates_with_fp_count( + pnad, all_candidates, min_candidates_to_keep, i) + final_candidates_for_pnad[i] = pruned_candidates + else: + final_candidates_for_pnad[ + i] = all_candidates[:min_candidates_to_keep] + + return final_candidates_for_pnad + + def _prune_candidates_with_fp_count( + self, pnad: PNAD, all_candidates: List[Set[LiftedAtom]], + min_candidates_to_keep: int, + pnad_idx: int) -> List[Set[LiftedAtom]]: + """Prune candidates using false positive count metric.""" + base_process = pnad.make_exogenous_process() + logging.debug( + f"Pruning {len(all_candidates)} candidates for PNAD {pnad_idx}:\n{base_process}" + ) + if CFG.use_wandb: + wandb.log({ + "pruning_info": + f"Pruning {len(all_candidates)} candidates for PNAD {pnad_idx}", + "base_process": str(base_process) + }) + + candidates_with_approx_scores = [] + for candidate in all_candidates: + base_process.condition_at_start = candidate + base_process.condition_overall = candidate + complexity_penalty = ( + CFG.process_condition_search_complexity_weight * + len(candidate)) + false_positive_states = self._get_false_positive_states_from_seg_trajs( + self._atom_change_segmented_trajs, [base_process]) + num_false_positives = sum( + len(s) for s in false_positive_states.values()) + cost = num_false_positives + complexity_penalty + candidates_with_approx_scores.append((cost, candidate)) + + candidates_with_approx_scores.sort(key=lambda x: x[0]) + top_candidates = self._get_top_candidates( + candidates_with_approx_scores, + percentage=0, + number=min_candidates_to_keep) + pruned_candidates = [cand for _, cand in top_candidates] + + logging.debug( + f"Pruned to {len(pruned_candidates)} candidates for PNAD {pnad_idx}." + ) + if CFG.use_wandb: + wandb.log({ + "pruned_candidates_count": len(pruned_candidates), + "pnad_id": pnad_idx + }) + + return pruned_candidates + + def _create_scoring_work_items( + self, pnads: List[PNAD], + final_candidates_for_pnad: Dict[int, + List[Set[LiftedAtom]]]) -> List: + """Create work items for parallel scoring.""" + load_dir, save_dir = None, None + if (self.online_learning_cycle is not None + and CFG.process_learning_init_at_previous_results): + load_save_dir = os.path.join(CFG.approach_dir, + utils.get_config_path_str()) + load_dir = os.path.join( + load_save_dir, f"online_cycle_{self.online_learning_cycle-1}") + save_dir = os.path.join( + load_save_dir, f"online_cycle_{self.online_learning_cycle}") + + indexed_pnads = {i: p for i, p in enumerate(pnads)} + work_items = [] + + for i, pnad in indexed_pnads.items(): + base_process = pnad.make_exogenous_process() + for condition_idx, condition in enumerate( + final_candidates_for_pnad[i]): + item = (i, condition_idx, copy.deepcopy(base_process), + condition, self._trajectories, self._predicates, + CFG.seed, CFG.cluster_and_search_vi_steps, + CFG.process_condition_search_complexity_weight, + load_dir, save_dir, + CFG.process_param_learning_patience) + work_items.append(item) + + return work_items + + def _process_scoring_results( + self, results: List, + final_candidates_for_pnad: Dict[int, List[Set[LiftedAtom]]], + pnads: List[PNAD]) -> Dict[int, FrozenSet[LiftedAtom]]: + """Process parallel scoring results and select best conditions.""" + indexed_pnads = {i: p for i, p in enumerate(pnads)} + pnad_scores: Dict[int, + List[Tuple[float, FrozenSet[LiftedAtom], Tuple[float, + ...], + ExogenousProcess]]] = defaultdict(list) + + for pnad_idx, condition_idx, cost, _, scores_tuple, process in results: + original_condition = final_candidates_for_pnad[pnad_idx][ + condition_idx] + process.condition_at_start = original_condition.copy() + process.condition_overall = original_condition.copy() + pnad_scores[pnad_idx].append( + (cost, frozenset(original_condition), scores_tuple, process)) + + best_conditions: Dict[int, FrozenSet[LiftedAtom]] = {} + for pnad_idx, scored_conditions in pnad_scores.items(): + scored_conditions.sort(key=lambda x: x[0]) + self.proc_name_to_results[ + indexed_pnads[pnad_idx].op.name] = scored_conditions + + self._log_scored_conditions(pnad_idx, scored_conditions, + indexed_pnads[pnad_idx]) + best_condition = self._select_best_condition( + pnad_idx, scored_conditions, indexed_pnads[pnad_idx]) + best_conditions[pnad_idx] = best_condition + logging.info(f"Selected best condition {best_condition}") + + return best_conditions + + def _log_scored_conditions(self, pnad_idx: int, scored_conditions: List, + pnad: PNAD) -> None: + """Log the scored conditions for debugging.""" + logging.debug( + f"Scored conditions for Process sketch {pnad_idx}:\n{pnad.make_exogenous_process()}" + ) + if CFG.use_wandb: + wandb.log({ + f"process_sketch_{pnad_idx}": + str(pnad.make_exogenous_process()) + }) + + for rank, result in enumerate(scored_conditions): + cost, condition_candidate, scores, process = result + process_param_str = ", ".join( + [f"{v:.4f}" for v in process._get_parameters()]) + logging.debug(f"Conditions {rank}: " + f"{sorted(condition_candidate)}, " + f"Cost: {cost}, " + f"ELBO: {scores[0]:.4f}, " + f"Exp_state_prob: {scores[1]:.4f}, " + f"Exp_delay_prob: {scores[2]:.4f}, " + f"Entropy: {scores[3]:.4f}, " + f"Process params: {process_param_str}") + + def _select_best_condition(self, pnad_idx: int, scored_conditions: List, + pnad: PNAD) -> FrozenSet[LiftedAtom]: + """Select the best condition from scored candidates.""" + multiple_top_conditions = False + best_ll = scored_conditions[0][2][0] + num_top_conditions = len( + list( + itertools.takewhile(lambda x: x[2][0] == best_ll, + scored_conditions))) + if num_top_conditions > 1: + multiple_top_conditions = True + + if (CFG.cluster_and_search_process_learner_llm_select_condition + and multiple_top_conditions): + best_condition = self._prompt_llm_to_select_from_top_conditions( + pnad, scored_conditions[:num_top_conditions]) + else: + _, best_condition, _, _ = scored_conditions[0] + + return best_condition # type: ignore[return-value] + + def _construct_final_pnads(self, + best_conditions: Dict[int, + FrozenSet[LiftedAtom]], + pnads: List[PNAD]) -> List[PNAD]: + """Construct the final unique PNADs with learned preconditions.""" + indexed_pnads = {i: p for i, p in enumerate(pnads)} + final_pnads: List[PNAD] = [] + + for pnad_idx in sorted(best_conditions.keys()): + cond_at_start = best_conditions[pnad_idx] + base_pnad = indexed_pnads[pnad_idx] + add_eff = base_pnad.op.add_effects + del_eff = base_pnad.op.delete_effects + new_params = { + v + for atom in cond_at_start | add_eff | del_eff + for v in atom.variables + } + + if self._is_unique_pnad(cond_at_start, base_pnad, final_pnads): + final_pnads.append( + PNAD( + base_pnad.op.copy_with(preconditions=cond_at_start, + parameters=new_params), + base_pnad.datastore, base_pnad.option_spec)) + + return final_pnads + + def _is_unique_pnad(self, precon: FrozenSet[LiftedAtom], pnad: PNAD, + final_pnads: List[PNAD]) -> bool: + """Check if a PNAD with given preconditions is unique.""" + for final_pnad in final_pnads: + # Quick size checks first for efficiency + if (len(precon) != len(final_pnad.op.preconditions) or + len(pnad.op.add_effects) != len(final_pnad.op.add_effects) + or len(pnad.op.delete_effects) != len( + final_pnad.op.delete_effects)): + continue + + suc, _ = utils.unify_preconds_effects_options( + frozenset(precon), + frozenset(final_pnad.op.preconditions), + frozenset(pnad.op.add_effects), + frozenset(final_pnad.op.add_effects), + frozenset(pnad.op.delete_effects), + frozenset(final_pnad.op.delete_effects), + pnad.option_spec[0], + final_pnad.option_spec[0], + tuple(pnad.option_spec[1]), + tuple(final_pnad.option_spec[1]), + ) + if suc: + return False + return True + + def _prompt_llm_to_select_from_top_conditions( + self, pnad: PNAD, scored_conditions: List[Tuple[float, + FrozenSet[LiftedAtom], + Tuple, CausalProcess]] + ) -> Set[LiftedAtom]: + """Use the LLM to select the best condition from the top scored + conditions for a PNAD.""" + assert self._llm is not None + # 1. Load the prompt template. + prompt_file = utils.get_path_to_predicators_root() + \ + "/predicators/nsrt_learning/strips_learning/" + \ + "llm_op_learning_prompts/"+\ + "cluster_and_search_process_learner_condition_select.prompt" + with open(prompt_file, "r") as f: + self.template = f.read() + + # 2. Fill the prompt template. + prompt = self.template.format( + EXOGENOUS_PROCESS_SKETCH=\ + pnad.make_exogenous_process()._str_wo_params, + TOP_SCORING_CONDITIONS="\n".join( + f"Conditions {i}: {sorted(condition)}" + for i, (_, condition, _, _) in enumerate(scored_conditions) + ) + ) + + # 3. Prompt the LLM. + response = self._llm.sample_completions(prompt, + imgs=None, + temperature=0, + seed=CFG.seed)[0] + + # Save the prompt and response for debugging + with open(f"{CFG.log_file}/pnad_{pnad.op.name}_cond_select.txt", + "w") as f: + f.write(f"{prompt}\n=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*=*\n" + f"{response}") + + # 4. Parse the answer. + indices_str = re.findall(r"(.*?)", response) + if indices_str: + try: + selected_idx = int(indices_str[0].strip()) + if 0 <= selected_idx < len(scored_conditions): + # The condition is the second element of the tuple. + _, best_condition, _, _ = scored_conditions[selected_idx] + return set(best_condition) + except (ValueError, IndexError): + # If parsing fails or index is out of bounds, fall back. + logging.warning("LLM response parsing failed or index out of " + "bounds.") + + # Fallback: if LLM fails to produce a valid choice, pick the best one. + logging.warning("LLM failed to select a condition, picking the best.") + _, best_condition, _, _ = scored_conditions[0] + return set(best_condition) + + +class ClusterAndInversePlanningProcessLearner(ClusteringProcessLearner): + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + from predicators.predicate_search_score_functions import \ + _ExpectedNodesScoreFunction + self._get_optimality_prob =\ + _ExpectedNodesScoreFunction._get_refinement_prob + + self._option_change_segmented_trajs: List[List[Segment]] = [] + self._demo_atoms_sequences: List[List[Set[LiftedAtom]]] = [] + self._total_num_candidates = 0 + + @classmethod + def get_name(cls) -> str: + return "cluster_and_inverse_planning" + + def _learn_pnad_preconditions(self, pnads: List[PNAD]) -> List[PNAD]: + """Find the set of PNADs (with corresponding processes) that allows the + agent make similar plans as the demonstrated/successful plans.""" + + self._total_num_candidates = 0 + # --- Existing exogenous processes --- + exogenous_process = [pnad.make_exogenous_process() for pnad in pnads] + + # Get the segmented trajectories for scoring the processes. + initial_segmenter_method = CFG.segmenter + CFG.segmenter = "atom_changes" + self._atom_change_segmented_trajs = [ + segment_trajectory(traj, self._predicates, verbose=False) + for traj in self._trajectories + ] + CFG.segmenter = "option_changes" + self._option_change_segmented_trajs = [ + segment_trajectory(traj, self._predicates, verbose=False) + for traj in self._trajectories + ] + CFG.segmenter = initial_segmenter_method + self._demo_atoms_sequences = [ + utils.segment_trajectory_to_atoms_sequence( + seg_traj) # type: ignore[misc] + for seg_traj in self._option_change_segmented_trajs + ] + # for i, seg_traj in enumerate(self._atom_change_segmented_trajs): + # logging.info(f"atom change trajectory {i}: {pformat(seg_traj)}") + + # --- Get the candidate preconditions --- + # First option. Candidates are all possible subsets. + conditions_at_start = [] + for pnad in pnads: + if CFG.exogenous_process_learner_do_intersect: + init_lift_atoms = self._induce_preconditions_via_intersection( + pnad) + else: + init_ground_atoms = pnad.datastore[0][0].init_atoms + var_to_obj = pnad.datastore[0][1] + obj_to_var = {v: k for k, v in var_to_obj.items()} + init_lift_atoms = set( + atom.lift(obj_to_var) for atom in init_ground_atoms) + + if CFG.cluster_and_inverse_planning_candidates == "all": + # 4 PNADS, with 7, 6, 7, 8 init atoms, possible combinations are + # - 2^7 * 2^6 * 2^7 * 2^8 = 2^28 = 268,435,456 + # - 2^10 * 2^10 * 2^10 * 2^10 = 2^40 = 1,099,511,627,776 + # Get the initial conditions of the PNAD + conditions_at_start.append(utils.all_subsets(init_lift_atoms)) + elif CFG.cluster_and_inverse_planning_candidates == "top_consistent": + conditions_at_start.append( + self._get_top_consistent_conditions( + init_lift_atoms, pnad, + CFG.cluster_and_inverse_planning_top_consistent_method, + CFG.seed)) + else: + raise NotImplementedError + + # --- Search for the best combination of preconditions --- + best_cost = float("inf") + best_conditions = [] + # Score all combinations of preconditions + for i, combination in enumerate( + itertools.product(*conditions_at_start)): + # Set the conditions for each process + for process, conditions in zip(exogenous_process, combination): + process.condition_at_start = conditions + process.condition_overall = conditions + + # Score this set of processes + cost = self.compute_processes_score(set(exogenous_process)) + if cost < best_cost: + best_cost = cost + best_conditions = combination + logging.debug( + f"Combination {i+1}/{self._total_num_candidates}: cost = {cost}," + f" Best cost = {best_cost}") + + # --- Create new PNADs with the best conditions --- + final_pnads: List[PNAD] = [] + for pnad, conditions in zip(pnads, best_conditions): + # Check if this PNAD is unique + for final_pnad in final_pnads: + suc, _ = utils.unify_preconds_effects_options( + frozenset(conditions), + frozenset(final_pnad.op.preconditions), + frozenset(pnad.op.add_effects), + frozenset(final_pnad.op.add_effects), + frozenset(pnad.op.delete_effects), + frozenset(final_pnad.op.delete_effects), + pnad.option_spec[0], + final_pnad.option_spec[0], + tuple(pnad.option_spec[1]), + tuple(final_pnad.option_spec[1]), + ) + if suc: + # TODO: merge datastores if they are the same + break + else: + # If we reach here, it means the PNAD is unique + # and we can add it to the final list + new_pnad = PNAD(pnad.op.copy_with(preconditions=conditions), + pnad.datastore, pnad.option_spec) + final_pnads.append(new_pnad) + return final_pnads + + def compute_processes_score( + self, exogenous_processes: Set[ExogenousProcess]) -> float: + """Score the PNAD based on how well it allows the agent to make + plans.""" + # TODO: also incorporate number of nodes expanded to the function + cost = 0.0 + for i, traj in enumerate(self._trajectories): + if not traj.is_demo: + continue + demo_atoms_sequence = self._demo_atoms_sequences[i] + task = self._train_tasks[traj.train_task_idx] + generator = task_plan_with_processes( + task, + self._predicates, + exogenous_processes | self._endogenous_processes, + CFG.seed, + CFG.grammar_search_task_planning_timeout, + # max_skeletons_optimized=CFG.sesame_max_skeletons_optimized, + max_skeletons_optimized=1, + use_visited_state_set=True) + + optimality_prob = 0.0 + num_nodes = CFG.grammar_search_expected_nodes_upper_bound + try: + for idx, (_, plan_atoms_sequence, + metrics) in enumerate(generator): + num_nodes = metrics["num_nodes_created"] + optimality_prob = self._get_optimality_prob( + demo_atoms_sequence, + plan_atoms_sequence) # type: ignore[arg-type] + except (PlanningTimeout, PlanningFailure): + pass + # low_quality_prob = 1.0 - optimality_prob + cost += (1 - optimality_prob) # * num_nodes + + return cost + + class ClusterAndSearchSTRIPSLearner(ClusteringSTRIPSLearner): """A clustering STRIPS learner that learns preconditions via search, following the LOFT algorithm: https://arxiv.org/abs/2103.00589.""" diff --git a/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/atom_ranking.prompt b/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/atom_ranking.prompt new file mode 100644 index 0000000000..6b6444c0f1 --- /dev/null +++ b/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/atom_ranking.prompt @@ -0,0 +1,28 @@ +You are an expert in automated planning and causal reasoning. Your task is to rank candidate atoms (predicates) by their likelihood of being necessary preconditions for specific process effects to occur. + +Given a process with specific add effects and delete effects, you need to evaluate which candidate atoms are most likely to be essential preconditions that must be true for the process to successfully achieve its effects. + +Key principles for ranking: +1. **Causal Relevance**: Atoms that are causally necessary for the effects to occur should be ranked higher +2. **Physical Constraints**: Atoms representing physical constraints or requirements should be prioritized +3. **Domain Knowledge**: Use common sense about how processes work in the real world +4. **Robot Independence**: Atoms involving robots as arguments are typically NOT necessary for exogenous processes +5. **State Dependencies**: Atoms that represent prerequisite states for the effects should be ranked higher + +For each process, I will provide: +- Add effects: What the process makes true +- Delete effects: What the process makes false +- Candidate atoms: Potential precondition atoms to rank + +{PROCESS_EFFECTS_AND_CANDIDATES} + +Please rank the candidate atoms for each process from most relevant to least relevant. Provide your ranking as a comma-separated list of atom indices (0-indexed), where 0 corresponds to the first atom in the candidate list, 1 to the second, etc. + +Format your response as: + +Process 0: 2,0,4,1,3 +Process 1: 1,3,0,2 +... + + +Only include the atom indices that you believe are actually necessary - you can exclude atoms you think are irrelevant by not including their indices in the ranking. diff --git a/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/cluster_and_search_process_learner_condition_select.prompt b/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/cluster_and_search_process_learner_condition_select.prompt new file mode 100644 index 0000000000..c8f1ae7680 --- /dev/null +++ b/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/cluster_and_search_process_learner_condition_select.prompt @@ -0,0 +1,10 @@ +You are an AI planning expert tasked to select the appropriate "conditions at start" for some processes, similar to how they are defined in PDDL2.1 or PDDL+. + +For each process, the conditions listed below got the same data likelihood. But due to limited data, the highest-likelihood or the simplest conditions might not be the best choice. So we want you to incorporate your world knowledge and reasoning to choose the most suitable "Conditions at start" among the top-scoring ones. + +{EXOGENOUS_PROCESS_SKETCH} + +{TOP_SCORING_CONDITIONS} + +Select the *index* of the condition you judge to be the most suitable. Provide your final answer in a `...` tag. +You should think through the reasoning internally, but only output the final answer. \ No newline at end of file diff --git a/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/condition_selection.prompt b/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/condition_selection.prompt new file mode 100644 index 0000000000..561bc7db47 --- /dev/null +++ b/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/condition_selection.prompt @@ -0,0 +1,21 @@ +We are identifying environmental conditions that are necessary for certain effects of some processes to occur. + +For each effect, please select the subset of atoms that you believe are necessary for the corresponding process to take place. As a rule of thumb, atoms involving a robot as an argument are never necessary conditions for exogenous processes. + +For example, +- For wet clothes to dry quickly outdoors, the clothes must be on the outdoor dryer and the weather must be sunny. However, it is not necessary for the robot to be outside or for the laundry basket to be outdoors. +- For a computer to complete running a program, the computer must be powered on and the program must remain active. However, it is not necessary for the robot to be seated nearby, or for a cup next to it to be filled with water. + +{EFFECTS_AND_CONDITIONS} + +Please structure your output in the following format, with one block for each effect (note that the angle brackets here are just for clarifying the syntax; do not output angle brackets in your responses): + +``` +Add effects: (and () () ...) +Delete effects: (and () () ...) +Conditions: (and () () ...) + +Add effects: (and () () ...) +Delete effects: (and () () ...) +Conditions: (and () () ...) +``` \ No newline at end of file diff --git a/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/condition_set_proposal.prompt b/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/condition_set_proposal.prompt new file mode 100644 index 0000000000..bcc4e3e27e --- /dev/null +++ b/predicators/nsrt_learning/strips_learning/llm_op_learning_prompts/condition_set_proposal.prompt @@ -0,0 +1,34 @@ +You are an expert in automated planning and causal reasoning. Your task is to propose the most likely sets of conditions for specific process effects to occur. + +Given a process with specific add effects and delete effects, you need to propose multiple coherent sets of candidate atoms that could serve as necessary conditions for the process to successfully achieve its effects. + +Key principles for proposing condition sets: +1. **Causal Relevance**: Each set should contain atoms that are causally necessary for the effects to occur +2. **Physical Constraints**: Include atoms representing physical constraints or requirements +3. **Domain Knowledge**: Use common sense about how processes work in the real world +4. **State Dependencies**: Include atoms that represent prerequisite states for the effects +5. Terminal-Progress Exclusion: If an add effect is a terminal/complete state within a progression family, do not include any intermediate/progress predicates from the same family as preconditions (e.g., Partially*, Started, InProgress, HasSome). + +Available predicates in the candidate atoms are: +{PREDICATE_LISTING} + +For each process, I will provide: +- Add effects: What the process makes true +- Delete effects: What the process makes false +- Candidate atoms: Potential precondition atoms to choose from + +{PROCESS_EFFECTS_AND_CANDIDATES} + +Please propose as many likely condition sets as you deem suitable for each process. Each condition set should be a coherent combination of atom indices that together form a plausible set of preconditions. It's possible that there is a large number of atoms in a condition set in some cases. + +Think step by step if it’s helpful before outputting your final response, formatted strictly as: + +Process 0: +Set 1: [2, 0, 4] +Set 2: [1, 3, 0, 5] +Set 3: [2, 4] +Process 1: +Set 1: [1, 3] +Set 2: [0, 1, 3] +... + \ No newline at end of file diff --git a/predicators/nsrt_learning/strips_learning/llm_strips_learner.py b/predicators/nsrt_learning/strips_learning/llm_strips_learner.py index 81a8209573..62b5f32de3 100644 --- a/predicators/nsrt_learning/strips_learning/llm_strips_learner.py +++ b/predicators/nsrt_learning/strips_learning/llm_strips_learner.py @@ -1,6 +1,7 @@ """Approaches that use an LLM to learn STRIPS operators instead of performing symbolic learning of any kind.""" +import logging import re from typing import Any, Dict, List, Optional, Set, Tuple @@ -84,7 +85,7 @@ def _parse_operator_str_into_structured_elems( closing_paren_loc = name_and_args.find(")") name_str = name_and_args[:opening_paren_loc] arg_str = name_and_args[opening_paren_loc + 1:closing_paren_loc] - args = arg_str.split() + args = arg_str.replace(",", "").split() # remove commas arg_dict = {} for i in range(0, len(args), 3): arg_name = args[i] @@ -119,8 +120,13 @@ def _convert_structured_precs_or_effs_into_lifted_atom_set( prec_arg_vars.append(op_var_name_to_op_var[prec_arg]) if not all_args_valid: continue # pragma: no cover - ret_atoms.add( - LiftedAtom(pred_name_to_pred[prec_name], prec_arg_vars)) + try: + ret_atoms.add( + LiftedAtom(pred_name_to_pred[prec_name], prec_arg_vars)) + except: + # This can happen if the predicate is not valid for the + # given types. We just ignore it. + pass return ret_atoms # NOTE: we actually do test this function, but the many sub-cases diff --git a/predicators/planning.py b/predicators/planning.py index d25b3f2fd7..df79766dd2 100644 --- a/predicators/planning.py +++ b/predicators/planning.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from itertools import islice from typing import Any, Collection, Dict, FrozenSet, Iterator, List, \ - Optional, Sequence, Set, Tuple + Optional, Sequence, Set, Tuple, Union import numpy as np @@ -25,10 +25,10 @@ from predicators.option_model import _OptionModelBase from predicators.refinement_estimators import BaseRefinementEstimator from predicators.settings import CFG -from predicators.structs import NSRT, AbstractPolicy, DefaultState, \ - DummyOption, GroundAtom, Metrics, Object, OptionSpec, \ +from predicators.structs import NSRT, AbstractPolicy, CausalProcess, \ + DefaultState, DummyOption, GroundAtom, Metrics, Object, OptionSpec, \ ParameterizedOption, Predicate, State, STRIPSOperator, Task, Type, \ - _GroundNSRT, _GroundSTRIPSOperator, _Option + _GroundCausalProcess, _GroundNSRT, _GroundSTRIPSOperator, _Option from predicators.utils import EnvironmentFailure, _TaskPlanningHeuristic _NOT_CAUSES_FAILURE = "NotCausesFailure" @@ -59,7 +59,7 @@ def sesame_plan( max_policy_guided_rollout: int = 0, refinement_estimator: Optional[BaseRefinementEstimator] = None, check_dr_reachable: bool = True, - allow_noops: bool = False, + allow_waits: bool = False, use_visited_state_set: bool = False ) -> Tuple[List[_Option], List[_GroundNSRT], Metrics]: """Run bilevel planning. @@ -77,7 +77,7 @@ def sesame_plan( task, option_model, nsrts, predicates, types, timeout, seed, task_planning_heuristic, max_skeletons_optimized, max_horizon, abstract_policy, max_policy_guided_rollout, refinement_estimator, - check_dr_reachable, allow_noops, use_visited_state_set) + check_dr_reachable, allow_waits, use_visited_state_set) if CFG.sesame_task_planner == "fdopt": assert abstract_policy is None return _sesame_plan_with_fast_downward(task, @@ -119,11 +119,12 @@ def _sesame_plan_with_astar( max_policy_guided_rollout: int = 0, refinement_estimator: Optional[BaseRefinementEstimator] = None, check_dr_reachable: bool = True, - allow_noops: bool = False, + allow_waits: bool = False, use_visited_state_set: bool = False ) -> Tuple[List[_Option], List[_GroundNSRT], Metrics]: """The default version of SeSamE, which runs A* to produce skeletons.""" init_atoms = utils.abstract(task.init, predicates) + logging.debug(f"Initial atoms: {init_atoms}") objects = list(task.init) start_time = time.perf_counter() ground_nsrts = sesame_ground_nsrts(task, init_atoms, nsrts, objects, @@ -142,7 +143,7 @@ def _sesame_plan_with_astar( # that we need to do this inside the while True here, because an NSRT # that initially has empty effects may later have a _NOT_CAUSES_FAILURE. reachable_nsrts = filter_nsrts(task, init_atoms, ground_nsrts, - check_dr_reachable, allow_noops) + check_dr_reachable, allow_waits) heuristic = utils.create_task_planning_heuristic( task_planning_heuristic, init_atoms, task.goal, reachable_nsrts, predicates, objects) @@ -170,6 +171,8 @@ def _sesame_plan_with_astar( key=lambda s: estimator.get_cost(task, *s))) refinement_start_time = time.perf_counter() for skeleton, atoms_sequence in gen: + logging.debug( + f"Found skeleton: {[n.short_str for n in skeleton]}") if CFG.sesame_use_necessary_atoms: atoms_seq = utils.compute_necessary_atoms_seq( skeleton, atoms_sequence, task.goal) @@ -194,6 +197,9 @@ def _sesame_plan_with_astar( return plan, skeleton, metrics partial_refinements.append((skeleton, plan)) if time.perf_counter() - start_time > timeout: + logging.debug("Exiting search due to timeout.") + logging.debug( + f"Partial refinements: {partial_refinements}") raise PlanningTimeout( "Planning timed out in refinement!", info={"partial_refinements": partial_refinements}) @@ -247,13 +253,13 @@ def filter_nsrts( init_atoms: Set[GroundAtom], ground_nsrts: List[_GroundNSRT], check_dr_reachable: bool = True, - allow_noops: bool = False, + allow_waits: bool = False, ) -> List[_GroundNSRT]: """Helper function for _sesame_plan_with_astar(); optionally filter out NSRTs with empty effects and/or those that are unreachable.""" nonempty_ground_nsrts = [ nsrt for nsrt in ground_nsrts - if allow_noops or (nsrt.add_effects | nsrt.delete_effects) + if allow_waits or (nsrt.add_effects | nsrt.delete_effects) ] all_reachable_atoms = utils.get_reachable_atoms(nonempty_ground_nsrts, init_atoms) @@ -269,9 +275,10 @@ def filter_nsrts( def task_plan_grounding( init_atoms: Set[GroundAtom], objects: Set[Object], - nsrts: Collection[NSRT], - allow_noops: bool = False, -) -> Tuple[List[_GroundNSRT], Set[GroundAtom]]: + nsrts: Collection[Union[NSRT, CausalProcess]], + allow_waits: bool = False, + compute_reachable_atoms: bool = True, +) -> Tuple[List[Union[_GroundNSRT, _GroundCausalProcess]], Set[GroundAtom]]: """Ground all operators for task planning into dummy _GroundNSRTs, filtering out ones that are unreachable or have empty effects. @@ -283,15 +290,22 @@ def task_plan_grounding( ground_nsrts = [] for nsrt in sorted(nsrts): for ground_nsrt in utils.all_ground_nsrts(nsrt, objects): - if allow_noops or (ground_nsrt.add_effects + if allow_waits or (ground_nsrt.add_effects | ground_nsrt.delete_effects): ground_nsrts.append(ground_nsrt) - reachable_atoms = utils.get_reachable_atoms(ground_nsrts, init_atoms) - reachable_nsrts = [ - nsrt for nsrt in ground_nsrts - if nsrt.preconditions.issubset(reachable_atoms) - ] - return reachable_nsrts, reachable_atoms + if compute_reachable_atoms: + reachable_atoms = utils.get_reachable_atoms(ground_nsrts, init_atoms) + else: + reachable_atoms = set() + + if CFG.planning_filter_unreachable_nsrt: + reachable_nsrts = [ + nsrt for nsrt in ground_nsrts + if nsrt.preconditions.issubset(reachable_atoms) + ] + else: + reachable_nsrts = ground_nsrts + return reachable_nsrts, reachable_atoms # type: ignore[return-value] def task_plan( @@ -460,12 +474,14 @@ def _skeleton_generator( # Generate primitive successors. for nsrt in utils.get_applicable_operators(ground_nsrts, node.atoms): - child_atoms = utils.apply_operator(nsrt, set(node.atoms)) + child_atoms = utils.apply_operator(nsrt, set( + node.atoms)) # type: ignore[type-var] if use_visited_state_set: frozen_atoms = frozenset(child_atoms) if frozen_atoms in visited_atom_sets: continue - child_skeleton = node.skeleton + [nsrt] + child_skeleton = node.skeleton + [nsrt + ] # type: ignore[list-item] child_skeleton_tup = tuple(child_skeleton) if child_skeleton_tup in visited_skeletons: # pragma: no cover continue @@ -543,6 +559,7 @@ def run_low_level_search( plan_found = False while cur_idx < len(skeleton): if time.perf_counter() - start_time > timeout: + logging.debug("Exiting low-level search due to timeout.") return longest_failed_refinement, False assert num_tries[cur_idx] < max_tries[cur_idx] try_start_time = time.perf_counter() @@ -562,9 +579,11 @@ def run_low_level_search( cur_idx += 1 if option.initiable(state): try: + logging.info(f"Running option {option}") next_state, num_actions = \ option_model.get_next_state_and_num_actions(state, option) except EnvironmentFailure as e: + logging.debug(f"Discovered a failure: {e}") can_continue_on = False # Remember only the most recent failure. discovered_failures[cur_idx - 1] = _DiscoveredFailure(e, nsrt) @@ -585,12 +604,20 @@ def run_low_level_search( static_obj_changed = True break if static_obj_changed: + logging.debug("Cannot continue: static object changed.") can_continue_on = False - # Check if we have exceeded the horizon. + # Check if we have exceeded the horizon in total. elif np.sum(num_actions_per_option[:cur_idx]) > max_horizon: + logging.debug("Cannot continue: exceeded total horizon.") + can_continue_on = False + # Check if we have exceeded the horizon individually. + elif num_actions >= CFG.max_num_steps_option_rollout: + logging.debug("Cannot continue: exceeded individual " + "horizon.") can_continue_on = False # Check if the option was effectively a noop. elif num_actions == 0: + logging.debug("Cannot continue: an noop") can_continue_on = False elif CFG.sesame_check_expected_atoms: # Check atoms against expected atoms_sequence constraint. @@ -611,6 +638,8 @@ def run_low_level_search( if cur_idx == len(skeleton): plan_found = True else: + logging.debug("Cannot continue: expected atoms not " + "hold.") can_continue_on = False else: # If we're not checking expected_atoms, we need to @@ -623,6 +652,7 @@ def run_low_level_search( can_continue_on = False else: # The option is not initiable. + logging.debug("Cannot continue: option not initiable.") can_continue_on = False if refinement_time is not None: try_end_time = time.perf_counter() @@ -668,6 +698,7 @@ def run_low_level_search( longest_failed_refinement }) return longest_failed_refinement, False + logging.debug("Option succeed!") # Should only get here if the skeleton was empty. assert not skeleton return [], True @@ -900,10 +931,10 @@ def task_plan_with_option_plan_constraint( ground_nsrts, _ = task_plan_grounding(init_atoms, objects, dummy_nsrts, - allow_noops=True) + allow_waits=True) heuristic = utils.create_task_planning_heuristic( CFG.sesame_task_planning_heuristic, init_atoms, goal, ground_nsrts, - predicates, objects) + predicates, objects) # type: ignore[type-var] def _check_goal( searchnode_state: Tuple[FrozenSet[GroundAtom], int]) -> bool: @@ -921,26 +952,30 @@ def _get_successor_with_correct_option( gt_param_option = option_plan[idx_into_traj][0] gt_objects = option_plan[idx_into_traj][1] - for applicable_nsrt in utils.get_applicable_operators( + for applicable_nsrt in utils.get_applicable_operators( # type: ignore[type-var] ground_nsrts, atoms): # NOTE: we check that the ParameterizedOptions are equal before # attempting to ground because otherwise, we might # get a parameter mismatch and trigger an AssertionError # during grounding. - if applicable_nsrt.option != gt_param_option: + if applicable_nsrt.option != gt_param_option: # type: ignore[attr-defined] continue - if applicable_nsrt.option_objs != gt_objects: + if applicable_nsrt.option_objs != gt_objects: # type: ignore[attr-defined] continue if atoms_seq is not None and not \ - applicable_nsrt.preconditions.issubset( + applicable_nsrt.preconditions.issubset( # type: ignore[attr-defined] atoms_seq[idx_into_traj]): continue - next_atoms = utils.apply_operator(applicable_nsrt, set(atoms)) + next_atoms = utils.apply_operator( + applicable_nsrt, set(atoms)) # type: ignore[type-var] # The returned cost is uniform because we don't # actually care about finding the shortest path; # just one that matches! - yield (applicable_nsrt, (frozenset(next_atoms), idx_into_traj + 1), - 1.0) + yield ( + applicable_nsrt, + (frozenset(next_atoms), + idx_into_traj + 1), # type: ignore[misc] + 1.0) init_atoms_frozen = frozenset(init_atoms) init_searchnode_state = (init_atoms_frozen, 0) @@ -1204,20 +1239,21 @@ def run_task_plan_once( assert task_planning_heuristic is not None heuristic = utils.create_task_planning_heuristic( task_planning_heuristic, init_atoms, goal, ground_nsrts, preds, - objects) + objects) # type: ignore[type-var] duration = time.perf_counter() - start_time timeout -= duration plan, atoms_seq, metrics = next( - task_plan(init_atoms, - goal, - ground_nsrts, - reachable_atoms, - heuristic, - seed, - timeout, - max_skeletons_optimized=1, - use_visited_state_set=True, - **kwargs)) + task_plan( + init_atoms, + goal, + ground_nsrts, # type: ignore[arg-type] + reachable_atoms, + heuristic, + seed, + timeout, + max_skeletons_optimized=1, + use_visited_state_set=True, + **kwargs)) if len(plan) > max_horizon: raise PlanningFailure( "Skeleton produced by A-star exceeds horizon!") diff --git a/predicators/planning_with_processes.py b/predicators/planning_with_processes.py new file mode 100644 index 0000000000..b080f6a40e --- /dev/null +++ b/predicators/planning_with_processes.py @@ -0,0 +1,1595 @@ +from __future__ import annotations + +import heapq as hq +import logging +import sys +import time +from collections import defaultdict +from copy import deepcopy +from dataclasses import dataclass +from itertools import islice +from pprint import pformat +from typing import Callable, Collection, Dict, Iterator, List, Optional, Set, \ + Tuple + +import numpy as np + +from predicators import utils +from predicators.option_model import _OptionModelBase +from predicators.planning import PlanningFailure, PlanningTimeout, \ + _DiscoveredFailureException, _MaxSkeletonsFailure, \ + _SkeletonSearchTimeout, run_low_level_search +from predicators.settings import CFG +from predicators.structs import AbstractProcessPolicy, CausalProcess, \ + DefaultState, DerivedPredicate, EndogenousProcess, GroundAtom, Metrics, \ + Object, Predicate, Task, Type, _GroundCausalProcess, \ + _GroundEndogenousProcess, _GroundExogenousProcess, _Option +from predicators.utils import _TaskPlanningHeuristic + + +def _build_exogenous_process_index( + ground_processes: List[_GroundCausalProcess], +) -> Dict[Predicate, List[_GroundExogenousProcess]]: + """Build index mapping predicates to exogenous processes that have those + predicates in their condition_at_start. + + This helps efficiently find which exogenous processes might be + triggered when new facts become true. + """ + precondition_to_exogenous_processes: Dict[ + Predicate, List[_GroundExogenousProcess]] = defaultdict(list) + for p in ground_processes: + if isinstance(p, _GroundExogenousProcess): + for atom in p.condition_at_start: + precondition_to_exogenous_processes[atom.predicate].append(p) + return precondition_to_exogenous_processes + + +def get_reachable_atoms_from_processes( + ground_processes: List[_GroundCausalProcess], + atoms: Set[GroundAtom], + derived_predicates: Set[DerivedPredicate] = set(), + objects: Set[Object] = set(), +) -> Set[GroundAtom]: + """Get all atoms that are reachable from the init atoms using ground + processes. + + This function builds a relaxed planning graph by applying exogenous processes + and derived predicates similar to when building the relaxed planning graph + in the ff_heuristic. + + Args: + ground_processes: List of grounded causal processes + atoms: Initial set of atoms + derived_predicates: Set of derived predicates to consider + objects: Set of objects for derived predicate evaluation + + Returns: + Set of all reachable atoms + """ + # Pre-compute dependencies for incremental derived predicates + dep_to_derived_preds: Dict[Predicate, + List[DerivedPredicate]] = defaultdict(list) + if derived_predicates: + for der_pred in derived_predicates: + if der_pred.auxiliary_predicates is not None: + for aux_pred in der_pred.auxiliary_predicates: + dep_to_derived_preds[aux_pred].append(der_pred) + + # Initialize with input atoms and any initial derived facts + reachable_atoms = atoms.copy() + if derived_predicates: + reachable_atoms.update( + utils.abstract_with_derived_predicates(reachable_atoms, + derived_predicates, + objects)) + + # Build relaxed planning graph until fixed point + while True: + fixed_point_reached = True + previous_atoms = reachable_atoms.copy() + + # Apply all applicable ground processes + newly_added_primitive_facts = set() + for process in ground_processes: + if process.condition_at_start.issubset(reachable_atoms): + # Add effects that aren't already reachable + new_effects = process.add_effects - reachable_atoms + if new_effects: + fixed_point_reached = False + newly_added_primitive_facts.update(new_effects) + reachable_atoms.update(new_effects) + + # Handle derived predicates incrementally if we added new primitive facts + if newly_added_primitive_facts and derived_predicates: + newly_derived_facts = _run_incremental_derived_predicate_logic( + newly_added_primitive_facts, + previous_atoms, + objects, + dep_to_derived_preds, + ) + if newly_derived_facts: + fixed_point_reached = False + reachable_atoms.update(newly_derived_facts) + + if fixed_point_reached: + break + + return reachable_atoms + + +def process_task_plan_grounding( + init_atoms: Set[GroundAtom], + objects: Set[Object], + cps: Collection[CausalProcess], + allow_waits: bool = True, + compute_reachable_atoms: bool = False, + derived_predicates: Set[DerivedPredicate] = set(), +) -> Tuple[List[_GroundCausalProcess], Set[GroundAtom]]: + """Ground all operators for task planning into dummy _GroundNSRTs, + filtering out ones that are unreachable or have empty effects. + + Also return the set of reachable atoms, which is used by task + planning to quickly determine if a goal is unreachable. + + See the task_plan docstring for usage instructions. + """ + ground_cps = [] + for cp in sorted(cps): + for ground_cp in utils.all_ground_nsrts(cp, objects): + if allow_waits or (ground_cp.add_effects + | ground_cp.delete_effects): + ground_cps.append(ground_cp) + if compute_reachable_atoms: + reachable_atoms = get_reachable_atoms_from_processes( + ground_cps, init_atoms, derived_predicates, + objects) # type: ignore[arg-type] + else: + reachable_atoms = set() + + reachable_nsrts = ground_cps + return reachable_nsrts, reachable_atoms # type: ignore[return-value] + + +@dataclass(repr=False, eq=False) +class _ProcessPlanningNode(): + """ + Args: + state_history: a finegrained, per-step history of the state trajectory + compared to atoms_sequence which is segmented by action. + action_history: a finegrained, per-step history of the action trajectory + compared to skeleton which is segmented by action. + """ + atoms: Set[GroundAtom] # per big step state + skeleton: List[_GroundEndogenousProcess] # per big step action + atoms_sequence: List[Set[GroundAtom]] # expected state sequence + parent: Optional[_ProcessPlanningNode] + cumulative_cost: float + state_history: List[Set[GroundAtom]] # per small step state + action_history: List[ + Optional[_GroundEndogenousProcess]] # per small step action + scheduled_events: Dict[int, List[Tuple[_GroundCausalProcess, int]]] + + +class ProcessWorldModel: + + def __init__( + self, + ground_processes: List[_GroundCausalProcess], + state: Set[GroundAtom], + state_history: List[Set[GroundAtom]] = [], + action_history: List[Optional[_GroundEndogenousProcess]] = [], + scheduled_events: Dict[int, List[Tuple[_GroundCausalProcess, + int]]] = {}, + t: int = 0, + derived_predicates: Set[DerivedPredicate] = set(), + objects: Set[Object] = set(), + precondition_to_exogenous_processes: Optional[Dict[ + Predicate, List[_GroundExogenousProcess]]] = None, + dep_to_derived_preds: Optional[Dict[Predicate, + List[DerivedPredicate]]] = None + ) -> None: + + self.ground_processes = ground_processes + self.state = state + self.state_history = state_history + self.current_action: Optional[_GroundEndogenousProcess] = None + self.action_history = action_history + self.scheduled_events: Dict[int, List[Tuple[_GroundCausalProcess, + int]]] = scheduled_events + self.t = t + self.derived_predicates = derived_predicates + self.objects = objects + + # --- Use provided indexes or build them if not provided --- + self._precondition_to_exogenous_processes: Dict[ + Predicate, List[_GroundExogenousProcess]] + if precondition_to_exogenous_processes is not None: + self._precondition_to_exogenous_processes = precondition_to_exogenous_processes + elif CFG.build_exogenous_process_index_for_planning: + # Fallback: build the index if not provided and CFG allows it + self._precondition_to_exogenous_processes = _build_exogenous_process_index( + self.ground_processes) + else: + # Don't build the index + self._precondition_to_exogenous_processes = defaultdict(list) + + self._dep_to_derived_preds: Dict[Predicate, List[DerivedPredicate]] + if dep_to_derived_preds is not None: + self._dep_to_derived_preds = dep_to_derived_preds + else: + # Fallback: build the index if not provided + self._dep_to_derived_preds = defaultdict(list) + for der_pred in self.derived_predicates: + for aux_pred in der_pred.auxiliary_predicates: # type: ignore[union-attr] + self._dep_to_derived_preds[aux_pred].append(der_pred) + + def small_step( + self, + small_step_action: Optional[_GroundEndogenousProcess] = None + ) -> None: + """Will keep the current action as a class variable for now, as opposed + to a part of the state variable as in the demo code.""" + # 1. self.current_action is set to an action when this small_step is + # first called. And is set back to None when `duration` timesteps + # sampled from its distribution passes. + # `small_step_action` is not None in the first call but becomes None in + # subsequent calls. + if small_step_action is not None: + self.current_action = small_step_action.copy() + self.action_history.append(self.current_action.copy() if self. + current_action is not None else None) + + # 2. Process effects scheduled for this timestep. + if self.t in self.scheduled_events: + primitive_facts_before = { + a + for a in self.state + if not isinstance(a.predicate, DerivedPredicate) + } + + for g_process, start_time in self.scheduled_events[self.t]: + if (all( + g_process.condition_overall.issubset(s) + for s in self.state_history[start_time + 1:]) + and g_process.condition_at_end.issubset(self.state)): + for atom in g_process.delete_effects: + self.state.discard(atom) + for atom in g_process.add_effects: + self.state.add(atom) + if isinstance(g_process, _GroundEndogenousProcess) and\ + small_step_action is None: + self.current_action = None + del self.scheduled_events[self.t] + + if len(self.derived_predicates) > 0: + primitive_facts_after = { + a + for a in self.state + if not isinstance(a.predicate, DerivedPredicate) + } + + # Only update if the primitive facts have changed. + if primitive_facts_before != primitive_facts_after: + deleted_facts = primitive_facts_before - primitive_facts_after + + # If any primitive fact was deleted, a full re-computation + # is the safest way to ensure correctness. + if deleted_facts: + # Remove all old derived facts. + self.state = { + atom + for atom in self.state + if not isinstance(atom.predicate, DerivedPredicate) + } + # Re-compute all derived facts from the new state. + self.state |= utils.abstract_with_derived_predicates( + self.state, self.derived_predicates, self.objects) + + # Otherwise, only additions occurred; we can be incremental. + else: + added_facts = primitive_facts_after - primitive_facts_before + # `existing_facts` includes primitive and derived facts + # before the new additions. + existing_facts_before_increment = self.state - added_facts + newly_derived_facts = _run_incremental_derived_predicate_logic( + added_facts, existing_facts_before_increment, + self.objects, self._dep_to_derived_preds) + self.state.update(newly_derived_facts) + + # 3. Schedule new events whose conditions are met. + # 3a. Handle the endogenous process (action) passed to this step. + # This is for starting a new action. + if (small_step_action is not None + and small_step_action.parent.option.name != 'Wait' + and # type: ignore[attr-defined] + small_step_action.condition_at_start.issubset(self.state)): + delay = small_step_action.delay_distribution.sample() + delay = max(1, delay) + scheduled_time = self.t + delay + if scheduled_time not in self.scheduled_events: + self.scheduled_events[scheduled_time] = [] + self.scheduled_events[scheduled_time].append( + (small_step_action, self.t)) + + # 3b. Handle exogenous processes. + if CFG.build_exogenous_process_index_for_planning: + # Use the index for efficiency. + # Find newly true primitive facts by comparing current vs. previous. + previous_facts = self.state_history[-1] if self.state_history \ + else set() + newly_added_facts = self.state - previous_facts + + # Gather all candidate processes touched by these new facts. + candidate_processes_to_check: Set[_GroundExogenousProcess] = set() + for fact in newly_added_facts: + candidate_processes_to_check.update( + self._precondition_to_exogenous_processes[fact.predicate]) + + # Check the full preconditions for only the candidate processes. + for g_process in candidate_processes_to_check: + if g_process.condition_at_start.issubset(self.state): + delay = g_process.delay_distribution.sample() + delay = max(1, delay) + scheduled_time = self.t + delay + if scheduled_time not in self.scheduled_events: + self.scheduled_events[scheduled_time] = [] + self.scheduled_events[scheduled_time].append( + (g_process, self.t)) + else: + # Fallback: check all exogenous processes (less efficient) + for g_process in self.ground_processes: + if isinstance(g_process, _GroundExogenousProcess): + first_state_or_prev_state_doesnt_satisfy = ( + len(self.state_history) == 0 + or not g_process.condition_at_start.issubset( + self.state_history[-1])) + if g_process.condition_at_start.issubset(self.state) and\ + first_state_or_prev_state_doesnt_satisfy: + delay = g_process.delay_distribution.sample() + delay = max(1, delay) + scheduled_time = self.t + delay + if scheduled_time not in self.scheduled_events: + self.scheduled_events[scheduled_time] = [] + self.scheduled_events[scheduled_time].append( + (g_process, self.t)) + + # --- END MODIFIED --- + + self.state_history.append(self.state.copy()) + + # if the action has finished and set to None. + if self.current_action is None: + return + self.t += 1 + + def big_step(self, + action_process: _GroundEndogenousProcess, + max_num_steps: int = 50) -> Set[GroundAtom]: + """current_action is set to an action in the first call to small_step + and is set to None when 1) the action reaches the end of its duration + 2) some aspects of the state changes; removing this because this can + cause action to stop before the end of its duration 3) reaches + max_num_steps.""" + initial_state = self.state.copy() + num_steps = 0 + action_not_finished = True + + while action_not_finished and num_steps < max_num_steps: + self.small_step(action_process) + num_steps += 1 + + if action_process is not None: + action_process = None # type: ignore[assignment] + + action_not_finished = self.current_action is not None + + # if currently executing Wait and state has changed, then break + if (self.current_action is not None + and self.current_action.parent.option.name == 'Wait' + and # type: ignore[attr-defined] + self.state != initial_state): + break + return self.state + + +def _skeleton_generator_with_processes( + task: Task, + ground_processes: List[_GroundCausalProcess], + init_atoms: Set[GroundAtom], + heuristic: _TaskPlanningHeuristic, + seed: int, + timeout: float, + metrics: Metrics, + max_skeletons_optimized: int, + abstract_policy: Optional[AbstractProcessPolicy] = None, + sesame_max_policy_guided_rollout: int = 0, + use_visited_state_set: bool = False, + log_sucessful_small_steps: bool = False, + log_heuristic: bool = False, + time_heuristic: bool = True, + heuristic_weight: float = 10, + derived_predicates: Set[DerivedPredicate] = set(), + objects: Set[Object] = set(), +) -> Iterator[Tuple[List[_GroundEndogenousProcess], List[Set[GroundAtom]]]]: + + # Filter out all the action from processes + # zero heuristic + objects = objects.copy() + + # --- Build indexes once for all ProcessWorldModel instances --- + # Index for efficient scheduling of exogenous processes + precondition_to_exogenous_processes: Optional[Dict[ + Predicate, List[_GroundExogenousProcess]]] = None + if CFG.build_exogenous_process_index_for_planning: + precondition_to_exogenous_processes = _build_exogenous_process_index( + ground_processes) + + # Pre-compute dependencies for incremental derived predicates + dep_to_derived_preds: Dict[Predicate, + List[DerivedPredicate]] = defaultdict(list) + for der_pred in derived_predicates: + for aux_pred in der_pred.auxiliary_predicates: # type: ignore[union-attr] + dep_to_derived_preds[aux_pred].append(der_pred) + # --- End index building --- + ground_action_processes = [ + p for p in ground_processes if isinstance(p, _GroundEndogenousProcess) + ] + start_time = time.perf_counter() + queue: List[Tuple[float, float, _ProcessPlanningNode]] = [] + root_node = _ProcessPlanningNode( + atoms=init_atoms, + skeleton=[], + atoms_sequence=[init_atoms], + parent=None, + cumulative_cost=0, + state_history=[], + action_history=[], + scheduled_events={}, + ) + metrics["num_nodes_created"] += 1 + rng_prio = np.random.default_rng(seed) + if time_heuristic: + heuristic_call_count = 0 + total_heuristic_time = 0.0 + heuristic_start_time = time.perf_counter() + h = heuristic(root_node.atoms) * heuristic_weight + heuristic_end_time = time.perf_counter() + heuristic_call_count += 1 + total_heuristic_time += (heuristic_end_time - heuristic_start_time) + else: + h = heuristic(root_node.atoms) * heuristic_weight + if log_heuristic: + logging.debug(f"Root heuristic: {h}") + hq.heappush(queue, (h, rng_prio.uniform(), root_node)) + # Initialize with empty skeleton for root. + # We want to keep track of the visited skeletons so that we avoid + # repeatedly outputting the same faulty skeletons. + visited_skeletons: Set[Tuple[_GroundCausalProcess, ...]] = set() + visited_skeletons.add(tuple(root_node.skeleton)) + if use_visited_state_set: + # This set will maintain (frozen) atom sets that have been fully + # expanded already, and ensure that we never expand redundantly. + visited_atom_sets = set() + # Start search. + while queue and (time.perf_counter() - start_time < timeout): + if int(metrics["num_skeletons_optimized"]) == max_skeletons_optimized: + raise _MaxSkeletonsFailure( + "Planning reached max_skeletons_optimized!") + _, _, node = hq.heappop(queue) + if use_visited_state_set: + frozen_atoms = frozenset(node.atoms) + visited_atom_sets.add(frozen_atoms) + # Good debug point #1: print out the skeleton here to see what + # the high-level search is doing. You can accomplish this via: + # for act in node.skeleton: + # logging.info(f"{act.name} {act.objects}") + # logging.info("") + if task.goal.issubset(node.atoms): + # If this skeleton satisfies the goal, yield it. + metrics["num_skeletons_optimized"] += 1 + time_taken = time.perf_counter() - start_time + logging.info(f"\n[Task Planner] Found Plan of length " + f"{len(node.skeleton)} in {time_taken:.2f}s:") + for process in node.skeleton: + logging.debug(process.name_and_objects_str()) + logging.debug("") + + if log_sucessful_small_steps: + prev_state: Optional[Set[GroundAtom]] = None + for i, (state, action) in enumerate( + zip(node.state_history, node.action_history)): + if i == 0: + logging.debug(f"State {i}: {sorted(state)}") + else: + assert prev_state is not None + logging.debug( + f"State {i}: " + f"Add atoms: {sorted(state - prev_state)} " + f"Del atoms: {sorted(prev_state - state)}") + action_str = action.name_and_objects_str() \ + if action is not None else None + logging.info(f"Action {i}: {action_str}\n") + prev_state = state + if prev_state is not None: + logging.debug( + f"State {len(node.state_history)}: " + f"Add atoms: " + f"{sorted(node.state_history[-1] - prev_state)} " + f"Del atoms: " + f"{sorted(prev_state - node.state_history[-1])}") + + # Log heuristic timing stats when a solution is found + if time_heuristic: + average_heuristic_time = total_heuristic_time / heuristic_call_count if heuristic_call_count > 0 else 0.0 + logging.debug( + f"Heuristic timing stats - Calls: {heuristic_call_count}, Total time: {total_heuristic_time:.4f}s, Average time: {average_heuristic_time:.4f}s" + ) + + yield node.skeleton, node.atoms_sequence + else: + # Generate successors. + metrics["num_nodes_expanded"] += 1 + # If an abstract policy is provided, generate policy-based + # successors first. + if abstract_policy is not None: + current_node = node + for _ in range(sesame_max_policy_guided_rollout): + if task.goal.issubset(current_node.atoms): + yield current_node.skeleton, current_node.atoms_sequence + break + ground_process = abstract_policy(current_node.atoms, + objects, task.goal) + if ground_process is None: + break + # Make sure ground_process is applicable and is an endogenous process + if not isinstance(ground_process, + _GroundEndogenousProcess): + break # type: ignore[unreachable] + if not ground_process.condition_at_start.issubset( + current_node.atoms): + break + + # Run the process through the world model to get the resulting state + world_model = ProcessWorldModel( + ground_processes=ground_processes.copy(), + state=current_node.atoms.copy(), + state_history=current_node.state_history.copy(), + action_history=current_node.action_history.copy(), + scheduled_events=deepcopy( + current_node.scheduled_events), + t=len(current_node.state_history), + derived_predicates=derived_predicates, + objects=objects, + precondition_to_exogenous_processes= + precondition_to_exogenous_processes, + dep_to_derived_preds=dep_to_derived_preds) + + world_model.big_step(ground_process) + child_atoms = world_model.state.copy() + + child_skeleton = current_node.skeleton + [ground_process] + child_skeleton_tup = tuple(child_skeleton) + if child_skeleton_tup in visited_skeletons: + continue + visited_skeletons.add(child_skeleton_tup) + # Note: the cost of taking a policy-generated action is 1, + # but the policy-generated skeleton is immediately yielded + # once it reaches a goal. This allows the planner to always + # trust the policy first, but it also allows us to yield a + # policy-generated plan without waiting to exhaustively + # rule out the possibility that some other primitive plans + # are actually lower cost. + child_cost = 1 + current_node.cumulative_cost + child_node = _ProcessPlanningNode( + atoms=child_atoms, + skeleton=child_skeleton, + atoms_sequence=current_node.atoms_sequence + + [child_atoms], + parent=current_node, + cumulative_cost=child_cost, + state_history=world_model.state_history.copy(), + action_history=world_model.action_history.copy(), + scheduled_events=deepcopy( + world_model.scheduled_events)) + metrics["num_nodes_created"] += 1 + # priority is g [cost] plus h [heuristic] + if time_heuristic: + heuristic_start_time = time.perf_counter() + h = heuristic(child_node.atoms) * heuristic_weight + heuristic_end_time = time.perf_counter() + heuristic_call_count += 1 + total_heuristic_time += (heuristic_end_time - + heuristic_start_time) + else: + h = heuristic(child_node.atoms) * heuristic_weight + priority = (child_node.cumulative_cost + h) + hq.heappush(queue, + (priority, rng_prio.uniform(), child_node)) + current_node = child_node + if time.perf_counter() - start_time >= timeout: + break + applicable_actions = list( + utils.get_applicable_operators(ground_action_processes, + node.atoms)) + + # Domain-specific pruning for domino environment + if CFG.env == "pybullet_domino_grid" and CFG.domino_prune_actions: + # Filter out backwards placements and redundant picks + filtered_actions = [] + placed_dominos = set() # Track which dominos have been placed + + # First pass: identify already placed dominos + for prev_action in node.skeleton: + if prev_action.parent.name == "PlaceDomino": + # The domino being placed is the second argument + if len(prev_action.objects) > 1: + placed_dominos.add(prev_action.objects[1]) + + for action in applicable_actions: # type: ignore[assignment] + # Always keep Wait and Push actions + if action.parent.name in ["Wait", "PushStartBlock" + ]: # type: ignore[union-attr] + filtered_actions.append(action) + # For Pick, only pick dominos that haven't been placed yet + elif action.parent.name == "PickDomino": # type: ignore[union-attr] + domino_to_pick = action.objects[ + 1] if len( # type: ignore[union-attr] + action.objects + ) > 1 else None # type: ignore[union-attr] + if domino_to_pick and domino_to_pick not in placed_dominos: + filtered_actions.append(action) + # For Place, apply heuristics + elif action.parent.name == "PlaceDomino": # type: ignore[union-attr] + # Keep all place actions for now, but could add more pruning + # E.g., only place in forward direction, avoid cycles, etc. + filtered_actions.append(action) + else: + filtered_actions.append(action) + + # If pruning removed all actions, fall back to unpruned + if filtered_actions: + applicable_actions = filtered_actions # type: ignore[assignment] + + for action_process in applicable_actions: + + # --- Run the action process on the world model + world_model = ProcessWorldModel( + ground_processes=ground_processes.copy(), + state=node.atoms.copy(), + state_history=node.state_history.copy(), + action_history=node.action_history.copy(), + scheduled_events=deepcopy(node.scheduled_events), + t=len(node.state_history), + derived_predicates=derived_predicates, + objects=objects, + precondition_to_exogenous_processes= + precondition_to_exogenous_processes, + dep_to_derived_preds=dep_to_derived_preds) + + assert isinstance(action_process, _GroundEndogenousProcess) + # plan_so_far = [p.name for p in node.skeleton] + # plan_so_far = [p.name_and_objects_str() for p in node.skeleton] + # logging.debug(f"Expand after plan {plan_so_far}:") + # applicable_actions = list(utils.get_applicable_operators( + # ground_action_processes, node.atoms)) + # num_applicable_actions = len(applicable_actions) + # logging.debug(f"Num applicable actions: {num_applicable_actions}") + # logging.debug(f"Taking action: {action_process.name_and_objects_str()}") + # action_names = [p.name_and_objects_str() for p in node.skeleton] + # # action_names = [p.name for p in node.skeleton] + # # target_action_names = ['PickJugFromOutsideFaucetAndBurner', + # # 'PlaceUnderFaucet', + # # 'SwitchFaucetOn', + # # 'SwitchBurnerOn', + # # 'SwitchFaucetOff', + # # 'PickJugFromFaucet', + # # 'PlaceOnBurner', + # # 'PickJugFromOutsideFaucetAndBurner', + # # 'PlaceUnderFaucet', + # # 'SwitchFaucetOn', + # # 'SwitchBurnerOn', + # # ] + # target_action_names = [ + # # Update with new location naming format: loc_x.xx_y.yy + # # 'PickDomino(robot:robot, domino_1:domino, loc_0.49_1.23:loc, ang_0:angle)', + # # 'PlaceDomino(robot:robot, domino_2:domino, domino_3:domino, pos_y0_x2:loc, rot_135:rot)', + # # 'PickDomino(robot:robot, domino_1:domino, pos_y0_x0:loc, rot_0:rot)', + # # 'PlaceDomino(robot:robot, domino_1:domino, domino_0:domino, pos_y1_x2:loc, rot_180:rot', + # ] + # # if action_names == target_action_names:# and \ + # # action_process.name_and_objects_str() == 'PlaceDomino(robot:robot, domino_1:domino, domino_0:domino, loc_x1_y0:loc, ang_-90:angle)': + # if False: # Update with actual action string when debugging + # # if action_names == target_action_names: + # breakpoint() + world_model.big_step(action_process) + child_atoms = world_model.state.copy() + # --- End + + # Same as standard skeleton generator + if use_visited_state_set: + frozen_atoms = frozenset(child_atoms) + if frozen_atoms in visited_atom_sets: + continue + child_skeleton = node.skeleton + [action_process] + child_skeleton_tup = tuple(child_skeleton) + if child_skeleton_tup in visited_skeletons: # pragma: no cover + continue + visited_skeletons.add(child_skeleton_tup) + # Action costs are unitary. + if action_process.option.name == 'Wait': + action_cost = 0.5 + else: + action_cost = 1.0 + child_cost = node.cumulative_cost + action_cost + child_node = _ProcessPlanningNode( + atoms=child_atoms, + skeleton=child_skeleton.copy(), + atoms_sequence=node.atoms_sequence + [child_atoms], + parent=node, + cumulative_cost=child_cost, + state_history=world_model.state_history.copy(), + action_history=world_model.action_history.copy(), + scheduled_events=deepcopy(world_model.scheduled_events)) + metrics["num_nodes_created"] += 1 + # priority is g [cost] plus h [heuristic] + if time_heuristic: + heuristic_start_time = time.perf_counter() + h = heuristic(child_node.atoms) * heuristic_weight + heuristic_end_time = time.perf_counter() + heuristic_call_count += 1 + total_heuristic_time += (heuristic_end_time - + heuristic_start_time) + else: + h = heuristic(child_node.atoms) * heuristic_weight + priority = (child_node.cumulative_cost + h) + if log_heuristic: + logging.debug( + f"Heuristic: {h}, g: {child_node.cumulative_cost}") + hq.heappush(queue, (priority, rng_prio.uniform(), child_node)) + if time.perf_counter() - start_time >= timeout: + logging.debug(f"Planning timeout of {timeout} reached.") + break + if time_heuristic: + average_heuristic_time = total_heuristic_time / heuristic_call_count if heuristic_call_count > 0 else 0.0 + logging.debug( + f"Heuristic timing stats - Calls: {heuristic_call_count}, " + f"Total time: {total_heuristic_time:.4f}s, " + f"Average time: {average_heuristic_time:.4f}s, " + f"Num_nodes_created: {metrics['num_nodes_created']}, " + f"Num_nodes_expanded: {metrics['num_nodes_expanded']}") + + if not queue: + raise _MaxSkeletonsFailure("Planning ran out of skeletons!") + assert time.perf_counter() - start_time >= timeout + raise _SkeletonSearchTimeout + + +def task_plan_from_task( + task: Task, + predicates: Collection[Predicate], + processes: Set[CausalProcess], + seed: int, + timeout: float, + max_skeletons_optimized: int, + use_visited_state_set: bool = True, + abstract_policy: Optional[AbstractProcessPolicy] = None, + max_policy_guided_rollout: int = 0, +) -> Iterator[Tuple[List[_GroundEndogenousProcess], List[Set[GroundAtom]], + Metrics]]: + predicates_set = set(predicates) + all_predicates = utils.add_in_auxiliary_predicates(predicates_set) + derived_predicates = utils.get_derived_predicates(all_predicates) + + init_atoms = utils.abstract(task.init, all_predicates) + logging.debug("[Task Planner] Task goal atoms: " + f"{pformat(sorted(task.goal))}") + logging.debug("[Task Planner] Task init atoms: " + f"{pformat(sorted(init_atoms))}") + goal = task.goal + objects = set(task.init) + ground_processes, reachable_atoms = process_task_plan_grounding( + init_atoms, + objects, + processes, + allow_waits=True, + compute_reachable_atoms=True, + derived_predicates=derived_predicates) + + if CFG.process_task_planning_heuristic == "goal_count": + heuristic = utils.create_task_planning_heuristic( # type: ignore[type-var] + CFG.process_task_planning_heuristic, init_atoms, goal, + ground_processes, all_predicates, objects) + elif CFG.process_task_planning_heuristic == "lm_cut": + heuristic = create_lm_cut_heuristic( # type: ignore[assignment] + goal, + ground_processes, + derived_predicates, + objects, + use_derived_predicates=CFG.use_derived_predicate_in_heuristic) + elif CFG.process_task_planning_heuristic == "h_max": + heuristic = create_h_max_heuristic( # type: ignore[assignment] + goal, + ground_processes, + derived_predicates, + objects, + use_derived_predicates=CFG.use_derived_predicate_in_heuristic) + + elif CFG.process_task_planning_heuristic == "h_ff": + heuristic = create_ff_heuristic( # type: ignore[assignment] + goal, + ground_processes, + derived_predicates, + objects, + use_derived_predicates=CFG.use_derived_predicate_in_heuristic) + else: + raise ValueError( + f"Unrecognized process_task_planning_heuristic: {CFG.process_task_planning_heuristic}" + ) + + return task_plan( + init_atoms, + goal, + ground_processes, + reachable_atoms, + heuristic, + seed, + timeout, + max_skeletons_optimized, + use_visited_state_set=use_visited_state_set, + derived_predicates=derived_predicates, + objects=objects, + abstract_policy=abstract_policy, + max_policy_guided_rollout=max_policy_guided_rollout, + ) + + +def task_plan( + init_atoms: Set[GroundAtom], + goal: Set[GroundAtom], + ground_processes: List[_GroundCausalProcess], + reachable_atoms: Set[GroundAtom], + heuristic: _TaskPlanningHeuristic, + seed: int, + timeout: float, + max_skeletons_optimized: int, + use_visited_state_set: bool = True, + derived_predicates: Set[DerivedPredicate] = set(), + objects: Set[Object] = set(), + abstract_policy: Optional[AbstractProcessPolicy] = None, + max_policy_guided_rollout: int = 0, +) -> Iterator[Tuple[List[_GroundEndogenousProcess], List[Set[GroundAtom]], + Metrics]]: + """Run only the task planning portion of SeSamE. A* search is run, and + skeletons that achieve the goal symbolically are yielded. Specifically, + yields a tuple of (skeleton, atoms sequence, metrics dictionary). + + This method is NOT used by SeSamE, but is instead provided as a + convenient wrapper around _skeleton_generator below (which IS used + by SeSamE) that takes in only the minimal necessary arguments. + + This method is tightly coupled with task_plan_grounding -- the reason they + are separate methods is that it is sometimes possible to ground only once + and then plan multiple times (e.g. from different initial states, or to + different goals). To run task planning once, call task_plan_grounding to + get ground_nsrts and reachable_atoms; then create a heuristic using + utils.create_task_planning_heuristic; then call this method. See the tests + in tests/test_planning for usage examples. + """ + if CFG.planning_check_dr_reachable and not goal.issubset(reachable_atoms): + logging.info(f"Detected goal unreachable. Goal: {goal}") + logging.info(f"Initial atoms: {init_atoms}") + raise PlanningFailure(f"Goal {goal} not dr-reachable") + dummy_task = Task(DefaultState, goal) + metrics: Metrics = defaultdict(float) + # logging.debug(f"init_atoms: {init_atoms}") + generator = _skeleton_generator_with_processes( + dummy_task, + ground_processes, + init_atoms, + heuristic, + seed, + timeout, + metrics, + max_skeletons_optimized, + abstract_policy=abstract_policy, + sesame_max_policy_guided_rollout=max_policy_guided_rollout, + use_visited_state_set=use_visited_state_set, + derived_predicates=derived_predicates, + objects=objects, + heuristic_weight=CFG.process_planning_heuristic_weight, + ) + + # Note that we use this pattern to avoid having to catch an exception + # when _skeleton_generator runs out of skeletons to optimize. + for skeleton, atoms_sequence in islice(generator, max_skeletons_optimized): + yield skeleton, atoms_sequence, metrics.copy() + + +def run_task_plan_with_processes_once( + task: Task, + processes: Set[CausalProcess], + preds: Set[Predicate], + types: Set[Type], + timeout: float, + seed: int, + task_planning_heuristic: str, + max_horizon: float = np.inf, + compute_reachable_atoms: bool = False, + abstract_policy: Optional[AbstractProcessPolicy] = None, + max_policy_guided_rollout: int = 0, +) -> Tuple[List[_GroundEndogenousProcess], List[Set[GroundAtom]], Metrics]: + """Get a single abstract plan for a task. + + The sequence of ground atom sets returned represent NECESSARY atoms. + """ + + start_time = time.perf_counter() + + if CFG.sesame_task_planner == "astar": + duration = time.perf_counter() - start_time + timeout -= duration + plan, atoms_seq, metrics = next( + task_plan_from_task( + task, + preds, + processes, + seed, + timeout, + max_skeletons_optimized=1, + abstract_policy=abstract_policy, + max_policy_guided_rollout=max_policy_guided_rollout, + )) + if len(plan) > max_horizon: + raise PlanningFailure( + "Skeleton produced by A-star exceeds horizon!") + else: + raise ValueError("Unrecognized sesame_task_planner: " + f"{CFG.sesame_task_planner}") + + # comment out for now + # necessary_atoms_seq = utils.compute_necessary_atoms_seq( + # plan, atoms_seq, goal) + necessary_atoms_seq: List[Set[GroundAtom]] = [] + + return plan, necessary_atoms_seq, metrics + + +def sesame_plan_with_processes( + task: Task, + option_model: _OptionModelBase, + processes: Set[CausalProcess], + predicates: Set[Predicate], + timeout: float, + seed: int, + max_skeletons_optimized: int, + max_horizon: int, + abstract_policy: Optional[AbstractProcessPolicy] = None, + max_policy_guided_rollout: int = 0, +) -> Tuple[List[_Option], List[_GroundEndogenousProcess], Metrics]: + """Run bilevel planning with processes (SeSamE-style). + + Generates process skeletons via A* search and refines each with low- + level search (backtracking over continuous parameter samples). + Returns a sequence of options, the process skeleton, and metrics. + """ + start_time = time.perf_counter() + + gen = task_plan_from_task( + task, + predicates, + processes, + seed, + timeout - (time.perf_counter() - start_time), + max_skeletons_optimized, + abstract_policy=abstract_policy, + max_policy_guided_rollout=max_policy_guided_rollout, + ) + + partial_refinements: list = [] + metrics: Metrics = defaultdict(float) + refinement_start_time = time.perf_counter() + + for skeleton, atoms_sequence, skel_metrics in gen: + # Update metrics from skeleton generation. + for k, v in skel_metrics.items(): + metrics[k] = v + + logging.debug(f"Found process skeleton: " + f"{[p.name_and_objects_str() for p in skeleton]}") + + try: + plan, suc = run_low_level_search( + task, + option_model, + skeleton, # type: ignore[arg-type] + atoms_sequence, + seed, + timeout - (time.perf_counter() - start_time), + metrics, + max_horizon) + except _DiscoveredFailureException: + # Process planning doesn't support failure discovery; + # treat as a failed skeleton. + suc = False + plan = [] + + if suc: + logging.info( + f"Process planning succeeded! Found plan of length " + f"{len(plan)} after " + f"{int(metrics['num_skeletons_optimized'])} " + f"skeletons with {int(metrics['num_samples'])} samples") + metrics["plan_length"] = len(plan) + metrics["refinement_time"] = (time.perf_counter() - + refinement_start_time) + return plan, skeleton, metrics + + partial_refinements.append((skeleton, plan)) + if time.perf_counter() - start_time > timeout: + raise PlanningTimeout( + "Process planning timed out in refinement!", + info={"partial_refinements": partial_refinements}) + + raise PlanningFailure("Process planning exhausted all skeletons!", + info={"partial_refinements": partial_refinements}) + + +def create_ff_heuristic( + goal: Set[GroundAtom], + ground_processes: List[_GroundCausalProcess], + derived_predicates: Set[DerivedPredicate] = set(), + objects: Set[Object] = set(), + use_derived_predicates: bool = True, + debug_log: bool = False, +) -> Callable[[Set[GroundAtom]], float]: + """Creates a callable FF heuristic function with efficient RPG + generation.""" + + adds_map: Dict[GroundAtom, List[_GroundCausalProcess]] = defaultdict(list) + for process in ground_processes: + for atom in process.add_effects: + adds_map[atom].append(process) + + # --- CHANGE START: Use pre-computation for the shared function --- + dep_to_derived_preds: Dict[Predicate, + List[DerivedPredicate]] = defaultdict(list) + if use_derived_predicates: + for der_pred in derived_predicates: + assert der_pred.auxiliary_predicates is not None, \ + f"Can't find auxiliary predicates for derived predicate " +\ + f"{der_pred.name}" + for aux_pred in der_pred.auxiliary_predicates: + dep_to_derived_preds[aux_pred].append(der_pred) + # --- CHANGE END --- + + def _ff_heuristic(atoms: Set[GroundAtom]) -> float: + """The FF heuristic using incremental RPG generation.""" + if goal.issubset(atoms): + return 0.0 + + # --- 1. Build the Relaxed Planning Graph (RPG) --- + initial_facts = atoms.copy() + if use_derived_predicates: + # The first layer must be a full, non-incremental computation. + initial_facts.update( + utils.abstract_with_derived_predicates(initial_facts, + derived_predicates, + objects)) + + fact_layers: List[Set[GroundAtom]] = [initial_facts] + process_layers: List[Set[_GroundCausalProcess]] = [] + + if debug_log: + count = 1 + logging.debug(f"Initial facts: {sorted(initial_facts)}") + while not goal.issubset(fact_layers[-1]): + if debug_log: + logging.debug(f"Applying actions {count}...") + count += 1 + current_facts = fact_layers[-1] + + # Find all processes whose preconditions are met in the current layer. + applicable_processes: Set[_GroundCausalProcess] = set() + for process in ground_processes: + if process.condition_at_start.issubset(current_facts): + applicable_processes.add(process) + + process_layers.append(applicable_processes) + + # --- Incremental Fact Generation --- + # a) Collect all new primitive facts from applicable processes. + primitive_add_effects = set() + for process in applicable_processes: + primitive_add_effects.update(process.add_effects) + + newly_added_primitive_facts = primitive_add_effects - current_facts + if debug_log: + logging.debug( + f"Newly added primitive facts: {sorted(newly_added_primitive_facts)}" + ) + + # b) Incrementally compute new derived facts. + newly_derived_facts = set() + if use_derived_predicates: + # --- CHANGE START: Call the shared function --- + newly_derived_facts = _run_incremental_derived_predicate_logic( + newly_added_primitive_facts, + current_facts, + objects, + dep_to_derived_preds, + ) + # --- CHANGE END --- + if debug_log: + logging.debug( + f"Newly derived facts: {sorted(newly_derived_facts)}\n" + ) + + next_facts = current_facts | newly_added_primitive_facts | newly_derived_facts + + # If the new layer is identical to the old one, we've stagnated. + if next_facts == current_facts: + return float('inf') + + fact_layers.append(next_facts) + + # --- 2. Extract a Relaxed Plan (Backward Search through the RPG) --- + relaxed_plan_actions: Set[_GroundEndogenousProcess] = set() + subgoals_to_achieve = goal.copy() + + for i in range(len(fact_layers) - 1, 0, -1): + + if use_derived_predicates: + for subgoal in subgoals_to_achieve.copy(): + # Case 1: The subgoal is a DERIVED predicate. + # It is achieved 'for free' by its supporting auxiliary predicates. + if isinstance(subgoal.predicate, DerivedPredicate): + # The new subgoals are the auxiliary predicates that support it. + # In a relaxed plan, we conservatively add all atoms from the + # previous layer that could be supporters. + try: + supporter_predicates =\ + utils.get_base_supporter_predicates( + subgoal.predicate) + except Exception as e: + logging.error( + f"Error getting base supporter predicates for {subgoal.predicate}: {e}" + ) + breakpoint() + new_subgoals = { + atom + for atom in fact_layers[i - 1] + if atom.predicate in supporter_predicates + } + + subgoals_to_achieve.update(new_subgoals) + subgoals_to_achieve.discard(subgoal) + if debug_log: + logging.debug(f"\nLayer {i} Subgoals to achieve: " + f"{sorted(subgoals_to_achieve)}") + + unachieved_subgoals = subgoals_to_achieve.copy() + for subgoal in unachieved_subgoals: + # If the subgoal appeared for the first time in this layer... + if subgoal in fact_layers[i] and subgoal not in fact_layers[i - + 1]: + + if debug_log: + logging.debug(f"Considering subgoal: {subgoal}") + + # Case 2: The subgoal is a PRIMITIVE predicate (original logic). + best_supporter = None + # Find a process from the previous layer that achieves it. + for process in adds_map.get(subgoal, []): + if process in process_layers[i - 1]: + if debug_log: + logging.debug( + f"Found supporter for {subgoal}: " + f"{process.name_and_objects_str()}") + best_supporter = process + break + + if best_supporter: + # Only agent actions (endogenous) contribute to the plan cost. + if isinstance(best_supporter, + _GroundEndogenousProcess): + relaxed_plan_actions.add(best_supporter) + + # Add the supporter's preconditions to our set of subgoals. + subgoals_to_achieve.update( + best_supporter.condition_at_start) + subgoals_to_achieve.discard(subgoal) + + return float(len(relaxed_plan_actions)) + + return _ff_heuristic + + +def create_lm_cut_heuristic( + goal: Set[GroundAtom], + ground_processes: List[_GroundCausalProcess], + derived_predicates: Set[DerivedPredicate] = set(), + objects: Set[Object] = set(), + use_derived_predicates: bool = True, +) -> Callable[[Set[GroundAtom]], float]: + """Creates a callable LM-cut heuristic function. + + This heuristic iteratively finds landmarks by computing a relaxed + plan, calculating its cost, and then assuming its effects have been + achieved before solving for the next landmark. This is a practical + implementation of the LM-cut principle. It also correctly handles + exogenous processes and derived predicates (axioms) as zero-cost + events. + """ + + # --- Pre-computation to speed up sub-problems --- + adds_map: Dict[GroundAtom, List[_GroundCausalProcess]] = defaultdict(list) + for process in ground_processes: + for atom in process.add_effects: + adds_map[atom].append(process) + + # --- CHANGE START: Use pre-computation for the shared function --- + dep_to_derived_preds: Dict[Predicate, + List[DerivedPredicate]] = defaultdict(list) + if use_derived_predicates: + for der_pred in derived_predicates: + for aux_pred in der_pred.auxiliary_predicates: # type: ignore[union-attr] + dep_to_derived_preds[aux_pred].append(der_pred) + # --- CHANGE END --- + + def _calculate_relaxed_plan( + current_atoms: Set[GroundAtom], current_goal: Set[GroundAtom] + ) -> Tuple[float, Set[_GroundCausalProcess]]: + """Helper that computes one relaxed plan (our landmark) from a given + state.""" + initial_facts = current_atoms.copy() + if use_derived_predicates: + initial_facts.update( + utils.abstract_with_derived_predicates(initial_facts, + derived_predicates, + objects)) + + if current_goal.issubset(initial_facts): + return 0.0, set() + + fact_layers: List[Set[GroundAtom]] = [initial_facts] + process_layers: List[Set[_GroundCausalProcess]] = [] + + while not current_goal.issubset(fact_layers[-1]): + current_facts = fact_layers[-1] + + applicable_processes: Set[_GroundCausalProcess] = set() + for process in ground_processes: + if process.condition_at_start.issubset(current_facts): + applicable_processes.add(process) + + process_layers.append(applicable_processes) + + primitive_add_effects = set() + for process in applicable_processes: + primitive_add_effects.update(process.add_effects) + newly_added_primitive_facts = primitive_add_effects - current_facts + + newly_derived_facts = set() + if use_derived_predicates: + # --- CHANGE START: Call the shared function --- + newly_derived_facts = _run_incremental_derived_predicate_logic( + newly_added_primitive_facts, + current_facts, + objects, + dep_to_derived_preds, + ) + # --- CHANGE END --- + + next_facts = current_facts | newly_added_primitive_facts | newly_derived_facts + + if next_facts == current_facts: + return float('inf'), set() + + fact_layers.append(next_facts) + + # 2. Extract one relaxed plan via backward search. + relaxed_plan: Set[_GroundCausalProcess] = set() + subgoals_to_achieve = current_goal.copy() + + for i in range(len(fact_layers) - 1, 0, -1): + for subgoal in subgoals_to_achieve.copy(): + if subgoal in fact_layers[i] and subgoal not in fact_layers[i - + 1]: + best_supporter = None + for process in adds_map.get(subgoal, []): + if process in process_layers[i - 1]: + best_supporter = process + break + + if best_supporter: + relaxed_plan.add(best_supporter) + subgoals_to_achieve.update( + best_supporter.condition_at_start) + subgoals_to_achieve.discard(subgoal) + + # 3. Calculate the cost of the relaxed plan. + cost = 0.0 + for process in relaxed_plan: + # Endogenous processes (agent actions) have a cost. + if isinstance(process, _GroundEndogenousProcess): + # Use axiom_cost if it's a derived predicate axiom, otherwise default to 1. + cost += getattr(process, 'axiom_cost', 1.0) + + return cost, relaxed_plan + + def _lm_cut_heuristic(atoms: Set[GroundAtom]) -> float: + """The main heuristic function. + + It iteratively calls the relaxed plan solver to find and sum the + costs of landmarks. + """ + total_cost = 0.0 + current_atoms = atoms.copy() + + # Loop until the goal is satisfied in our simulated state. + while not goal.issubset(current_atoms): + # Find the cost and plan for the next landmark. + landmark_cost, landmark_plan = _calculate_relaxed_plan( + current_atoms, goal) + + # If a landmark is infinitely costly, the goal is unreachable. + if landmark_cost == float('inf'): + return float('inf') + + # If we found a plan with no cost (e.g., only free events), + # but haven't reached the goal, we must force progress by adding + # at least one real action. A cost of 1 is the minimum. + if landmark_cost == 0.0: + total_cost += 1.0 + + total_cost += landmark_cost + + # "Apply" the landmark by adding the effects of its plan to our state. + if not landmark_plan: + # Should not be reachable if cost is not inf, but as a safeguard... + return float('inf') + + for process in landmark_plan: + current_atoms.update(process.add_effects) + + return total_cost + + return _lm_cut_heuristic + + +def create_h_max_heuristic( + goal: Set[GroundAtom], + ground_processes: List[_GroundCausalProcess], + derived_predicates: Set[DerivedPredicate] = set(), + objects: Set[Object] = set(), + use_derived_predicates: bool = True, +) -> Callable[[Set[GroundAtom]], float]: + """Creates a callable h_max heuristic function. + + This heuristic is compatible with exogenous processes (zero-cost) + and derived predicates (zero-cost). It works by building a Relaxed + Planning Graph (RPG) and finding the maximum cost to achieve any + single atom in the goal set. The cost of an atom is the cost of the + cheapest process that achieves it, where the cost of a process is + the maximum cost of any of its preconditions plus its own cost (1 + for actions, 0 otherwise). + """ + + # Pre-computation for derived predicate dependencies. + dep_to_derived_preds: Dict[Predicate, + List[DerivedPredicate]] = defaultdict(list) + if use_derived_predicates: + for der_pred in derived_predicates: + for aux_pred in der_pred.auxiliary_predicates: # type: ignore[union-attr] + dep_to_derived_preds[aux_pred].append(der_pred) + + def _h_max_heuristic(atoms: Set[GroundAtom]) -> float: + """The h_max heuristic function.""" + if goal.issubset(atoms): + return 0.0 + + # Initialize costs: 0 for initial atoms, infinity otherwise. + atom_costs = defaultdict(lambda: float('inf')) + for atom in atoms: + atom_costs[atom] = 0.0 + + # Iteratively relax costs until a fixed point is reached. + while True: + costs_changed = False + + # --- 1. Propagate costs through primitive processes --- + for process in ground_processes: + # Cost of preconditions is the max cost of any single precond. + precond_cost = max( + [atom_costs[p] for p in process.condition_at_start] + or [0.0]) + + if precond_cost == float('inf'): + continue + + # Actions (endogenous) have cost 1, others (exogenous) have cost 0. + process_cost = 1.0 if isinstance( + process, _GroundEndogenousProcess) else 0.0 + total_cost = precond_cost + process_cost + + # Update costs of effects if we found a cheaper way to achieve them. + for effect in process.add_effects: + if total_cost < atom_costs[effect]: + atom_costs[effect] = total_cost + costs_changed = True + + # --- 2. Propagate costs through derived predicates (zero-cost) --- + if use_derived_predicates: + # We need to loop here to handle chains of derived predicates. + while True: + derived_costs_changed = False + # This logic is a simplified version of the incremental approach, + # adapted for h_max's cost propagation. + current_facts_for_eval = { + a + for a, c in atom_costs.items() if c != float('inf') + } + + # Check all derived predicates whose inputs might have changed. + derived_atoms = utils._abstract_with_derived_predicates( + current_facts_for_eval, derived_predicates, objects) + + for derived_atom in derived_atoms: + # To determine the cost, we need to find the specific + # atoms that make this derived predicate true. This is + # complex, so we approximate by taking the max cost + # of any atom in the current state. This is a safe + # over-approximation for the preconditions. A more + # precise implementation would require inspecting the + # logic inside the 'holds' function. For now, we + # find the cost of the supporter atoms. + # NOTE: This is a simplification. A fully correct h_max + # would need to know the specific atoms that satisfy + # the 'holds' condition. We find the supporters by + # checking the auxiliary predicates. + supporter_atoms: Set[GroundAtom] = set() + for p in derived_atom.predicate.auxiliary_predicates: # type: ignore[attr-defined] + supporter_atoms.update( + a for a in current_facts_for_eval + if a.predicate == p) + + if not supporter_atoms: continue + + derived_cost = max( + [atom_costs[a] for a in supporter_atoms] or [0.0]) + + if derived_cost < atom_costs[derived_atom]: + atom_costs[derived_atom] = derived_cost + derived_costs_changed = True + costs_changed = True + + if not derived_costs_changed: + break + + # If no costs were updated in a full pass, we've reached a fixed point. + if not costs_changed: + break + + # The heuristic value is the max cost of any goal atom. + goal_costs = [atom_costs[g] for g in goal] + + # If any goal atom is infinitely costly, the goal is unreachable. + if not goal_costs or max(goal_costs) == float('inf'): + return float('inf') + + return max(goal_costs) + + return _h_max_heuristic + + +def _run_incremental_derived_predicate_logic( + newly_added_facts: Set[GroundAtom], + existing_facts: Set[GroundAtom], + objects: Set[Object], + dep_to_derived_preds: Dict[Predicate, List[DerivedPredicate]], +) -> Set[GroundAtom]: + """Incrementally compute the fixed point of derived predicate atoms.""" + all_newly_derived_facts: Set[GroundAtom] = set() + facts_for_next_iter = newly_added_facts.copy() + + while facts_for_next_iter: + derived_preds_to_check: Set[DerivedPredicate] = set() + for fact in facts_for_next_iter: + if fact.predicate in dep_to_derived_preds: + derived_preds_to_check.update( + dep_to_derived_preds[fact.predicate]) + + if not derived_preds_to_check: + break + + current_state_for_eval = existing_facts | all_newly_derived_facts |\ + newly_added_facts + potential_new_atoms = utils._abstract_with_derived_predicates( + current_state_for_eval, derived_preds_to_check, objects) + + truly_new_atoms = potential_new_atoms - (existing_facts + | all_newly_derived_facts) + + if not truly_new_atoms: + break + + all_newly_derived_facts.update(truly_new_atoms) + facts_for_next_iter = truly_new_atoms + + return all_newly_derived_facts + + +if __name__ == "__main__": + from predicators.envs.pybullet_boil import PyBulletBoilEnv + from predicators.ground_truth_models import get_gt_options, \ + get_gt_processes + args = utils.parse_args() + utils.update_config(args) + str_args = " ".join(sys.argv) + utils.configure_logging() + CFG.seed = 0 + CFG.env = "pybullet_boil" + CFG.planning_filter_unreachable_nsrt = False + CFG.planning_check_dr_reachable = False + + env = PyBulletBoilEnv() + # objects + robot = env._robot + faucet = env._faucet + jug1 = env._jugs[0] + burner1 = env._burners[0] + + # Processes + options = get_gt_options(env.get_name()) + processes = get_gt_processes(env.get_name(), env.predicates, options) + action_processes = [ + p for p in processes if isinstance(p, EndogenousProcess) + ] + pick = [p for p in action_processes if p.name == 'PickJugFromFaucet'][0] + # place = [p for p in action_processes if p.name == 'PlaceUnderFaucet'][0] + switch_on = [p for p in action_processes if p.name == 'SwitchFaucetOn'][0] + switch_off = [p for p in action_processes + if p.name == 'SwitchFaucetOff'][0] + wait_proc = [p for p in action_processes if p.name == 'Wait'][0] + + plan: List[_GroundEndogenousProcess] = [ + switch_on.ground([robot, faucet]), + switch_off.ground([robot, faucet]), + wait_proc.ground([robot]), + wait_proc.ground([robot]) + ] + + # Predicates + predicates = env.predicates + + def policy() -> Optional[_GroundEndogenousProcess]: + global plan + if len(plan) > 0: + return plan.pop(0) + else: + return None + + # Task + rng = np.random.default_rng(CFG.seed) + task = env._make_tasks(1, rng)[0] # type: ignore[call-arg, arg-type] + ground_processes, _reachable_atoms = process_task_plan_grounding( + init_atoms=task.init, # type: ignore[arg-type] + objects=set(task.init), + cps=processes, + allow_waits=True, + compute_reachable_atoms=False) + + world_model = ProcessWorldModel(ground_processes=ground_processes, + state=utils.abstract( + task.init, predicates), + state_history=[], + action_history=[], + scheduled_events={}, + t=0) + for _ in range(100): + action = policy() + if action is not None: + world_model.big_step(action) + else: + break diff --git a/predicators/predicate_search_score_functions.py b/predicators/predicate_search_score_functions.py index c4f8a24547..0884615629 100644 --- a/predicators/predicate_search_score_functions.py +++ b/predicators/predicate_search_score_functions.py @@ -7,27 +7,34 @@ import re import time from dataclasses import dataclass, field -from typing import Callable, Collection, Dict, FrozenSet, List, Sequence, \ - Set, Tuple +from typing import Callable, Collection, Dict, FrozenSet, List, Optional, \ + Sequence, Set, Tuple import numpy as np from predicators import utils +from predicators.nsrt_learning.process_learning_main import \ + learn_processes_from_data from predicators.nsrt_learning.segmentation import segment_trajectory from predicators.nsrt_learning.strips_learning import learn_strips_operators from predicators.planning import PlanningFailure, PlanningTimeout, task_plan, \ task_plan_grounding +from predicators.planning_with_processes import \ + task_plan_from_task as task_plan_with_processes from predicators.settings import CFG -from predicators.structs import GroundAtom, GroundAtomTrajectory, \ - LowLevelTrajectory, Object, OptionSpec, Predicate, Segment, \ - STRIPSOperator, Task, _GroundSTRIPSOperator +from predicators.structs import CausalProcess, GroundAtom, \ + GroundAtomTrajectory, LowLevelTrajectory, Object, OptionSpec, Predicate, \ + Segment, STRIPSOperator, Task, _GroundSTRIPSOperator def create_score_function( - score_function_name: str, initial_predicates: Set[Predicate], - atom_dataset: List[GroundAtomTrajectory], candidates: Dict[Predicate, - float], - train_tasks: List[Task]) -> _PredicateSearchScoreFunction: + score_function_name: str, + initial_predicates: Set[Predicate], + atom_dataset: List[GroundAtomTrajectory], + candidates: Dict[Predicate, float], + train_tasks: List[Task], + current_processes: Optional[Set[CausalProcess]], + use_processes: bool = False) -> _PredicateSearchScoreFunction: """Public method for creating a score function object.""" if score_function_name == "prediction_error": return _PredictionErrorScoreFunction(initial_predicates, atom_dataset, @@ -38,7 +45,7 @@ def create_score_function( if score_function_name == "hadd_match": return _RelaxationHeuristicMatchBasedScoreFunction( initial_predicates, atom_dataset, candidates, train_tasks, - ["hadd"]) + ["hadd"]) # type: ignore[arg-type] match = re.match(r"([a-z\,]+)_(\w+)_lookaheaddepth(\d+)", score_function_name) if match is not None: @@ -59,7 +66,7 @@ def create_score_function( atom_dataset, candidates, train_tasks, - heuristic_names, + heuristic_names, # type: ignore[arg-type] lookahead_depth=lookahead_depth) assert score_name == "count" return _RelaxationHeuristicCountBasedScoreFunction( @@ -67,7 +74,7 @@ def create_score_function( atom_dataset, candidates, train_tasks, - heuristic_names, + heuristic_names, # type: ignore[arg-type] lookahead_depth=lookahead_depth, demos_only=False) if score_function_name == "exact_energy": @@ -89,9 +96,14 @@ def create_score_function( created_or_expanded = match.groups()[0] assert created_or_expanded in ("created", "expanded") metric_name = f"num_nodes_{created_or_expanded}" - return _ExpectedNodesScoreFunction(initial_predicates, atom_dataset, - candidates, train_tasks, - metric_name) + return _ExpectedNodesScoreFunction( + initial_predicates, + atom_dataset, + candidates, + train_tasks, + _current_processes=current_processes, + _use_processes=use_processes, + metric_name=metric_name) raise NotImplementedError( f"Unknown score function: {score_function_name}.") @@ -122,11 +134,15 @@ def _get_predicate_penalty( @dataclass(frozen=True, eq=False, repr=False) class _OperatorLearningBasedScoreFunction(_PredicateSearchScoreFunction): """A score function that learns operators given the set of predicates.""" + _current_processes: Optional[Set[CausalProcess]] = field(default=None) + _use_processes: bool = False def evaluate(self, candidate_predicates: FrozenSet[Predicate]) -> float: + # Lower scores are better. total_cost = sum(self._candidates[pred] for pred in candidate_predicates) - logging.info(f"Evaluating predicates: {candidate_predicates}, with " + new_predicates = candidate_predicates - self._initial_predicates + logging.info(f"Evaluating: {new_predicates}, with " f"total cost {total_cost}") start_time = time.perf_counter() pruned_atom_data = utils.prune_ground_atom_dataset( @@ -143,36 +159,63 @@ def evaluate(self, candidate_predicates: FrozenSet[Predicate]) -> float: low_level_trajs = [ll_traj for ll_traj, _ in pruned_atom_data] del pruned_atom_data try: - pnads = learn_strips_operators(low_level_trajs, - self._train_tasks, - set(candidate_predicates - | self._initial_predicates), - segmented_trajs, - verify_harmlessness=False, - verbose=False, - annotations=None) + if self._use_processes: + assert CFG.only_learn_exogenous_processes, \ + "Learning endogenous processes is not supported yet." + # We can currently use this because we are only learning + # exogenous processes; don't do sampler learning for actions. + processes = learn_processes_from_data( # type: ignore[call-arg] + low_level_trajs, + self._train_tasks, + set(candidate_predicates | self._initial_predicates), + current_processes=self._current_processes, + relearn_all_exogenous_processes=True, + log_all_processes=False, + ) + else: + pnads = learn_strips_operators(low_level_trajs, + self._train_tasks, + set(candidate_predicates + | self._initial_predicates), + segmented_trajs, + verify_harmlessness=False, + verbose=False, + annotations=None) except TimeoutError: logging.info( "Warning: Operator Learning timed out! Skipping evaluation.") return float('inf') - logging.debug( - f"Learned {len(pnads)} operators for this predicate set.") - for pnad in pnads: + if self._use_processes: + op_score = self.evaluate_with_operators( + candidate_predicates, + low_level_trajs, + segmented_trajs, + processes, # type: ignore[arg-type] + []) + strips_ops = processes # type: ignore[assignment] + else: logging.debug( - f"Operator {pnad.op.name} has {len(pnad.datastore)} datapoints." - ) - strips_ops = [pnad.op for pnad in pnads] - option_specs = [pnad.option_spec for pnad in pnads] - op_score = self.evaluate_with_operators(candidate_predicates, - low_level_trajs, - segmented_trajs, strips_ops, - option_specs) + f"Learned {len(pnads)} operators for this predicate set.") + for pnad in pnads: + logging.debug( + f"Operator {pnad.op.name} has {len(pnad.datastore)} datapoints." + ) + strips_ops = [pnad.op + for pnad in pnads] # type: ignore[assignment] + option_specs = [pnad.option_spec for pnad in pnads] + op_score = self.evaluate_with_operators( + candidate_predicates, low_level_trajs, segmented_trajs, + strips_ops, option_specs) # type: ignore[arg-type] pred_penalty = self._get_predicate_penalty(candidate_predicates) - op_penalty = self._get_operator_penalty(strips_ops) + op_penalty = self._get_operator_penalty( + strips_ops) # type: ignore[arg-type] total_score = op_score + pred_penalty + op_penalty - logging.info(f"\tTotal score: {total_score} computed in " - f"{time.perf_counter()-start_time:.3f} seconds") + logging.info( + f"\tTotal score: {total_score:.3f}, " + f"model score: {op_score:.3f} " + f"pred penalty: {pred_penalty}, model penalty: {op_penalty} " + f"computed in {time.perf_counter()-start_time:.3f} seconds") return total_score def evaluate_with_operators(self, @@ -263,20 +306,21 @@ def evaluate_with_operators(self, ground_nsrts, reachable_atoms = task_plan_grounding( init_atoms, objects, dummy_nsrts) traj_goal = self._train_tasks[traj.train_task_idx].goal - heuristic = utils.create_task_planning_heuristic( + heuristic = utils.create_task_planning_heuristic( # type: ignore[type-var] CFG.sesame_task_planning_heuristic, init_atoms, traj_goal, ground_nsrts, candidate_predicates | self._initial_predicates, objects) try: _, _, metrics = next( - task_plan(init_atoms, - traj_goal, - ground_nsrts, - reachable_atoms, - heuristic, - CFG.seed, - CFG.grammar_search_task_planning_timeout, - max_skeletons_optimized=1)) + task_plan( + init_atoms, + traj_goal, + ground_nsrts, # type: ignore[arg-type] + reachable_atoms, + heuristic, + CFG.seed, + CFG.grammar_search_task_planning_timeout, + max_skeletons_optimized=1)) assert "num_nodes_expanded" in metrics node_expansions = metrics["num_nodes_expanded"] assert node_expansions < node_expansion_upper_bound @@ -301,7 +345,8 @@ class _ExpectedNodesScoreFunction(_OperatorLearningBasedScoreFunction): difference gets larger. """ - metric_name: str # num_nodes_created or num_nodes_expanded + metric_name: str = field( + kw_only=True) # num_nodes_created or num_nodes_expanded def evaluate_with_operators(self, candidate_predicates: FrozenSet[Predicate], @@ -312,30 +357,61 @@ def evaluate_with_operators(self, assert self.metric_name in ("num_nodes_created", "num_nodes_expanded") score = 0.0 seen_demos = 0 + matching_plan_bonus =\ + CFG.grammar_search_additional_bonus_for_matching_plan assert len(low_level_trajs) == len(segmented_trajs) for ll_traj, seg_traj in zip(low_level_trajs, segmented_trajs): if seen_demos >= CFG.grammar_search_max_demos: break + # TODO: can just add here that we only look at successful trajs for + # computing the score; for now this is making a stronger assumption + # of demos if not ll_traj.is_demo: continue demo_atoms_sequence = utils.segment_trajectory_to_atoms_sequence( seg_traj) seen_demos += 1 - init_atoms = demo_atoms_sequence[0] goal = self._train_tasks[ll_traj.train_task_idx].goal - # Ground everything once per demo. - objects = set(ll_traj.states[0]) - dummy_nsrts = utils.ops_and_specs_to_dummy_nsrts( - strips_ops, option_specs) - ground_nsrts, reachable_atoms = task_plan_grounding( - init_atoms, - objects, - dummy_nsrts, - allow_noops=CFG.grammar_search_expected_nodes_allow_noops) - heuristic = utils.create_task_planning_heuristic( - CFG.sesame_task_planning_heuristic, init_atoms, goal, - ground_nsrts, candidate_predicates | self._initial_predicates, - objects) + if CFG.grammar_search_expected_nodes_max_skeletons == -1: + max_skeletons = CFG.sesame_max_skeletons_optimized + else: + max_skeletons = CFG.grammar_search_expected_nodes_max_skeletons + assert max_skeletons <= CFG.sesame_max_skeletons_optimized + assert not CFG.sesame_use_visited_state_set + if self._use_processes: + generator = task_plan_with_processes( + self._train_tasks[ll_traj.train_task_idx], + candidate_predicates | self._initial_predicates, + strips_ops, # type: ignore[arg-type] + CFG.seed, + CFG.grammar_search_task_planning_timeout, + max_skeletons_optimized=max_skeletons, + use_visited_state_set=True) + else: + init_atoms = demo_atoms_sequence[0] + # Ground everything once per demo. + objects = set(ll_traj.states[0]) + dummy_nsrts = utils.ops_and_specs_to_dummy_nsrts( + strips_ops, option_specs) + ground_nsrts, reachable_atoms = task_plan_grounding( + init_atoms, + objects, + dummy_nsrts, + allow_waits=CFG.grammar_search_expected_nodes_allow_waits) + heuristic = utils.create_task_planning_heuristic( # type: ignore[type-var] + CFG.sesame_task_planning_heuristic, init_atoms, goal, + ground_nsrts, + candidate_predicates | self._initial_predicates, objects) + generator = task_plan( + init_atoms, # type: ignore[assignment] + goal, + ground_nsrts, # type: ignore[arg-type] + reachable_atoms, + heuristic, + CFG.seed, + CFG.grammar_search_task_planning_timeout, + max_skeletons, + use_visited_state_set=False) # The expected time needed before a low-level plan is found. We # approximate this using node creations and by adding a penalty # for every skeleton after the first to account for backtracking. @@ -344,28 +420,18 @@ def evaluate_with_operators(self, # not been found, updated after each new goal-reaching skeleton is # considered. refinable_skeleton_not_found_prob = 1.0 - if CFG.grammar_search_expected_nodes_max_skeletons == -1: - max_skeletons = CFG.sesame_max_skeletons_optimized - else: - max_skeletons = CFG.grammar_search_expected_nodes_max_skeletons - assert max_skeletons <= CFG.sesame_max_skeletons_optimized - assert not CFG.sesame_use_visited_state_set - generator = task_plan(init_atoms, - goal, - ground_nsrts, - reachable_atoms, - heuristic, - CFG.seed, - CFG.grammar_search_task_planning_timeout, - max_skeletons, - use_visited_state_set=False) try: - for idx, (_, plan_atoms_sequence, + for idx, (plan, plan_atoms_sequence, metrics) in enumerate(generator): assert goal.issubset(plan_atoms_sequence[-1]) # Estimate the probability that this skeleton is refinable. - refinement_prob = self._get_refinement_prob( - demo_atoms_sequence, plan_atoms_sequence) + task_unsolvable = not goal.issubset( + demo_atoms_sequence[-1]) + if CFG.env_has_impossible_goals and task_unsolvable: + refinement_prob = 0.0 + else: + refinement_prob = self._get_refinement_prob( + demo_atoms_sequence, plan_atoms_sequence) # Get the number of nodes that have been created or # expanded so far. assert self.metric_name in metrics @@ -373,20 +439,38 @@ def evaluate_with_operators(self, # This contribution to the expected number of nodes is for # the event that the current skeleton is refinable, but no # previous skeleton has been refinable. - p = refinable_skeleton_not_found_prob * refinement_prob - expected_planning_time += p * num_nodes + terminate_prob = refinable_skeleton_not_found_prob *\ + refinement_prob + expected_planning_time += terminate_prob * num_nodes + if matching_plan_bonus != 0 and \ + (len(plan_atoms_sequence) == len(demo_atoms_sequence) + ) and \ + ([seg.get_option().name for seg in seg_traj] == \ + [g_proc.option.name for g_proc in plan]): + expected_planning_time -= matching_plan_bonus # Apply a penalty to account for the time that we'd spend # in backtracking if the last skeleton was not refinable. if idx > 0: w = CFG.grammar_search_expected_nodes_backtracking_cost - expected_planning_time += p * w + expected_planning_time += terminate_prob * w # Update the probability that no skeleton yet is refinable. refinable_skeleton_not_found_prob *= (1 - refinement_prob) - except (PlanningTimeout, PlanningFailure): + # logging.debug(f"id {idx}: refinement_prob: {refinement_prob}, " + # f"refinable_skeleton_not_found_prob: {refinable_skeleton_not_found_prob}, " + # f"terminate_prob: {terminate_prob},\n" + # f"num_nodes: {num_nodes}, ") + except (PlanningTimeout, PlanningFailure) as e: # Note if we failed to find any skeleton, the next lines add # the upper bound with refinable_skeleton_not_found_prob = 1.0, # so no special action is required. - pass + if CFG.env_has_impossible_goals: + predicated_unsolvable = "not dr-reachable" in str(e) + # check if the last state in the traj satisfies the goal + task_unsolvable = not goal.issubset( + demo_atoms_sequence[-1]) + if predicated_unsolvable and task_unsolvable: + expected_planning_time -= \ + CFG.grammar_search_recognizing_unsolvable_goals_bonus # After exhausting the skeleton budget or timeout, we use this # probability to estimate a "worst-case" planning time, making the # soft assumption that some skeleton will eventually work. @@ -420,7 +504,7 @@ class _HeuristicBasedScoreFunction(_OperatorLearningBasedScoreFunction): Subclasses must choose the heuristic function and how to evaluate against the demonstrations. """ - heuristic_names: Sequence[str] + heuristic_names: Sequence[str] = field(default=("hadd", ), init=False) demos_only: bool = field(default=True) def evaluate_with_operators(self, @@ -713,7 +797,7 @@ def _generate_heuristic( strips_ops, option_specs) ground_nsrts, reachable_atoms = task_plan_grounding( init_atoms, objects, dummy_nsrts) - heuristic = utils.create_task_planning_heuristic( + heuristic = utils.create_task_planning_heuristic( # type: ignore[type-var] CFG.sesame_task_planning_heuristic, init_atoms, goal, ground_nsrts, set(candidate_predicates) | self._initial_predicates, objects) @@ -724,14 +808,15 @@ def _task_planning_h(atoms: Set[GroundAtom]) -> float: return cache[frozenset(atoms)] try: skeleton, atoms_sequence, _ = next( - task_plan(atoms, - goal, - ground_nsrts, - reachable_atoms, - heuristic, - CFG.seed, - CFG.grammar_search_task_planning_timeout, - max_skeletons_optimized=1)) + task_plan( + atoms, + goal, + ground_nsrts, # type: ignore[arg-type] + reachable_atoms, + heuristic, + CFG.seed, + CFG.grammar_search_task_planning_timeout, + max_skeletons_optimized=1)) except (PlanningFailure, PlanningTimeout): return float("inf") assert atoms_sequence[0] == atoms