In [11]:
import cv2
import gym
import gym.spaces

import argparse
import time
import numpy as np
import collections

import torch
import torch.nn as nn
import torch.optim as optim

from tensorboardX import SummaryWriter

# Wrappers

Obtained from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py

In [15]:
class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        """For environments where the user need to press FIRE for the game to start."""
        super(FireResetEnv, self).__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def step(self, action):
        return self.env.step(action)

    def reset(self):
        self.env.reset()
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset()
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset()
        return obs


class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every `skip`-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self):
        """Clear past frame buffer and init. to first obs. from inner env."""
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)


class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]),
                                                dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
                                                old_space.high.repeat(n_steps, axis=0), dtype=dtype)

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer

In [13]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size = self._get_conf_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
        
    def _get_conf_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))
    
    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)

In [14]:
DEFAULT_ENV_NAME = "PongNoFrameskip-v4"
MEAN_REWARD_BOUND = 19.5

GAMMA = 0.99
BATCH_SIZE = 32
REPLAY_SIZE = 10000
LEARNING_RATE = 1e-4
SYNC_TARGET_FRAMES = 1000
REPLAY_START_SIZE = 10000

EPSILON_DECAY_LAST_FRAME = 10**5
EPSILON_START = 1.0
EPSILON_FINAL = 0.02


Experience = collections.namedtuple('Experience', field_names=['state', 'action', 'reward', 'done', 'new_state'])


class ExperienceBuffer:
    def __init__(self, capacity):
        self.buffer = collections.deque(maxlen=capacity)

    def __len__(self):
        return len(self.buffer)

    def append(self, experience):
        self.buffer.append(experience)

    def sample(self, batch_size):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])
        return np.array(states), np.array(actions), np.array(rewards, dtype=np.float32), \
               np.array(dones, dtype=np.uint8), np.array(next_states)


class Agent:
    def __init__(self, env, exp_buffer):
        self.env = env
        self.exp_buffer = exp_buffer
        self._reset()

    def _reset(self):
        self.state = env.reset()
        self.total_reward = 0.0

    def play_step(self, net, epsilon=0.0, device="cpu"):
        done_reward = None

        if np.random.random() < epsilon:
            action = env.action_space.sample()
        else:
            state_a = np.array([self.state], copy=False)
            state_v = torch.tensor(state_a).to(device)
            q_vals_v = net(state_v)
            _, act_v = torch.max(q_vals_v, dim=1)
            action = int(act_v.item())

        # do step in the environment
        new_state, reward, is_done, _ = self.env.step(action)
        self.total_reward += reward

        exp = Experience(self.state, action, reward, is_done, new_state)
        self.exp_buffer.append(exp)
        self.state = new_state
        if is_done:
            done_reward = self.total_reward
            self._reset()
        return done_reward


def calc_loss(batch, net, tgt_net, device="cpu"):
    states, actions, rewards, dones, next_states = batch

    states_v = torch.tensor(states).to(device)
    next_states_v = torch.tensor(next_states).to(device)
    actions_v = torch.tensor(actions).to(device)
    rewards_v = torch.tensor(rewards).to(device)
    done_mask = torch.ByteTensor(dones).to(device)

    state_action_values = net(states_v).gather(1, actions_v.unsqueeze(-1)).squeeze(-1)
    next_state_values = tgt_net(next_states_v).max(1)[0]
    next_state_values[done_mask] = 0.0
    next_state_values = next_state_values.detach()

    expected_state_action_values = next_state_values * GAMMA + rewards_v
    return nn.MSELoss()(state_action_values, expected_state_action_values)

