In [2]:
# For interactive plotting 
%matplotlib qt



In [3]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque, namedtuple
import matplotlib.pyplot as plt
import os

In [4]:
# Parameters
env_name = 'LunarLander-v3'
seed = 42
main_net_path = 'main_net.pth'

mini_batch_size = 128
buffer_size_limit = 10000
steps_until_value_iteration = 500
steps_until_target_net_update = 2000

gamma = 0.99
epsilon_start = 1.0
epsilon_end = 0.1
epsilon_decay = 0.999
epsilon = epsilon_start
lr=1e-2

In [5]:
# Environment and seeds
env = gym.make(env_name, render_mode='rgb_array')

env.reset(seed=seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7be5b2340e90>

In [6]:
# Main q-network and target q-network
observation_size = env.observation_space.shape[0]
action_size = env.action_space.n

def make_mlp():
    return nn.Sequential(
        nn.Linear(observation_size, 64),
        nn.ReLU(),
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear(64, action_size)
    )

if os.path.exists(main_net_path):
    main_net = torch.load(main_net_path)
else:
    main_net = make_mlp()
    torch.save(main_net, main_net_path)
target_net = make_mlp()

def update_target_net():
    target_net.load_state_dict(main_net.state_dict())

update_target_net() # They are same from the start

optimiser = optim.Adam(main_net.parameters(), lr=lr)

In [7]:
# Implement expsilon-greedy policy
def get_action(observation):
    '''
    observation: numpy array returned by env.step()
    returns: integer action
    '''
    possible_actions = [i for i in range(env.action_space.n)]
    if np.random.random() < epsilon:
        return np.random.choice(possible_actions)

    observation = torch.as_tensor(observation, dtype=torch.float32)
    with torch.no_grad():
        q_star_per_each_action = main_net(observation)
        action = torch.argmax(q_star_per_each_action).item()
    return action

In [8]:
# Epsilon decay 
def decrease_epsilon():
    global epsilon
    epsilon = max(epsilon_end, epsilon * epsilon_decay)

In [9]:
# Replay buffer
replay_buffer = deque([], maxlen=buffer_size_limit)

Timestep   = namedtuple('timestep',   ["state", "action", "reward", "next_state", "done"]) 
'''
Data types of Timestep:
state/next_state is numpy array; 
action is int; 
reward is float; 
done is bool;
'''

Mini_batch = namedtuple('mini_batch', ["states", "actions", "rewards", "next_states", "dones"])


def record_timestep(timestep):
    replay_buffer.append(timestep)


def sample_a_mini_batch():
    '''
    returns: named tuple with 5 1d tensors
    '''
    mini_batch = random.sample(replay_buffer, mini_batch_size)
    mini_batch = list(zip(*mini_batch)) # Transpose

    # Convert list of ndarrays to ndarray because Creating a tensor from a list of numpy.ndarrays is extremely slow. 
    states = np.array(mini_batch[0])
    next_states = np.array(mini_batch[3])
    
    states = torch.tensor(states, dtype=torch.float32)
    actions = torch.tensor(mini_batch[1], dtype=torch.int64)
    rewards = torch.tensor(mini_batch[2], dtype=torch.float32)
    next_states = torch.tensor(next_states, dtype=torch.float32)
    dones = torch.tensor(mini_batch[4], dtype = torch.bool)
    
    return Mini_batch(states, actions, rewards, next_states, dones)

In [10]:
# Use mini_batch to get loss
def compute_loss(mini_batch):
    # Compute targets 
    v_star_of_next_states = torch.max(target_net(mini_batch.next_states), dim=1)[0]
    v_star_of_next_states = v_star_of_next_states * (~mini_batch.dones) 
    y = mini_batch.rewards + gamma * v_star_of_next_states 

    # Compute main_net's predictions
    predictions = main_net(mini_batch.states).gather(1, mini_batch.actions.unsqueeze(1)).squeeze(1)

    # Loss 
    loss = nn.functional.mse_loss(predictions, y)

    return loss

In [11]:
# Plot episode returns throughout training
episode_returns = []
smoothed_returns = []
smooth_alpha = 0.01  # Lower = smoother
current_episode_rewards = []

# Create a separate window for the plot at the start
plt.ion()
fig, ax = plt.subplots()
fig.canvas.manager.set_window_title('DQN Training Progress')
returns_line, = ax.plot([], [], label='Episode Return', alpha=0.5)
smooth_line, = ax.plot([], [], label='Smoothed Return', color='orange')
ax.set_xlabel('Episode')
ax.set_ylabel('Return')
ax.legend()
fig.show()

def log_reward_for_plotting(reward):
    global current_episode_rewards
    current_episode_rewards.append(reward)

def update_plot():
    global episode_returns, smoothed_returns, current_episode_rewards
    episode_return = sum(current_episode_rewards)
    episode_returns.append(episode_return)
    # Exponential moving average for smoothing
    if smoothed_returns:
        new_smooth = smooth_alpha * episode_return + (1 - smooth_alpha) * smoothed_returns[-1]
    else:
        new_smooth = episode_return
    smoothed_returns.append(new_smooth)
    current_episode_rewards = []

    returns_line.set_data(range(len(episode_returns)), episode_returns)
    smooth_line.set_data(range(len(smoothed_returns)), smoothed_returns)
    ax.relim()
    ax.autoscale_view()
    fig.canvas.draw()
    fig.canvas.flush_events()

qt.glx: qglx_findConfig: Failed to finding matching FBConfig for QSurfaceFormat(version 2.0, options QFlags<QSurfaceFormat::FormatOption>(), depthBufferSize -1, redBufferSize 1, greenBufferSize 1, blueBufferSize 1, alphaBufferSize -1, stencilBufferSize -1, samples -1, swapBehavior QSurfaceFormat::SingleBuffer, swapInterval 1, colorSpace QSurfaceFormat::DefaultColorSpace, profile  QSurfaceFormat::NoProfile)
No XVisualInfo for format QSurfaceFormat(version 2.0, options QFlags<QSurfaceFormat::FormatOption>(), depthBufferSize -1, redBufferSize 1, greenBufferSize 1, blueBufferSize 1, alphaBufferSize -1, stencilBufferSize -1, samples -1, swapBehavior QSurfaceFormat::SingleBuffer, swapInterval 1, colorSpace QSurfaceFormat::DefaultColorSpace, profile  QSurfaceFormat::NoProfile)
Falling back to using screens root_visual.


In [12]:
# Training loop
observation, _ = env.reset()

t = 0
while True:
    t += 1
    
    action = get_action(observation)
    
    new_observation, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated 

    timestep = Timestep(observation, action, reward, new_observation, done)
    record_timestep(timestep)

    observation = new_observation

    log_reward_for_plotting(reward)


    if done:
        observation, _ = env.reset()
        decrease_epsilon()
        update_plot()

    timesteps_passed = t + 1
    
    # Do weights update if its time to
    if timesteps_passed % steps_until_value_iteration == 0:
        mini_batch = sample_a_mini_batch()
        loss = compute_loss(mini_batch)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

        torch.save(main_net, main_net_path)

    # Update target net if its time to
    if timesteps_passed % steps_until_target_net_update == 0:
        update_target_net()

qt.glx: qglx_findConfig: Failed to finding matching FBConfig for QSurfaceFormat(version 2.0, options QFlags<QSurfaceFormat::FormatOption>(), depthBufferSize -1, redBufferSize 1, greenBufferSize 1, blueBufferSize 1, alphaBufferSize -1, stencilBufferSize -1, samples -1, swapBehavior QSurfaceFormat::SingleBuffer, swapInterval 1, colorSpace QSurfaceFormat::DefaultColorSpace, profile  QSurfaceFormat::NoProfile)
No XVisualInfo for format QSurfaceFormat(version 2.0, options QFlags<QSurfaceFormat::FormatOption>(), depthBufferSize -1, redBufferSize 1, greenBufferSize 1, blueBufferSize 1, alphaBufferSize -1, stencilBufferSize -1, samples -1, swapBehavior QSurfaceFormat::SingleBuffer, swapInterval 1, colorSpace QSurfaceFormat::DefaultColorSpace, profile  QSurfaceFormat::NoProfile)
Falling back to using screens root_visual.
qt.glx: qglx_findConfig: Failed to finding matching FBConfig for QSurfaceFormat(version 2.0, options QFlags<QSurfaceFormat::FormatOption>(), depthBufferSize -1, redBufferSize 1

KeyboardInterrupt: 