# Part 1: Pong Tournament

In [None]:
import warnings
warnings.filterwarnings("ignore")

import cv2
import numpy as np
import collections
import random
import time
import datetime
import os

import gymnasium as gym
from gymnasium import spaces

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Using gym version: {gym.__version__}")

Using device: cpu
Using gym version: 0.26.2


## Data Preprocessing (Wrappers)

In [None]:
class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super(FireResetEnv, self).__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def step(self, action):
        return self.env.step(action)

    def reset(self):
        self.env.reset()
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset()
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset()
        return obs

    
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        super(MaxAndSkipEnv, self).__init__(env)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self):
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
                                                old_space.high.repeat(n_steps, axis=0), dtype=dtype)

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1],
                                old_shape[0], old_shape[1]), dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0

    
def make_env(env_name):
    env = gym.make(env_name)
    print("Standard Env.        : {}".format(env.observation_space.shape))
    env = MaxAndSkipEnv(env)
    print("MaxAndSkipEnv        : {}".format(env.observation_space.shape))
    env = FireResetEnv(env)
    print("FireResetEnv         : {}".format(env.observation_space.shape))
    env = ProcessFrame84(env)
    print("ProcessFrame84       : {}".format(env.observation_space.shape))
    env = ImageToPyTorch(env)
    print("ImageToPyTorch       : {}".format(env.observation_space.shape))
    env = BufferWrapper(env, 4)
    print("BufferWrapper        : {}".format(env.observation_space.shape))
    env = ScaledFloatFrame(env)
    print("ScaledFloatFrame     : {}".format(env.observation_space.shape))
    
    return env

def print_env_info(name, env):
    obs = env.reset()
    print("*** {} Environment ***".format(name))
    print("Observation shape: {}, type: {} and range [{},{}]".format(obs.shape, obs.dtype, np.min(obs), np.max(obs)))
    print("Observation sample:\n{}".format(obs))

#taste the environment
ENV_NAME = "PongNoFrameskip-v4"
env = make_env(ENV_NAME)
print_env_info("Wrapped", env)


## DQN

In [None]:
class DQN(nn.Module):
    def __init__(self, input_shape, output_shape):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten()
        )

        with torch.no_grad():
            n_flatten = self.net(torch.zeros(1, *input_shape)).shape[1]

        self.fc = nn.Sequential(
            nn.Linear(n_flatten, 512),
            nn.ReLU(),
            nn.Linear(512, output_shape)
        )

    def forward(self, x):
        x = self.net(x)
        return self.fc(x)
    
# Test the model
env = make_env(ENV_NAME)
model = DQN(env.observation_space.shape, env.action_space.n).to(device)
print(model)

In [None]:
# Experience Buffer and Agent 
class ExperienceBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

class Agent:
    def __init__(self, env, exp_buffer, device ):
        self.env = env
        self.exp_buffer = exp_buffer
        self.device = device
    
    def reset(self):
        self.state = self.env.reset()
        self.total_reward = 0
    
    def play_step(self, net, epsilon=0.0):
        """
        Select an action using epsilon-greedy policy
        Returns: reward if episode ends, else None
        """
        if random.random() < epsilon:
            return random.randrange(net.net[-1].out_features)
        else:
            state = torch.tensor(np.array([self.state]), dtype=torch.float32).to(device)
            q_values = net(state)
            action = int(torch.argmax(q_values, dim=1).item())
        
        state, reward, terminated, truncated, info = self.env.step(action)
        done = terminated or truncated
        self.total_reward += reward

        self.exp_buffer.push(self.state, action, reward, state, done)

        self.state = state

        if done:
            reward = self.total_reward
            self.reset()
            return reward

# test the agent
exp_buffer = ExperienceBuffer(10000)
agent = Agent(env, exp_buffer, model, device)

In [None]:
import typing as tt
def batch_to_tensors(batch: tt.List[Experience], device: torch.device) -> BatchTensors:
    """
    Convert a batch of Experience to a tuple of tensors.

    Args:
        batch (tt.List[Experience]): A list of Experience objects.
        device (torch.device): The device to which the tensors will be moved.

    Returns:
        BatchTensors: A tuple containing the tensors for states, actions, rewards, dones, and new_states.
    """
    states, actions, rewards, dones, new_state = [], [], [], [], []
    for e in batch:
        states.append(e.state)
        actions.append(e.action)
        rewards.append(e.reward)
        dones.append(e.done_trunc)
        new_state.append(e.new_state)
    states_t = torch.as_tensor(np.asarray(states))
    actions_t = torch.LongTensor(actions)
    rewards_t = torch.FloatTensor(rewards)
    dones_t = torch.BoolTensor(dones)
    new_states_t = torch.as_tensor(np.asarray(new_state))

    return states_t.to(device), actions_t.to(device), rewards_t.to(device), dones_t.to(device),  new_states_t.to(device)