In [17]:
def make_env(env_name):
    env = gym.make(env_name)
    env = MaxAndSkipEnv(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    return ScaledFloatFrame(env)

In [21]:
device = torch.device("cuda")

env = make_env(DEFAULT_ENV_NAME)

net = DQN(env.observation_space.shape, env.action_space.n).to(device)
tgt_net = DQN(env.observation_space.shape, env.action_space.n).to(device)
writer = SummaryWriter(comment="-" + DEFAULT_ENV_NAME)
print(net)

buffer = ExperienceBuffer(REPLAY_SIZE)
agent = Agent(env, buffer)
epsilon = EPSILON_START

optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
total_rewards = []
frame_idx = 0
ts_frame = 0
ts = time.time()
best_mean_reward = None

while True:
    frame_idx += 1
    epsilon = max(EPSILON_FINAL, EPSILON_START - frame_idx / EPSILON_DECAY_LAST_FRAME)

    reward = agent.play_step(net, epsilon, device=device)
    if reward is not None:
        total_rewards.append(reward)
        speed = (frame_idx - ts_frame) / (time.time() - ts)
        ts_frame = frame_idx
        ts = time.time()
        mean_reward = np.mean(total_rewards[-100:])
        print("%d: done %d games, mean reward %.3f, eps %.2f, speed %.2f f/s" % (
            frame_idx, len(total_rewards), mean_reward, epsilon,
            speed
        ))
        writer.add_scalar("epsilon", epsilon, frame_idx)
        writer.add_scalar("speed", speed, frame_idx)
        writer.add_scalar("reward_100", mean_reward, frame_idx)
        writer.add_scalar("reward", reward, frame_idx)
        if best_mean_reward is None or best_mean_reward < mean_reward:
            torch.save(net.state_dict(), DEFAULT_ENV_NAME + "-best.dat")
            if best_mean_reward is not None:
                print("Best mean reward updated %.3f -> %.3f, model saved" % (best_mean_reward, mean_reward))
            best_mean_reward = mean_reward
        if mean_reward > MEAN_REWARD_BOUND:
            print("Solved in %d frames!" % frame_idx)
            break

    if len(buffer) < REPLAY_START_SIZE:
        continue

    if frame_idx % SYNC_TARGET_FRAMES == 0:
        tgt_net.load_state_dict(net.state_dict())

    optimizer.zero_grad()
    batch = buffer.sample(BATCH_SIZE)
    loss_t = calc_loss(batch, net, tgt_net, device=device)
    loss_t.backward()
    optimizer.step()
writer.close()

DQN(
  (conv): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=3136, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=6, bias=True)
  )
)
851: done 1 games, mean reward -21.000, eps 0.99, speed 642.39 f/s
1806: done 2 games, mean reward -20.000, eps 0.98, speed 617.86 f/s
Best mean reward updated -21.000 -> -20.000, model saved
2744: done 3 games, mean reward -20.333, eps 0.97, speed 623.05 f/s
3654: done 4 games, mean reward -20.500, eps 0.96, speed 628.08 f/s
4534: done 5 games, mean reward -20.600, eps 0.95, speed 624.29 f/s
5352: done 6 games, mean reward -20.667, eps 0.95, speed 623.61 f/s
6114: done 7 games, mean reward -20.714, eps 0.94, speed 612.55 f/s
6982: done 8 games, mean reward -20.625, eps 0.9

112887: done 89 games, mean reward -18.798, eps 0.02, speed 80.30 f/s
Best mean reward updated -18.841 -> -18.798, model saved
115682: done 90 games, mean reward -18.678, eps 0.02, speed 80.76 f/s
Best mean reward updated -18.798 -> -18.678, model saved
118978: done 91 games, mean reward -18.527, eps 0.02, speed 80.86 f/s
Best mean reward updated -18.678 -> -18.527, model saved
121778: done 92 games, mean reward -18.413, eps 0.02, speed 80.95 f/s
Best mean reward updated -18.527 -> -18.413, model saved
124455: done 93 games, mean reward -18.323, eps 0.02, speed 80.76 f/s
Best mean reward updated -18.413 -> -18.323, model saved
127152: done 94 games, mean reward -18.255, eps 0.02, speed 80.50 f/s
Best mean reward updated -18.323 -> -18.255, model saved
130017: done 95 games, mean reward -18.158, eps 0.02, speed 80.74 f/s
Best mean reward updated -18.255 -> -18.158, model saved
133040: done 96 games, mean reward -18.042, eps 0.02, speed 80.81 f/s
Best mean reward updated -18.158 -> -18.0

309289: done 154 games, mean reward -3.300, eps 0.02, speed 80.86 f/s
Best mean reward updated -3.660 -> -3.300, model saved
311853: done 155 games, mean reward -2.970, eps 0.02, speed 80.18 f/s
Best mean reward updated -3.300 -> -2.970, model saved
314361: done 156 games, mean reward -2.630, eps 0.02, speed 80.41 f/s
Best mean reward updated -2.970 -> -2.630, model saved
316426: done 157 games, mean reward -2.270, eps 0.02, speed 80.59 f/s
Best mean reward updated -2.630 -> -2.270, model saved
318684: done 158 games, mean reward -1.910, eps 0.02, speed 80.78 f/s
Best mean reward updated -2.270 -> -1.910, model saved
321626: done 159 games, mean reward -1.630, eps 0.02, speed 80.87 f/s
Best mean reward updated -1.910 -> -1.630, model saved
324080: done 160 games, mean reward -1.280, eps 0.02, speed 80.64 f/s
Best mean reward updated -1.630 -> -1.280, model saved
326316: done 161 games, mean reward -0.930, eps 0.02, speed 81.10 f/s
Best mean reward updated -1.280 -> -0.930, model saved


