# XAI In Action - pgeon

This notebook shows the current functionalities of the **pgeon** library.

In [1]:
import pprint

## Preparation

Loading an environment, an agent and a discretizer; the necessary elements to generate a Policy Graph.

In [2]:
import gymnasium as gym

from example.cartpole.discretizer import CartpoleDiscretizer

In [3]:
environment = gym.make('CartPole-v1')
discretizer = CartpoleDiscretizer()

In [34]:
from pgeon import Agent
from ray.rllib.algorithms.algorithm import Algorithm

class CartpoleAgent(Agent):
    def __init__(self, path):
        self.agent = Algorithm.from_checkpoint(path)

    def act(self, state):
        return self.agent.compute_single_action(state)

In [5]:
agent = CartpoleAgent('checkpoints/PPO_CartPole-v1_1acbb_00000_0_2023-12-05_19-28-36/checkpoint_000000')

`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2023-12-07 12:02:20,633	INFO worker.py:1673 -- Started a local Ray instance.


## Policy Graph generation

In [6]:
from pgeon import PolicyGraph

Policy Graphs are instantiated with an environment and a discretizer.

In [7]:
pg = PolicyGraph(environment, discretizer)

Fitting PG...: 100%|██████████| 200/200 [00:46<00:00,  4.29it/s]


We generate a Policy Graph with the `fit()` function, in this case generating 1000 episode trajectories from our agent. If the PG has been previously fit, one can choose to update the PG with new trajectories (instead of re-generating the PG) with `update=True`.

In [None]:
pg = pg.fit(agent, num_episodes=200, update=False)

In [8]:
print(f'Number of nodes: {len(pg.nodes)}')
print(f'Number of edges: {len(pg.edges)}')

Number of nodes: 14
Number of edges: 130


Each node has information about a discretized state:

In [9]:
arbitrary_state = list(pg.nodes)[0]

print(arbitrary_state)
print(f'  Times visited: {pg.nodes[arbitrary_state]["frequency"]}')
print(f'  p(s):          {pg.nodes[arbitrary_state]["probability"]:.3f}')

(Position(MIDDLE), Velocity(RIGHT), Angle(FALLING_RIGHT))
  Times visited: 7490
  p(s):          0.212


Each edge has information about a transition between states:

In [10]:
arbitrary_edge = list(pg.edges)[0]

print(f'From:    {arbitrary_edge[0]}')
print(f'Action:  {arbitrary_edge[2]}')
print(f'To:      {arbitrary_edge[1]}')
print(f'  Times visited:      {pg[arbitrary_edge[0]][arbitrary_edge[1]][arbitrary_edge[2]]["frequency"]}')
print(f'  p(s_to,a | s_from): {pg[arbitrary_edge[0]][arbitrary_edge[1]][arbitrary_edge[2]]["probability"]:.3f}')

From:    (Position(MIDDLE), Velocity(RIGHT), Angle(FALLING_RIGHT))
Action:  1
To:      (Position(MIDDLE), Velocity(RIGHT), Angle(STAB_LEFT))
  Times visited:      1225
  p(s_to,a | s_from): 0.168


The `PolicyGraph` object also stores the full discretized episode trajectories of the last fit.

In [11]:
len(pg._trajectories_of_last_fit)

200

Each trajectory is stored as a (state0, action0, state1, ..., stateN) tuple .

In [12]:
pg._trajectories_of_last_fit[0]

[(Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(FALLING_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_LEFT)),
 0,
 (Position(MIDDLE), Velocity(LEFT), Angle(STAB_RIGHT)),
 1,
 (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT)),
 0,
 (Position(MIDDLE), Velocity(LEFT),

## Loading and saving Policy Graphs

### Pickle

Saving as pickle lets you restore the full state of the object.

In [13]:
pg.save('pickle', './ppo-cartpole.pickle')

In [14]:
pg_pickle = PolicyGraph.from_pickle('./ppo-cartpole.pickle')

print(f'Number of nodes:             {len(pg_pickle.nodes)}')
print(f'Number of edges:             {len(pg_pickle.edges)}')
print(f'Num. of stored trajectories: {len(pg._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             130
Num. of stored trajectories: 200


### CSV

Saving as CSV creates three separated CSV files for node, edge and trajectory information.

In [15]:
import csv

In [16]:
pg.save('csv', ['./ppo-cartpole_nodes.csv', './ppo-cartpole_edges.csv', './ppo-cartpole_trajectories.csv'])

In [17]:
with open('ppo-cartpole_nodes.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(10):
        print(next(csv_r))

['id', 'value', 'p(s)', 'frequency']
['0', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(FALLING_RIGHT)', '0.21179730799683294', '7490']
['1', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STUCK_RIGHT)', '0.171021377672209', '6048']
['2', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STAB_RIGHT)', '0.0221128831580138', '782']
['3', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STUCK_RIGHT)', '0.0385702974776609', '1364']
['4', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(FALLING_LEFT)', '0.05762922746295668', '2038']
['5', 'Position(MIDDLE)&Velocity(RIGHT)&Angle(STAB_LEFT)', '0.22455039022734985', '7941']
['6', 'Position(MIDDLE)&Velocity(LEFT)&Angle(FALLING_RIGHT)', '0.12249745503902273', '4332']
['7', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STAB_RIGHT)', '0.06749802058590657', '2387']
['8', 'Position(MIDDLE)&Velocity(LEFT)&Angle(STUCK_LEFT)', '0.022056328469630133', '780']


Edges and trajectories use the IDs of the nodes, from the corresponding node CSV file.

In [18]:
with open('ppo-cartpole_edges.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(10):
        print(next(csv_r))

['from', 'to', 'action', 'p(s)', 'frequency']
['0', '5', '1', '0.1678082191780822', '1225']
['0', '6', '0', '0.06698630136986301', '489']
['0', '1', '1', '0.2606849315068493', '1903']
['0', '0', '1', '0.27287671232876715', '1992']
['0', '0', '0', '0.23164383561643836', '1691']
['1', '6', '0', '0.19284886608177454', '1165']
['1', '5', '1', '0.5272305909617613', '3185']
['1', '0', '0', '0.27528554875020694', '1663']
['1', '13', '1', '0.0013242840589306405', '8']


Each trajectory is stored as a series of (state0, action0, state1, ..., stateN) lists

In [19]:
with open('ppo-cartpole_trajectories.csv', 'r+') as f:
    csv_r = csv.reader(f)
    for i in range(1):
        print(next(csv_r))

['8', '0', '7', '1', '8', '0', '7', '1', '10', '0', '7', '1', '10', '0', '7', '0', '7', '1', '7', '0', '7', '1', '7', '0', '7', '1', '8', '0', '7', '1', '3', '0', '6', '1', '3', '0', '6', '1', '3', '1', '9', '0', '6', '0', '6', '1', '6', '0', '6', '1', '6', '0', '6', '1', '6', '0', '6', '1', '6', '1', '3', '1', '5', '0', '3', '0', '6', '1', '3', '1', '5', '0', '3', '1', '5', '1', '5', '0', '5', '0', '6', '1', '5', '1', '5', '0', '5', '1', '5', '0', '5', '0', '6', '1', '1', '0', '6', '1', '1', '0', '6', '1', '1', '1', '5', '0', '1', '0', '6', '1', '1', '0', '6', '1', '1', '1', '5', '0', '1', '0', '6', '1', '1', '1', '5', '0', '1', '1', '5', '1', '5', '0', '5', '0', '1', '1', '5', '0', '0', '1', '5', '0', '0', '1', '5', '0', '0', '1', '5', '0', '0', '1', '5', '0', '0', '0', '6', '1', '0', '1', '5', '0', '0', '1', '5', '0', '0', '1', '5', '0', '0', '1', '5', '0', '0', '0', '6', '1', '0', '1', '1', '0', '0', '1', '1', '1', '5', '0', '1', '1', '5', '0', '1', '1', '5', '0', '1', '1', '5', '0

There are two ways of loading Policy Graphs from CSV files. When loading from nodes and edges, though, episode trajectories cannot be restored.

In [21]:
pg_csv = PolicyGraph.from_nodes_and_trajectories('./ppo-cartpole_nodes.csv', './ppo-cartpole_trajectories.csv',
                                          environment, discretizer)
print(f'Number of nodes:             {len(pg_csv.nodes)}')
print(f'Number of edges:             {len(pg_csv.edges)}')
print(f'Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             130
Num. of stored trajectories: 200


In [20]:
pg_csv = PolicyGraph.from_nodes_and_edges('./ppo-cartpole_nodes.csv', './ppo-cartpole_edges.csv',
                                          environment, discretizer)
print(f'Number of nodes:             {len(pg_csv.nodes)}')
print(f'Number of edges:             {len(pg_csv.edges)}')
print(f'Num. of stored trajectories: {len(pg_csv._trajectories_of_last_fit)}')

Number of nodes:             14
Number of edges:             130
Num. of stored trajectories: 0


### Gram

PGs can also be exported to the [gram](https://neo4j.com/developer-blog/gram-a-data-graph-format/) format, allowing visualization using Neo4j. Episode trajectories cannot be stored in this format, though.

PGs currently cannot be loaded from a Gram file.

In [22]:
pg.save('gram', './ppo-cartpole.gram')

In [23]:
!head ./ppo-cartpole.gram


(s0:State {
  value:"Position(MIDDLE)&Velocity(RIGHT)&Angle(FALLING_RIGHT)"
  probability:0.21179730799683294
  frequency:7490
})
(s1:State {
  value:"Position(MIDDLE)&Velocity(RIGHT)&Angle(STUCK_RIGHT)"
  probability:0.171021377672209
  frequency:6048


In [24]:
!tail ./ppo-cartpole.gram

(s13)-[a1 {probability:0.11049723756906077 frequency:20}]->(s4)
(s13)-[a0 {probability:0.0718232044198895 frequency:13}]->(s4)
(s13)-[a0 {probability:0.04419889502762431 frequency:8}]->(s8)
(s13)-[a0 {probability:0.06629834254143646 frequency:12}]->(s2)
(s13)-[a0 {probability:0.06077348066298342 frequency:11}]->(s6)
(s13)-[a1 {probability:0.03314917127071823 frequency:6}]->(s13)
(s13)-[a0 {probability:0.0055248618784530384 frequency:1}]->(s13)
(s13)-[a0 {probability:0.016574585635359115 frequency:3}]->(s0)
(s13)-[a1 {probability:0.011049723756906077 frequency:2}]->(s0)
(s13)-[a1 {probability:0.04419889502762431 frequency:8}]->(s1)

## Using PG-based policies

Using the `PGBasedPolicy`, we can create policies that replicate an agent's behavior, based on their generated Policy Graph. These policies are subclasses of the `pgeon.Agent` class.

The policy mode (greedy/stochastic) can be specified via the `PGBasedPolicyMode` enum. The behavior when encountering an unknown node (select random action/search nearest node in PG) can be specified via the `PGBasedPolicyNodeNotFoundMode` enum.

In [25]:
from pgeon import PGBasedPolicy, PGBasedPolicyMode, PGBasedPolicyNodeNotFoundMode

In [32]:
policy = PGBasedPolicy(pg, mode=PGBasedPolicyMode.GREEDY,
                       node_not_found_mode=PGBasedPolicyNodeNotFoundMode.RANDOM_UNIFORM)

In [33]:
obs, _ = environment.reset()
action = policy.act(obs)

print(f'Observed state:  {obs}')
print(f'Discretization:  {policy.pg.discretizer.discretize(obs)}')
print(f'Selected action: {action}')

Observed state:  [ 0.04819953 -0.01298102  0.04525431 -0.01622432]
Discretization:  (Position(MIDDLE), Velocity(LEFT), Angle(STUCK_RIGHT))
Selected action: 0


## Implementing new Discretizers

In order to generate Policy Graphs using a certain environment, a Discretizer that transforms the state into a series of predicates has to be implemented by creating a class that inherits from `pgeon.Discretizer` and implements all its abstract methods.

In [38]:
from enum import Enum, auto

from pgeon import Predicate

Firstly, a set of predicates and their values has to be decided. In this case we use three: the cartpole's `Position` (is the cart in the middle, left or right?) and `Velocity` (is the cart moving left or right?), and the state of its pole (`Angle`, meaning the pole is standing, falling to one side, stabilizing...).

Each of the predicates and its possible values are represented as an enum.

In [37]:
class Position(Enum):
    LEFT = auto()
    MIDDLE = auto()
    RIGHT = auto()

class Velocity(Enum):
    LEFT = auto()
    RIGHT = auto()

class Angle(Enum):
    STANDING = auto()
    STUCK_LEFT = auto()
    STUCK_RIGHT = auto()
    FALLING_LEFT = auto()
    FALLING_RIGHT = auto()
    STABILIZING_LEFT = auto()
    STABILIZING_RIGHT = auto()

This is an example of a state as a set of predicates. Note that a predicate accepts an ordered list of values (e.g. `[Position.LEFT, Velocity.RIGHT]`), as some environments benefit from that level of description.

In [41]:
Predicate(Position, [Position.LEFT]), Predicate(Velocity, [Velocity.LEFT]), Predicate(Angle, [Angle.STABILIZING_RIGHT])

(Position(LEFT), Velocity(LEFT), Angle(STABILIZING_RIGHT))

A discretizer class needs to implement the following methods:

- `discretize(self, state)`: Converts an environment's raw observation into a discretized state.
- `state_to_str(self, state) -> str`: Converts a discrete state into a string (used in serialization).
- `str_to_state(self, state: str)`: Converts a string representing a state into said state (used in serialization).
- `nearest_state(self, state)`: A generator function that, given a certain discrete state, yields the nearest discrete states, in order. The distance heuristic is left to the implementer.

This is an example use of these methods:

In [66]:
obs, _ = environment.reset()
discretized_obs = discretizer.discretize(obs)
str_obs = discretizer.state_to_str(discretized_obs)
str_to_state = discretizer.str_to_state(str_obs)

In [68]:
print(f'Observed state:  {obs}')
print(f'Discretization:  {discretized_obs}')
print(f'State to str:    {str_obs}')
print(f'Str to state:    {str_to_state}')

Observed state:  [0.03045334 0.04321589 0.01941581 0.03160628]
Discretization:  (Position(MIDDLE), Velocity(RIGHT), Angle(STUCK_RIGHT))
State to str:    Position(MIDDLE)&Velocity(RIGHT)&Angle(STUCK_RIGHT)
Str to state:    (Position(MIDDLE), Velocity(RIGHT), Angle(STUCK_RIGHT))
