In [8]:
import numpy as np
import pandas as pd
import gymnasium as gym
from time import time
from bettermdptools.utils.test_env import TestEnv
from bettermdptools.algorithms.rl import RL
from bettermdptools.algorithms.planner import Planner

from grid_search2 import set_seed, check_convergence, check_ql_convergence

# Constants
ENVIRONMENTS = ["FrozenLake8x8-v1", "FrozenLake16x16-v1"]
SLIPPERY_STATES = [True, False]
ALGORITHMS = ["VI", "PI", "QL"]
CV_JOBS = 20
SEED = 666
N_ITERATIONS = 10000


def run_algorithm(env, algo_name, config, n_episodes=100):
    """Run specified algorithm and configuration on the environment"""
    set_seed(SEED)
    start_time = time()

    if algo_name in ["VI", "PI"]:
        planner = Planner(env.P)
        if algo_name == "VI":
            V, V_track, pi = planner.value_iteration(**config)
        else:
            V, V_track, pi = planner.policy_iteration(**config)
        iterations_to_converge = check_convergence(V_track)
    else:  # QL
        agent = RL(env)
        Q, V, pi, Q_track, pi_track = agent.q_learning(**config)
        iterations_to_converge = check_ql_convergence(Q_track)

    runtime = time() - start_time
    episode_rewards = TestEnv.test_env(env=env, n_iters=n_episodes, pi=pi)
    cumulative_score = np.sum(episode_rewards)

    return cumulative_score, runtime, iterations_to_converge


def load_hyperparameters(env_name, algo_name):
    """Load the best configuration from grid search results"""
    df = pd.read_csv(f"results2/{env_name}/{algo_name}_grid_search_results.csv")
    df.sort_values(by="cumulative_score", ascending=False, inplace=True)
    best_config = df.iloc[0].to_dict()


    for col in ["cumulative_score", "runtime", "iterations_to_converge"]:
        best_config.pop(col, None)

    
    n_iters = best_config.get("n_episodes", best_config.get("n_iters", N_ITERATIONS))
    best_config["n_episodes" if "n_episodes" in best_config else "n_iters"] = n_iters
    return best_config


def experiment():
    results = []
    env_configs = ["FrozenLake8x8-v1", "FrozenLake16x16-v1"]
    slippery_states = [True, False]
    algorithms = ["VI", "PI", "QL"]

    for env_name in env_configs:
        for slippery in slippery_states:
            env = gym.make(env_name, is_slippery=slippery)

            for algo_name in algorithms:
                optimized_config = load_hyperparameters(env_name, algo_name)
                n_iters = optimized_config.get(
                    "n_episodes", optimized_config.get("n_iters", N_ITERATIONS)
                )
                baseline_config = {
                    "n_episodes" if algo_name == "QL" else "n_iters": n_iters
                }

                for config_name, config in [
                    ("baseline", baseline_config),
                    ("optimized", optimized_config),
                ]:
                    cumulative_scores, runtimes, convergence_iters = [], [], []

                    for _ in range(CV_JOBS):
                        score, runtime, convergence = run_algorithm(
                            env, algo_name, config
                        )
                        cumulative_scores.append(score)
                        runtimes.append(runtime)
                        convergence_iters.append(convergence)

                    results.append(
                        {
                            "environment": env_name,
                            "is_slippery": slippery,
                            "algorithm": algo_name,
                            "configuration": config_name,
                            "average_score": np.mean(cumulative_scores),
                            "std_dev_score": np.std(cumulative_scores),
                            "average_runtime": np.mean(runtimes),
                            "std_dev_runtime": np.std(runtimes),
                            "average_convergence": np.mean(convergence_iters),
                            "std_dev_convergence": np.std(convergence_iters),
                        }
                    )

    return pd.DataFrame(results)


