# XAI In Action - pgeon

This notebook shows the current functionalities of the **pgeon** library.

## Preparation

Loading an environment, an agent and a discretizer; the necessary elements to generate a Policy Graph.

In [1]:
import gymnasium as gym

from discretizer import CartpoleDiscretizer

In [2]:
environment = gym.make('CartPole-v1')
discretizer = CartpoleDiscretizer()

In [3]:
from simple_agent import SimpleCartpoleAgent

# Use the simple agent instead of loading large checkpoint files

In [4]:
agent = SimpleCartpoleAgent()

## Policy Graph generation

In [5]:
from pgeon import PolicyApproximatorFromBasicObservation, GraphRepresentation, PredicateBasedState

Policy Graphs are instantiated with an environment and a discretizer.

In [6]:
representation = GraphRepresentation()
approximator = PolicyApproximatorFromBasicObservation(discretizer, representation, environment, agent)

We generate a Policy Graph with the `fit()` function, in this case generating 1000 episode trajectories from our agent. If the PG has been previously fit, one can choose to update the PG with new trajectories (instead of re-generating the PG) with `update=True`.

In [7]:
approximator.fit(n_episodes=200)
pg = approximator.policy_representation

Fitting policy approximator...:  16%|█▌        | 31/200 [00:00<00:03, 52.62it/s]

Fitting policy approximator...: 100%|██████████| 200/200 [00:02<00:00, 76.91it/s] 


In [8]:
print(f'Number of states: {len(list(pg.states))}')
print(f'Number of transitions: {len(list(pg.transitions))}')


Number of states: 14
Number of transitions: 58


ach node has information about a discretized state:

In [9]:
arbitrary_state = next(iter(pg.states))

print(arbitrary_state)
print(f'  Times visited: {pg.states[arbitrary_state].metadata.frequency}')
print(f'  p(s):          {pg.states[arbitrary_state].metadata.probability:.3f}')

predicates=frozenset({Angle(STABILIZING_LEFT), Position(MIDDLE), Velocity(RIGHT)})
  Times visited: 9052
  p(s):          0.273


Each edge has information about a transition between states:

In [10]:
arbitrary_transition = next(iter(pg.transitions))
# The transition is already a TransitionData object
from_state = arbitrary_transition.from_state
to_state = arbitrary_transition.to_state
transition_data = arbitrary_transition

print(f'From:    {from_state}')
print(f'Action:  {transition_data.transition.action}')
print(f'To:      {to_state}')
print(f'  Times visited:      {transition_data.transition.frequency}')
print(f'  p(s_to,a | s_from): {transition_data.transition.probability:.3f}')

From:    predicates=frozenset({Angle(STABILIZING_LEFT), Position(MIDDLE), Velocity(RIGHT)})
Action:  0
To:      predicates=frozenset({Angle(STABILIZING_RIGHT), Position(MIDDLE), Velocity(RIGHT)})
  Times visited:      5
  p(s_to,a | s_from): 0.001


The `PolicyGraph` object also stores the full discretized episode trajectories of the last fit.

In [11]:
len(approximator._trajectories_of_last_fit)

200

Each trajectory is stored as a (state0, action0, state1, ..., stateN) tuple .

In [12]:
approximator._trajectories_of_last_fit[0]

