In [1]:
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
env = gym.make("CartPole-v1")

# if GPU is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [3]:
class DQN(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(DQN, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    # Called with either one element to determine next action, or a batch
    # during optimization. Returns tensor([[left0exp,right0exp]...]).
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)

In [4]:
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Get number of actions from gym action space
n_actions = env.action_space.n
# Get the number of state observations
state, info = env.reset()
n_observations = len(state)

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)


steps_done = 0

def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END)*math.exp(-1.*steps_done/EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1)[1].view(1, 1)
    else:
        return torch.tensor(env.action_space.sample()).view(1, 1)

In [6]:
def optimize_model():
    if len(memory) < BATCH_SIZE: return
    target_net.eval()
    batch = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*batch))
    
    non_final_masks = torch.tensor([s is not None for s in batch.next_state])
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    
    state_batch = torch.cat(batch.state)
    reward_batch = torch.cat(batch.reward)
    action_batch = torch.cat(batch.action)
    
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    next_state_values = torch.zeros(BATCH_SIZE, device = device)
    with torch.no_grad():
        next_state_values[non_final_masks] = target_net(non_final_next_states).max(1)[0]
    
    expected_state_action_values = reward_batch + (GAMMA*next_state_values)
    
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, next_state_values[..., None])
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 80)
    optimizer.step()

In [None]:
for episode in range(num_episodes):
    observation, _ = env.reset()
    observation = torch.from_numpy(observation)[None, ...]
    rewards = []
    for t in count():
        action = select_action(observation)
        next_observation, reward, terminated, truncated, _ = env.step(action.item())
        next_observation = torch.from_numpy(next_observation)[None, ...]
        rewards.append(reward)
        reward = torch.as_tensor(reward)[None, ...]
        done = terminated or truncated
        
        if terminated:
            next_observation = None
    
        memory.push(observation, action, next_observation, reward)
        observation = next_observation   
        
        optimize_model()
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = (1-TAU)*target_net_state_dict[key] + TAU*policy_net_state_dict[key]
        target_net.load_state_dict(target_net_state_dict)
        if done:
            break
    print(f'Episode: {episode}   Performance: {sum(rewards)}')

Episode: 0   Performance: 8.0
Episode: 0   Performance: 8.0
Episode: 0   Performance: 9.0
Episode: 0   Performance: 9.0
Episode: 0   Performance: 11.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 8.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 8.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 9.0
Episode: 0   Performance: 8.0
Episode: 0   Performance: 9.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 11.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 11.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 11.0
Episode: 0   Performance: 9.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 14.0
Episode: 0   Performance: 10.0
Episode: 0   Performance: 10.0
Episode: 0   Perfo

Episode: 0   Performance: 135.0
Episode: 0   Performance: 489.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 107.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 79.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 297.0
Episode: 0   Performance: 500.0
Episode: 0   Performance: 305.0
Episode: 

In [None]:
batch = memory.sample(BATCH_SIZE)
batch = Transition(*zip(*batch))
zzzzzz
reward_batch = torch.cat(batch.reward)

non_final_masks = torch.tensor([s is not None for s in batch.next_state])
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])

In [None]:
next_state_max = torch.zeros(BATCH_SIZE)
with torch.no_grad():
    next_state_max[non_final_masks] = target_net(non_final_next_states).max(1)[0]
expected_state_action_values = reward_batch + (GAMMA*next_state_max)