In [None]:
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np

# Standard Library
from pathlib import Path
from collections import deque, namedtuple
from typing import List
import random
import datetime
import os
import copy
import itertools

# Gym is an OpenAI toolkit for RL
import gym
from gym.spaces import Box, Discrete
from gym.wrappers import FrameStack

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
print(f'PyTorch using {device} device')

In [None]:
class AutoFire(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        self._lives = 5
        self.action_space = Discrete(3)
        self._fire = False

    def step(self, action):
        if self._fire:
            obs, reward, done, info = self.env.step(1)
            self._fire = False
        elif action > 0:
            obs, reward, done, info = self.env.step(action + 1)
        else:
            obs, reward, done, info = self.env.step(action + 1)

        if info['ale.lives'] < self._lives:
            self._lives = info['ale.lives']
            self._fire = True
        return obs, reward, done, info

In [None]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        """Return only every `skip`-th frame"""
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        """Repeat action, and sum reward"""
        total_reward = 0.0
        done = False
        for i in range(self._skip):
            # Accumulate reward and repeat the same action
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info

In [None]:
class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def permute_orientation(self, observation):
        # permute [H, W, C] array to [C, H, W] tensor
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        return observation

    def observation(self, observation):
        observation = self.permute_orientation(observation)
        transform = T.Grayscale()
        observation = transform(observation)
        return observation

In [None]:
class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        if isinstance(shape, int):
            self.shape = (shape, shape)
        else:
            self.shape = tuple(shape)

        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        transforms = T.Compose(
            [T.Resize(self.shape), T.Normalize(0, 255)]
        )
        observation = transforms(observation).squeeze(0)
        return observation

In [None]:
class Epsilon:

    def __init__(self, start, end, decay):
        self._start = start
        self._end = end
        self._decay = decay
        self._current = start

    def reset(self):
        self._current = self._start
    
    def __next__(self):
        self._current = max(self._current * self._decay, self._end)
        return self._current
    
    @property
    def value(self):
        return self._current

    @property
    def start(self):
        return self._start

    @property
    def decay(self):
        return self._decay

    @property
    def end(self):
        return self._end
    
    def __str__(self):
        return f'Epsilon(start={self.start}, end={self.end}, decay={self.decay}, current={self.value})'
    
    def __repr__(self):
        return self.__str__()

In [None]:
Transition = namedtuple('Transition', ['state', 'action', 'next_state', 'reward', 'done'])


class ReplayMemory:

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, state, action, next_state, reward, done):
        self.memory.append(Transition(state, action, next_state, reward, done))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
def conv2d_size_out(size, kernel_size, stride):
    return (size - (kernel_size - 1) - 1) // stride  + 1

class DQN(nn.Module):

    def __init__(self, state_shape, action_count):
        super(DQN, self).__init__()
        self.state_shape = state_shape
        self.action_count = action_count
        self.net = DQN.construct_net(state_shape[0], action_count)

    @classmethod
    def construct_net(cls, channels, actions):
        layers = [
            nn.Conv2d(in_channels=channels, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, actions),
        ]
        return nn.Sequential(*layers)
    
    def forward(self, X):
        return self.net(X)
    
    def act(self, X):
        return torch.argmax(self.net(X), axis=1)


In [None]:
import numpy as np
import time, datetime
import matplotlib.pyplot as plt


