## MDP visual representations

Colosseum provides two visual representations for the MDPs in the form of graphs.
In the *MDP representation*, states are depicted by circular nodes, actions by square nodes, and transition probabilities by the thickness of the edges connecting action nodes to state nodes.
Although quite complex, this visualization gives a complete picture of the MDP structure and allows representing the state-action value function.
The *Markov chain representation* is closer to a Markov chain graph.
A directed edge connects state $s$ to state $s'$ if there is a non-zero probability of transitioning from $s$ to $s'$.
Through this simpler representation, it is possible to obtain a solid intuition about the MDP structure and inspect the state value function.

In [None]:
import matplotlib.cm
import matplotlib.pyplot as plt

from colosseum.mdps.deep_sea import DeepSeaContinuous
from colosseum.mdps.frozen_lake import FrozenLakeContinuous
from colosseum.mdps.minigrid_empty import MiniGridEmptyContinuous
from colosseum.mdps.minigrid_rooms import MiniGridRoomsContinuous
from colosseum.mdps.visualization import plot_MDP_graph, plot_MCGraph

### $\texttt{DeepSea}$ and $\texttt{FrozenLake}$ MDP plots.

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 8))

mdp = DeepSeaContinuous(seed=0, size=5, random_action_p=0.3)
plot_MDP_graph(mdp, prog="dot", ncol=1, fontsize=16, ax=ax1)

mdp = FrozenLakeContinuous(seed=0, size=4, p_frozen=0.8)
plot_MDP_graph(mdp, prog="dot", ncol=3, fontsize=16, ax=ax2)

plt.tight_layout()
plt.show()

### $\texttt{MiniGridEmpty}$ MDP plot

In [None]:
mdp = MiniGridEmptyContinuous(seed=0, size=4)
plot_MDP_graph(mdp, ncol=3, figsize=(12, 12), fontsize=18)

### $\texttt{MiniGridRooms}$ MDP plot

In [None]:
mdp = MiniGridRoomsContinuous(seed=0, room_size=3, n_rooms=9)
plot_MDP_graph(mdp, ncol=3, figsize=(20, 20), fontsize=19)

### $\texttt{DeepSea}$ and $\texttt{FrozenLake}$ Markov chain plots

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 8))

mdp = DeepSeaContinuous(seed=0, size=15, random_action_p=0.3)
plot_MCGraph(mdp, fontsize=12, ax=ax1)

mdp = FrozenLakeContinuous(seed=0, size=15, p_frozen=0.8)
plot_MCGraph(mdp, fontsize=12, ax=ax2)

plt.tight_layout()
plt.show()

### $\texttt{MiniGridEmpty}$ and $\texttt{MiniGridRooms}$ Markov chain plots

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 8))

mdp = MiniGridEmptyContinuous(seed=0, size=7)
plot_MCGraph(mdp, fontsize=12, ax=ax1)

mdp = MiniGridRoomsContinuous(seed=0, room_size=3, n_rooms=9)
plot_MCGraph(mdp, fontsize=12, ax=ax2)

plt.tight_layout()
plt.show()

### $\texttt{DeepSea}$ visitation counts labels.

In [None]:
mdp = DeepSeaContinuous(seed=0, size=5, random_action_p=0.3)
mdp.reset()
for _ in range(800):
    mdp.random_step()
node_labels = mdp.get_visitation_counts()
action_labels = mdp.get_visitation_counts(False)
plot_MDP_graph(
    mdp,
    prog="dot",
    ncol=1,
    figsize=(7, 7),
    fontsize=12,
    node_labels=node_labels,
    action_labels=action_labels,
    int_labels_offset_x=0,
    int_labels_offset_y=0,
    font_color_state_actions_labels="white",
    no_written_state_action_labels=False,
    no_written_state_labels=False,
)

### $\texttt{FrozenLake}$ visitation counts labels.

