In [None]:
# imports:
import gymnasium as gym
from gymnasium.wrappers import MaxAndSkipObservation, ResizeObservation, GrayscaleObservation, FrameStackObservation, ReshapeObservation
import ale_py

import numpy as np

import torch
import torch.nn as nn        
import torch.optim as optim 
from torchsummary import summary

import collections

import wandb
import datetime

In [18]:
# version
print("Using Gymnasium version {}".format(gym.__version__))

ENV_NAME = "ALE/Breakout-v5"
test_env = gym.make(ENV_NAME)

print(test_env.unwrapped.get_action_meanings())
print(test_env.observation_space.shape)

Using Gymnasium version 1.0.0
['NOOP', 'FIRE', 'RIGHT', 'LEFT']
(210, 160, 3)


In [19]:
# Source: M3-2_Example_1a (DQN on Pong, train)
class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


def make_env(env_name):
    env = gym.make(env_name)
    print("Standard Env.        : {}".format(env.observation_space.shape))
    env = MaxAndSkipObservation(env, skip=4)
    print("MaxAndSkipObservation: {}".format(env.observation_space.shape))
    #env = FireResetEnv(env)
    env = ResizeObservation(env, (84, 84))
    print("ResizeObservation    : {}".format(env.observation_space.shape))
    env = GrayscaleObservation(env, keep_dim=True)
    print("GrayscaleObservation : {}".format(env.observation_space.shape))
    env = ImageToPyTorch(env)
    print("ImageToPyTorch       : {}".format(env.observation_space.shape))
    env = ReshapeObservation(env, (84, 84))
    print("ReshapeObservation   : {}".format(env.observation_space.shape))
    env = FrameStackObservation(env, stack_size=4)
    print("FrameStackObservation: {}".format(env.observation_space.shape))
    env = ScaledFloatFrame(env)
    print("ScaledFloatFrame     : {}".format(env.observation_space.shape))
    
    return env

env=make_env(ENV_NAME)

Standard Env.        : (210, 160, 3)
MaxAndSkipObservation: (210, 160, 3)
ResizeObservation    : (84, 84, 3)
GrayscaleObservation : (84, 84, 1)
ImageToPyTorch       : (1, 84, 84)
ReshapeObservation   : (84, 84)
FrameStackObservation: (4, 84, 84)
ScaledFloatFrame     : (4, 84, 84)


In [20]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [21]:
def make_DQN(input_shape, output_shape):
    net = nn.Sequential(
        nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(64, 64, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(64*7*7, 512),
        nn.ReLU(),
        nn.Linear(512, output_shape)
    )
    return net



In [34]:
MEAN_REWARD_BOUND = 100
NUMBER_OF_REWARDS_TO_AVERAGE = 10

GAMMA = 0.99

BATCH_SIZE = 32
LEARNING_RATE = 1e-4

EXPERIENCE_REPLAY_SIZE = 10000
SYNC_TARGET_NETWORK = 1000

EPS_START = 1.0
EPS_DECAY = 0.999985
EPS_MIN = 0.02

INITIAL_BETA=0.4

In [23]:
Experience = collections.namedtuple('Experience', field_names=['state', 'action', 'reward', 'done', 'new_state'])

class PrioritizedExperienceReplayBuffer:
    def __init__(self, capacity, eps=0.001, alpha=0.6, beta=INITIAL_BETA):
        self.buffer = collections.deque(maxlen=capacity)

        # To make add priority to the experiences we add new attributes to the class
        self.priorities = collections.deque(maxlen=capacity) # This indicates the priorities of the experiences
        self.eps = eps  # This is a small constant to ensure no zero priority
        self.alpha = alpha  # This is an exponent for scaling priorities
        self.beta = beta  # This is and exponent for importance sampling adjustment

    def __len__(self):
        return len(self.buffer)

    # This function adds a new experience to the buffer with max priority
    def append(self, experience):
        self.buffer.append(experience)
        max_priority = max(self.priorities, default=1.0)
        self.priorities.append(max_priority)

    # This function calculates sampling probabilities for the buffer
    def _get_probabilities(self):
        scaled_priorities = np.array(self.priorities) ** self.alpha
        return scaled_priorities / scaled_priorities.sum()

    # This function calculates importance-sampling weights
    def _get_importance(self, probabilities):
        importance = ((1 / len(self.buffer)) * (1 / probabilities)) ** self.beta
        importance_normalized = importance / importance.max()
        return importance_normalized

    # This function samples a batch of experiences from the buffer and returns the batch, importance weights, and indices for priority updates
    def sample(self, batch_size):
        sample_size = min(len(self.buffer), batch_size)
        sample_probs = self._get_probabilities()
        sample_indices = np.random.choice(len(self.buffer), size=sample_size, p=sample_probs)

        experiences = [self.buffer[idx] for idx in sample_indices]
        importance = self._get_importance(sample_probs[sample_indices])

        states, actions, rewards, dones, next_states = zip(*experiences)

        return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
                np.array(dones, dtype=np.uint8), np.array(next_states)), importance, sample_indices

    # This function updates priorities for the given indices using the errors and adds a small epsilon to ensure no priority is zero.
    def update_priorities(self, indices, errors):
        for idx, error in zip(indices, errors):
            self.priorities[idx] = (abs(error) + self.eps)

