In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

env = gym.make("CartPole-v1")

# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [3]:
class QuantileEmbedding(nn.Module):
    def __init__(self, embedding_dim=64):
        super(QuantileEmbedding, self).__init__()
        self.embedding_dim = embedding_dim
        self.fc = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, taus):
        N = taus.size(0)
        taus = taus.unsqueeze(-1)  # Add an extra dimension
        # Create a constant tensor of dimension (embedding_dim, )
        i_tensor = torch.arange(self.embedding_dim, device=taus.device).float().unsqueeze(0)
        # Broadcast multiply
        cos_trans = torch.cos(taus * math.pi * i_tensor)
        # Pass through the fully connected layer
        quantile_embedding = self.fc(cos_trans)
        return quantile_embedding

class IQN(nn.Module):
    def __init__(self, n_observations, n_actions, embedding_dim=64):
        super(IQN, self).__init__()
        self.quantile_embedding = QuantileEmbedding(embedding_dim)
        self.layer1 = nn.Linear(n_observations + embedding_dim, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, state, taus):
        quantile_embedding = self.quantile_embedding(taus)
        # Repeat each element in the state tensor to match the batch size of quantile_embedding
        repeated_state = state.repeat_interleave(quantile_embedding.size(0) // state.size(0), dim=0)
        combined = torch.cat((repeated_state, quantile_embedding), dim=-1)
        x = F.relu(self.layer1(combined))
        x = F.relu(self.layer2(x))
        return self.layer3(x)



In [4]:
# BATCH_SIZE is the number of transitions sampled from the replay buffer
# GAMMA is the discount factor as mentioned in the previous section
# EPS_START is the starting value of epsilon
# EPS_END is the final value of epsilon
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
# TAU is the update rate of the target network
# LR is the learning rate of the ``AdamW`` optimizer
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)

policy_net = IQN(n_observations, n_actions).to(device)
target_net = IQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)


steps_done = 0


def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            num_quantiles = 100
            taus = torch.linspace(0, 1, num_quantiles, device=device).unsqueeze(0)
            action_quantiles = policy_net(state, taus)
            mean_action_values = action_quantiles.mean(1)
            return mean_action_values.max(1).indices.view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)





episode_durations = []


def plot_durations(show_result=False):
    plt.figure(1)
    durations_t = torch.tensor(episode_durations, dtype=torch.float)
    if show_result:
        plt.title('Result')
    else:
        plt.clf()
        plt.title('Training...')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations_t.numpy())
    # Take 100 episode averages and plot them too
    if len(durations_t) >= 100:
        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            num_quantiles = 100
            taus = torch.linspace(0, 1, num_quantiles, device=device).unsqueeze(0)
            action_quantiles = policy_net(state, taus)
            mean_action_values = action_quantiles.mean(1)
            return mean_action_values.max(1).indices.view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)


    plt.pause(0.001)  # pause a bit so that plots are updated
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

SyntaxError: invalid syntax (2364314659.py, line 70)

In [None]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return

    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))

    # Convert batch components to tensors
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)

    # Sample a set of quantile fractions (τ) for the current batch
    taus = torch.rand(BATCH_SIZE, device=device)

    # Compute the quantile values for the current state and action pairs
    state_action_quantiles = policy_net(state_batch, taus).gather(1, action_batch.unsqueeze(-1).expand(BATCH_SIZE, taus.size(1)))

    # Compute the next state values
    next_state_quantiles = torch.zeros(BATCH_SIZE, taus.size(1), device=device)
    if non_final_mask.sum() > 0:
        with torch.no_grad():
            next_state_values = target_net(non_final_next_states, taus).max(2)[0]
            next_state_quantiles[non_final_mask] = next_state_values

    # Compute the target quantile values
    expected_quantiles = reward_batch.unsqueeze(1) + (GAMMA * next_state_quantiles)

    # Compute the Quantile Regression Loss
    td_error = expected_quantiles.unsqueeze(2) - state_action_quantiles.unsqueeze(1)
    abs_td_error = td_error.abs()
    huber_loss = torch.where(abs_td_error <= 1, 0.5 * td_error.pow(2), abs_td_error - 0.5)
    quantile_loss = torch.abs(taus.unsqueeze(1) - (td_error.detach() < 0).float()) * huber_loss
    loss = quantile_loss.mean()

    # Optimize the model
    optimizer.zero_grad()
    loss.backward()
    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()


In [None]:
if torch.cuda.is_available():
    num_episodes = 600
else:
    num_episodes = 50

for i_episode in range(num_episodes):
    # Initialize the environment and get it's state
    state, info = env.reset()
    state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
    for t in count():
        action = select_action(state)
        observation, reward, terminated, truncated, _ = env.step(action.item())
        reward = torch.tensor([reward], device=device)
        done = terminated or truncated

        if terminated:
            next_state = None
        else:
            next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

        # Store the transition in memory
        memory.push(state, action, next_state, reward)

        # Move to the next state
        state = next_state

        # Perform one step of the optimization (on the policy network)
        optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)

        if done:
            episode_durations.append(t + 1)
            plot_durations()
            break

print('Complete')
plot_durations(show_result=True)
plt.ioff()
plt.show()

RuntimeError: The expanded size of the tensor (64) must match the existing size (4) at non-singleton dimension 1.  Target sizes: [100, 64].  Tensor sizes: [1, 4]

In [None]:
# run the model