In [None]:
%pip install numpy scikit-learn wlplan

To see how you can use `wlplan` for both training and search, see this [test](../../tests/train_eval_blocks_test.py). This notebook only contains the training part.

In [None]:
import os
import numpy as np
import pymimir
import wlplan
from wlplan.data import Dataset, ProblemStates
from wlplan.feature_generation import WLFeatures
from wlplan.planning import Predicate, parse_domain
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import DotProduct

1. Define the state space for the Blocksworld domain with first-order logic.

In [None]:
name_to_predicate = {
    "on": Predicate("on", 2),
    "on-table": Predicate("on-table", 1),
    "clear": Predicate("clear", 1),
    "holding": Predicate("holding", 1),
    "arm-empty": Predicate("arm-empty", 0),
}
predicates = list(name_to_predicate.values())
wlplan_domain = wlplan.planning.Domain(name="blocksworld", predicates=predicates)
# Alternatively, you can directly parse the domain from a PDDL file
# wlplan_domain = parse_domain("blocksworld/domain.pddl")

2. [The most work] Parse training data in the form of (state, optimal cost to go) pairs using a parser of your choice. Here, I used `mimir` but anything else can do. 

In [None]:
domain_pddl = "blocksworld/domain.pddl"
mimir_domain = pymimir.DomainParser(str(domain_pddl)).parse()

wlplan_data = []
y = []

for f in os.listdir("blocksworld/training_plans"):
    problem_pddl = "blocksworld/training/" + f.replace(".plan", ".pddl")
    plan_file = "blocksworld/training_plans/" + f

    # Parse problem with mimir
    mimir_problem = pymimir.ProblemParser(str(problem_pddl)).parse(mimir_domain)
    mimir_state = mimir_problem.create_state(mimir_problem.initial)

    name_to_schema = {s.name: s for s in mimir_domain.action_schemas}
    name_to_object = {o.name: o for o in mimir_problem.objects}

    # Construct wlplan problem
    positive_goals = []
    for literal in mimir_problem.goal:
        assert not literal.negated
        mimir_atom = literal.atom
        wlplan_atom = wlplan.planning.Atom(
            predicate=name_to_predicate[mimir_atom.predicate.name],
            objects=[o.name for o in mimir_atom.terms],
        )
        positive_goals.append(wlplan_atom)

    wlplan_problem = wlplan.planning.Problem(
        domain=wlplan_domain,
        objects=list(name_to_object.keys()),
        positive_goals=positive_goals,
        negative_goals=[],
    )
    # Alternatively, you can directly parse the domain from a PDDL file
    # wlplan_problem = parse_problem(domain_pddl, problem_pddl)
    
    # Collect actions
    actions = []
    with open(plan_file, "r") as f:
        lines = f.readlines()
        for line in lines:
            if line.startswith(";"):
                continue
            action_name = line.strip()
            action_name = action_name.replace("(", "")
            action_name = action_name.replace(")", "")
            toks = action_name.split(" ")
            schema = toks[0]
            schema = name_to_schema[schema]
            args = toks[1:]
            args = [name_to_object[arg] for arg in args]
            action = pymimir.Action.new(mimir_problem, schema, args)
            actions.append(action)

    # Collect plan trace states
    wlplan_states = []

    def mimir_to_wlplan_state(mimir_state: pymimir.State):
        wlplan_state = []
        for atom in mimir_state.get_atoms():
            wlplan_atom = wlplan.planning.Atom(
                predicate=name_to_predicate[atom.predicate.name],
                objects=[o.name for o in atom.terms],
            )
            wlplan_state.append(wlplan_atom)
        return wlplan_state
    
    h_opt = len(actions)
    wlplan_states.append(mimir_to_wlplan_state(mimir_state))
    y.append(h_opt)
    for action in actions:
        h_opt -= 1
        mimir_state = action.apply(mimir_state)
        wlplan_states.append(mimir_to_wlplan_state(mimir_state))
        y.append(h_opt)

    problem_states = ProblemStates(problem=wlplan_problem, states=wlplan_states)
    wlplan_data.append(problem_states)

# This is what we need to feed into our feature generator below
dataset = Dataset(domain=wlplan_domain, data=wlplan_data)

3. Collect and generate features from the preprocessed data

In [None]:
feature_generator = WLFeatures(domain=wlplan_domain, iterations=4)
feature_generator.collect(dataset)
X = np.array(feature_generator.embed(dataset)).astype(float)
y = np.array(y)
print(f"{X.shape=}")
print(f"{y.shape=}")

Train a Gaussian Process Regression model

In [None]:
linear_kernel = DotProduct(sigma_0=0, sigma_0_bounds="fixed")
model = GaussianProcessRegressor(kernel=linear_kernel, alpha=1e-7, random_state=0)
model.fit(X, y)
y_pred = model.predict(X)
loss = np.mean((y - y_pred) ** 2)
print(f"{loss=}")