if __name__ == "__main__":
    # Run experiment and save results
    # results_df = experiment()

    results = []
    env_configs = ["FrozenLake8x8-v1", "FrozenLake16x16-v1"]
    slippery_states = [True, False]
    algorithms = ["VI", "PI", "QL"]

    for env_name in env_configs:
        for slippery in slippery_states:
            env = gym.make(env_name, is_slippery=slippery)

            for algo_name in algorithms:
                optimized_config = load_hyperparameters(env_name, algo_name)
                n_iters = int(
                    optimized_config.get(
                        "n_episodes", optimized_config.get("n_iters", N_ITERATIONS)
                    )
                )
                optimized_config["n_episodes" if algo_name == "QL" else "n_iters"] = (
                    n_iters
                )
                baseline_config = {
                    "n_episodes" if algo_name == "QL" else "n_iters": n_iters
                }

                for config_name, config in [
                    ("baseline", baseline_config),
                    ("optimized", optimized_config),
                ]:
                    cumulative_scores, runtimes, convergence_iters = [], [], []

                    for _ in range(CV_JOBS):
                        score, runtime, convergence = run_algorithm(
                            env, algo_name, config
                        )
                        cumulative_scores.append(score)
                        runtimes.append(runtime)
                        convergence_iters.append(convergence)

                    results.append(
                        {
                            "environment": env_name,
                            "is_slippery": slippery,
                            "algorithm": algo_name,
                            "configuration": config_name,
                            "average_score": np.mean(cumulative_scores),
                            "std_dev_score": np.std(cumulative_scores),
                            "average_runtime": np.mean(runtimes),
                            "std_dev_runtime": np.std(runtimes),
                            "average_convergence": np.mean(convergence_iters),
                            "std_dev_convergence": np.std(convergence_iters),
                        }
                    )

    results_df = pd.DataFrame(results)

    results_df.to_csv("stochasticity_results.csv", index=False)

runtime = 0.40 seconds


  if not isinstance(terminated, (bool, np.bool8)):


runtime = 0.42 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.41 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.40 seconds
runtime = 0.31 seconds
runtime = 0.32 seconds
runtime = 0.32 seconds
runtime = 0.32 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.32 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.33 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.31 seconds
runtime = 0.25 seconds
runtime = 0.25 seconds
runtime = 0.25 seconds
runtime = 0.24 seconds
runtime = 0

                                                      

runtime = 6.25 seconds


                                                      

runtime = 6.12 seconds


                                                      

runtime = 6.27 seconds


                                                      

runtime = 6.25 seconds


                                                      

runtime = 6.18 seconds


                                                      

runtime = 5.03 seconds


                                                      

runtime = 6.37 seconds


                                                      

runtime = 6.27 seconds


                                                      

runtime = 6.24 seconds


                                                      

runtime = 6.23 seconds


                                                      

runtime = 6.18 seconds


                                                      

runtime = 6.20 seconds


                                                      

runtime = 4.88 seconds


                                                      

runtime = 6.14 seconds


                                                      

runtime = 6.28 seconds


                                                      

runtime = 6.19 seconds


                                                      

runtime = 6.22 seconds


                                                      

runtime = 6.23 seconds


                                                      

runtime = 6.19 seconds


                                                      

runtime = 6.20 seconds


                                                      

runtime = 4.72 seconds


                                                      

runtime = 8.24 seconds


                                                      

runtime = 8.10 seconds


                                                      

runtime = 8.25 seconds


                                                      

runtime = 8.16 seconds


                                                      

runtime = 4.79 seconds


                                                      

runtime = 8.26 seconds


                                                      

runtime = 8.16 seconds


                                                      

runtime = 4.73 seconds


                                                      

runtime = 8.32 seconds


                                                      

runtime = 8.20 seconds


                                                      

runtime = 8.20 seconds


                                                      

runtime = 8.37 seconds


                                                      

runtime = 8.58 seconds


                                                      

runtime = 8.32 seconds


                                                      

runtime = 8.26 seconds


                                                      

runtime = 4.75 seconds


                                                      

runtime = 8.70 seconds


                                                      

runtime = 8.65 seconds


  if not isinstance(terminated, (bool, np.bool8)):


runtime = 8.22 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0.00 seconds
runtime = 0

                                                      

runtime = 11.41 seconds


                                                      

runtime = 11.54 seconds


                                                      

runtime = 11.58 seconds


                                                      

runtime = 11.46 seconds


                                                      

runtime = 11.54 seconds


                                                      

runtime = 11.43 seconds


                                                      

