# XAI In Action - pgeon

This notebook shows the current functionalities of the **pgeon** library.

## Preparation

Loading an environment, an agent and a discretizer; the necessary elements to generate a Policy Graph.

In [1]:
import gymnasium as gym

from example.cartpole.discretizer import CartpoleDiscretizer

In [2]:
import torch

In [3]:
environment = gym.make('CartPole-v1')
discretizer = CartpoleDiscretizer()

In [4]:
from pgeon import Agent
from ray.rllib.algorithms.algorithm import Algorithm

class CartpoleAgent(Agent):
    def __init__(self, path):
        self.agent = Algorithm.from_checkpoint(path)

    def act(self, state):
        return self.agent.compute_single_action(state)

In [5]:
agent = CartpoleAgent('checkpoints/PPO_CartPole-v1_1acbb_00000_0_2023-12-05_19-28-36/checkpoint_000000')

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2023-12-16 13:00:09,401	INFO worker.py:1673 -- Started a local Ray instance.


## Policy Graph generation

In [6]:
from pgeon import PolicyGraph

Policy Graphs are instantiated with an environment and a discretizer.

In [7]:
pg = PolicyGraph(environment, discretizer)

We generate a Policy Graph with the `fit()` function, in this case generating 1000 episode trajectories from our agent. If the PG has been previously fit, one can choose to update the PG with new trajectories (instead of re-generating the PG) with `update=True`.

In [8]:
pg = pg.fit(agent, num_episodes=200, update=False)

Fitting PG...: 100%|██████████| 200/200 [00:27<00:00,  7.31it/s]


In [9]:
print(f'Number of nodes: {len(pg.nodes)}')
print(f'Number of edges: {len(pg.edges)}')

Number of nodes: 14
Number of edges: 136


ach node has information about a discretized state:

In [10]:
arbitrary_state = list(pg.nodes)[0]

print(arbitrary_state)
print(f'  Times visited: {pg.nodes[arbitrary_state]["frequency"]}')
print(f'  p(s):          {pg.nodes[arbitrary_state]["probability"]:.3f}')

(Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT))
  Times visited: 184
  p(s):          0.005


Each edge has information about a transition between states:

In [11]:
arbitrary_edge = list(pg.edges)[0]

print(f'From:    {arbitrary_edge[0]}')
print(f'Action:  {arbitrary_edge[2]}')
print(f'To:      {arbitrary_edge[1]}')
print(f'  Times visited:      {pg[arbitrary_edge[0]][arbitrary_edge[1]][arbitrary_edge[2]]["frequency"]}')
print(f'  p(s_to,a | s_from): {pg[arbitrary_edge[0]][arbitrary_edge[1]][arbitrary_edge[2]]["probability"]:.3f}')

From:    (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT))
Action:  1
To:      (Position(MIDDLE), Velocity(RIGHT), Angle(FALLING_LEFT))
  Times visited:      9
  p(s_to,a | s_from): 0.049


The `PolicyGraph` object also stores the full discretized episode trajectories of the last fit.

In [12]:
len(pg._trajectories_of_last_fit)

200

Each trajectory is stored as a (state0, action0, state1, ..., stateN) tuple .

In [13]:
pg._trajectories_of_last_fit[0]

