In [6]:
import numpy as np
from stable_baselines3.dqn.policies import DQNPolicy

class CustomDQNPolicy(DQNPolicy):
    def __init__(self, *args, threshold=0.99, threshold_step = 1e-5, default_policy=None, **kwargs):
        super(CustomDQNPolicy, self).__init__(*args, **kwargs)
        self.threshold = threshold
        self.threshold_decrement = threshold_step
        self.default_policy = default_policy

    def forward(self, obs, deterministic=False):
        # Sample a random variable
        random_var = np.random.rand()

        if random_var > self.threshold:
            # Use the custom policy to select an action
            # If no default passed in, select a random action
            if self.default_policy == None:
                action = np.random.randint(0, self.action_space.n)
            else: 
                # select an action using the default policy
                action = self.default_policy.forward(obs, deterministic)
        else:
            # Use the original DQN policy to select an action
            action = super(CustomDQNPolicy, self).forward(obs, deterministic)

        # Decrement the threshold
        self.threshold = max(0.01, self.threshold - self.threshold_decrement)
        return action


In [11]:
import sys 
sys.path.append("/home/tromero_client/LazyMDP")
import gymnasium as gym
from stable_baselines3 import DQN, PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import EvalCallback, StopTrainingOnRewardThreshold
from wandb.integration.sb3 import WandbCallback
from lazywrapper.custom_dqn_policy import CustomDQNPolicy
import wandb

def get_default_policy(env:str = None, file_path:str = None) -> PPO:
    """Get the default policy from a file path

    Args:
        env (str, optional): the environment name. Defaults to None.
        file_path (str, optional): Filepath for special cases --- e.g. if random policy selected. Defaults to None.

    Raises:
        ValueError: _description_

    Returns:
        PPO: _description_
    """
    if env != None:
        file_path = "/home/tromero_client/LazyMDP/baselines/suboptimal_pretrained_policies/ppo_" + env + "_0"
    try:
        policy = PPO.load(file_path)
    except:
        raise ValueError("Could not load policy from file path")
    return policy

policy_kwargs = {
    "threshold" : 0.1,
    "threshold_step" :1e-5,
    "default_policy" : get_default_policy(env="CartPole-v1")
}
    
environment = "CartPole-v1"
# Set up Parallel environments -- vec env for trainig, single for evaluation
vec_env = make_vec_env(environment, n_envs=8)
wandb.init(project="dqn-cartpole-test", config={"environment": environment}, mode = "disabled")
# Set up Callbacks for evaluation
# Stop training when the model reaches the reward threshold
callback_on_best = StopTrainingOnRewardThreshold(reward_threshold=200, verbose=1)
eval_callback = EvalCallback(vec_env, callback_on_new_best=callback_on_best, n_eval_episodes=10, eval_freq=1000, verbose=1)
wandb_callback = WandbCallback(verbose=2)
# Set up model 
model = DQN(policy=CustomDQNPolicy, env=vec_env, verbose=1, policy_kwargs=policy_kwargs)
# run model
model.learn(total_timesteps=1_000_000, callback=[eval_callback, wandb_callback])



Using cuda device
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 14       |
|    ep_rew_mean      | 14       |
|    exploration_rate | 0.999    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 10624    |
|    time_elapsed     | 0        |
|    total_timesteps  | 136      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.517    |
|    n_updates        | 1        |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 16.8     |
|    ep_rew_mean      | 16.8     |
|    exploration_rate | 0.998    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 10570    |
|    time_elapsed     | 0        |
|    total_timesteps  | 224      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.525    |
| 

<stable_baselines3.dqn.dqn.DQN at 0x7f997dff10f0>