In [13]:
# ====================================
# 1) IMPORTING LIBRARIES
# ====================================
import os
import matplotlib
import torch
import datetime
import csv
import cv2

import gymnasium as gym
import gymnasium.wrappers as gym_wrap
import matplotlib.pyplot as plt
import numpy as np

from gymnasium.spaces import Box
from tensordict import TensorDict
from torch import nn
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

<contextlib.ExitStack at 0x156067320>

In [4]:
# ====================================
# 2) GYMNASIUM WRAPPER FUNCTIONS
# ====================================

def create_skip_frame_wrapper(env, skip=4):
    """Skip frames wrapper to reduce computational load"""
    class SkipFrame(gym.Wrapper):
        def __init__(self, env, skip):
            super().__init__(env)
            self._skip = skip

        def step(self, action):
            total_reward = 0.0
            for _ in range(self._skip):
                state, reward, terminated, truncated, info = self.env.step(action)
                total_reward += reward
                if terminated:
                    break
            return state, total_reward, terminated, truncated, info
    
    return SkipFrame(env, skip)

def create_grayscale_wrapper(env):
    """Convert RGB observations to grayscale"""
    class GrayScaleObservation(gym.ObservationWrapper):
        def __init__(self, env):
            super().__init__(env)
            obs_space = env.observation_space
            h, w = obs_space.shape[:2]
            self.observation_space = gym.spaces.Box(
                low=0, high=255, shape=(h, w), dtype=np.uint8
            )

        def observation(self, obs):
            return cv2.cvtColor(obs, cv2.COLOR_RGB2GRAY)
    
    return GrayScaleObservation(env)

def create_resize_wrapper(env, shape=84):
    """Resize observations to specified shape"""
    class ResizeObservation(gym.ObservationWrapper):
        def __init__(self, env, shape):
            super().__init__(env)
            self.shape = (shape, shape)
            self.observation_space = gym.spaces.Box(
                low=0, high=255, shape=(shape, shape), dtype=np.uint8
            )

        def observation(self, obs):
            return cv2.resize(obs, self.shape, interpolation=cv2.INTER_AREA)
    
    return ResizeObservation(env, shape)

def create_frame_stack_wrapper(env, num_stack=4):
    """Stack multiple frames to provide temporal information"""
    class FrameStack(gym.Wrapper):
        def __init__(self, env, num_stack):
            super().__init__(env)
            self.num_stack = num_stack
            self.frames = []
            obs_shape = env.observation_space.shape
            self.observation_space = gym.spaces.Box(
                low=0,
                high=255,
                shape=(num_stack, *obs_shape),
                dtype=np.uint8
            )

        def reset(self, **kwargs):
            obs, info = self.env.reset(**kwargs)
            self.frames = [obs for _ in range(self.num_stack)]
            return self._get_observation(), info

        def step(self, action):
            obs, reward, terminated, truncated, info = self.env.step(action)
            self.frames.pop(0)
            self.frames.append(obs)
            return self._get_observation(), reward, terminated, truncated, info

        def _get_observation(self):
            return np.stack(self.frames, axis=0)
    
    return FrameStack(env, num_stack)


In [5]:
# ====================================
# 3) DQN NEURAL NETWORK FUNCTIONS
# ====================================

