In [1]:
import torch
from torch import nn
import gymnasium as gym
from collections import namedtuple, deque
from itertools import count
import random
import math
import os
from tqdm.notebook import tqdm
import cv2
import numpy as np
# from torch.utils.tensorboard import SummaryWriter

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.01
EPS_DECAY = 10000
TAU = 0.005
LR = 1e-4
NUM_EPISODES = 10000
SAVE_FREQ = 100
MAX_EPISODE_STEPS = 10000

game_name = 'MsPacman-v5'

steps_done = 0

In [3]:
class NoopResetEnv(gym.Wrapper):
    def __init__(self, env, noop_max=30):
        """start the game with no-op actions to provide random starting positions
        No-op is assumed to be action 0.
        """
        gym.Wrapper.__init__(self, env)
        self.noop_max = noop_max
        self.override_num_noops = None
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

    def reset(self, **kwargs):
        """ Do no-op action for a number of steps in [1, noop_max]."""
        self.env.reset(**kwargs)
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            noops = np.random.randint(1, self.noop_max + 1) #pylint: disable=E1101
        assert noops > 0
        obs = None
        for _ in range(noops):
            obs, _, terminated, truncated, _ = self.env.step(self.noop_action)
            if terminated or truncated:
                obs = self.env.reset(**kwargs)
        return obs, {}

    def step(self, action):
        return self.env.step(action)



class WarpFrame(gym.ObservationWrapper):
    def __init__(self, env, width=84, height=84):
        """
        Warp frames to 84x84 as done in the Nature paper and later work.
        """
        super().__init__(env)
        self._width = width
        self._height = height

        self.observation_space = gym.spaces.Box(
            low=0,
            high=255,
            shape=(1, self._height, self._width),
            dtype=np.uint8,
        )

    def observation(self, obs):

        obs = cv2.resize(
            obs, (self._width, self._height), interpolation=cv2.INTER_AREA
        )

        obs = np.expand_dims(obs, 0)

        return obs


def create_env(env_name=game_name, noop_start=True, render=False):

    env = gym.make(f'ALE/{env_name}', obs_type='grayscale', frameskip=4, repeat_action_probability=0, full_action_space=False, render_mode='human' if render else 'rgb_array')

    env = WarpFrame(env)
    if noop_start:
        env = NoopResetEnv(env)

    return env

In [4]:
env = create_env()
n_observations = env.observation_space.shape[0]
n_actions = env.action_space.n
env.close()

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


In [5]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)
        
    def push(self, *args):
        self.memory.append(Transition(*args))
    
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def __len__(self):
        return len(self.memory)

In [6]:
class DQN(nn.Module):
    def __init__(self, n_observations, n_actions):
        super().__init__()

        self.feature = nn.Sequential(
            nn.Conv2d(1, 32, 8, 4),
            nn.ReLU(True),
            nn.Conv2d(32, 64, 4, 2),
            nn.ReLU(True),
            nn.Conv2d(64, 64, 3, 1),
            nn.ReLU(True),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(True),
        )

        self.net = nn.Sequential(
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, n_actions)
        )
    
    def forward(self, x):
        features = self.feature(x)
        return self.net(features)

In [7]:
policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = torch.optim.Adam(policy_net.parameters(), lr=LR)
memory = ReplayMemory(100000)
# writer = SummaryWriter(log_dir='logs/dqn')

In [8]:
def select_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            return policy_net(state).max(1).indices.view(1, 1)
    else:
        return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)

In [9]:
def optimize_model():
    if len(memory) < BATCH_SIZE:
        return
    transitions = memory.sample(BATCH_SIZE)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    next_state_values = torch.zeros(BATCH_SIZE, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
        
    expected_state_action_values = (next_state_values * GAMMA) + reward_batch
    
    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

    return loss.item()

In [10]:
assert False

AssertionError: 

In [11]:
try:
    for i_episode in tqdm(range(NUM_EPISODES)):
        env = create_env()
        state, info = env.reset()
        state = torch.tensor(state, device=device, dtype=torch.float).unsqueeze(0)
        
        sum_reward = 0

        for t in count():
            action = select_action(state)
            observation, reward, terminated, truncated, _ = env.step(action.item())
            sum_reward += reward
            reward = torch.tensor(reward, device=device).unsqueeze(0)
            done = terminated or truncated
            
            if terminated:
                next_state = None
            else:
                next_state = torch.tensor(observation, device=device, dtype=torch.float).unsqueeze(0)
            
            memory.push(state, action, next_state, reward)            
            state = next_state
            
            loss = optimize_model()
            
            target_net_state_dict = target_net.state_dict()
            policy_net_state_dict = policy_net.state_dict()
            for key in target_net_state_dict:
                target_net_state_dict[key] = TAU * policy_net_state_dict[key] + (1 - TAU) * target_net_state_dict[key]
            target_net.load_state_dict(target_net_state_dict)
            
            if i_episode % SAVE_FREQ == 0:
                os.makedirs(f'models/dqn/{game_name}', exist_ok=True)
                torch.save(policy_net.state_dict(), f'models/dqn/{game_name}/{i_episode}.pt')

            if done or t >= MAX_EPISODE_STEPS:
                os.makedirs(f'logs/dqn/{game_name}', exist_ok=True)
                with open(f'logs/dqn/{game_name}/episode_return.txt', 'a') as f:
                    f.write(f'{sum_reward} {i_episode}\n')

                if loss:
                    with open(f'logs/dqn/{game_name}/training_loss.txt', 'a') as f:
                        f.write(f'{loss} {i_episode}\n')
                break
        env.close()
except KeyboardInterrupt:
    pass

  0%|          | 0/10000 [00:00<?, ?it/s]

### 4 min training

In [12]:
env = create_env(render=True)
obs, _ = env.reset()
state = torch.tensor(obs, device=device, dtype=torch.float).unsqueeze(0)
done = False
total_reward = 0
for t in count():
    with torch.no_grad():
        action = select_action(state)
    obs, reward, terminated, truncated, _, = env.step(action.item())
    state = torch.tensor(obs, device=device, dtype=torch.float).unsqueeze(0)
    done = terminated or truncated
    if done:
        break
    total_reward += reward
env.close()
print(total_reward)

700.0


: 

In [None]:
# torch.save({
#     'policy_net_state_dict': policy_net.state_dict(),
#     'target_net_state_dict': target_net.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict()
# }, 'dqn_checkpoint.tar')