[(Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 1,
 (Position(MIDDLE), Velocity(RIGHT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 1,
 (Position(MIDDLE), Velocity(RIGHT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (P

## Loading and saving Policy Graphs

### Pickle

Saving as pickle lets you restore the full state of the object.

In [14]:
pg.save('pickle', './ppo-cartpole.pickle')

In [15]:
pg_pickle = PolicyGraph.from_pickle('./ppo-cartpole.pickle')

print(f'Number of nodes:             {len(pg_pickle.nodes)}')
print(f'Number of edges:             {len(pg_pickle.edges)}')
print(f'Num. of stored trajectories: {len(pg._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             136
Num. of stored trajectories: 200


### CSV

Saving as CSV creates three separated CSV files for node, edge and trajectory information.

In [16]:
import csv

In [17]:
pg.save('csv', ['./ppo-cartpole_nodes.csv', './ppo-cartpole_edges.csv', './ppo-cartpole_trajectories.csv'])

In [18]:
with open('ppo-cartpole_nodes.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(10):
        print(next(csv_r))

['id', 'value', 'p(s)', 'frequency']
['0', 'Position(MIDDLE)&Velocity(LEFT)&Angle(FALLING_LEFT)', '0.004879342349509414', '184']
['1', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STABILIZING_LEFT)', '0.2239459029435163', '8445']
['2', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STABILIZING_RIGHT)', '0.07340228056218509', '2768']
['3', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(FALLING_LEFT)', '0.06070007955449483', '2289']
['4', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STUCK_LEFT)', '0.024051975603288252', '907']
['5', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STUCK_RIGHT)', '0.03898170246618934', '1470']
['6', 'Position(MIDDLE)&Velocity(LEFT)&Angle(FALLING_RIGHT)', '0.12686290108724477', '4784']
['7', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STUCK_RIGHT)', '0.16727658446035534', '6308']
['8', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STABILIZING_LEFT)', '0.005118005833996287', '193']


Edges and trajectories use the IDs of the nodes, from the corresponding node CSV file.

In [19]:
with open('ppo-cartpole_edges.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(10):
        print(next(csv_r))

['from', 'to', 'action', 'p(s)', 'frequency']
['0', '3', '1', '0.04891304347826087', '9']
['0', '2', '0', '0.6847826086956522', '126']
['0', '4', '0', '0.2554347826086957', '47']
['0', '0', '1', '0.005434782608695652', '1']
['0', '0', '0', '0.005434782608695652', '1']
['1', '5', '0', '0.06052351060049745', '511']
['1', '6', '0', '0.08930474949662442', '754']
['1', '7', '0', '0.31126376880255835', '2628']
['1', '9', '0', '0.17150302025346442', '1448']


Each trajectory is stored as a series of (state0, action0, state1, ..., stateN) lists

In [20]:
with open('ppo-cartpole_trajectories.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(1):
        print(next(csv_r))

['4', '1', '3', '0', '4', '0', '2', '1', '0', '1', '3', '0', '0', '0', '2', '1', '0', '0', '4', '1', '0', '0', '4', '1', '0', '0', '4', '0', '2', '1', '4', '0', '2', '0', '2', '1', '2', '0', '2', '1', '2', '0', '2', '1', '2', '1', '0', '0', '2', '0', '2', '1', '4', '0', '2', '1', '4', '0', '2', '0', '2', '1', '6', '1', '5', '0', '6', '1', '5', '1', '8', '0', '6', '1', '8', '0', '6', '1', '8', '0', '6', '0', '6', '1', '6', '1', '8', '0', '6', '1', '8', '0', '6', '1', '5', '0', '6', '0', '6', '1', '6', '1', '5', '0', '6', '1', '5', '0', '6', '1', '5', '0', '6', '1', '5', '1', '8', '0', '6', '1', '8', '0', '6', '0', '6', '1', '6', '1', '5', '0', '6', '1', '5', '1', '1', '0', '5', '1', '1', '0', '6', '1', '1', '0', '6', '1', '1', '0', '6', '1', '7', '1', '1', '0', '7', '0', '6', '1', '7', '1', '1', '0', '7', '1', '1', '0', '9', '1', '1', '1', '1', '0', '7', '0', '9', '1', '7', '0', '9', '1', '7', '0', '9', '1', '7', '1', '1', '0', '9', '0', '9', '1', '9', '1', '1', '1', '1', '0', '7', '1',

There are two ways of loading Policy Graphs from CSV files. When loading from nodes and edges, though, episode trajectories cannot be restored.

In [21]:
pg_csv = PolicyGraph.from_nodes_and_trajectories('./ppo-cartpole_nodes.csv', './ppo-cartpole_trajectories.csv',
                                          environment, discretizer)
print(f'Number of nodes:             {len(pg_csv.nodes)}')
print(f'Number of edges:             {len(pg_csv.edges)}')
print(f'Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             136
Num. of stored trajectories: 200


In [22]:
pg_csv = PolicyGraph.from_nodes_and_edges('./ppo-cartpole_nodes.csv', './ppo-cartpole_edges.csv',
                                          environment, discretizer)
print(f'Number of nodes:             {len(pg_csv.nodes)}')
print(f'Number of edges:             {len(pg_csv.edges)}')
print(f'Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             136
Num. of stored trajectories: 0


### Gram

PGs can also be exported to the [gram](https://neo4j.com/developer-blog/gram-a-data-graph-format/) format, allowing visualization using Neo4j. Episode trajectories cannot be stored in this format, though.

PGs currently cannot be loaded from a Gram file.

In [23]:
pg.save('gram', './ppo-cartpole.gram')

In [24]:
!head ./ppo-cartpole.gram


CREATE (s0:State {
  uid: "s0",
  value: "Position(MIDDLE)&Velocity(LEFT)&Angle(FALLING_LEFT)",
  probability: 0.004879342349509414, 
  frequency:184
});
CREATE (s1:State {
  uid: "s1",
  value: "Position(MIDDLE)&Velocity(RIGHT)&Angle(STABILIZING_LEFT)",


In [25]:
!tail ./ppo-cartpole.gram

MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s9:State) WHERE s9.uid = "s9" CREATE (s13)-[:a1 {probability:0.05405405405405406, frequency:8}]->(s9);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s6:State) WHERE s6.uid = "s6" CREATE (s13)-[:a0 {probability:0.14864864864864866, frequency:22}]->(s6);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s6:State) WHERE s6.uid = "s6" CREATE (s13)-[:a1 {probability:0.0945945945945946, frequency:14}]->(s6);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s7:State) WHERE s7.uid = "s7" CREATE (s13)-[:a1 {probability:0.20945945945945946, frequency:31}]->(s7);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s4:State) WHERE s4.uid = "s4" CREATE (s13)-[:a0 {probability:0.02027027027027027, frequency:3}]->(s4);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s5:State) WHERE s5.uid = "s5" CREATE (s13)-[:a1 {probability:0.12837837837837837, frequency:19}]->(s5);
MATCH (s13:State) WHERE s13.uid = "s13" MATCH (s13:State) WHERE s13.uid = "s13" CREATE (s13

## Using PG-based policies

Using the `PGBasedPolicy`, we can create policies that replicate an agent's behavior, based on their generated Policy Graph. These policies are subclasses of the `pgeon.Agent` class.

The policy mode (greedy/stochastic) can be specified via the `PGBasedPolicyMode` enum. The behavior when encountering an unknown node (select random action/search nearest node in PG) can be specified via the `PGBasedPolicyNodeNotFoundMode` enum.

In [26]:
from pgeon import PGBasedPolicy, PGBasedPolicyMode, PGBasedPolicyNodeNotFoundMode

In [27]:
policy = PGBasedPolicy(pg, mode=PGBasedPolicyMode.GREEDY,
                       node_not_found_mode=PGBasedPolicyNodeNotFoundMode.RANDOM_UNIFORM)

In [28]:
obs, _ = environment.reset()
action = policy.act(obs)

print(f'Observed state:  {obs}')
print(f'Discretization:  {policy.pg.discretizer.discretize(obs)}')
print(f'Selected action: {action}')

Observed state:  [ 0.00483929  0.04979198 -0.02431643 -0.04302176]
Discretization:  (Position(MIDDLE), Velocity(RIGHT), Angle(STUCK_LEFT))
Selected action: 0


## Implementing new Discretizers

In order to generate Policy Graphs using a certain environment, a Discretizer that transforms the state into a series of predicates has to be implemented by creating a class that inherits from `pgeon.Discretizer` and implements all its abstract methods.

In [29]:
from enum import Enum, auto

from pgeon import Predicate

Firstly, a set of predicates and their values has to be decided. In this case we use three: the cartpole's `Position` (is the cart in the middle, left or right?) and `Velocity` (is the cart moving left or right?), and the state of its pole (`Angle`, meaning the pole is standing, falling to one side, stabilizing...).

Each of the predicates and its possible values are represented as an enum.

In [30]:
class Position(Enum):
    LEFT = auto()
    MIDDLE = auto()
    RIGHT = auto()

class Velocity(Enum):
    LEFT = auto()
    RIGHT = auto()

class Angle(Enum):
    STANDING = auto()
    STUCK_LEFT = auto()
    STUCK_RIGHT = auto()
    FALLING_LEFT = auto()
    FALLING_RIGHT = auto()
    STABILIZING_LEFT = auto()
    STABILIZING_RIGHT = auto()

This is an example of a state as a set of predicates. Note that a predicate accepts an ordered list of values (e.g. `[Position.LEFT, Velocity.RIGHT]`), as some environments benefit from that level of description.

In [31]:
Predicate(Position, [Position.LEFT]), Predicate(Velocity, [Velocity.LEFT]), Predicate(Angle, [Angle.STABILIZING_RIGHT])

(Position(LEFT), Velocity(LEFT), Angle(STABILIZING_RIGHT))

A discretizer class needs to implement the following methods:

- `discretize(self, state)`: Converts an environment's raw observation into a discretized state.
- `state_to_str(self, state) -> str`: Converts a discrete state into a string (used in serialization).
- `str_to_state(self, state: str)`: Converts a string representing a state into said state (used in serialization).
- `nearest_state(self, state)`: A generator function that, given a certain discrete state, yields the nearest discrete states, in order. The distance heuristic is left to the implementer.

This is an example use of these methods:

In [32]:
obs, _ = environment.reset()
discretized_obs = discretizer.discretize(obs)
str_obs = discretizer.state_to_str(discretized_obs)
str_to_state = discretizer.str_to_state(str_obs)

In [33]:
print(f'Observed state:  {obs}')
print(f'Discretization:  {discretized_obs}')
print(f'State to str:    {str_obs}')
print(f'Str to state:    {str_to_state}')

Observed state:  [-0.01459135 -0.01986943  0.04615097  0.04219142]
Discretization:  (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT))
State to str:    Position(MIDDLE)&Velocity(LEFT)&Angle(STUCK_RIGHT)
Str to state:    (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT))


In [34]:
possible_actions = pg.question1(discretized_obs)

print(f'From {discretized_obs}, I will take one of these actions:')
for action, prob in possible_actions:
    print('\t->', action.name, '\tProb:', round(prob * 100, 2), '%')

From (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT)), I will take one of these actions:
	-> RIGHT 	Prob: 50.2 %
	-> LEFT 	Prob: 49.8 %


In [35]:
best_states = pg.question2(0)
print(f'I will perform action {0} in these states:')
print('\n'.join([str(state) for state in best_states]))

I will perform action 0 in these states:
(Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT))
(Position(MIDDLE), Velocity(RIGHT), Angle(STABILIZING_LEFT))
(Position(MIDDLE), Velocity(RIGHT), Angle(FALLING_LEFT))
(Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT))
(Position(MIDDLE), Velocity(RIGHT), Angle(STANDING))
(Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT))
(Position(MIDDLE), Velocity(RIGHT), Angle(STUCK_LEFT))
(Position(MIDDLE), Velocity(LEFT), Angle(STABILIZING_LEFT))


In [None]:
print(f'Supposing I was in the middle, moving right, with the pole standing upright, '
      f'if I did not choose to move left was due to...')
counterfactuals = pg.question3((
    Predicate(Position, [Position(Position.MIDDLE)]),
    Predicate(Velocity, [Velocity(Velocity.RIGHT)]),
    Predicate(Angle, [Angle(Angle.STANDING)])), 0)
for ct in counterfactuals:
    print(f'...{" and ".join([str(i[0]) + " -> " + str(i[1]) for i in ct.values()])}')