450834: done 221 games, mean reward 15.250, eps 0.02, speed 79.17 f/s
Best mean reward updated 15.160 -> 15.250, model saved
452849: done 222 games, mean reward 15.350, eps 0.02, speed 79.19 f/s
Best mean reward updated 15.250 -> 15.350, model saved
454624: done 223 games, mean reward 15.430, eps 0.02, speed 79.58 f/s
Best mean reward updated 15.350 -> 15.430, model saved
456518: done 224 games, mean reward 15.400, eps 0.02, speed 81.13 f/s
458257: done 225 games, mean reward 15.610, eps 0.02, speed 81.32 f/s
Best mean reward updated 15.430 -> 15.610, model saved
460426: done 226 games, mean reward 15.600, eps 0.02, speed 81.60 f/s
462826: done 227 games, mean reward 15.720, eps 0.02, speed 80.65 f/s
Best mean reward updated 15.610 -> 15.720, model saved
464522: done 228 games, mean reward 15.920, eps 0.02, speed 80.18 f/s
Best mean reward updated 15.720 -> 15.920, model saved
466364: done 229 games, mean reward 16.010, eps 0.02, speed 79.96 f/s
Best mean reward updated 15.920 -> 16.01

600198: done 300 games, mean reward 17.970, eps 0.02, speed 79.14 f/s
602159: done 301 games, mean reward 17.960, eps 0.02, speed 79.13 f/s
603963: done 302 games, mean reward 18.000, eps 0.02, speed 79.70 f/s
Best mean reward updated 17.980 -> 18.000, model saved
605792: done 303 games, mean reward 18.000, eps 0.02, speed 79.38 f/s
607584: done 304 games, mean reward 18.010, eps 0.02, speed 80.70 f/s
Best mean reward updated 18.000 -> 18.010, model saved
609364: done 305 games, mean reward 18.000, eps 0.02, speed 79.93 f/s
611308: done 306 games, mean reward 18.010, eps 0.02, speed 79.48 f/s
612982: done 307 games, mean reward 18.070, eps 0.02, speed 79.19 f/s
Best mean reward updated 18.010 -> 18.070, model saved
614917: done 308 games, mean reward 18.040, eps 0.02, speed 79.45 f/s
616561: done 309 games, mean reward 18.080, eps 0.02, speed 80.04 f/s
Best mean reward updated 18.070 -> 18.080, model saved
618323: done 310 games, mean reward 18.340, eps 0.02, speed 79.34 f/s
Best mean 

757325: done 388 games, mean reward 19.160, eps 0.02, speed 80.27 f/s
759312: done 389 games, mean reward 19.130, eps 0.02, speed 77.45 f/s
761236: done 390 games, mean reward 19.140, eps 0.02, speed 81.31 f/s
762941: done 391 games, mean reward 19.160, eps 0.02, speed 80.28 f/s
764630: done 392 games, mean reward 19.160, eps 0.02, speed 78.58 f/s
766944: done 393 games, mean reward 19.110, eps 0.02, speed 79.17 f/s
768689: done 394 games, mean reward 19.140, eps 0.02, speed 79.44 f/s
770636: done 395 games, mean reward 19.140, eps 0.02, speed 81.80 f/s
772433: done 396 games, mean reward 19.150, eps 0.02, speed 81.23 f/s
774205: done 397 games, mean reward 19.150, eps 0.02, speed 79.54 f/s
775893: done 398 games, mean reward 19.190, eps 0.02, speed 81.12 f/s
Best mean reward updated 19.160 -> 19.190, model saved
777804: done 399 games, mean reward 19.170, eps 0.02, speed 81.09 f/s
779439: done 400 games, mean reward 19.200, eps 0.02, speed 78.77 f/s
Best mean reward updated 19.190 -> 

939572: done 491 games, mean reward 19.480, eps 0.02, speed 79.40 f/s
941330: done 492 games, mean reward 19.480, eps 0.02, speed 79.15 f/s
943176: done 493 games, mean reward 19.530, eps 0.02, speed 78.88 f/s
Best mean reward updated 19.480 -> 19.530, model saved
Solved in 943176 frames!