runtime = 11.47 seconds


                                                      

runtime = 11.66 seconds


                                                      

runtime = 11.44 seconds


                                                      

runtime = 11.51 seconds


                                                      

runtime = 11.38 seconds


                                                      

runtime = 11.57 seconds


                                                      

runtime = 11.56 seconds


                                                      

runtime = 11.41 seconds


                                                      

runtime = 11.49 seconds


                                                      

runtime = 11.33 seconds


                                                      

runtime = 11.48 seconds


                                                      

runtime = 11.40 seconds


                                                      

runtime = 11.62 seconds


                                                      

runtime = 11.42 seconds


                                                      

runtime = 11.90 seconds


                                                      

runtime = 11.97 seconds


                                                      

runtime = 11.92 seconds


                                                      

runtime = 11.95 seconds


                                                      

runtime = 12.00 seconds


                                                      

runtime = 11.93 seconds


                                                      

runtime = 11.83 seconds


                                                      

runtime = 11.84 seconds


                                                      

runtime = 12.00 seconds


                                                      

runtime = 11.90 seconds


                                                      

runtime = 11.95 seconds


                                                      

runtime = 11.85 seconds


                                                      

runtime = 11.90 seconds


                                                      

runtime = 11.79 seconds


                                                      

runtime = 11.96 seconds


                                                      

runtime = 11.80 seconds


                                                      

runtime = 11.91 seconds


                                                      

runtime = 11.87 seconds


                                                      

runtime = 11.90 seconds


                                                      

runtime = 12.00 seconds
runtime = 3.88 seconds


  if not isinstance(terminated, (bool, np.bool8)):


runtime = 3.82 seconds
runtime = 3.85 seconds
runtime = 3.82 seconds
runtime = 3.92 seconds
runtime = 3.85 seconds
runtime = 3.86 seconds
runtime = 3.85 seconds
runtime = 3.79 seconds
runtime = 3.87 seconds
runtime = 3.85 seconds
runtime = 3.84 seconds
runtime = 3.81 seconds
runtime = 3.84 seconds
runtime = 3.91 seconds
runtime = 3.83 seconds
runtime = 3.87 seconds
runtime = 3.85 seconds
runtime = 3.95 seconds
runtime = 3.84 seconds
runtime = 0.30 seconds
runtime = 0.29 seconds
runtime = 0.30 seconds
runtime = 0.31 seconds
runtime = 0.29 seconds
runtime = 0.29 seconds
runtime = 0.29 seconds
runtime = 0.30 seconds
runtime = 0.29 seconds
runtime = 0.30 seconds
runtime = 0.30 seconds
runtime = 0.30 seconds
runtime = 0.30 seconds
runtime = 0.29 seconds
runtime = 0.29 seconds
runtime = 0.29 seconds
runtime = 0.29 seconds
runtime = 0.29 seconds
runtime = 0.29 seconds
runtime = 0.30 seconds
runtime = 6.13 seconds
runtime = 6.14 seconds
runtime = 6.13 seconds
runtime = 6.13 seconds
runtime = 6

                                                      

runtime = 4.71 seconds


                                                      

runtime = 4.80 seconds


                                                      

runtime = 4.74 seconds


                                                      

runtime = 4.78 seconds


                                                      

runtime = 4.82 seconds


                                                      

runtime = 4.76 seconds


                                                      

runtime = 4.78 seconds


                                                      

runtime = 4.80 seconds


                                                      

runtime = 4.79 seconds


                                                      

runtime = 4.77 seconds


                                                      

runtime = 4.80 seconds


                                                      

runtime = 4.83 seconds


                                                      

runtime = 4.77 seconds


                                                      

runtime = 4.78 seconds


                                                      

runtime = 4.78 seconds


                                                      

runtime = 4.86 seconds


                                                      

runtime = 4.81 seconds


                                                      

runtime = 4.75 seconds


                                                      

runtime = 4.92 seconds


                                                      

runtime = 4.98 seconds


                                                      

runtime = 5.91 seconds


                                                      

runtime = 5.91 seconds


                                                      

runtime = 5.95 seconds


                                                      

runtime = 5.93 seconds


                                                      

