In [14]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from IPython.display import Video
import imageio
import os

In [None]:
ENV_NAME = "MountainCar-v0"
PATH = f'Models/DQN/{ENV_NAME}/'
MODEL_FILE = os.path.join(PATH, ENV_NAME + ".pth")
EVAL_VIDEO_FILE = os.path.join(PATH, f'{ENV_NAME}_eval_video.mp4')
PLOT_TRAINING = False
os.makedirs(PATH, exist_ok=True)

In [16]:
env = gym.make(ENV_NAME, render_mode='rgb_array')
# Get number of actions + observations from gym action space
n_actions = env.action_space.n
state, info = env.reset()
n_observations = len(state)

In [17]:
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

<contextlib.ExitStack at 0x178ab5040>

In [18]:
# if GPU is to be used
device = torch.device(
    "cuda" if torch.cuda.is_available() else
    "mps" if torch.backends.mps.is_available() else
    "cpu"
)
device

device(type='mps')

In [19]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


# Create a Queue that will store each S,A,S,R 
class ReplayMemory(object):

    # make a queue 
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    # Puts S,A,S,R on the queue (FIFO)
    def push(self, *args):
        self.memory.append(Transition(*args))

    # Select random entries
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    # get the length of our memory
    def __len__(self):
        return len(self.memory)

In [20]:
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__() 

        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, x):
        out = F.relu(self.layer1(x))
        out = F.relu(self.layer2(out))
        return self.layer3(out) 


# Setup Training Varaibles + Algo

In [21]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4


policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000) # this must always be greater than the batch size

In [22]:
# Epsilon-greedy policy: balances exploration and exploitation
steps_done = 0

def eps_greedy(state):
    global steps_done
    sample = random.random()

    # Calculate epsilon: decreases over time to explore less
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-steps_done / EPS_DECAY)
    steps_done += 1

    if sample > eps_threshold:
        # Exploitation: choose the best action
        with torch.no_grad():
            return policy_net(state).argmax(dim=1).view(1, 1)
    else:
        # Exploration: choose a random action
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)


In [23]:
# List to store the duration of each episode
episode_durations = []

def plot_durations(show_result=False):
    # Create or select the figure for plotting
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    
    # Set the title based on whether it's the final result or ongoing training
    if show_result:
        plt.title('Result')
    else:
        plt.clf()  # Clear the current figure for live updates
        plt.title('Training...')
    
    # Label the axes and plot the episode durations
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    
    # Plot the 100-episode running average for smoother trends
    if len(durations_t) >= 100:
        # Calculate the running average using a sliding window of 100 episodes
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        # Pad the beginning with zeros to align with the episode timeline
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())
    
    # Briefly pause to allow the plot to update in real-time
    plt.pause(0.001)
    
    # Handle interactive display for Jupyter Notebooks (if applicable)
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)  # Clear the output for live updates
        else:
            display.display(plt.gcf())  # Display the final result


In [24]:
def optimize_model():
    # Exit if there aren't enough samples
    if len(memory) < BATCH_SIZE:
        return

    # Sample a batch of transitions and unpack into separate batches
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    # Prepare batches for states, actions, rewards, and non-final next states
    non_final_mask = torch.tensor([s is not None for s in batch.next_state], device=device, dtype=torch.bool) # tells us if end state or not
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Calculate Q-values for current states and actions
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Calculate expected Q-values for next states
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values

    # this is the predicted "target" or estimated ground truth goodness of the state action pair we are looking at 
    expected_state_action_values = reward_batch + (GAMMA * next_state_values)

    # Compute loss and update model (we want the policy network to basically predict the estimated goodness, so we just take the error between the current output and the estimated ground truth)
    loss = nn.SmoothL1Loss()(state_action_values, expected_state_action_values.unsqueeze(1))
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()


# Train Model

In [None]:
if torch.cuda.is_available() or torch.backends.mps.is_available():
    num_episodes = 600
else:
    num_episodes = 50

for i_episode in range(num_episodes):
    # Initialize the environment and get its state
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    total_reward = 0 

    if i_episode % 100 == 0:
        frames = []

    for t in count():
        action = eps_greedy(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        total_reward += reward
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if i_episode % 100 == 0:
            frames.append(env.render())

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            if  PLOT_TRAINING:
                plot_durations()
            else:
                print(f"Episode : {i_episode} | Duration: {t + 1} | Total Reward : {total_reward}")
            break
    
    if len(frames) > 0:
        video_path = os.path.join(PATH, f'{ENV_NAME}_train_video_ep_{i_episode}.mp4')
        imageio.mimsave(video_path, frames, fps=20)
        frames = []

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

# Save Model

In [None]:
torch.save(policy_net.state_dict(), MODEL_FILE)
policy_net.eval()  # Set the model to evaluation mode

DQN(
  (layer1): Linear(in_features=4, out_features=128, bias=True)
  (layer2): Linear(in_features=128, out_features=128, bias=True)
  (layer3): Linear(in_features=128, out_features=2, bias=True)
)

# Load Model

In [None]:
# Assuming `model` is your model class
model = DQN(n_observations, n_actions)  # Reinstantiate the model
model.load_state_dict(torch.load(MODEL_FILE))
model.to(device)
model.eval()  # Set the model to evaluation mode

  model.load_state_dict(torch.load(MODEL_FILE))


DQN(
  (layer1): Linear(in_features=4, out_features=128, bias=True)
  (layer2): Linear(in_features=128, out_features=128, bias=True)
  (layer3): Linear(in_features=128, out_features=2, bias=True)
)

# Test The Model

In [None]:
env = gym.make(ENV_NAME, render_mode='rgb_array')
env.reset()
frames = []

# Initialize the environment and get its state
state, info = env.reset()

for i_episode in range(1000):

    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    action = model(state).argmax(dim=1).view(1, 1)
    observation, reward, terminated, truncated, _ = env.step(action.item())

    if terminated or truncated:
        print(f"Episode ended after {i_episode+1} steps")
        break

    frames.append(env.render())
    state = observation

env.close()

imageio.mimsave(EVAL_VIDEO_FILE, frames, fps=20)
Video(EVAL_VIDEO_FILE, embed=True) # Display the video in Jupyter Notebook

Episode ended after 439 steps