def create_dqn_network(in_dim, out_dim):
    """Create DQN neural network"""
    channel_n, height, width = in_dim
    
    if height != 84 or width != 84:
        raise ValueError(f"DQN model requires input of a (84, 84)-shape. Input of a ({height, width})-shape was passed.")
    
    net = nn.Sequential(
        nn.Conv2d(in_channels=channel_n, out_channels=16,
                  kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(in_channels=16, out_channels=32,
                  kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(2592, 256),
        nn.ReLU(),
        nn.Linear(256, out_dim),
    )
    return net

In [6]:
# ====================================
# 4) AGENT STATE AND BUFFER FUNCTIONS
# ====================================

def initialize_agent(state_space_shape, action_n, gamma=0.95, epsilon=1, 
                    epsilon_decay=0.9999925, epsilon_min=0.05, double_q=False):
    """Initialize agent parameters and networks"""
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    agent_state = {
        'gamma': gamma,
        'epsilon': epsilon,
        'epsilon_decay': epsilon_decay,
        'epsilon_min': epsilon_min,
        'state_shape': state_space_shape,
        'action_n': action_n,
        'double_q': double_q,
        'save_dir': './training/saved_models/',
        'log_dir': './training/logs/',
        'device': device,
        'act_taken': 0,
        'n_updates': 0
    }
    
    # Create networks
    updating_net = create_dqn_network(state_space_shape, action_n).float().to(device)
    frozen_net = create_dqn_network(state_space_shape, action_n).float().to(device)
    
    # Create optimizer and loss function
    optimizer = torch.optim.Adam(updating_net.parameters(), lr=0.0002)
    loss_fn = torch.nn.SmoothL1Loss()
    
    # Create replay buffer
    buffer = TensorDictReplayBuffer(
        storage=LazyMemmapStorage(300000, device=torch.device("cpu"))
    )
    
    return agent_state, updating_net, frozen_net, optimizer, loss_fn, buffer

def store_experience(buffer, state, action, reward, new_state, terminated):
    """Store experience in replay buffer"""
    buffer.add(TensorDict({
        "state": torch.tensor(state),
        "action": torch.tensor(action),
        "reward": torch.tensor(reward),
        "new_state": torch.tensor(new_state),
        "terminated": torch.tensor(terminated)
    }, batch_size=[]))

def get_batch_samples(buffer, batch_size, device):
    """Sample batch from replay buffer"""
    batch = buffer.sample(batch_size)
    states = batch.get('state').type(torch.FloatTensor).to(device)
    new_states = batch.get('new_state').type(torch.FloatTensor).to(device)
    actions = batch.get('action').squeeze().to(device)
    rewards = batch.get('reward').squeeze().to(device)
    terminateds = batch.get('terminated').squeeze().to(device)
    return states, actions, rewards, new_states, terminateds


In [7]:
# ====================================
# 5) ACTION SELECTION AND LEARNING FUNCTIONS
# ====================================

def take_action(agent_state, updating_net, state):
    """Select action using epsilon-greedy policy"""
    if np.random.rand() < agent_state['epsilon']:
        action_idx = np.random.randint(agent_state['action_n'])
    else:
        state_tensor = torch.tensor(
            state, dtype=torch.float32, device=agent_state['device']
        ).unsqueeze(0)
        action_values = updating_net(state_tensor)
        action_idx = torch.argmax(action_values, axis=1).item()
    
    # Update epsilon
    if agent_state['epsilon'] > agent_state['epsilon_min']:
        agent_state['epsilon'] *= agent_state['epsilon_decay']
    else:
        agent_state['epsilon'] = agent_state['epsilon_min']
    
    agent_state['act_taken'] += 1
    return action_idx

def update_network(agent_state, updating_net, frozen_net, optimizer, loss_fn, buffer, batch_size):
    """Update the neural network using experience replay"""
    agent_state['n_updates'] += 1
    
    states, actions, rewards, new_states, terminateds = get_batch_samples(
        buffer, batch_size, agent_state['device']
    )
    
    action_values = updating_net(states)
    td_est = action_values[np.arange(batch_size), actions]
    
    if agent_state['double_q']:
        with torch.no_grad():
            next_actions = torch.argmax(updating_net(new_states), axis=1)
            tar_action_values = frozen_net(new_states)
        td_tar = rewards + (1 - terminateds.float()) * agent_state['gamma'] * \
                tar_action_values[np.arange(batch_size), next_actions]
    else:
        with torch.no_grad():
            tar_action_values = frozen_net(new_states)
        td_tar = rewards + (1 - terminateds.float()) * agent_state['gamma'] * \
                tar_action_values.max(1)[0]
    
    loss = loss_fn(td_est, td_tar)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return td_est, loss.item()

In [8]:
# ====================================
# 6) SAVE/LOAD AND LOGGING FUNCTIONS
# ====================================

def save_model(agent_state, updating_net, frozen_net, optimizer, save_name=None):
    """Save model checkpoint"""
    save_dir = agent_state['save_dir']
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    if save_name is None:
        save_name = f"DQN_{agent_state['act_taken']}"
    
    save_path = os.path.join(save_dir, f"{save_name}.pt")
    torch.save({
        'upd_model_state_dict': updating_net.state_dict(),
        'frz_model_state_dict': frozen_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'action_number': agent_state['act_taken'],
        'epsilon': agent_state['epsilon']
    }, save_path)
    print(f"Model saved to {save_path} at step {agent_state['act_taken']}")

def write_log(agent_state, date_list, time_list, reward_list, length_list, 
              loss_list, epsilon_list, log_filename='training_log.csv'):
    """Write training log to CSV file"""
    log_dir = agent_state['log_dir']
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    rows = [['date'] + date_list,
            ['time'] + time_list,
            ['reward'] + reward_list,
            ['length'] + length_list,
            ['loss'] + loss_list,
            ['epsilon'] + epsilon_list]
    
    with open(os.path.join(log_dir, log_filename), 'w') as csvfile:
        csvwriter = csv.writer(csvfile)
        csvwriter.writerows(rows)


In [9]:
# ====================================
# 7) PLOTTING AND VISUALIZATION FUNCTIONS
# ====================================

def plot_reward(episode_num, reward_list, n_steps):
    """Plot training rewards with moving average"""
    plt.figure(1)
    rewards_tensor = torch.tensor(reward_list, dtype=torch.float)
    
    if len(rewards_tensor) >= 11:
        eval_reward = torch.clone(rewards_tensor[-10:])
        mean_eval_reward = round(torch.mean(eval_reward).item(), 2)
        std_eval_reward = round(torch.std(eval_reward).item(), 2)
        plt.clf()
        plt.title(f'Episode #{episode_num}: {n_steps} steps, '
                 f'reward {mean_eval_reward}±{std_eval_reward}')
    else:
        plt.clf()
        plt.title('Training...')
    
    plt.xlabel('Episode')
    plt.ylabel('Reward')
    plt.plot(rewards_tensor.numpy())
    
    if len(rewards_tensor) >= 50:
        reward_f = torch.clone(rewards_tensor[:50])
        means = rewards_tensor.unfold(0, 50, 1).mean(1).view(-1)
        means = torch.cat((torch.ones(49) * torch.mean(reward_f), means))
        plt.plot(means.numpy())
    
    plt.pause(0.001)
    if is_ipython:
        display.display(plt.gcf())
        display.clear_output(wait=True)

def print_evaluation(timestep_n, episode, n_updates, epsilon, reward_list, length_list, eval_window=50):
    """Print evaluation statistics"""
    if len(reward_list) >= eval_window:
        rewards_tensor = torch.tensor(reward_list, dtype=torch.float)
        eval_reward = torch.clone(rewards_tensor[-eval_window:])
        mean_eval_reward = round(torch.mean(eval_reward).item(), 2)
        std_eval_reward = round(torch.std(eval_reward).item(), 2)

        lengths_tensor = torch.tensor(length_list, dtype=torch.float)
        eval_length = torch.clone(lengths_tensor[-eval_window:])
        mean_eval_length = round(torch.mean(eval_length).item(), 2)
        std_eval_length = round(torch.std(eval_length).item(), 2)

        print(f'Evaluation: {timestep_n} timestep')
        print(f'    reward {mean_eval_reward}±{std_eval_reward}')
        print(f'    episode length {mean_eval_length}±{std_eval_length}')
        print(f'    episodes: {episode}')
        print(f'    n_updates: {n_updates}')
        print(f'    epsilon: {epsilon}')


In [10]:
# ====================================
# 8) ENVIRONMENT SETUP FUNCTION
# ====================================

def setup_environment(env_name="CarRacing-v3", skip_frames=4, resize_shape=84, frame_stack=4):
    """Setup and wrap the environment"""
    env = gym.make(env_name, continuous=False)
    env = create_skip_frame_wrapper(env, skip=skip_frames)
    env = create_grayscale_wrapper(env)
    env = create_resize_wrapper(env, shape=resize_shape)
    env = create_frame_stack_wrapper(env, num_stack=frame_stack)
    return env

In [11]:
# ====================================
# 9) MAIN TRAINING FUNCTION
# ====================================

def train_dqn(play_n_episodes=3000, batch_size=32, double_q=False, 
              when2learn=4, when2sync=5000, when2save=100000, 
              when2report=5000, when2eval=50000, when2log=10,
              report_type='plot'):
    """Main DQN training function"""
    
    # Setup environment
    env = setup_environment()
    state, info = env.reset()
    action_n = env.action_space.n
    
    # Initialize agent
    agent_state, updating_net, frozen_net, optimizer, loss_fn, buffer = initialize_agent(
        state.shape, action_n, double_q=double_q
    )
    
    # Training tracking variables
    episode_epsilon_list = []
    episode_reward_list = []
    episode_length_list = []
    episode_loss_list = []
    episode_date_list = []
    episode_time_list = []
    
    episode = 0
    timestep_n = 0
    
    # Training loop
    while episode <= play_n_episodes:
        episode += 1
        episode_reward = 0
        episode_length = 0
        updating = True
        loss_list = []
        episode_epsilon_list.append(agent_state['epsilon'])
        
        # Episode loop
        while updating:
            timestep_n += 1
            episode_length += 1
            
            # Take action and step environment
            action = take_action(agent_state, updating_net, state)
            new_state, reward, terminated, truncated, info = env.step(action)
            episode_reward += reward
            
            # Store experience
            store_experience(buffer, state, action, reward, new_state, terminated)
            state = new_state
            updating = not (terminated or truncated)
            
            # Sync target network
            if timestep_n % when2sync == 0:
                upd_net_param = updating_net.state_dict()
                frozen_net.load_state_dict(upd_net_param)
            
            # Save model
            if timestep_n % when2save == 0:
                save_model(agent_state, updating_net, frozen_net, optimizer)
            
            # Update network
            if timestep_n % when2learn == 0 and len(buffer) > batch_size:
                q, loss = update_network(agent_state, updating_net, frozen_net, 
                                       optimizer, loss_fn, buffer, batch_size)
                loss_list.append(loss)
            
            # Report progress
            if timestep_n % when2report == 0 and report_type == 'text':
                print(f'Report: {timestep_n} timestep')
                print(f'    episodes: {episode}')
                print(f'    n_updates: {agent_state["n_updates"]}')
                print(f'    epsilon: {agent_state["epsilon"]}')
            
            # Evaluation
            if timestep_n % when2eval == 0 and report_type == 'text':
                print_evaluation(timestep_n, episode, agent_state['n_updates'], 
                               agent_state['epsilon'], episode_reward_list, episode_length_list)
        
        # Reset environment for next episode
        state, info = env.reset()
        
        # Record episode statistics
        episode_reward_list.append(episode_reward)
        episode_length_list.append(episode_length)
        episode_loss_list.append(np.mean(loss_list) if loss_list else 0)
        now_time = datetime.datetime.now()
        episode_date_list.append(now_time.date().strftime('%Y-%m-%d'))
        episode_time_list.append(now_time.time().strftime('%H:%M:%S'))
        
        # Plot progress
        if report_type == 'plot':
            plot_reward(episode, episode_reward_list, timestep_n)
        
        # Write logs
        if episode % when2log == 0:
            write_log(agent_state, episode_date_list, episode_time_list,
                     episode_reward_list, episode_length_list, episode_loss_list,
                     episode_epsilon_list, log_filename='DQN_training_log.csv')
    
    # Final evaluation and save
    if report_type == 'text':
        print_evaluation(timestep_n, episode, agent_state['n_updates'], 
                        agent_state['epsilon'], episode_reward_list, episode_length_list, 100)
    
    # Final save
    save_model(agent_state, updating_net, frozen_net, optimizer, "DQN_final")
    write_log(agent_state, episode_date_list, episode_time_list,
             episode_reward_list, episode_length_list, episode_loss_list,
             episode_epsilon_list, log_filename='DQN_final_log.csv')
    
    env.close()
    plt.ioff()
    plt.show()
    
    return episode_reward_list, episode_length_list


In [None]:
# ====================================
# 10) MAIN EXECUTION
# ====================================

if __name__ == "__main__":
    # Run training
    reward_history, length_history = train_dqn(
        play_n_episodes=3000,
        batch_size=32,
        double_q=False,     # Change to 'True' for DDQN
        report_type='plot'  # Change to 'text' for text-based reporting
    )
    
    print("Training completed!")