In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
from dataclasses import asdict
from pathlib import Path
import json

from rl_intro.agent.core import AgentConfig
from rl_intro.agent.agent_expected_sarsa import AgentExpectedSarsa
from rl_intro.agent.agent_sarsa import AgentSarsa
from rl_intro.agent.agent_q_learning import AgentQLearning
from rl_intro.environment.gridworld import GridWorld, GridWorldConfig
from rl_intro.agent.policy import EpsilonGreedyPolicy, EpsilonGreedyConfig
from rl_intro.simulation.experiment import (
    ExperimentBatch,
    ExperimentConfig,
    AgentRecipe,
    EnvironmentRecipe,
)
from rl_intro.evaluation.parse import parse_experiment_json, parse_experiment_batch_json
from rl_intro.evaluation.analyze import analyze_experiment, analyze_experiments
from rl_intro.evaluation.plot import (
    plot_state_visit_frequency,
    plot_final_values,
    plot_cumulative_reward,
    plot_average_reward_per_episode,
)
from rl_intro.utils.logger import logger
from rl_intro.utils.visualize import grid_str


## Single experiment

### Running the experiment

Configuration of the environment, agent and experiment

In [None]:
n_rows, n_cols = 4, 10
n_actions = 4

experiment_config = ExperimentConfig(n_episodes=1000, max_steps=200)

env_config = GridWorldConfig(
    width=n_cols,
    height=n_rows,
    start_states=[0],
    terminal_states=[39],
    cliff_states=[4, 24, 5, 25],
    wall_states=[2, 12, 22, 17, 27, 37],
    random_seed=42,
)

agent_config = AgentConfig(
    n_states=n_cols * n_rows,
    n_actions=n_actions,
    random_seed=42,
    learning_rate=0.3,
    discount=1.0,
)
policy_config = EpsilonGreedyConfig(epsilon=0.1)

logger.setLevel("DEBUG")

Running an experiment with given configuration

In [None]:
agent = AgentExpectedSarsa(agent_config, EpsilonGreedyPolicy(policy_config))
env = GridWorld(env_config)
experiment = Experiment(agent, env, experiment_config)

logger.info(f"Environment: {env.to_str()}")
logger.info(f"Agent: {agent}")

# * running the experiment
experiment_log = experiment.run()

# * log extra information like this
logger.debug(grid_str(agent.get_greedy_actions(), n_cols, n_rows))
logger.debug(grid_str(agent.get_greedy_values(), n_cols, n_rows))

Saving experiment data

In [None]:
log_file = Path("../data") / "single_experiment_logs.json"

In [None]:

with open(log_file, "w") as f:
    json.dump(asdict(experiment_log), f, indent=4)
logger.info(f"Experiment completed and logs saved to '{log_file}'.")

### Analyzing the experiment

In [None]:
experiment_log = parse_experiment_json(log_file) 
experiment_analysis = analyze_experiment(experiment_log, n_rows, n_cols)

In [None]:
plot_cumulative_reward([experiment_analysis], plt.subplots()[1], interval=(0, 5000));

In [None]:
plot_average_reward_per_episode([
    experiment_analysis], plt.subplots()[1], interval=(0, 300)
);

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax = ax.flatten()
plot_final_values(experiment_analysis, ax[0])
plot_state_visit_frequency(experiment_analysis, ax[1]);

## Multi experiment

### Running the experiments

In [None]:
n_rows, n_cols = 4, 10
n_actions = 4

experiment_config = ExperimentConfig(n_episodes=1000, max_steps=200)

env_config = GridWorldConfig(
    width=n_cols,
    height=n_rows,
    start_states=[0],
    terminal_states=[39],
    cliff_states=[4, 24, 5, 25],
    wall_states=[2, 12, 22, 17, 27, 37],
    random_seed=42,
)

agent_config = AgentConfig(
    n_states=n_cols * n_rows,
    n_actions=n_actions,
    random_seed=42,
    learning_rate=0.3,
    discount=1.0,
)
policy_config = EpsilonGreedyConfig(epsilon=0.1)

logger.setLevel("DEBUG")

In [None]:
environment_recipe = EnvironmentRecipe(
    environment_class=GridWorld,
    environment_config=env_config,
)

agent_sarsa_recipe = AgentRecipe(
    agent_class=AgentSarsa,
    agent_config=agent_config,
    policy_class=EpsilonGreedyPolicy,
    policy_config=policy_config,
)

agent_expected_sarsa_recipe = AgentRecipe(
    agent_class=AgentExpectedSarsa,
    agent_config=agent_config,
    policy_class=EpsilonGreedyPolicy,
    policy_config=policy_config,
)

agent_q_learning_recipe = AgentRecipe(
    agent_class=AgentQLearning,
    agent_config=agent_config,
    policy_class=EpsilonGreedyPolicy,
    policy_config=policy_config,
)

In [None]:
experiment_batch = ExperimentBatch(
    agent_recipes=[
        agent_sarsa_recipe,
        agent_expected_sarsa_recipe,
        agent_q_learning_recipe,
    ],
    env_recipes=[environment_recipe],
    experiment_config=experiment_config,
    n_runs=10,
)

# * running the experiment
experiment_logs = experiment_batch.run()

In [None]:
log_file = Path("../data") / "single_experiment_logs.json"

In [None]:
with open(log_file, "w") as f:
    json.dump([asdict(log) for log in experiment_logs], f, indent=4)
logger.info(f"Experiment completed and logs saved to '{log_file}'.")

### Analyzing the experiments

In [None]:
# experiment_logs = parse_experiment_batch_json(log_file)
experiment_results = analyze_experiments(
    experiment_logs, n_rows, n_cols, agent_grouping=True
)

n_results = len(experiment_results)  # seeds for each agent are grouped together
logger.info(f"Grouped analysis into {n_results} result(s)")

In [None]:

plot_cumulative_reward(experiment_results, plt.subplots()[1], interval=(0, 20000));

In [None]:
plot_average_reward_per_episode(
    experiment_results, plt.subplots()[1], interval=(0, 500)
);

In [None]:

for i, experiment_analysis in enumerate(experiment_results):
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax = ax.flatten()
    plot_final_values(experiment_analysis, ax[0])
    plot_state_visit_frequency(experiment_analysis, ax[1]);
    