In [18]:
import matplotlib.pyplot as plt
import gymnasium as gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import ale_py
import random
import glob
import imageio
import cv2 as cv
import os
import time
from collections import deque

In [19]:
env = gym.make("ALE/Pong-v5")

In [20]:
action_map = {
    0 : 0,
    1 : 2,
    2 : 3
}


In [21]:
def evaluate_average_return(Q, n_episodes, epsilon):

    returns = []
    for episode in range(n_episodes):

        total_reward = 0

        # Skip first frame (different color)
        frame, _ = env.reset()
        _ = env.step(0)

        frame, reward, terminated, truncated, info = env.step(0)
        observation = get_observation(frame)
        state = torch.tensor(observation).float().to(device)

        # Wait to have 4 frames as a first full sequence
        for _ in range(3):
            frame, reward, terminated, truncated, info = env.step(0)
            observation = get_observation(frame)
            state = torch.cat((state, torch.tensor(observation).float().to(device)))

        done = False
        while not done:

            with torch.no_grad():
                # Decaying epsilon
                epsilon = max(epsilon_min, epsilon - epsilon_decay)
                if random.random() < epsilon:
                    action = np.random.choice(list(action_map.keys()))
                else:
                    action = torch.argmax(Q(state)).item()

            frame, reward, terminated, truncated, info = env.step(action_map[action])
            observation = get_observation(frame)
            total_reward += reward

            next_state = state.clone()
            next_state[:18] = next_state[6:].clone()
            next_state[18:] = torch.tensor(observation).float().to(device)

            done = terminated or truncated

            state = next_state

        returns.append(total_reward)

    return np.mean(returns)

In [22]:
# checkpoints_to_evaluate = [1750, 1800, 1850]
# av_returns = {}

# for checkpoint in checkpoints_to_evaluate:
    
#     training_vars = load_checkpoint(Q, Q_optimizer, Buffer, f"checkpoints/training3/{checkpoint}.pth")
#     av_return = evaluate_average_return(Q, 10, 0)
#     av_returns[checkpoint] = float(av_return)
#     print(f"{checkpoint} : {float(av_return)}")

# print(av_returns)

In [23]:
# # Count action taken

# n_episodes = 10
# epsilon = 0
# actions = np.zeros(6)
# for episode in range(n_episodes):

#     # Skip first frame (different color)
#     frame, _ = env.reset()
#     _ = env.step(0)

#     frame, reward, terminated, truncated, info = env.step(0)
#     observation = get_observation(frame)
#     state = torch.tensor(observation).float().to(device)

#     # Wait to have 4 frames as a first full sequence
#     for _ in range(3):
#         frame, reward, terminated, truncated, info = env.step(0)
#         observation = get_observation(frame)
#         state = torch.cat((state, torch.tensor(observation).float().to(device)))

#     done = False
#     while not done:

#         if np.random.rand() < epsilon:
#             action = env.action_space.sample()
#         else:
#             action = torch.argmax(Q(state)).item()

#         actions[action] += 1
#         frame, reward, terminated, truncated, info = env.step(action)
#         observation = get_observation(frame)

#         next_state = state.clone()
#         next_state[:18] = next_state[6:].clone()
#         next_state[18:] = torch.tensor(observation).float().to(device)

#         done = terminated or truncated
#         if done:
#             break

#         state = next_state

# print(actions / np.sum(actions))

In [24]:
# # --- Pygame init ---
# pygame.init()
# screen = pygame.display.set_mode((400, 300))
# pygame.display.set_caption("Play Pong with Keyboard")
# clock = pygame.time.Clock()

# # --- Gym init ---
# env = gym.make("ALE/Pong-v5", render_mode="human")
# obs, info = env.reset()

# # Mapping: keys -> actions
# # 0: NOOP, 1: FIRE, 2: RIGHT, 3: LEFT, 4: RIGHTFIRE, 5: LEFTFIRE
# key_action_map = {
#     pygame.K_UP: 2,     # Move up (RIGHT in Pong's terms)
#     pygame.K_DOWN: 3,   # Move down (LEFT in Pong's terms)
#     pygame.K_SPACE: 1,  # Fire (start the game)
# }

