In [1]:
import gymnasium as gym
from gymnasium.wrappers import TimeLimit
import mani_skill.envs
from stable_baselines3 import DQN
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.env_checker import check_env
from mani_skill.utils import gym_utils
from mani_skill.utils.wrappers.record import RecordEpisode
from mani_skill.vector.wrappers.sb3 import ManiSkillSB3VectorEnv
import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import imageio

SEED = 17
NUM_ENVS = 8
total_timesteps = 200000

In [2]:
env = gym.make("CartPole-v1", render_mode = "rgb_array")
states, _ = env.reset()
check_env(env, warn = True)
eval_env = gym.make("CartPole-v1", render_mode = "rgb_array")

In [3]:
# function for running the dqn network

# dqn_model = DQN(policy = "MlpPolicy", buffer_size = 10000, tau = 1.0, target_update_interval = 1000, env = env, learning_rate = 1e-3, 
# batch_size = 64, gamma = .99, exploration_fraction = .1, learning_starts = 1000, verbose = 1, exploration_final_eps = .01)
# dqn_model.learn(total_timesteps = total_timesteps, progress_bar = True)
# dqn_model.save("dqn_cartpole")


def run_dqn(policy = "MlpPolicy", env = None, learning_rate = 1e-3, gamma = .99, save_file_name = ""):
    dqn_model = DQN(policy = policy, buffer_size = 10000, tau = 1.0, target_update_interval = 1000, env = env, learning_rate = learning_rate, 
    batch_size = 64, gamma = gamma, exploration_fraction = .1, learning_starts = 1000, verbose = 1, exploration_final_eps = .01)
    dqn_model.learn(total_timesteps = total_timesteps, progress_bar = True)
    dqn_model.save(save_file_name)

In [4]:
def get_total_reward(model_name, eval_env, iterations = 10):

    model = DQN.load(model_name)
    total_rewards_list = []

    for _ in range(iterations):
        obs, _ = eval_env.reset()
        done = False
        total_reward = 0
        truncated = False

        while not done:
            action, _states = model.predict(obs, deterministic = True)
            obs, reward, done, truncated, info = eval_env.step(action)
            total_reward += reward
            
            if truncated:
                break

        total_rewards_list.append(total_reward/500)
        print(f"Iteration: {_}\nTotal Rewards {total_reward}")
    
    return total_rewards_list

In [None]:
# record a video
# print("Recording a video...")
# video_frames = []
# obs, _ = eval_env.reset()
# done = False

# while not done:
#     frame = eval_env.render()
#     video_frames.append(frame)
#     action, _ = dqn_model.predict(obs, deterministic = True)
#     obs, reward, done, truncated, info = eval_env.step(action)
#     if truncated:
#         break

# eval_env.close()

# # save the video
# os.makedirs("videos", exist_ok = True)
# video_path = os.path.join("videos", "dqn_cartpole_video.mp4")
# imageio.mimsave(video_path, video_frames, fps = 30)
# print(f"Video saved to {video_path}")

def record_video(model = None, env = None, video_name = None):
    print("Recording a video...")
    video_frames = []
    obs, _ = env.reset()
    done = False

    while not done:
        frame = eval_env.render()
        video_frames.append(frame)
        action, _ = model.predict(obs, deterministic = True)
        obs, reward, done, truncated, info = env.step(action)
        if truncated:
            break

    env.close()

    # save the video
    os.makedirs("videos", exist_ok = True)
    video_path = os.path.join("videos", f"{video_name}.mp4")
    imageio.mimsave(video_path, video_frames, fps = 30)
    print(f"Video saved to {video_path}")

In [None]:
####################################
# Run the Tests
####################################

iterations = [100000, 200000]
learning_rate = [1e-1, 1e-3, 1e-5]
gamma = [.99]



In [None]:
total_rewards = get_total_reward("dqn_cartpole", eval_env, iterations = 100)
print(total_rewards)

In [None]:
plt.title("DQN Average Reward")
plt.plot(total_rewards, color = "red", label = "alpha = .01")
plt.legend()
plt.xlabel("Episode Steps")
plt.ylabel("Average Reward")
plt.show()

In [None]:


# record a video
print("Recording a video...")
video_frames = []
obs, _ = eval_env.reset()
done = False

while not done:
    frame = eval_env.render()
    video_frames.append(frame)
    action, _ = dqn_model.predict(obs, deterministic = True)
    obs, reward, done, truncated, info = eval_env.step(action)
    if truncated:
        break

eval_env.close()

# save the video
os.makedirs("videos", exist_ok = True)
video_path = os.path.join("videos", "dqn_cartpole_video.mp4")
imageio.mimsave(video_path, video_frames, fps = 30)
print(f"Video saved to {video_path}")