In [24]:
class DQNAgent:
    def __init__(self, env, exp_replay_buffer):
        self.env = env
        self.exp_replay_buffer = exp_replay_buffer
        self._reset()

    def _reset(self):
        self.current_state = self.env.reset()[0]
        self.total_reward = 0.0

    def step(self, net, epsilon=0.0, device="cpu"):
        done_reward = None
        if np.random.random() < epsilon:
            action = env.action_space.sample()
        else:
            state_ = np.array([self.current_state])
            state = torch.tensor(state_).to(device)
            q_vals = net(state)
            _, act_ = torch.max(q_vals, dim=1)
            action = int(act_.item())

        new_state, reward, terminated, truncated, _ = self.env.step(action)
        is_done = terminated or truncated
        self.total_reward += reward

        exp = Experience(self.current_state, action, reward, is_done, new_state)
        self.exp_replay_buffer.append(exp)
        self.current_state = new_state

        if is_done:
            done_reward = self.total_reward
            self._reset()

        return done_reward

In [35]:
# login
wandb.login()

# start a new wandb run to track this script
wandb.init(
    project="Part1_DQN",
    config={
        "gamma": GAMMA,
        "learning_rate": LEARNING_RATE,
        "eps_start": EPS_START,
        "eps_decay": EPS_DECAY
    }
)

In [37]:
print(">>> Training starts at ",datetime.datetime.now())

>>> Training starts at  2024-11-30 13:32:58.729366


In [38]:
net = make_DQN(env.observation_space.shape, env.action_space.n).to(device)
target_net = make_DQN(env.observation_space.shape, env.action_space.n).to(device)

# Here we replace standard buffer with the PER buffer
priorized_buffer = PrioritizedExperienceReplayBuffer(EXPERIENCE_REPLAY_SIZE)
agent = DQNAgent(env, priorized_buffer)

epsilon = EPS_START
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
total_rewards = []
losses = []
step_number = 0

while True:
    step_number += 1
    epsilon = max(epsilon * EPS_DECAY, EPS_MIN)

    # Here we update beta from INITIAL_BETA to 1.0
    priorized_buffer.beta = min(1.0, priorized_buffer.beta + (1.0 - INITIAL_BETA) / 120000) # This is divided by 120000 because this is the expected number of steps for the algorithm to run. In this way, beta will be 1.0 when the algorithm reach the 120000 steps


    reward = agent.step(net, epsilon, device=device)
    if reward is not None:
        total_rewards.append(reward)

        mean_reward = np.mean(total_rewards[-NUMBER_OF_REWARDS_TO_AVERAGE:])

        print(f"Step:{step_number} | Total games:{len(total_rewards)} | Mean reward: {mean_reward:.3f}  (epsilon used: {epsilon:.2f})")
        wandb.log({"epsilon": epsilon, "reward_100": mean_reward, "reward": reward}, step=step_number)

        if mean_reward > MEAN_REWARD_BOUND:
            print(f"SOLVED in {step_number} steps and {len(total_rewards)} games")
            break

    if len(priorized_buffer) < EXPERIENCE_REPLAY_SIZE:
        continue

    (states_, actions_, rewards_, dones_, next_states_), importance, indices = priorized_buffer.sample(BATCH_SIZE)
    importance = torch.tensor(importance, dtype=torch.float32).to(device)

    states = torch.tensor(states_).to(device)
    next_states = torch.tensor(next_states_).to(device)
    actions = torch.tensor(actions_).to(device)
    rewards = torch.tensor(rewards_).to(device)
    dones = torch.BoolTensor(dones_).to(device)

    Q_values = net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

    next_state_values = target_net(next_states).max(1)[0]
    next_state_values[dones] = 0.0
    next_state_values = next_state_values.detach()

    expected_Q_values = next_state_values * GAMMA + rewards


    errors = torch.abs(Q_values - expected_Q_values).detach().cpu().numpy()  # Here we calculate the TD errors
    loss = (importance * nn.MSELoss(reduction='none')(Q_values, expected_Q_values)).mean() # Here we use reduction='none' because we want to calculate the MSE element by element first, multiply it by the importance weights, and then take the mean

    # We update the buffer priorities based on the TD errors
    priorized_buffer.update_priorities(indices, errors)

    losses.append(loss.item())
    mean_losses = np.mean(losses[-NUMBER_OF_REWARDS_TO_AVERAGE:])
    wandb.log({"loss_100": mean_losses, "loss": loss.item()}, step=step_number)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step_number % SYNC_TARGET_NETWORK == 0:
        target_net.load_state_dict(net.state_dict())

