# Imports

In [1]:
import gymnasium as gym
import numpy as np
from Models.DDQN.DDQN_Agent import DDQN_Agent
import torch
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

# For tensorboard logging
now = datetime.now()
logdir = "Logging/DDQN-CARTPOLE/Tensorboard/" + now.strftime("%Y%m%d-%H%M%S") + "/"
writer = SummaryWriter(log_dir=logdir)

  from .autonotebook import tqdm as notebook_tqdm


# Setup environment

In [2]:
env = gym.make("CartPole-v1")
print(env.action_space.n)
print(env.observation_space.shape[0])

2
4


# Device

In [3]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

# Print the device as a check
print("Device used: ", DEVICE)

Device used:  cuda:0


# Hyperparameters setup

In [4]:
seed = 42
batch_size = 64
n_training_iterations = 5
alpha = 0.06
epsilon = 0.1
epsilon_decay = 1
min_epsilon = 0.05
buffer_size = 1000
gamma = 0.95
horizon = 20 # Number of steps before training the agent

max_num_steps = 1000000
num_test_runs = 10
num_episodes_before_test_runs = 100
threshold_test_return_to_update_opponents = 0.5
time_scale = 90
num_steps_before_save = 500000

In [5]:
def evaluate(env, agent, num_eval_episodes):
    
    # Set the model in evaluation mode
    agent.q_network.eval()
    
    # Run num_eval_episodes episodes and calculate the total return
    total_return = 0
    for _ in tqdm(range(num_eval_episodes)):

        state, _ = env.reset()
        done = False
        while not done:
            
            with torch.no_grad():

                # Select the actions for each agent
                action = agent.select_action(torch.FloatTensor(state).to(DEVICE), greedy=True)
            
            # Step the environment forward
            next_states, reward, terminated, truncated, _ = env.step(action.item())

            # Check if they're done
            done = terminated or truncated
            
            # Add the individual agents' rewards to the total returns (Since they're the same for both agents)
            total_return += reward

            # Update the states
            state = next_states
    
    # Set the model back in training mode
    agent.q_network.train()

    # Return the average return
    return total_return / num_eval_episodes

In [6]:
def train(alpha, epsilon, epsilon_decay, min_epsilon, buffer_size, gamma, horizon, batch_size, n_training_iterations, max_num_steps, num_test_runs, num_episodes_before_test_runs, num_steps_before_save):

    # Create the environment
    env = gym.make("CartPole-v1")
    torch.manual_seed(seed)
    # env.seed(seed)

    # Create the player agents
    agent = DDQN_Agent(state_size=env.observation_space.shape[0], 
                       action_size=env.action_space.n, 
                       alpha=alpha, 
                       epsilon=epsilon, 
                       epsilon_decay=epsilon_decay, 
                       min_epsilon=min_epsilon, 
                       buffer_size=buffer_size, 
                       gamma=gamma, 
                       batch_size=batch_size,
                       DEVICE=DEVICE)

    # Store the best test return
    n_steps = 0
    e = 0 # Episode number
    # Train the agent
    while n_steps < max_num_steps:

        if n_steps > 0 and n_steps % num_steps_before_save == 0:
            agent.save_models("./Logging/DDQN-CARTPOLE/Checkpoints", 1, n_steps)
        
        if e % num_episodes_before_test_runs == 0:

            average_test_return = evaluate(env, agent, num_test_runs)
            writer.add_scalar("AverageTestReturn-TrainStep", average_test_return, n_steps)
            writer.add_scalar("AverageTestReturn-TrainEpisode", average_test_return, e)
            writer.flush()

        # Reset the environment, extracting the initial states of all 4 agents
        state, _ = env.reset()

        # Append the initial state to memory
        agent.memory.states.append(state)
        
        # Run an episode
        done = False
        total_return = 0
        while not done:  

            # Select the actions for each agent
            action = agent.select_action(torch.FloatTensor(state).to(DEVICE))

            # Step the environment forward
            next_state, reward, terminated, truncated, _ = env.step(action)
            
            # Check if the episode is done
            done = terminated or truncated

            # Store the experience in the replay buffer
            agent.remember(next_state, action, reward, terminated)

            # Update the states
            state = next_state

            # Add the individual agents' rewards to the total returns (Since they're the same for both agents)
            total_return += reward

            # Train the agents if the number of steps is a multiple of the horizon
            if n_steps > 0 and n_steps % horizon == 0:
                for _ in range(n_training_iterations):
                    agent.training_iteration()

            # Incremenent the number of steps
            n_steps += 1
        
        # Update the epsilon
        agent.decay_epsilon()

        # Increment the episode number
        e += 1
        
        # Print the training returns
        writer.add_scalar("TrainReturn-TrainStep", total_return, n_steps)
        writer.add_scalar("TrainReturn-TrainEpisode", total_return, e)

        writer.flush()

In [7]:
train(alpha=alpha, 
      epsilon=epsilon, 
      epsilon_decay=epsilon_decay, 
      min_epsilon=min_epsilon, 
      buffer_size=buffer_size, 
      gamma=gamma, 
      horizon=horizon, 
      batch_size=batch_size, 
      n_training_iterations=n_training_iterations, 
      max_num_steps=max_num_steps, 
      num_test_runs=num_test_runs, 
      num_episodes_before_test_runs=num_episodes_before_test_runs, 
      num_steps_before_save=num_steps_before_save)

  return torch.Tensor(self.states)[idx], torch.LongTensor(self.actions)[idx], \
100%|██████████| 10/10 [00:00<00:00, 154.73it/s]
100%|██████████| 10/10 [00:00<00:00, 150.33it/s]
100%|██████████| 10/10 [00:00<00:00, 168.24it/s]
100%|██████████| 10/10 [00:00<00:00, 156.31it/s]
100%|██████████| 10/10 [00:00<00:00, 156.21it/s]
100%|██████████| 10/10 [00:00<00:00, 165.84it/s]
100%|██████████| 10/10 [00:00<00:00, 131.52it/s]
100%|██████████| 10/10 [00:00<00:00, 187.66it/s]
100%|██████████| 10/10 [00:00<00:00, 211.56it/s]
100%|██████████| 10/10 [00:00<00:00, 207.65it/s]
100%|██████████| 10/10 [00:00<00:00, 159.93it/s]
100%|██████████| 10/10 [00:00<00:00, 172.86it/s]
100%|██████████| 10/10 [00:00<00:00, 149.91it/s]
100%|██████████| 10/10 [00:00<00:00, 175.28it/s]
100%|██████████| 10/10 [00:00<00:00, 157.85it/s]
100%|██████████| 10/10 [00:00<00:00, 190.88it/s]
100%|██████████| 10/10 [00:00<00:00, 175.43it/s]
100%|██████████| 10/10 [00:00<00:00, 168.65it/s]
100%|██████████| 10/10 [00:00<00:00, 1

KeyboardInterrupt: 