In [1]:
import torch
from torch import nn
import gymnasium as gym
from collections import namedtuple, deque
from itertools import count
import random
import math
import os
from tqdm.notebook import tqdm
# from torch.utils.tensorboard import SummaryWriter

In [2]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
        
    def push(self, *args):
        self.memory.append(Transition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [3]:
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_observations, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )
        
        # self.recurrent = nn.LSTM(128+n_actions+1, 128, 1)
    
    def forward(self, x):
        return self.net(x)

In [4]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4
NUM_EPISODES = 600
SAVE_FREQ = 50

game_name = 'CartPole-v1'
env = gym.make(game_name)
n_observations = env.observation_space.shape[0]
n_actions = env.action_space.n

steps_done = 0

In [5]:
policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = torch.optim.Adam(policy_net.parameters(), lr=LR)
memory = ReplayMemory(10000)
# writer = SummaryWriter(log_dir='logs/dqn')

In [6]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1).indices.view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)

In [7]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
        
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    return loss.item()

In [8]:
try:
    for i_episode in tqdm(range(NUM_EPISODES)):
        state, info = env.reset()
        state = torch.tensor(state, device=device, dtype=torch.float).unsqueeze(0)
        
        sum_reward = 0

        for t in count():
            action = select_action(state)
            observation, reward, terminated, truncated, _ = env.step(action.item())
            sum_reward += reward
            reward = torch.tensor(reward, device=device).unsqueeze(0)
            done = terminated or truncated
            
            if terminated:
                next_state = None
            else:
                next_state = torch.tensor(observation, device=device, dtype=torch.float).unsqueeze(0)
            
            memory.push(state, action, next_state, reward)
            
            state = next_state
            
            loss = optimize_model()
            
            target_net_state_dict = target_net.state_dict()
            policy_net_state_dict = policy_net.state_dict()
            for key in target_net_state_dict:
                target_net_state_dict[key] = TAU * policy_net_state_dict[key] + (1 - TAU) * target_net_state_dict[key]
            target_net.load_state_dict(target_net_state_dict)
            
            if i_episode % SAVE_FREQ == 0:
                os.makedirs(f'models/dqn/{game_name}', exist_ok=True)
                torch.save(policy_net.state_dict(), f'models/dqn/{game_name}/{i_episode}.pt')

            if done:
                os.makedirs(f'logs/dqn/{game_name}', exist_ok=True)
                with open(f'logs/dqn/{game_name}/episode_return.txt', 'a') as f:
                    f.write(f'{sum_reward} {i_episode}\n')

                if loss:
                    with open(f'logs/dqn/{game_name}/training_loss.txt', 'a') as f:
                        f.write(f'{loss} {i_episode}\n')
                break
except KeyboardInterrupt:
    pass

  0%|          | 0/600 [00:00<?, ?it/s]

### 4 min training

In [12]:
env = gym.make(game_name, render_mode='human')
obs, _ = env.reset()
state = torch.tensor(obs, device=device, dtype=torch.float).unsqueeze(0)
done = False
total_reward = 0
while not done:
    with torch.no_grad():
        action = select_action(state)
    obs, reward, terminated, truncated, _, = env.step(action.item())
    state = torch.tensor(obs, device=device, dtype=torch.float).unsqueeze(0)
    done = terminated or truncated
    total_reward += reward
env.close()
print(total_reward)

286.0


In [9]:
# torch.save({
#     'policy_net_state_dict': policy_net.state_dict(),
#     'target_net_state_dict': target_net.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict()
# }, 'dqn_checkpoint.tar')