Step:43 | Total games:1 | Mean reward: 0.000  (epsilon used: 1.00)
Step:94 | Total games:2 | Mean reward: 0.500  (epsilon used: 1.00)
Step:162 | Total games:3 | Mean reward: 1.333  (epsilon used: 1.00)
Step:211 | Total games:4 | Mean reward: 1.250  (epsilon used: 1.00)
Step:254 | Total games:5 | Mean reward: 1.000  (epsilon used: 1.00)
Step:303 | Total games:6 | Mean reward: 1.000  (epsilon used: 1.00)
Step:404 | Total games:7 | Mean reward: 1.571  (epsilon used: 0.99)
Step:453 | Total games:8 | Mean reward: 1.500  (epsilon used: 0.99)
Step:526 | Total games:9 | Mean reward: 1.667  (epsilon used: 0.99)
Step:575 | Total games:10 | Mean reward: 1.600  (epsilon used: 0.99)




Step:653 | Total games:11 | Mean reward: 2.000  (epsilon used: 0.99)
Step:698 | Total games:12 | Mean reward: 1.900  (epsilon used: 0.99)
Step:761 | Total games:13 | Mean reward: 1.700  (epsilon used: 0.99)
Step:815 | Total games:14 | Mean reward: 1.600  (epsilon used: 0.99)
Step:892 | Total games:15 | Mean reward: 2.000  (epsilon used: 0.99)
Step:960 | Total games:16 | Mean reward: 2.200  (epsilon used: 0.99)
Step:1025 | Total games:17 | Mean reward: 2.000  (epsilon used: 0.98)
Step:1119 | Total games:18 | Mean reward: 2.400  (epsilon used: 0.98)
Step:1199 | Total games:19 | Mean reward: 2.500  (epsilon used: 0.98)
Step:1272 | Total games:20 | Mean reward: 2.600  (epsilon used: 0.98)
Step:1334 | Total games:21 | Mean reward: 2.300  (epsilon used: 0.98)
Step:1394 | Total games:22 | Mean reward: 2.300  (epsilon used: 0.98)
Step:1450 | Total games:23 | Mean reward: 2.400  (epsilon used: 0.98)
Step:1511 | Total games:24 | Mean reward: 2.600  (epsilon used: 0.98)
Step:1584 | Total games:25

KeyboardInterrupt: 

In [None]:
torch.save(net.state_dict(), "../models/Part1_DQN.dat")

In [39]:
print(">>> Training ends at ",datetime.datetime.now())

>>> Training ends at  2024-11-30 13:38:24.016023


In [40]:
# Finish the wandb run, necessary in notebooks
wandb.finish()

VBox(children=(Label(value='0.026 MB of 0.026 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epsilon,████▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁
loss,█▃▇▇██▃▃▂▃▇▃▂▅▅▄▂▂▂▂▂▄▂▂▃▂▂▁▂▂▁▂▂▁▂▃▂▅▂▂
loss_100,▃▃▆▄▆█▇▄▅▃▃▂▂▁▂▂▂▃▅▃▃▃▂▃▂▂▁▂▂▂▁▁▁▁▁▂▂▂▂▂
reward,▃▄▄▅▃▂▄▃▅▁▂▁▂▃▂▁█▄▃▄▇▄▁█▃▅▆▃▂▃▂▂▄▂▃▃▂▁▄▂
reward_100,▂▃▂▃▃▆▆▇▄▂▅▃▂▁▃▁▅▇▇█▆▇█▇▆▄▇▅▆▅▆▇█▄▃▃▄▄▅▃

0,1
epsilon,0.81946
loss,0.01461
loss_100,0.0098
reward,2.0
reward_100,1.2
