In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from collections import deque, namedtuple
import matplotlib.pyplot as plt
import os

In [None]:
# Parameters
env_name = 'LunarLander-v3'
mini_batch_size = 128
steps_until_value_iteration = 500
steps_until_target_net_update = 2000
gamma = 0.99
epsilon = 0.3
lr=1e-2
buffer_size_limit = 10000
seed = 42

In [None]:
# Environment and seeds
env = gym.make(env_name)

env.reset(seed=seed)
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
# Main q-network and target q-network
observation_size = env.observation_space.shape[0]
action_size = env.action_space.n

def make_mlp():
    return nn.Sequential(
        nn.Linear(observation_size, 64),
        nn.ReLU(),
        nn.Linear(64, 64),
        nn.ReLU(),
        nn.Linear(64, action_size)
    )

main_net = make_mlp()
target_net = make_mlp()

def update_target_net():
    target_net.load_state_dict(main_net.state_dict())

update_target_net() # They are same from the start

optimiser = optim.Adam(main_net.parameters(), lr=lr)

In [None]:
# Implement expsilon-greedy policy
def get_action(observation):
    '''
    observation: numpy array returned by env.step()
    returns: integer action
    '''
    possible_actions = [i for i in range(env.action_space.n)]
    if np.random.random() < epsilon:
        return np.random.choice(possible_actions)

    observation = torch.as_tensor(observation, dtype=torch.float32)
    with torch.no_grad():
        q_star_per_each_action = main_net(observation)
        action = torch.argmax(q_star_per_each_action).item()
    return action

In [None]:
# Replay buffer
replay_buffer = deque([], maxlen=buffer_size_limit)

Timestep   = namedtuple('timestep',   ["state", "action", "reward", "next_state", "done"]) 
'''
Data types of Timestep:
state/next_state is numpy array; 
action is int; 
reward is float; 
done is bool;
'''

Mini_batch = namedtuple('mini_batch', ["states", "actions", "rewards", "next_states", "dones"])


def record_timestep(timestep):
    replay_buffer.append(timestep)


def sample_a_mini_batch():
    '''
    returns: named tuple with 5 1d tensors
    '''
    mini_batch = random.sample(replay_buffer, mini_batch_size)
    mini_batch = list(zip(*mini_batch)) # Transpose

    # Convert list of ndarrays to ndarray because Creating a tensor from a list of numpy.ndarrays is extremely slow. 
    states = np.array(mini_batch[0])
    next_states = np.array(mini_batch[3])
    
    states = torch.tensor(states, dtype=torch.float32)
    actions = torch.tensor(mini_batch[1], dtype=torch.int64)
    rewards = torch.tensor(mini_batch[2], dtype=torch.float32)
    next_states = torch.tensor(next_states, dtype=torch.float32)
    dones = torch.tensor(mini_batch[4], dtype = torch.bool)
    
    return Mini_batch(states, actions, rewards, next_states, dones)

In [None]:
# Use mini_batch to get loss
def compute_loss(mini_batch):
    # Compute targets 
    v_star_of_next_states = torch.max(target_net(mini_batch.next_states), dim=1)[0]
    v_star_of_next_states = v_star_of_next_states * (~mini_batch.dones) 
    y = mini_batch.rewards + gamma * v_star_of_next_states 

    # Compute main_net's predictions
    predictions = main_net(mini_batch.states).gather(1, mini_batch.actions.unsqueeze(1)).squeeze(1)

    # Loss 
    loss = nn.functional.mse_loss(predictions, y)

    return loss

In [None]:
# Training loop
observation, _ = env.reset()

t = 0
while True:
    t += 1
    
    action = get_action(observation)
    
    new_observation, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated 

    timestep = Timestep(observation, action, reward, new_observation, done)
    record_timestep(timestep)

    observation = new_observation
    episode_return += reward
    if done:
        observation, _ = env.reset()

    timesteps_passed = t + 1
    
    # Do weights update if its time to
    if timesteps_passed % steps_until_value_iteration == 0:
        mini_batch = sample_a_mini_batch()
        loss = compute_loss(mini_batch)

        optimiser.zero_grad()
        loss.backward()
        optimiser.step()

    # Update target net if its time to
    if timesteps_passed % steps_until_target_net_update == 0:
        update_target_net()