In [None]:
from behavior_generation_lecture_python.mdp.policy import CategorialPolicy
from behavior_generation_lecture_python.utils.grid_plotting import (
    make_plot_policy_step_function,
)
from behavior_generation_lecture_python.mdp.mdp import (
    GridMDP,
    policy_gradient,
    derive_deterministic_policy,
    GRID_MDP_DICT,
    HIGHWAY_MDP_DICT,
    LC_RIGHT_ACTION,
    STAY_IN_LANE_ACTION,
)

HIGHWAY_MDP_DICT["restrict_actions_to_available_states"] = False

## TOY EXAMPLE

In [None]:
grid_mdp = GridMDP(**GRID_MDP_DICT)

In [None]:
pol = CategorialPolicy(
    sizes=[len(grid_mdp.initial_state), 32, len(grid_mdp.actions)],
    actions=list(grid_mdp.actions),
)

In [None]:
model_checkpoints = policy_gradient(
    mdp=grid_mdp,
    pol=pol,
    iterations=100,
    return_history=True,
)

In [None]:
policy_array = [
    derive_deterministic_policy(mdp=grid_mdp, pol=model) for model in model_checkpoints
]

In [None]:
plot_policy_step_grid_map = make_plot_policy_step_function(
    columns=4, rows=3, policy_over_time=policy_array
)

In [None]:
mkdocs_flag = True
if mkdocs_flag:
    import ipywidgets
    from IPython.display import display

    iteration_slider = ipywidgets.IntSlider(
        min=0, max=len(model_checkpoints) - 1, step=1, value=0
    )
    w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
    display(w)

In [None]:
plot_policy_step_grid_map(100)

## HIGHWAY EXAMPLE

In [None]:
if False:
    # we will change this to true later on, to see the effect
    HIGHWAY_MDP_DICT["transition_probabilities_per_action"][LC_RIGHT_ACTION] = [
        (0.4, LC_RIGHT_ACTION),
        (0.6, STAY_IN_LANE_ACTION),
    ]

In [None]:
highway_mdp = GridMDP(**HIGHWAY_MDP_DICT)

In [None]:
pol = CategorialPolicy(
    sizes=[len(highway_mdp.initial_state), 32, len(highway_mdp.actions)],
    actions=list(highway_mdp.actions),
)

In [None]:
model_checkpoints = policy_gradient(
    mdp=highway_mdp,
    pol=pol,
    iterations=200,
    return_history=True,
)

In [None]:
policy_array = [
    derive_deterministic_policy(mdp=highway_mdp, pol=model)
    for model in model_checkpoints
]

In [None]:
plot_policy_step_grid_map = make_plot_policy_step_function(
    columns=10, rows=4, policy_over_time=policy_array
)

In [None]:
if mkdocs_flag:
    import ipywidgets
    from IPython.display import display

    iteration_slider = ipywidgets.IntSlider(
        min=0, max=len(model_checkpoints) - 1, step=1, value=0
    )
    w = ipywidgets.interactive(plot_policy_step_grid_map, iteration=iteration_slider)
    display(w)

In [None]:
plot_policy_step_grid_map(200)