In [1]:
import os
import matplotlib
import torch
import datetime
import csv

import gymnasium as gym
import gymnasium.wrappers as gym_wrap
import matplotlib.pyplot as plt
import numpy as np
import DQN_model as DQN

from gymnasium.spaces import Box
from tensordict import TensorDict
from torch import nn
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

  register_pytree_node(
  register_pytree_node(
  register_pytree_node(


<contextlib.ExitStack at 0x15a0774d0>

In [2]:
import cv2

class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_space = env.observation_space
        h, w = obs_space.shape[:2]
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(h, w), dtype=np.uint8
        )

    def observation(self, obs):
        return cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)


class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        self.shape = (shape, shape)
        self.observation_space = gym.spaces.Box(
            low=0, high=255, shape=(shape, shape), dtype=np.uint8
        )

    def observation(self, obs):
        return cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA)


class FrameStack(gym.Wrapper):
    def __init__(self, env, num_stack):
        super().__init__(env)
        self.num_stack = num_stack
        self.frames = []
        obs_shape = env.observation_space.shape
        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(num_stack, *obs_shape),
            dtype=np.uint8
        )

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.frames = [obs for _ in range(self.num_stack)]
        return self._get_observation(), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.frames.pop(0)
        self.frames.append(obs)
        return self._get_observation(), reward, terminated, truncated, info

    def _get_observation(self):
        return np.stack(self.frames, axis=0)


## Train a model from scratch

In [3]:
env = gym.make("CarRacing-v3", continuous=False)
env = DQN.SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)
env = FrameStack(env, num_stack=4)

state, info = env.reset()
action_n = env.action_space.n
driver = DQN.Agent(state.shape, action_n, double_q=False)

batch_n = 32
play_n_episodes = 3000

episode_epsilon_list = []
episode_reward_list = []
episode_length_list = []
episode_loss_list = []
episode_date_list = []
episode_time_list = []

episode = 0
timestep_n = 0

when2learn = 4
when2sync = 5000
when2save = 100000
when2report = 5000
when2eval = 50000
when2log = 10
report_type = 'plot'

while episode <= play_n_episodes:
    episode += 1
    episode_reward = 0
    episode_length = 0
    updating = True
    loss_list = []
    episode_epsilon_list.append(driver.epsilon)

    while updating:
        timestep_n += 1
        episode_length += 1

        action = driver.take_action(state)
        new_state, reward, terminated, truncated, info = env.step(action)
        episode_reward += reward
        driver.store(state, action, reward, new_state, terminated)
        state = new_state
        updating = not (terminated or truncated)

        if timestep_n % when2sync == 0:
            upd_net_param = driver.updating_net.state_dict()
            driver.frozen_net.load_state_dict(upd_net_param)

        if timestep_n % when2save == 0:
            save_dir = driver.save_dir
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, f"DQN_{driver.act_taken}.pt")
            torch.save({
                'upd_model_state_dict': driver.updating_net.state_dict(),
                'frz_model_state_dict': driver.frozen_net.state_dict(),
                'optimizer_state_dict': driver.optimizer.state_dict(),
                'action_number': driver.act_taken,
                'epsilon': driver.epsilon
            }, save_path)

        if timestep_n % when2learn == 0:
            q, loss = driver.update_net(batch_n)
            loss_list.append(loss)

        if timestep_n % when2report == 0 and report_type == 'text':
            print(f'Report: {timestep_n} timestep')
            print(f'    episodes: {episode}')
            print(f'    n_updates: {driver.n_updates}')
            print(f'    epsilon: {driver.epsilon}')

        if timestep_n % when2eval == 0 and report_type == 'text':
            rewards_tensor = torch.tensor(episode_reward_list, dtype=torch.float)
            eval_reward = torch.clone(rewards_tensor[-50:])
            mean_eval_reward = round(torch.mean(eval_reward).item(), 2)
            std_eval_reward = round(torch.std(eval_reward).item(), 2)

            lengths_tensor = torch.tensor(episode_length_list, dtype=torch.float)
            eval_length = torch.clone(lengths_tensor[-50:])
            mean_eval_length = round(torch.mean(eval_length).item(), 2)
            std_eval_length = round(torch.std(eval_length).item(), 2)

            print(f'Evaluation: {timestep_n} timestep')
            print(f'    reward {mean_eval_reward}±{std_eval_reward}')
            print(f'    episode length {mean_eval_length}±{std_eval_length}')
            print(f'    episodes: {episode}')
            print(f'    n_updates: {driver.n_updates}')
            print(f'    epsilon: {driver.epsilon}')

    state, info = env.reset()

    episode_reward_list.append(episode_reward)
    episode_length_list.append(episode_length)
    episode_loss_list.append(np.mean(loss_list))
    now_time = datetime.datetime.now()
    episode_date_list.append(now_time.date().strftime('%Y-%m-%d'))
    episode_time_list.append(now_time.time().strftime('%H:%M:%S'))

    if report_type == 'plot':
        draw_check = DQN.plot_reward(episode, episode_reward_list, timestep_n)

    if episode % when2log == 0:
        driver.write_log(
            episode_date_list,
            episode_time_list,
            episode_reward_list,
            episode_length_list,
            episode_loss_list,
            episode_epsilon_list,
            log_filename='DQN_log_test.csv'
        )

if report_type == 'text':
    rewards_tensor = torch.tensor(episode_reward_list, dtype=torch.float)
    eval_reward = torch.clone(rewards_tensor[-100:])
    mean_eval_reward = round(torch.mean(eval_reward).item(), 2)
    std_eval_reward = round(torch.std(eval_reward).item(), 2)

    lengths_tensor = torch.tensor(episode_length_list, dtype=torch.float)
    eval_length = torch.clone(lengths_tensor[-100:])
    mean_eval_length = round(torch.mean(eval_length).item(), 2)
    std_eval_length = round(torch.std(eval_length).item(), 2)

    print(f'Final evaluation: {timestep_n} timestep')
    print(f'    reward {mean_eval_reward}±{std_eval_reward}')
    print(f'    episode length {mean_eval_length}±{std_eval_length}')
    print(f'    episodes: {episode}')
    print(f'    n_updates: {driver.n_updates}')
    print(f'    epsilon: {driver.epsilon}')

# Final save
save_dir = driver.save_dir
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f"DQN_{driver.act_taken}.pt")
torch.save({
    'upd_model_state_dict': driver.updating_net.state_dict(),
    'frz_model_state_dict': driver.frozen_net.state_dict(),
    'optimizer_state_dict': driver.optimizer.state_dict(),
    'action_number': driver.act_taken,
    'epsilon': driver.epsilon
}, save_path)

driver.write_log(
    episode_date_list,
    episode_time_list,
    episode_reward_list,
    episode_length_list,
    episode_loss_list,
    episode_epsilon_list,
    log_filename='DQN_log_test.csv'
)

env.close()
plt.ioff()
plt.show()


<Figure size 640x480 with 0 Axes>