In [1]:
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
import numpy as np
import random
import pandas as pd
import copy
import gym
from tqdm import tqdm
import collections
import cv2
import gym.spaces

In [2]:
class FireResetEnv(gym.Wrapper):
    def __init__(self, env=None):
        super(FireResetEnv, self).__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3
    
    def step(self, action):
        return self.env.step(action)

    def reset(self):
        self.env.reset()
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset()
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset()
        return obs

class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every 'skip'-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def _reset(self):
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs

class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3:
            img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
        elif frame.size == 250 * 160 * 3:
            img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)

class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
        old_space.high.repeat(n_steps, axis=0), dtype=dtype)

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer

class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)

class ScaledFloatFrame(gym.ObservationWrapper):
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0
    
def make_env(env_name):
    env = gym.make(env_name)
    env = MaxAndSkipEnv(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    return ScaledFloatFrame(env)

env = make_env('PongNoFrameskip-v4')

In [3]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DQN, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
        
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x)
        conv_out = conv_out.view(x.size()[0], -1)
        return self.fc(conv_out)

In [8]:
def main(nEpisode=100, GAMMA=0.99, EPSILON_0=1, EPSILON_FINAL=0.02, DECAYING_RATE=10**(-5), storeQ=1000,
         MAX_ITER=200000, BATCH_SIZE = 32, REPLAY_SIZE = 10000, REPLAY_START_SIZE=10000, LEARNING_RATE=1e-4, gpu=False):
    
    device = torch.device("cuda" if gpu else "cpu")
    Q = DQN(env.observation_space.shape, env.action_space.n).to(device)
    QHat = DQN(env.observation_space.shape, env.action_space.n).to(device)
    epsilon = EPSILON_0
    buffer = collections.deque(maxlen=REPLAY_SIZE)
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(Q.parameters(), lr=LEARNING_RATE)
    total_rewards = []
    frame_id = 0
    
    # visualize with tensorboardX
    writer = SummaryWriter(comment="-" + "Pong")
    
    # best mean reward for the last 100 episodes
    best_mean_reward = None
    
    # main loop
    for step in range(nEpisode):
        obs = env.reset()
        total_reward = 0
        for _ in range(MAX_ITER):
            frame_id += 1
            epsilon = max(EPSILON_FINAL, EPSILON_0 - frame_id * DECAYING_RATE)
            if np.random.random() < epsilon:
                action = np.random.randint(env.action_space.n)
            else:
                obs1 = np.array([obs], copy=False)
                obs1 = torch.tensor(obs1).to(device)
                qVals = Q(obs1)
                _, actionV = torch.max(qVals, dim=1)
                action = int(actionV.item())
            obsNext, reward, done , _ = env.step(action)
            total_reward += reward
            buffer.append(collections.deque([obs, action, reward, done, obsNext]))
            obs = obsNext
            
            if len(buffer) >= REPLAY_START_SIZE:
                indices = np.random.choice(len(buffer), BATCH_SIZE, replace=False)
                print(indices)
                observations, actions, rewards, dones, observationsNext = zip(*[buffer[idx] for idx in indices])
                
                observations, actions, rewards, dones, observationsNext = np.array(observations), np.array(actions), np.array(rewards, dtype=np.float32), np.array(dones, dtype=np.uint8), np.array(observationsNext) 
                observationsV = torch.FloatTensor(observations).to(device)
                observationsNextV = torch.FloatTensor(observationsNext).to(device)
                actionsV = torch.tensor(actions).to(device)
                rewardsV = torch.tensor(rewards).to(device)
                doneMask = torch.ByteTensor(dones).to(device)                
                

                stateActionValues = Q(observationsV).gather(1, actionsV.unsqueeze(-1)).squeeze(-1)
                nextStateValues = QHat(observationsNextV).max(1)[0]
                nextStateValues[doneMask] = 0.0
                nextStateValues = nextStateValues.detach()

                expectedStateActionValues = nextStateValues * GAMMA + rewardsV
                optimizer.zero_grad()
                loss = loss_fn(stateActionValues, expectedStateActionValues)
                loss.backward()
                optimizer.step()
            
            if frame_id % storeQ == 0:
                QHat = copy.deepcopy(Q)
            
            if done:
                break
        
        # report progress
        if total_reward is not None:
            total_rewards.append(total_reward)
            mean_reward = np.mean(total_rewards[-100:])
            writer.add_scalar("epsilon", epsilon, frame_id)
            writer.add_scalar("reward_100", mean_reward, frame_id)
            writer.add_scalar("reward", total_reward, frame_id)
            
        # save model and update best_mean_reward 
        if best_mean_reward is None or best_mean_reward < mean_reward:
            torch.save({
                'game': step,
                'model_state_dict': Q.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss
                }, "DQN_saved_models\\Pong_best.tar")
            best_mean_reward = mean_reward
        print(step, mean_reward, frame_id)
    return Q, total_rewards

In [9]:
Q, totalRewards = main(500, REPLAY_START_SIZE=100, gpu=False)

[92 29  9 84 45 56 82 76 12 90  5 54 38 18 59 95 89  4 27 87 53  6 98 78
 43 97 20 88 69 23 48 21]
[34 32 51 71 42 93  3 77 29 94 39 22 38 27 28 98 58  7 96 35 90 41 31 30
 44 86 26 40  9 91 16 10]
[75 50 44 31 40  6 41 81 22  7 64 92 38  3 53 20 74 23 25 88 16 29 49 47
 30 21 58 97 71 63 76 43]
[ 0 15 96  2 22 99 65 45 81 46  8 25 32  1 12  3 69 61 53 42 28 19 44 33
 20 97 50 90 76 37 31 26]
[  1  73  34  61   2  35  74  72  40 103  99  94  89  82   8  90  45  96
  27  91   5   0  66  76  43  68  75   9  52  28   4  39]
[  8  70  21  16  14  72  39 101   9  29  25  93  65  27  85  73  30  35
  76 103  98  22  33  79  24  58  91  19  78  51  68  49]
[ 57  40  85  83  60  47  59  20  67  30  24  66  18 105 103  44  12  74
  46  69  99  86  96  10  28   6  68  22  64 104  25  14]
[40 12 28 68 77 96 89 53 30 44 14 72 56 46 27 29 34 24 11 95 50 15 49 85
  5 47 23 66 42 16 83 33]
[ 74  82  35 102 107  89   6  36  37  71  22  80  65  54  49  16  97  94
  96   0  44  68  62 100  57  29  48  8

KeyboardInterrupt: 