Skip to content

Commit

Permalink
Merge pull request #476 from aiplan4eu/causal-graph
Browse files Browse the repository at this point in the history
Refactored causal graph generation in a problem method
  • Loading branch information
alvalentini committed Sep 5, 2023
2 parents f7590a3 + 5936968 commit 96f7940
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 105 deletions.
3 changes: 2 additions & 1 deletion unified_planning/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from unified_planning.model.operators import OperatorKind
from unified_planning.model.parameter import Parameter
from unified_planning.model.abstract_problem import AbstractProblem
from unified_planning.model.problem import Problem
from unified_planning.model.problem import Problem, generate_causal_graph
from unified_planning.model.contingent_problem import ContingentProblem
from unified_planning.model.delta_stn import DeltaSimpleTemporalNetwork
from unified_planning.model.problem_kind import ProblemKind
Expand Down Expand Up @@ -98,6 +98,7 @@
"Parameter",
"AbstractProblem",
"Problem",
"generate_causal_graph",
"ProblemKind",
"State",
"UPState",
Expand Down
141 changes: 140 additions & 1 deletion unified_planning/model/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""This module defines the problem class."""


from numbers import Real
from itertools import chain, product
import unified_planning as up
import unified_planning.model.tamp
from unified_planning.model import Fluent
Expand All @@ -35,7 +35,12 @@
from unified_planning.exceptions import (
UPProblemDefinitionError,
UPTypeError,
UPUsageError,
UPNoSuitableEngineAvailableException,
UPUnsupportedProblemTypeError,
)

import networkx as nx
from fractions import Fraction
from typing import Optional, List, Dict, Set, Tuple, Union, cast, Iterable

Expand Down Expand Up @@ -996,3 +1001,137 @@ def update_problem_kind_action(
self.kind.set_time("CONTINUOUS_TIME")
else:
raise NotImplementedError


def generate_causal_graph(
problem: Problem,
) -> Tuple[
nx.DiGraph,
Dict[
Tuple["up.model.fnode.FNode", "up.model.fnode.FNode"],
Set["up.plans.ActionInstance"],
],
]:
"""
This method generates the causal graph of the given problem. The causal graph is a directed
graph where the nodes are the fluents of the problem (instantiated to objects) and there is
an edge from node A to node B if an action (instantiated to objects) reads/writes A and
writes B. This means that somehow the A and B fluents are related trough that action.
:param problem: The problem to compute the causal graph.
:return: The tuple where the first element is the causal graph and the second element is the
mapping from the pairs of nodes connected in the graph to the set of actions that link
the first node to the second; every element of the set is composed by 2 elements, the
first one is the lifted action, the second one is the tuple of parameters used to ground
the action.
"""
if isinstance(
problem, (up.model.htn.HierarchicalProblem, up.model.ContingentProblem)
):
raise NotImplementedError
assert type(problem) == Problem, "Error not handled."

if not problem.actions:
raise UPUsageError("Can't create the causal graph of a Problem without actions")
to_ground = any(a.parameters for a in problem.actions)
actions_mapping: Dict[
"up.model.action.Action",
Tuple["up.model.action.Action", Tuple["up.model.fnode.FNode", ...]],
] = {}
grounded_problem = problem
if to_ground:
try:
with problem.environment.factory.Compiler(
problem_kind=problem.kind, compilation_kind="GROUNDING"
) as grounder:
res = grounder.compile(problem)
grounded_problem = res.problem
ai_mapping = res.map_back_action_instance
for ga in grounded_problem.actions:
lifted_ai = ai_mapping(ga())
actions_mapping[ga] = (
lifted_ai.action,
lifted_ai.actual_parameters,
)
except UPNoSuitableEngineAvailableException as ex:
raise UPUsageError(
"To plot the causal graph of a problem, the problem must be grounded or a grounder capable of handling the given problem must be installed.\n"
+ str(ex)
)

# Populate 2 maps:
# one from a fluent to all the actions reading that fluent
# one from a fluent to all the actions writing that fluent
fluents_red: Dict["up.model.fnode.FNode", Set["up.model.action.Action"]] = {}
fluents_written: Dict["up.model.fnode.FNode", Set["up.model.action.Action"]] = {}

fve = problem.environment.free_vars_extractor
for action in grounded_problem.actions:
assert not action.parameters
if isinstance(action, up.model.action.InstantaneousAction):
for p in action.preconditions:
for fluent in fve.get(p):
if any(map(fve.get, fluent.args)):
raise UPUnsupportedProblemTypeError(
f"Fluent {fluent} contains other fluents. Causal Graph can't be generated with nested fluents."
)
fluents_red.setdefault(fluent, set()).add(action)
for e in action.effects:
fluent = e.fluent
assert fluent.is_fluent_exp()
assert not any(map(fve.get, fluent.args)), "Error in effect definition"
fluents_written.setdefault(fluent, set()).add(action)
for fluent in chain(fve.get(e.value), fve.get(e.condition)):
if any(map(fve.get, fluent.args)):
raise NotImplementedError(
"nested fluents still are not implemented"
)
fluents_red.setdefault(fluent, set()).add(action)
elif isinstance(action, up.model.action.DurativeAction):
for p in chain(*action.conditions.values()):
for fluent in fve.get(p):
if any(map(fve.get, fluent.args)):
raise UPUnsupportedProblemTypeError(
f"Fluent {fluent} contains other fluents. Causal Graph can't be generated with nested fluents."
)
fluents_red.setdefault(fluent, set()).add(action)
for e in chain(*action.effects.values()):
fluent = e.fluent
assert fluent.is_fluent_exp()
if any(map(fve.get, fluent.args)):
raise NotImplementedError("nested fluents are not implemented")
fluents_written.setdefault(fluent, set()).add(action)
for fluent in chain(fve.get(e.value), fve.get(e.condition)):
if any(map(fve.get, fluent.args)):
raise NotImplementedError("nested fluents are not implemented")
fluents_red.setdefault(fluent, set()).add(action)
else:
raise NotImplementedError
edge_actions_map: Dict[
Tuple["up.model.fnode.FNode", "up.model.fnode.FNode"],
Set["up.plans.ActionInstance"],
] = {}
graph = nx.DiGraph()
all_fluents = set(chain(fluents_red.keys(), fluents_written.keys()))
# Add an edge if a fluent that is red or written and it's in the same action of a written fluent
empty_set: Set["up.model.fnode.FNode"] = set()
for left_node, right_node in product(all_fluents, fluents_written.keys()):
rn_actions = fluents_written.get(right_node, empty_set)
if left_node != right_node and rn_actions:
actions = edge_actions_map.setdefault((left_node, right_node), set())
edge_created = False
for ln_action in chain(
fluents_red.get(left_node, empty_set),
fluents_written.get(left_node, empty_set),
):
assert isinstance(ln_action, up.model.Action)
if ln_action in rn_actions:
if not edge_created:
edge_created = True
graph.add_edge(left_node, right_node)
actions.add(
up.plans.ActionInstance(
*actions_mapping.get(ln_action, (ln_action, tuple()))
)
)
return graph, edge_actions_map
110 changes: 8 additions & 102 deletions unified_planning/plot/causal_graph_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
Problem,
FNode,
Action,
InstantaneousAction,
DurativeAction,
generate_causal_graph,
)
from unified_planning.plot.utils import (
ARROWSIZE,
Expand All @@ -37,6 +36,7 @@
)
from unified_planning.engines import CompilationKind


from itertools import chain, product
import networkx as nx
from typing import (
Expand Down Expand Up @@ -126,107 +126,12 @@ def plot_causal_graph(
if generate_node_label is None:
generate_node_label = str

if not problem.actions:
raise UPUsageError("Can't plot the causal graph of a Problem without actions")
to_ground = any(a.parameters for a in problem.actions)
actions_mapping: Dict[Action, Tuple[Action, Tuple[FNode, ...]]] = {}
grounded_problem = problem
if to_ground:
try:
with problem.environment.factory.Compiler(
problem_kind=problem.kind, compilation_kind=CompilationKind.GROUNDING
) as grounder:
res = grounder.compile(problem)
grounded_problem = res.problem
ai_mapping = res.map_back_action_instance
for ga in grounded_problem.actions:
lifted_ai = ai_mapping(ga())
actions_mapping[ga] = (
lifted_ai.action,
lifted_ai.actual_parameters,
)
except UPNoSuitableEngineAvailableException as ex:
raise UPUsageError(
"To plot the causal graph of a problem, the problem must be grounder or a grounder capable of handling the given problem must be installed.\n"
+ str(ex)
)

# Populate 2 maps:
# one from a fluent to all the actions reading that fluent
# one from a fluent to all the actions writing that fluent
fluents_red: Dict[FNode, Set[Action]] = {}
fluents_written: Dict[FNode, Set[Action]] = {}
graph, edge_actions = generate_causal_graph(problem)
edge_labels_set: Dict[Tuple[FNode, FNode], Set[str]] = {
k: set(edge_label_function(e.action, e.actual_parameters) for e in v)
for k, v in edge_actions.items()
}

fve = problem.environment.free_vars_extractor
for action in grounded_problem.actions:
assert not action.parameters
if isinstance(action, InstantaneousAction):
for p in action.preconditions:
for fluent in fve.get(p):
if any(map(fve.get, fluent.args)):
raise UPUnsupportedProblemTypeError(
f"Fluent {fluent} contains other fluents. Causal can't be plotted with nested fluents."
)
fluents_red.setdefault(fluent, set()).add(action)
for e in action.effects:
fluent = e.fluent
assert fluent.is_fluent_exp()
assert not any(map(fve.get, fluent.args)), "Error in effect definition"
fluents_written.setdefault(fluent, set()).add(action)
for fluent in chain(fve.get(e.value), fve.get(e.condition)):
if any(map(fve.get, fluent.args)):
raise NotImplementedError(
"nested fluents still are not implemented"
)
fluents_red.setdefault(fluent, set()).add(action)
elif isinstance(action, DurativeAction):
for p in chain(*action.conditions.values()):
for fluent in fve.get(p):
if any(map(fve.get, fluent.args)):
raise UPUnsupportedProblemTypeError(
f"Fluent {fluent} contains other fluents. Causal can't be plotted with nested fluents."
)
fluents_red.setdefault(fluent, set()).add(action)
for e in chain(*action.effects.values()):
fluent = e.fluent
assert fluent.is_fluent_exp()
if any(map(fve.get, fluent.args)):
raise NotImplementedError(
"nested fluents still are not implemented"
)
fluents_written.setdefault(fluent, set()).add(action)
for fluent in chain(fve.get(e.value), fve.get(e.condition)):
if any(map(fve.get, fluent.args)):
raise NotImplementedError(
"nested fluents still are not implemented"
)
fluents_red.setdefault(fluent, set()).add(action)
else:
raise NotImplementedError
edge_labels_set: Dict[Tuple[FNode, FNode], Set[str]] = {}
graph = nx.DiGraph()
all_fluents = set(chain(fluents_red.keys(), fluents_written.keys()))
# Add an edge if a fluent that is red or written and it's in the same action of a written fluent
empty_set: Set[FNode] = set()
for left_node, right_node in product(all_fluents, fluents_written.keys()):
rn_actions = fluents_written.get(right_node, empty_set)
if left_node != right_node and rn_actions:
label = edge_labels_set.setdefault((left_node, right_node), set())
edge_created = False
for ln_action in chain(
fluents_red.get(left_node, empty_set),
fluents_written.get(left_node, empty_set),
):
assert isinstance(ln_action, Action)
if ln_action in rn_actions:
if not edge_created:
edge_created = True
graph.add_edge(left_node, right_node)
label.add(
edge_label_function(
*actions_mapping.get(ln_action, (ln_action, tuple()))
)
)
edge_labels: Dict[Tuple[FNode, FNode], str] = {
edge: ", ".join(labels) for edge, labels in edge_labels_set.items() if labels
}
Expand All @@ -243,6 +148,7 @@ def plot_causal_graph(
font_size=font_size,
font_color=font_color,
draw_networkx_kwargs=draw_networkx_kwargs,
prog="dot",
)
nx.draw_networkx_edge_labels(
graph,
Expand Down
3 changes: 2 additions & 1 deletion unified_planning/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def draw_base_graph(
font_size: int = FONT_SIZE,
font_color: str = FONT_COLOR,
draw_networkx_kwargs: Optional[Dict[str, Any]] = None,
prog: str = "dot",
):
import matplotlib.pyplot as plt # type: ignore[import]

Expand Down Expand Up @@ -92,7 +93,7 @@ def length_factor(label_length: int) -> float:
fig = plt.figure(figsize=figsize)
ax = fig.add_subplot()

pos = _generate_positions(graph, prog="dot", top_bottom=top_bottom)
pos = _generate_positions(graph, prog=prog, top_bottom=top_bottom)

nx.draw_networkx(
graph,
Expand Down

0 comments on commit 96f7940

Please sign in to comment.