In [1]:
import gymnasium as gym
from gymnasium.wrappers import RecordVideo
import matplotlib.pyplot as plt
import numpy as np
import pickle

from scripts.ddqn_agent import DDQNAgent
from scripts.training import Trainer, trainingInspector, test_agent, plot_test_results, compute_decay

## Hyperparameters

In [2]:
def episode_trigger(x):
    if x % 1000 == 0:
        return True
    return False

def process_hyperparameters_ddqn(hyperparameters):
    hyperparameters.update({
        "decay_type": compute_decay(
            hyperparameters["eps_start"],
            hyperparameters["eps_end"],
            hyperparameters["frac_episodes_to_decay"],
            hyperparameters["num_episodes"],
            hyperparameters["decay_type"]
        )
    })

    hyperparameters.pop("frac_episodes_to_decay", None)
    return hyperparameters
    
    

In [3]:
# Top 3 hyperparameter configurations
ddqn_type1_hyperparameter_list = [

    process_hyperparameters_ddqn(_) for _ in  [
    # Best performing hyperparameters
    {
        "num_episodes": 10000,
        "max_return": 500,
        "BUFFER_SIZE": int(1e5),
        "BATCH_SIZE": 64,
        "UPDATE_EVERY": 20,
        "LR": 0.1,
        "eps_start": 1,
        "eps_end": 0.01,
        "decay_type": "linear",
        "frac_episodes_to_decay": 0.5
    }

]]

## Running Experiments

In [None]:
env = gym.make('CartPole-v1', render_mode="rgb_array")
env = RecordVideo(
    env,
    video_folder="backups/cartpole-ddqn-type1-visualizations",
    name_prefix="eval",
    episode_trigger=episode_trigger
)

ddqn_type1_agent = DDQNAgent(
    state_space=env.observation_space,
    action_space=env.action_space,
    network_type=2,
    seed=0
)

trainer = Trainer()
ddqn_type1_results = test_agent(env, ddqn_type1_agent, trainer, ddqn_type1_hyperparameter_list, num_experiments=1)

env.close()

  logger.warn(
Training:  25%|█████████████████████████████▋                                                                                           | 2453/10000 [03:02<08:54, 14.11it/s, Mean Score=21]

In [None]:
combined_results = ddqn_type1_results

# with open("backups/cartpole-plots/cartpole_ddqn_type1_results.pickle", 'wb') as handle:
#     pickle.dump(ddqn_type1_results, handle, protocol=pickle.HIGHEST_PROTOCOL)

plot_test_results(combined_results, [0])