# done = False
# while True:
#     action = 0  # Default NOOP

#     # --- Handle events ---
#     for event in pygame.event.get():
#         if event.type == pygame.QUIT:
#             env.close()
#             pygame.quit()
#             raise SystemExit
        
#         if event.type == pygame.KEYDOWN and event.key == pygame.K_ESCAPE:
#             env.close()
#             pygame.quit()
#             raise SystemExit

#     # Get pressed keys
#     keys = pygame.key.get_pressed()
#     for key, mapped_action in key_action_map.items():
#         if keys[key]:
#             action = mapped_action

#     # Step the environment
#     obs, reward, terminated, truncated, info = env.step(action)
#     done = terminated or truncated

#     if done:
#         obs, info = env.reset()

#     clock.tick(60)  # Limit loop to 60 FPS

In [25]:
# env = gym.make("ALE/Pong-v5", render_mode="human")

# frame, info = env.reset()

# observation = get_observation(frame)
# state = torch.tensor(observation).float().to(device)

# # Wait to have 4 first frames as a first full sequence
# for _ in range(3):
#     frame, reward, terminated, truncated, info = env.step(0)
#     observation = get_observation(frame)
#     state = torch.cat((state, torch.tensor(observation).float().to(device)))

# for _ in range(10000):

#     action = torch.argmax(Q(state)).item()
#     frame, reward, terminated, truncated, info = env.step(action)
#     observation = get_observation(frame)

#     action = env.action_space.sample()  
#     observation, reward, terminated, truncated, info = env.step(action)

#     next_state = state.clone()
#     next_state[:18] = next_state[6:].clone()
#     next_state[18:] = torch.tensor(observation).float().to(device)


#     if terminated or truncated:
#             break

#     state = next_state

#     if terminated or truncated:
#         break
# env.close()

In [26]:
# unique_colors = {tuple(map(int, pixel)) for pixel in state.reshape(-1, 3)}
# print(unique_colors)

In [27]:
# states, actions, rewards, next_states, terminateds, truncateds = Buffer.sample(32)

Inputs :
- Can see the ball (bool)
- Can see the adversary (bool)
- x,y of the ball
- y of both paddles

Reward :
- -1 if lost a point
- +1 is won a point

In [28]:
# from utils import get_state

# frame, _ = env.reset()
# ball_position = (0, np.array([0, 0]))
# frames = []

# for i in range(20):
#     frame, reward, terminated, truncated, info = env.step(env.action_space.sample())
#     ball_position, state = get_state(frame, ball_position)
#     frames.append(frame)

In [29]:
# import matplotlib.pyplot as plt

# fig, ax = plt.subplots(1,2)
# ax[0].imshow(frames[-2])
# ax[1].imshow(frames[-1])
# state

In [30]:
# from importlib import reload
# import utils

# # Force reload
# reload(utils)

In [31]:
# from utils import generate_evaluation_states
# ev_states = generate_evaluation_states(env, device, 100)
# torch.save(ev_states, "ev_states.pt")

# 1. DQN

Blabla

In [32]:
# Resume training of start from scratch
resume_training = False
checkpoint = "training/dqn/training3/10.pth"

max_training_time = 7 #h

In [33]:
training_numbers = [int(folder.split("training")[-1]) for folder in glob.glob("training/dqn/*")]
training_number = max(training_numbers) + 1 if len(training_numbers) > 0 else 1
os.mkdir(f"training/dqn/training{training_number}")

In [34]:
from dqn import QNetwork, ReplayBuffer, Update_Q

device = 'cpu'

Q = QNetwork(input_dim = 9, output_dim = 3).to(device)
Q_target = QNetwork(input_dim = 9, output_dim = 3).to(device)
Q_target.load_state_dict(Q.state_dict())
Q_optimizer = torch.optim.Adam(Q.parameters(), lr = 0.0001)
Buffer   = ReplayBuffer()

ev_states = torch.load("ev_states.pt")