runtime = 5.82 seconds


                                                      

runtime = 5.92 seconds


                                                      

runtime = 5.93 seconds


                                                      

runtime = 5.84 seconds


                                                      

runtime = 5.97 seconds


                                                      

runtime = 5.64 seconds


                                                      

runtime = 5.65 seconds


                                                      

runtime = 5.73 seconds


                                                      

runtime = 5.76 seconds


                                                      

runtime = 5.87 seconds


                                                      

runtime = 5.78 seconds


                                                      

runtime = 5.90 seconds


                                                      

runtime = 5.90 seconds


                                                      

runtime = 5.67 seconds


                                                      

runtime = 5.76 seconds


  if not isinstance(terminated, (bool, np.bool8)):


runtime = 5.64 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.02 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.01 seconds
runtime = 0.02 seconds
runtime = 0.02 seconds
runtime = 0

                                                      

runtime = 11.19 seconds


                                                      

runtime = 11.19 seconds


                                                      

runtime = 11.21 seconds


                                                      

runtime = 11.23 seconds


                                                      

runtime = 11.27 seconds


                                                      

runtime = 11.24 seconds


                                                      

runtime = 11.27 seconds


                                                      

runtime = 11.17 seconds


                                                      

runtime = 11.21 seconds


                                                      

runtime = 11.25 seconds


                                                      

runtime = 11.24 seconds


                                                      

runtime = 11.16 seconds


                                                      

runtime = 11.21 seconds


                                                      

runtime = 11.18 seconds


                                                      

runtime = 11.24 seconds


                                                      

runtime = 11.24 seconds


                                                      

runtime = 11.18 seconds


                                                      

runtime = 11.27 seconds


                                                      

runtime = 11.13 seconds


                                                      

runtime = 11.15 seconds


                                                      

runtime = 12.73 seconds


                                                      

runtime = 12.51 seconds


                                                      

runtime = 12.65 seconds


                                                      

runtime = 12.59 seconds


                                                      

runtime = 12.45 seconds


                                                      

runtime = 12.49 seconds


                                                      

runtime = 12.66 seconds


                                                      

runtime = 12.88 seconds


                                                      

runtime = 12.90 seconds


                                                      

runtime = 12.71 seconds


                                                      

runtime = 12.70 seconds


                                                      

runtime = 12.91 seconds


                                                      

runtime = 12.72 seconds


                                                      

runtime = 12.69 seconds


                                                      

runtime = 12.79 seconds


                                                      

runtime = 12.66 seconds


                                                      

runtime = 12.70 seconds


                                                      

runtime = 12.72 seconds


                                                      

runtime = 12.78 seconds


                                                      

runtime = 12.86 seconds


In [10]:
results_df

Unnamed: 0,environment,is_slippery,algorithm,configuration,average_score,std_dev_score,average_runtime,std_dev_runtime,average_convergence,std_dev_convergence
0,FrozenLake8x8-v1,True,VI,baseline,88.4,2.782086,0.400669,0.004733,467.0,0.0
1,FrozenLake8x8-v1,True,VI,optimized,88.25,4.010923,0.313757,0.004203,467.0,0.0
2,FrozenLake8x8-v1,True,PI,baseline,87.4,3.679674,0.245451,0.002376,6.0,0.0
3,FrozenLake8x8-v1,True,PI,optimized,88.35,2.885741,0.118794,0.002254,7.0,0.0
4,FrozenLake8x8-v1,True,QL,baseline,8.45,25.350493,6.0964,0.383775,1.0,0.0
5,FrozenLake8x8-v1,True,QL,optimized,13.75,27.863731,7.599329,1.4338,1.0,0.0
6,FrozenLake8x8-v1,False,VI,baseline,0.0,0.0,0.002418,0.000142,15.0,0.0
7,FrozenLake8x8-v1,False,VI,optimized,0.0,0.0,0.002322,0.000101,15.0,0.0
8,FrozenLake8x8-v1,False,PI,baseline,0.0,0.0,0.003207,0.000102,15.0,0.0
9,FrozenLake8x8-v1,False,PI,optimized,0.0,0.0,0.003195,8.5e-05,15.0,0.0