In [None]:
mdp = FrozenLakeContinuous(seed=0, size=4, p_frozen=0.8)
mdp.reset()
for _ in range(500):
    mdp.random_step()
node_labels = mdp.get_visitation_counts()
action_labels = mdp.get_visitation_counts(False)
plot_MDP_graph(
    mdp,
    prog="dot",
    ncol=3,
    figsize=(6, 6),
    node_labels=node_labels,
    action_labels=action_labels,
    int_labels_offset_x=0,
    int_labels_offset_y=0,
    font_color_state_actions_labels="white",
    no_written_state_action_labels=False,
    no_written_state_labels=False,
)

### $\texttt{DeepSea}$ visitation counts heatmap.

In [None]:
mdp = DeepSeaContinuous(seed=0, size=5, random_action_p=0.3)
mdp.reset()
for _ in range(5000):
    mdp.random_step()
node_labels = mdp.get_visitation_counts()
action_labels = mdp.get_visitation_counts(False)
plot_MDP_graph(
    mdp,
    prog="dot",
    ncol=1,
    figsize=(7, 7),
    fontsize=12,
    node_labels=node_labels,
    action_labels=action_labels,
    cm_state_labels=matplotlib.cm.get_cmap("Blues"),
    cm_state_actions_labels=matplotlib.cm.get_cmap("Greens"),
)

### $\texttt{FrozenLake}$ visitation counts heatmap.

In [None]:
mdp = FrozenLakeContinuous(seed=0, size=4, p_frozen=0.8)
mdp.reset()
for _ in range(5000):
    mdp.random_step()
node_labels = mdp.get_visitation_counts()
action_labels = mdp.get_visitation_counts(False)
plot_MDP_graph(
    mdp,
    prog="dot",
    ncol=3,
    figsize=(6, 6),
    node_labels=node_labels,
    action_labels=action_labels,
    cm_state_labels=matplotlib.cm.get_cmap("Blues"),
    cm_state_actions_labels=matplotlib.cm.get_cmap("Greens"),
)

### $\texttt{MiniGridEmpty}$ visitation counts labels.

In [None]:
mdp = MiniGridEmptyContinuous(seed=0, size=7)
mdp.reset()
for _ in range(2000):
    mdp.random_step()
node_labels = mdp.get_visitation_counts()
plot_MCGraph(
    mdp,
    figsize=(8, 8),
    fontsize=14,
    labels=node_labels,
    node_size=250,
    font_color_labels="white",
)

### $\texttt{MiniGridRooms}$ visitation counts labels.

In [None]:
mdp = MiniGridRoomsContinuous(seed=0, room_size=3, n_rooms=9)
mdp.reset()
for _ in range(2000):
    mdp.random_step()
node_labels = mdp.get_visitation_counts()
plot_MCGraph(
    mdp,
    figsize=(9, 9),
    fontsize=15,
    labels=node_labels,
    node_size=250,
    font_color_labels="white",
)

### $\texttt{MiniGridEmpty}$ visitation counts heatmap.

In [None]:
mdp = MiniGridEmptyContinuous(seed=0, size=7)
mdp.reset()
for _ in range(10_000):
    mdp.random_step()
node_labels = mdp.get_visitation_counts()
plot_MCGraph(
    mdp,
    figsize=(8, 8),
    fontsize=14,
    labels=node_labels,
    node_size=250,
    cm_state_labels=matplotlib.cm.get_cmap("Blues"),
)

### $\texttt{MiniGridRooms}$ visitation counts heatmap.

In [None]:
mdp = MiniGridRoomsContinuous(seed=0, room_size=3, n_rooms=9)
mdp.reset()
for _ in range(10_000):
    mdp.random_step()
node_labels = mdp.get_visitation_counts()
plot_MCGraph(
    mdp,
    figsize=(9, 9),
    fontsize=15,
    labels=node_labels,
    node_size=250,
    cm_state_labels=matplotlib.cm.get_cmap("Blues"),
)