# Hyperparameters
gamma = 0.99
epsilon = 1
epsilon_min = 0.1
epsilon_decay = 0.00001
batch_size = 32
max_episode = 3000
max_time_steps = 10000
update_frequency = 1
target_update_frequency = 1000
checkpoint_frequency = 50
ever_won = False

In [35]:
from dqn import load_checkpoint

if resume_training:
    training_vars = load_checkpoint(Q, Q_optimizer, Buffer, checkpoint)
    Q_target.load_state_dict(Q.state_dict())
    returns, avg_Qvalues, td_losses, episode_start, epsilon = training_vars
else:
    episode_start = 0
    returns = []
    avg_Qvalues = []
    td_losses = []

In [None]:
from utils import get_state, generate_video
from dqn import save_checkpoint, get_avg_Qvalues

start_time = time.time()
tot_training_steps = 0
ball_position = (0, np.array([0, 0]))

for episode in range(episode_start, max_episode):

    total_reward = 0
    points_scored = 0

    # Skip first frame (different color)
    frame, _ = env.reset()
    _ = env.step(0)

    frame, reward, terminated, truncated, info = env.step(0)
    ball_position, state = get_state(frame, ball_position)
    state = torch.tensor(state).float().to(device)

    for t in range(max_time_steps):

        with torch.no_grad():
            # Decaying epsilon
            epsilon = max(epsilon_min, epsilon - epsilon_decay)
            # epsilon-greedy action selection
            if random.random() < epsilon:
                action = np.random.choice(list(action_map.keys()))
            else:
                action = torch.argmax(Q(state)).item()

        frame, reward, terminated, truncated, info = env.step(action_map[action])

        ball_position, next_state = get_state(frame, ball_position)
        next_state = torch.tensor(next_state).float().to(device)

        total_reward += reward
        if reward == 1 : points_scored += 1

        Buffer.put([state, int(action), reward, next_state, terminated, truncated])

        if Buffer.size() > 1000 and t % update_frequency == 0:
                td_loss = Update_Q(Buffer, Q, Q_target, Q_optimizer, batch_size, gamma)
                td_losses.append(td_loss)

        tot_training_steps += 1
        if tot_training_steps % target_update_frequency == 0:
            Q_target.load_state_dict(Q.state_dict())

        if terminated or truncated:
                break

        state = next_state

    # print('episode: {}, reward: {:.1f}'.format(episode, total_reward))
    returns.append(total_reward)
    avg_Qvalues.append(get_avg_Qvalues(Q, ev_states))

    if (episode + 1) % 30 == 0:
        print(f"{episode+1} episodes done. Average reward on last 30 ep. : {np.mean(returns[-30:])}")

    if points_scored == 21 and not ever_won:
        ever_won = True
        print(f"First win ! (Episode {episode})")

    # Training checkpoint
    if (episode + 1) % checkpoint_frequency == 0:
        save_checkpoint(Q, Q_optimizer, Buffer, returns, avg_Qvalues, td_losses, episode, epsilon, f"training/dqn/training{training_number}/{episode+1}.pth")
        generate_video(env, Q, device, action_map, epsilon=0, n_episodes=1, filename=f"training/dqn/training{training_number}/{episode+1}.mp4")

    if time.time() - start_time > 3600 * max_training_time:
        print(f"Maximum training time of {max_training_time}h exceeded. Interrupting training after {episode} episodes.")
        break

env.close()

episode: 0, reward: -20.0
episode: 1, reward: -20.0


KeyboardInterrupt: 

In [None]:
avg_returns = [np.mean(returns[i-100:i]) for i in range(100, len(returns))]
plt.plot(range(100, len(returns)), avg_returns)
plt.title("Average return per episode (100 last episodes)")
plt.xlabel("Episodes")
plt.ylabel("Average Return")

In [None]:
plt.plot(range(len(avg_Qvalues)), avg_Qvalues)
plt.title("Average Q_value")
plt.xlabel("Episodes")
plt.ylabel("Average Q_value")

In [None]:
plt.plot(range(len(td_losses)), td_losses)
plt.title("TD Loss")
plt.xlabel("Timesteps")
plt.ylabel("TD Loss")