In [1]:
import gym
from stable_baselines3 import DQN
import torch
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from stable_baselines3.common.evaluation import evaluate_policy

# Create the environment
env = gym.make("LunarLander-v2")

# Define a custom feature extractor for the policy and target models
class CustomFeatureExtractor(BaseFeaturesExtractor):
    def forward(self, observations):
        # Assuming observations are already normalized
        return observations

# Create two DQN models, one for the policy network and another for the target network
policy_model = DQN("MlpPolicy", env, verbose=1, policy_kwargs={"features_extractor_class": CustomFeatureExtractor})
target_model = DQN("MlpPolicy", env, verbose=1, policy_kwargs={"features_extractor_class": CustomFeatureExtractor})

# Define a custom loss function for the critic network (target network)
def double_q_loss(policy_q, target_q1, target_q2):
    return torch.mean((policy_q - torch.min(target_q1, target_q2)) ** 2)

# Set the loss function for the critic network in the policy model
policy_model.policy.net.critic.loss_fn = double_q_loss

# Train the DDQN model
total_timesteps = 10000  # Adjust this as needed
target_update_interval = 100  # Adjust this as needed

for t in range(total_timesteps):
    action, _ = policy_model.predict(obs, deterministic=True)
    new_obs, reward, done, info = env.step(action)

    # Update the policy network
    policy_model.learn(total_timesteps=1)

    if t % target_update_interval == 0:
        # Update the target network
        target_model.load_state_dict(policy_model.state_dict())

    obs = new_obs

    if done:
        obs = env.reset()

# Save and load the DDQN model
policy_model.save("DDQN-LunarLander-v2")
del policy_model

policy_model = DQN.load("DDQN-LunarLander-v2")

# Evaluate the policy
mean_reward, std_reward = evaluate_policy(policy_model, env, n_eval_episodes=10)
print(f"Mean Reward: {mean_reward:.2f} +/- {std_reward:.2f}")


Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.




AssertionError: 