class MetricLogger:
    def __init__(self, save_dir):
        dt = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M')
        save_dir = Path.joinpath(Path.cwd(), save_dir, dt)
        if not save_dir.exists():
            save_dir.mkdir(parents=True)

        self.save_log = save_dir / "log.csv"
        with open(self.save_log, "w") as f:
            f.write('Episode,Step,Mean Reward,Mean Length, Mean Q Value, Time Delta, Time\n')
        self.ep_rewards_plot = save_dir / "reward_plot.jpg"
        self.ep_lengths_plot = save_dir / "length_plot.jpg"
        self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
        self.ep_avg_qs_plot = save_dir / "q_plot.jpg"

        # History metrics
        self.ep_rewards = []
        self.ep_lengths = []
        self.ep_avg_losses = []
        self.ep_avg_qs = []

        # Moving averages, added for every call to record()
        self.moving_avg_ep_rewards = []
        self.moving_avg_ep_lengths = []
        self.moving_avg_ep_avg_losses = []
        self.moving_avg_ep_avg_qs = []

        # Current episode metric
        self.init_episode()

        # Timing
        self.record_time = time.time()

    def log_step(self, reward, loss, q):
        self.curr_ep_reward += reward
        self.curr_ep_length += 1
        if loss:
            self.curr_ep_loss += loss
            self.curr_ep_q += q
            self.curr_ep_loss_length += 1

    def log_episode(self):
        "Mark end of episode"
        self.ep_rewards.append(self.curr_ep_reward)
        self.ep_lengths.append(self.curr_ep_length)
        if self.curr_ep_loss_length == 0:
            ep_avg_loss = 0
            ep_avg_q = 0
        else:
            ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
            ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
        self.ep_avg_losses.append(ep_avg_loss)
        self.ep_avg_qs.append(ep_avg_q)

        self.init_episode()

    def init_episode(self):
        self.curr_ep_reward = 0.0
        self.curr_ep_length = 0
        self.curr_ep_loss = 0.0
        self.curr_ep_q = 0.0
        self.curr_ep_loss_length = 0

    def record(self, episode, epsilon, step):
        mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
        mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
        mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
        mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
        self.moving_avg_ep_rewards.append(mean_ep_reward)
        self.moving_avg_ep_lengths.append(mean_ep_length)
        self.moving_avg_ep_avg_losses.append(mean_ep_loss)
        self.moving_avg_ep_avg_qs.append(mean_ep_q)

        last_record_time = self.record_time
        self.record_time = time.time()
        time_since_last_record = np.round(self.record_time - last_record_time, 3)

        print(
            f"Episode {episode} - "
            f"Step {step} - "
            f"Epsilon {epsilon} - "
            f"Mean Reward {mean_ep_reward} - "
            f"Mean Length {mean_ep_length} - "
            f"Mean Loss {mean_ep_loss} - "
            f"Mean Q Value {mean_ep_q} - "
            f"Time Delta {time_since_last_record} - "
            f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
        )

        with open(self.save_log, "a") as f:
            f.write(f'{episode},'
                f'{step},'
                f'{epsilon:0.4f},'
                f'{mean_ep_reward:0.3f},'
                f'{mean_ep_length:0.3f},'
                f'{mean_ep_loss:0.3f},'
                f'{mean_ep_q:0.3f},'
                f'{time_since_last_record:0.3f},'
                f'{datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")}\n'
            )

        for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
            plt.plot(getattr(self, f"moving_avg_{metric}"))
            plt.savefig(getattr(self, f"{metric}_plot"))
            plt.clf()

In [None]:
def make_env():
    env = gym.make('Breakout-v0')
    env = SkipFrame(env, skip=4)
    env = GrayScaleObservation(env)
    env = ResizeObservation(env, shape=84)
    env = FrameStack(env, num_stack=4)
    env = AutoFire(env)
    return env

In [None]:
SAVE_EVERY = 1e4
BURNIN = 1e4
LEARN_EVERY = 3
SYNC_EVERY = 100
SAVE_DIR = 'models'

hyper = {
    'REPLAY_MEMORY_SIZE': 12000,
    'BATCH_SIZE': 32,
    'GAMMA': 0.9,
    'LEARNING_RATE': 1e-3, 
    'MOMENTUM': 0.95,
    'EPSILON_START': 1,
    'EPSILON_END': 0.2,
    'EPSILON_DECAY': 0.999995,
}