def calc_loss(batch: tt.List[Experience], net: DQN, tgt_net: DQN, device: torch.device, gamma: float) -> torch.Tensor:
    """
    Calculate the loss for a batch of experiences.
    
    Args:
        batch (tt.List[Experience]): A list of Experience objects.
        net (DQN): The current DQN network.
        tgt_net (DQN): The target DQN network.
        device (torch.device): The device to which the tensors will be moved.
    
    Returns:
        torch.Tensor: The calculated loss.
    """
    states_t, actions_t, rewards_t, dones_t, new_states_t = batch_to_tensors(batch, device)

    state_action_values = net(states_t).gather(1, actions_t.unsqueeze(-1)).squeeze(-1)
    with torch.no_grad():
        next_state_values = tgt_net(new_states_t).max(1)[0]
        next_state_values[dones_t] = 0.0
        next_state_values = next_state_values.detach()

    expected_state_action_values = next_state_values * gamma + rewards_t
    
    return nn.MSELoss()(state_action_values, expected_state_action_values)

def plot_rewards(rewards, ma_window=100):
    plt.figure(figsize=(12,8))
    plt.title("Rewards")
    plt.xlabel("Episode")
    plt.ylabel("Reward")
    plt.plot(rewards, label='Rewards')
    if len(rewards) >= ma_window:
        ma = np.convolve(rewards, np.ones(ma_window)/ma_window, mode='valid')
        plt.plot(range(ma_window-1, len(rewards)), ma, label='Moving Average (window={})'.format(ma_window), color='orange')
    plt.legend()
    plt.show()

In [None]:
# Parameters
ENV_NAME = "PongNoFrameskip-v4"
MEAN_REWARD_BOUND = 19

GAMMA = 0.99
BATCH_SIZE = 32
REPLAY_SIZE = 10000
LEARNING_RATE = 1e-4
SYNC_TARGET_FRAMES = 1000
REPLAY_START_SIZE = 10000

EPSILON_DECAY_LAST_FRAME = 150000
EPSILON_START = 1.0
EPSILON_FINAL = 0.01
NUM_EPISODES = 5000

print(f"Training DQN on {ENV_NAME} environment")
print(f"Batch size: {BATCH_SIZE}, Replay size: {REPLAY_SIZE}, Learning rate: {LEARNING_RATE}")
print(f"Sync target frames: {SYNC_TARGET_FRAMES}")
print(f"Epsilon start: {EPSILON_START}, Epsilon final: {EPSILON_FINAL}")

## Training

In [None]:
def train_dqn():
    env = make_env(ENV_NAME)
    net = DQN(env.observation_space.shape, env.action_space.n).to(device)
    tgt_net = DQN(env.observation_space.shape, env.action_space.n).to(device)
    print(net)

    buffer = ExperienceBuffer(REPLAY_SIZE)
    agent = Agent(env, buffer)
    epsilon = EPSILON_START

    optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
    total_rewards = []
    frame_idx = 0
    ts_frame = 0
    ts = time.time()
    best_m_reward = None

    print(">>> Training starts at ",datetime.datetime.now())

    for episode in range(NUM_EPISODES):
        frame_idx += 1
        epsilon = max(EPSILON_FINAL, EPSILON_START - frame_idx / EPSILON_DECAY_LAST_FRAME)

        reward = agent.play_step(net, device, epsilon)
        if reward is not None:
            total_rewards.append(reward)
            speed = (frame_idx - ts_frame) / (time.time() - ts)
            ts_frame = frame_idx
            ts = time.time()
            m_reward = np.mean(total_rewards[-100:])
            print(f"{frame_idx}: done {len(total_rewards)} games, reward {m_reward:.3f}, eps {epsilon:.2f}, speed {speed:.2f} f/s")
            wandb.log({"epsilon": epsilon, "speed": speed, "reward_100": m_reward, "reward": reward}, step=frame_idx)

            if best_m_reward is None or best_m_reward < m_reward:
                if best_m_reward is not None:
                    print(f"Best reward updated {best_m_reward:.3f} -> {m_reward:.3f}")
                    model_name = os.path.join("models", ENV_NAME + "_DQN.pth")
                    print(f"Saving model '{model_name}'")
                    torch.save(net.state_dict(), model_name)
                best_m_reward = m_reward
            if m_reward > MEAN_REWARD_BOUND:
                print("Solved in %d frames!" % frame_idx)
                break
        if len(buffer) < REPLAY_START_SIZE:
            continue
        if frame_idx % SYNC_TARGET_FRAMES == 0:
            tgt_net.load_state_dict(net.state_dict())

        optimizer.zero_grad()
        batch = buffer.sample(BATCH_SIZE)
        loss_t = calc_loss(batch, net, tgt_net, device)
        loss_t.backward()
        optimizer.step()

        # Plotting rewards
        if episode % 50 == 0 and episode > 0:
            plot_rewards(total_rewards)
    
    print(">>> Training ends at ",datetime.datetime.now())
    
    return total_rewards

# Start training
dqn_rewards = train_dqn()

# REINFORCE