In [1]:
# For tips on running notebooks in Google Colab, see
# https://pytorch.org/tutorials/beginner/colab
%matplotlib inline

In [2]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = gym.make("CartPole-v1")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [4]:
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [5]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)

steps_done = 0


def select_action(state):
    global steps_done  # Keeps track of the number of steps (actions selected)
    sample = random.random()  # Generates a random sample for epsilon-greedy strategy

    # Calculate the epsilon threshold for the current step using exponential decay
    # Starts from EPS_START and decays towards EPS_END at a rate determined by EPS_DECAY
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
                    math.exp(-1. * steps_done / EPS_DECAY)

    steps_done += 1  # Increment the steps_done counter

    # Decision making for choosing an action based on epsilon-greedy strategy
    if sample > eps_threshold:
        # Exploitation: Choose the best action based on current policy
        with torch.no_grad():  # Disable gradient calculation for inference
            # The policy network predicts the Q-values for all actions given the current state
            # .max(1) finds the action with the highest Q-value
            # .indices.view(1, 1) formats the chosen action for compatibility with environment
            return policy_net(state).max(1).indices.view(1, 1)
    else:
        # Exploration: Choose a random action
        # This allows the agent to explore the action space and discover new strategies
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)


episode_durations = []


def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())

    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

In [6]:
def optimize_model():
    # Check if enough samples are available in memory to create a batch
    if len(memory) < BATCH_SIZE:
        return  # Not enough samples, skip this round of optimization

    # Sample a batch of transitions from memory
    transitions = memory.sample(BATCH_SIZE)

    # This clever trick transposes the batch of transitions to a Transition of batch-arrays.
    # It effectively organizes the data for easy batch processing.
    batch = Transition(*zip(*transitions))

    # Create a mask for non-final states (i.e., states that are not the end of an episode)
    # and a tensor for holding non-final next states
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                            batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

    # Concatenate all states, actions, and rewards into separate tensors
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Compute Q(s_t, a) for each action taken in the batch
    # This step involves forward passing the state_batch through the policy_net
    # and using gather to select the Q-values for the actions actually taken
    state_action_values = policy_net(state_batch).gather(1, action_batch)

    # Initialize a tensor for the next state values with zeros for all batch samples
    # This will be updated with the predicted Q values for non-final states
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    # Compute V(s_{t+1}) for all non-final next states using the target network
    # The max predicted Q value for the next states are selected with max(1).values
    # This operation is wrapped in torch.no_grad() to prevent gradient computation
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1).values

    # Compute the expected Q values for the current state-action pairs
    # This is done by adding the (discounted) best future rewards to the immediate rewards
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch

    # Compute the loss between the current Q values and the expected Q values
    # The Huber loss is used here, which is less sensitive to outliers than squared error loss
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))

    # Zero all gradients for the variables which the optimizer will update
    optimizer.zero_grad()
    # Calculate the gradients of the loss with respect to all parameters
    # in the policy network involved in its computation
    loss.backward()
    # Clip gradients to prevent very large values which can destabilize training
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    # Perform a single optimization step (parameter update)
    optimizer.step()


In [10]:
if torch.cuda.is_available():
    num_episodes = 600
else:
    num_episodes = 50

for i_episode in range(num_episodes):
    # Initialize the environment and get its state
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for t in count():
        action = select_action(state)

        print(f"Action:{action}")

        observation, reward, terminated, truncated, _ = env.step(action.item())

        print(f"Observation: {observation}\n"
              f"Reward: {reward}\n"
              f"Terminated: {terminated}\n"
              f"Truncated: {truncated}\n")

        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1 - TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

Action:tensor([[0]], device='cuda:0')
Observation: [-0.03249109 -0.19444314  0.01266395  0.3309469 ]
Reward: 1.0
Terminated: False
Truncated: False

Action:tensor([[1]], device='cuda:0')
Observation: [-0.03637995  0.00049628  0.01928289  0.04228433]
Reward: 1.0
Terminated: False
Truncated: False

Action:tensor([[0]], device='cuda:0')
Observation: [-0.03637003 -0.19489679  0.02012858  0.34098828]
Reward: 1.0
Terminated: False
Truncated: False

Action:tensor([[1]], device='cuda:0')
Observation: [-4.0267963e-02 -6.6930283e-05  2.6948344e-02  5.4720085e-02]
Reward: 1.0
Terminated: False
Truncated: False

Action:tensor([[0]], device='cuda:0')
Observation: [-0.0402693  -0.1955647   0.02804274  0.35578212]
Reward: 1.0
Terminated: False
Truncated: False

Action:tensor([[1]], device='cuda:0')
Observation: [-0.04418059 -0.00085246  0.03515839  0.07207207]
Reward: 1.0
Terminated: False
Truncated: False

Action:tensor([[0]], device='cuda:0')
Observation: [-0.04419765 -0.19646035  0.03659983  0.375


KeyboardInterrupt