[(Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(RIGHT), Angle(STABILIZING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(RIGHT), Angle(S

## Loading and saving Policy Graphs

### Pickle

Saving as pickle lets you restore the full state of the object.

In [13]:
approximator.save("pickle", "./ppo-cartpole.pickle")

In [14]:
approximator_from_pickle = PolicyApproximatorFromBasicObservation.from_pickle("./ppo-cartpole.pickle")

print(f"Number of states:             {len(list(approximator_from_pickle.policy_representation.states))}")
print(f"Number of transitions:             {len(list(approximator_from_pickle.policy_representation.transitions))}")
print(f"Num. of stored trajectories: {len(approximator_from_pickle._trajectories_of_last_fit)}")

Number of states:             14
Number of transitions:             58
Num. of stored trajectories: 200


### CSV

Saving as CSV creates three separated CSV files for node, edge and trajectory information.

In [15]:
import csv

In [16]:
# TODO: fix
# approximator.save("csv", ["./ppo-cartpole_nodes.csv", "./ppo-cartpole_edges.csv"])

In [17]:
# TODO: fix
# with open('ppo-cartpole_nodes.csv', 'r+') as f:
#     csv_r = csv.reader(f)
#     for i in range(10):
#         print(next(csv_r))

Edges and trajectories use the IDs of the nodes, from the corresponding node CSV file.

In [18]:
# with open('ppo-cartpole_edges.csv', 'r+') as f:
#     csv_r = csv.reader(f)
#     for i in range(10):
#         print(next(csv_r))

Each trajectory is stored as a series of (state0, action0, state1, ..., stateN) lists

In [19]:
# with open('ppo-cartpole_trajectories.csv', 'r+') as f:
#     csv_r = csv.reader(f)
#     for i in range(1):
#         print(next(csv_r))

There are two ways of loading Policy Graphs from CSV files. When loading from nodes and edges, though, episode trajectories cannot be restored.

In [20]:
# pg_csv = PolicyApproximatorFromBasicObservation.from_nodes_and_trajectories("./ppo-cartpole_nodes.csv", "./ppo-cartpole_trajectories.csv",
#                                           environment, discretizer)
# print(f"Number of states:             {len(list(pg_csv.policy_representation.states))}")
# print(f"Number of transitions:             {len(list(pg_csv.policy_representation.transitions))}")
# print(f"Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}")

In [21]:
# pg_csv = PolicyApproximatorFromBasicObservation.from_nodes_and_edges("./ppo-cartpole_nodes.csv", "./ppo-cartpole_edges.csv",
#                                           environment, discretizer)
# print(f"Number of states:             {len(list(pg_csv.policy_representation.states))}")
# print(f"Number of transitions:             {len(list(pg_csv.policy_representation.transitions))}")
# print(f"Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}")

### Gram

PGs can also be exported to the [gram](https://neo4j.com/developer-blog/gram-a-data-graph-format/) format, allowing visualization using Neo4j. Episode trajectories cannot be stored in this format, though.

PGs currently cannot be loaded from a Gram file.

In [22]:
approximator.save("gram", "./ppo-cartpole.gram")

In [23]:
!head ./ppo-cartpole.gram


CREATE (s0:State {
  uid: "s0",
  value: "('predicates', frozenset({Angle(STABILIZING_LEFT), Position(MIDDLE), Velocity(RIGHT)}))",
  probability: 0.2730288954575617, 
  frequency:9052
});
CREATE (s1:State {
  uid: "s1",
  value: "('predicates', frozenset({Angle(STUCK_RIGHT), Position(MIDDLE), Velocity(LEFT)}))",


In [24]:
!tail ./ppo-cartpole.gram

MATCH (s11:State) WHERE s11.uid = "s11" MATCH (s2:State) WHERE s2.uid = "s2" CREATE (s11)-[:a1 {probability:0.024390243902439025, frequency:1}]->(s2);
MATCH (s11:State) WHERE s11.uid = "s11" MATCH (s11:State) WHERE s11.uid = "s11" CREATE (s11)-[:a0 {probability:0.024390243902439025, frequency:1}]->(s11);
MATCH (s11:State) WHERE s11.uid = "s11" MATCH (s1:State) WHERE s1.uid = "s1" CREATE (s11)-[:a1 {probability:0.0975609756097561, frequency:4}]->(s1);
MATCH (s11:State) WHERE s11.uid = "s11" MATCH (s0:State) WHERE s0.uid = "s0" CREATE (s11)-[:a1 {probability:0.3170731707317073, frequency:13}]->(s0);
MATCH (s11:State) WHERE s11.uid = "s11" MATCH (s3:State) WHERE s3.uid = "s3" CREATE (s11)-[:a1 {probability:0.17073170731707318, frequency:7}]->(s3);
MATCH (s12:State) WHERE s12.uid = "s12" MATCH (s11:State) WHERE s11.uid = "s11" CREATE (s12)-[:a0 {probability:0.0625, frequency:1}]->(s11);
MATCH (s12:State) WHERE s12.uid = "s12" MATCH (s0:State) WHERE s0.uid = "s0" CREATE (s12)-[:a1 {probabil

## Using PG-based policies

Using the `PGBasedPolicy`, we can create policies that replicate an agent's behavior, based on their generated Policy Graph. These policies are subclasses of the `pgeon.Agent` class.

The policy mode (greedy/stochastic) can be specified via the `PGBasedPolicyMode` enum. The behavior when encountering an unknown node (select random action/search nearest node in PG) can be specified via the `PGBasedPolicyNodeNotFoundMode` enum.

In [25]:
# TODO: fix
# from pgeon import PGBasedPolicy, PGBasedPolicyMode, PGBasedPolicyNodeNotFoundMode

In [26]:
# TODO: fix
# policy = PGBasedPolicy(pg, mode=PGBasedPolicyMode.GREEDY,
#                        node_not_found_mode=PGBasedPolicyNodeNotFoundMode.RANDOM_UNIFORM)

In [27]:
# obs, _ = environment.reset()
# action = policy.act(obs)

# print(f'Observed state:  {obs}')
# print(f'Discretization:  {policy.pg.discretizer.discretize(obs)}')
# print(f'Selected action: {action}')

## Implementing new Discretizers

In order to generate Policy Graphs using a certain environment, a Discretizer that transforms the state into a series of predicates has to be implemented by creating a class that inherits from `pgeon.Discretizer` and implements all its abstract methods.

In [28]:
from enum import Enum, auto

from pgeon import Predicate

Firstly, a set of predicates and their values has to be decided. In this case we use three: the cartpole's `Position` (is the cart in the middle, left or right?) and `Velocity` (is the cart moving left or right?), and the state of its pole (`Angle`, meaning the pole is standing, falling to one side, stabilizing...).

Each of the predicates and its possible values are represented as an enum.

In [29]:
class Position(Enum):
    LEFT = auto()
    MIDDLE = auto()
    RIGHT = auto()

class Velocity(Enum):
    LEFT = auto()
    RIGHT = auto()

class Angle(Enum):
    STANDING = auto()
    STUCK_LEFT = auto()
    STUCK_RIGHT = auto()
    FALLING_LEFT = auto()
    FALLING_RIGHT = auto()
    STABILIZING_LEFT = auto()
    STABILIZING_RIGHT = auto()

This is an example of a state as a set of predicates. Note that a predicate accepts an ordered list of values (e.g. `[Position.LEFT, Velocity.RIGHT]`), as some environments benefit from that level of description.

In [30]:
PredicateBasedState([Predicate(Position.LEFT), Predicate(Velocity.LEFT), Predicate(Angle.STABILIZING_RIGHT)])

PredicateBasedState(predicates=frozenset({Angle(STABILIZING_RIGHT), Velocity(LEFT), Position(LEFT)}))

A discretizer class needs to implement the following methods:

- `discretize(self, state)`: Converts an environment's raw observation into a discretized state.
- `state_to_str(self, state) -> str`: Converts a discrete state into a string (used in serialization).
- `str_to_state(self, state: str)`: Converts a string representing a state into said state (used in serialization).
- `nearest_state(self, state)`: A generator function that, given a certain discrete state, yields the nearest discrete states, in order. The distance heuristic is left to the implementer.

This is an example use of these methods:

In [31]:
obs, _ = environment.reset()
discretized_obs = discretizer.discretize(obs)
str_obs = discretizer.state_to_str(discretized_obs)
str_to_state = discretizer.str_to_state(str_obs)

In [32]:
print(f'Observed state:  {obs}')
print(f'Discretization:  {discretized_obs}')
print(f'State to str:    {str_obs}')
print(f'Str to state:    {str_to_state}')

Observed state:  [ 0.00483929  0.04979198 -0.02431643 -0.04302176]
Discretization:  (Position(MIDDLE), Velocity(RIGHT), Angle(STUCK_LEFT))
State to str:    Position(MIDDLE)&Velocity(RIGHT)&Angle(STUCK_LEFT)
Str to state:    (Position(MIDDLE), Velocity(RIGHT), Angle(STUCK_LEFT))


In [33]:
# Create a PredicateBasedState for the query
query_state = PredicateBasedState([
    Predicate(Position.MIDDLE),
    Predicate(Velocity.RIGHT),
    Predicate(Angle.STANDING),
])

possible_actions = approximator.question1(query_state)

print(f'From {discretized_obs}, I will take one of these actions:')
for action, prob in possible_actions:
    action_name = "LEFT" if action == 0 else "RIGHT"
    print(f"	-> {action_name}	Prob: {round(prob * 100, 2)}%")

From (Position(MIDDLE), Velocity(RIGHT), Angle(STUCK_LEFT)), I will take one of these actions:
	-> LEFT	Prob: 59.52%
	-> LEFT	Prob: 26.18%
	-> LEFT	Prob: 8.42%
	-> LEFT	Prob: 3.78%
	-> LEFT	Prob: 0.52%
	-> LEFT	Prob: 0.45%
	-> LEFT	Prob: 0.41%
	-> LEFT	Prob: 0.27%
	-> LEFT	Prob: 0.18%
	-> LEFT	Prob: 0.13%
	-> LEFT	Prob: 0.09%
	-> LEFT	Prob: 0.06%


In [34]:
best_states = approximator.question2(0)
print(f'I will perform action {0} in these states:')
print('\n'.join([str(state) for state in best_states]))

I will perform action 0 in these states:
predicates=frozenset({Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)})
predicates=frozenset({Angle(STABILIZING_LEFT), Position(RIGHT), Velocity(RIGHT)})
predicates=frozenset({Angle(STABILIZING_LEFT), Position(MIDDLE), Velocity(RIGHT)})
predicates=frozenset({Angle(STABILIZING_RIGHT), Position(MIDDLE), Velocity(LEFT)})
predicates=frozenset({Angle(STABILIZING_RIGHT), Position(MIDDLE), Velocity(RIGHT)})
predicates=frozenset({Position(MIDDLE), Velocity(RIGHT), Angle(STUCK_LEFT)})
predicates=frozenset({Angle(STUCK_RIGHT), Position(RIGHT), Velocity(RIGHT)})
predicates=frozenset({Position(MIDDLE), Angle(STANDING), Velocity(RIGHT)})


In [35]:
print(f'Supposing I was in the middle, moving right, with the pole standing upright, '
      f'if I did not choose to move left was due to...')
try:
    counterfactuals = approximator.question3(query_state, 0)
    for ct in counterfactuals:
        print(f"...{" and ".join([str(i[0]) + " -> " + str(i[1]) for i in ct.values()])}")
except ValueError as e:
    print(f"Could not generate counterfactuals: {e}")
    print("This might be because the query state is not well-represented in the policy graph.")

Supposing I was in the middle, moving right, with the pole standing upright, if I did not choose to move left was due to...
Could not generate counterfactuals: 'a' cannot be empty unless no samples are taken
This might be because the query state is not well-represented in the policy graph.