In [None]:
class Agent:

    def __init__(self, state_shape, action_count, device, *, path=None, epsilon=True, **hyper):
        self.device = device
        self.action_count = action_count
        self.state_shape = state_shape

        # Construct target, and online net
        if path is None:
            self._target = DQN(state_shape, action_count).to(device)
        else:
            self._target = torch.load(path).to(device)
        
        if epsilon:
            self.epsilon = Epsilon(start=hyper['EPSILON_START'],
                end=hyper['EPSILON_END'],
                decay=hyper['EPSILON_DECAY'])
        else:
            self.epsilon = Epsilon(hyper['EPSILON_END'], hyper['EPSILON_END'], 1)
        
        self._online = copy.deepcopy(self._target).to(device)
        for p in self._online.parameters():
            p.requires_grad = True
        for p in self._target.parameters():
            p.requires_grad = False

        self.step = 0
        self.memory = ReplayMemory(hyper['REPLAY_MEMORY_SIZE'])
        self.batch_size = hyper['BATCH_SIZE']
        self.gamma = hyper['GAMMA']
        self.optimizer = torch.optim.Adam(self._online.parameters(), lr=hyper['LEARNING_RATE'])
        self.loss_fn = torch.nn.SmoothL1Loss()

    def save(self):
        filename = f"dqn{int(self.step // SAVE_EVERY)}.chkpt"
        save_path = Path.joinpath(Path.cwd(), SAVE_DIR, filename)
        torch.save(self._target, save_path)
        print(f'DQN saved to {save_path} at step {self.step}')
    
    def reset_epsilon(self):
        self.epsilon.reset()

    def learn(self):
        if self.step % SYNC_EVERY == 0:
            self.sync_Q_target()
        if self.step % SAVE_EVERY == 0:
            self.save()
        if self.step < BURNIN:
            return None, None
        if self.step % LEARN_EVERY != 0:
            return None, None

        state, action, next_state, reward, done = self.recall()
        td_est = self.td_estimate(state, action)
        td_tgt = self.td_target(reward, next_state, done)
        loss = self.update_Q_online(td_est, td_tgt)
        return td_est.mean().item(), loss
    
    def update_Q_online(self, td_estimate, td_target):
        loss = self.loss_fn(td_estimate, td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def sync_Q_target(self):
        self._target.load_state_dict(self._online.state_dict())
    
    def td_estimate(self, state, action):
        action = action.view(-1)
        current_Q = self.online(state)
        return current_Q[np.arange(0, self.batch_size), action]

    @torch.no_grad()
    def td_target(self, reward, next_state, done):
        next_state_Q = self.online(next_state)
        best_action = torch.argmax(next_state_Q, axis=1)
        next_Q = self.target(next_state)[
            np.arange(0, self.batch_size), best_action
        ]
        return (reward + (1 - done.float()) * self.gamma * next_Q).float()

    def remember(self, state, action, next_state, reward, done):
        state = torch.tensor(state.__array__()).to(self.device)
        action = torch.tensor([action]).to(self.device)
        next_state = torch.tensor(next_state.__array__()).to(self.device)
        reward = torch.tensor([reward]).to(self.device)
        done = torch.tensor([done]).to(self.device)

        self.memory.push(state, action, next_state, reward, done)
    
    def recall(self):
        batch = self.memory.sample(self.batch_size)
        state, next_state, action, reward, done = map(torch.stack, zip(*batch))
        return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()

    def target(self, X):
        return self._target(X)

    def online(self, X):
        return self._online(X)
    
    def act(self, state):
        '''
        Given a state, choose an epsilon-greedy action and update value of step.
        '''
        self.step += 1
        if np.random.rand() < next(self.epsilon):
            return self.explore()
        else:
            return self.exploit(state)

    def explore(self):
        return np.random.randint(self.action_count)

    def exploit(self, state):
        state = torch.tensor(state.__array__()).to(self.device).view(-1, *self.state_shape)
        return self._target.act(state).item()

    @property
    def exploration_rate(self):
        return self.epsilon.value

In [None]:
env = make_env()
ENV_SHAPE = env.observation_space.shape
ACTION_COUNT = env.action_space.n

In [None]:
agent = Agent(ENV_SHAPE, ACTION_COUNT, device, path='models/dqn1.chkpt', **hyper)

In [None]:
save_dir = Path(SAVE_DIR)
episodes = 4000

for i in itertools.count():
    agent.reset_epsilon()
    logger = MetricLogger(save_dir)
    print(f'Running Epoch {i+1}')
    for e in range(episodes):
        state = env.reset()
        while True:
            action = agent.act(state)
            next_state, reward, done, info = env.step(action)
            agent.remember(state, action, next_state, reward, done)
            q, loss = agent.learn()
            logger.log_step(reward, loss, q)
            state = next_state
            if done:
                break
        logger.log_episode()
        if e % 20 == 0:
            logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.step)