In [1]:
import torch
import torch.nn as nn
import gym
import numpy as np
from random import random
import argparse
from tqdm import tqdm
from collections import OrderedDict
# import visdom
import time
import cv2

## 经验回放类

In [2]:
class ReplayBuffer:
    def __init__(self, pool_size=1000000, frame_history_len=4):
        # the total size of the Buffer
        self.pool_size = pool_size
        # the number of memories of each observation
        self.frame_history_len = frame_history_len

        # stored memories (the list of dict)
        self.memories = None
        self.obs_shape = None

        self.number_of_memories = 0
        self.next_idx = 0

    def _check_idx(self, cur_idx):
        """
        if memory pool cannot meet "frame_history_len" frames, then padding 0.

        situation 1: cur_idx < frame_history_len and memory pool is not full      --> padding 0
        situation 2: cur_idx < frame_history_len and memory pool is full          --> no padding
        situation 3: appear "stop" flag (check from end to start)                 --> padding 0
        situation 4: other                                                        --> no padding

        :return: idx_flag, missing_context, start_idx, end_idx
        """
        end_idx = cur_idx + 1  # exclusive
        start_idx = end_idx - self.frame_history_len  # inclusive
        is_sit_3 = False

        # situation 1 or 2 or 3
        if start_idx < 0:
            start_idx = 0
            missing_context = self.frame_history_len - (end_idx - start_idx)

            # situation 1
            if self.number_of_memories != self.pool_size:
                # not check end frame
                for idx in range(start_idx, end_idx-1):
                    # 0, 1|, 0, 0|, ...
                    if self.memories[idx % self.pool_size]['done']:
                        start_idx = idx + 1
                        is_sit_3 = True

                    if is_sit_3:
                        missing_context = self.frame_history_len - (end_idx - start_idx)
                        return 3, missing_context, start_idx, end_idx

                return 1, missing_context, start_idx, end_idx

            # situation 2
            else:
                for idx in range(start_idx, end_idx-1):
                    if self.memories[idx % self.pool_size]['done']:
                        start_idx = idx + 1
                        is_sit_3 = True

                    if is_sit_3:
                        missing_context = self.frame_history_len - (end_idx - start_idx)
                        return 3, missing_context, start_idx, end_idx

                # not check end frame
                for i in range(missing_context, 0, -1):
                    idx = self.pool_size - i
                    if self.memories[idx % self.pool_size]['done']:
                        start_idx = (idx + 1) % self.pool_size
                        is_sit_3 = True

                    if is_sit_3:
                        # ..., end_idx|, ..., |start_idx, ., end
                        if start_idx > end_idx:
                            missing_context = self.frame_history_len - (self.pool_size - start_idx + end_idx)
                        else:
                            missing_context = self.frame_history_len - (end_idx - start_idx)
                        return 3, missing_context, start_idx, end_idx

                start_idx = self.pool_size - missing_context
                return 2, 0, start_idx, end_idx

        # situation 3: appear "stop" flag
        for idx in range(start_idx, end_idx-1):
            if self.memories[idx % self.pool_size]['done']:
                start_idx = idx + 1
                is_sit_3 = True

            if is_sit_3:
                missing_context = self.frame_history_len - (end_idx - start_idx)
                return 3, missing_context, start_idx, end_idx

        return 4, 0, start_idx, end_idx

    def _encoder_observation(self, cur_idx):
        """
        concatenate recent "frame_history_len" frames
        obs: (c, h, w) => (frame_history_len*c, h, w)
        :param cur_idx: current frame's index
        :return: tensor
        """

        encoded_observation = []

        idx_flag, missing_context, start_idx, end_idx = self._check_idx(cur_idx)

        if missing_context > 0:
            for i in range(missing_context):
                encoded_observation.append(np.zeros_like(self.memories[0]['obs']))

        # situation 3 in situation 2
        if start_idx > end_idx:
            for idx in range(start_idx, self.pool_size):
                encoded_observation.append(self.memories[idx % self.pool_size]['obs'])
            for idx in range(end_idx):
                encoded_observation.append(self.memories[idx % self.pool_size]['obs'])
        else:
            for idx in range(start_idx, end_idx):
                encoded_observation.append(self.memories[idx % self.pool_size]['obs'])

        # encoded_observation: [k, c, h, w] => [k*c, h, w]
        encoded_observation = np.concatenate(encoded_observation, 0)
        return encoded_observation

    def encoder_recent_observation(self):
        """
        concatenate recent "frame_history_len" frames
        :return:
        """
        assert self.number_of_memories > 0

        current_idx = self.next_idx - 1
        # when next_idx == 0
        if current_idx < 0:
            current_idx = self.pool_size - 1

        return self._encoder_observation(current_idx)

    def sample_memories(self, batch_size):
        """
        choose randomly "batch_size" memories (batch_size, )
        :param batch_size:
        :return:
        """
        # ensure s_{i+1} is exist
        sample_idxs = np.random.randint(0, self.number_of_memories-1, [batch_size])

        # [batch_size, frame_history_len*c, h, w]
        obs_batch = np.zeros(
            [batch_size, self.obs_shape[0] * self.frame_history_len, self.obs_shape[1], self.obs_shape[2]])
        next_obs_batch = np.copy(obs_batch)
        action_batch = np.zeros([batch_size, 1])  # [batch_size, ]
        reward_batch = np.zeros([batch_size, 1])  # [batch_size, ]
        done_batch = []

        for i in range(batch_size):
            obs_batch[i] = self._encoder_observation(sample_idxs[i])
            next_obs_batch[i] = self._encoder_observation(sample_idxs[i] + 1)
            action_batch[i] = self.memories[sample_idxs[i]]['action']
            reward_batch[i] = self.memories[sample_idxs[i]]['reward']
            done_batch.append(self.memories[sample_idxs[i]]['done'])

        return obs_batch, next_obs_batch, action_batch, reward_batch, done_batch

    def store_memory_obs(self, frame):
        """
        store observation of memory
        :param frame: numpy array
                      Array of shape (img_h, img_w, img_c) and dtype np.uint8
        :return:
        """
        # obs is a image (h, w, c)
        frame = frame.transpose(2, 0, 1)  # c, w, h

        if self.obs_shape is None:
            self.obs_shape = frame.shape

        if self.memories is None:
            self.memories = [dict() for i in range(self.pool_size)]

        self.memories[self.next_idx]['obs'] = frame
        index = self.next_idx

        self.next_idx = (self.next_idx + 1) % self.pool_size
        self.number_of_memories = min([self.number_of_memories + 1, self.pool_size])

        return index

    def store_memory_effect(self, index, action, reward, done):
        """
        store other information of memory
        :param action: scalar
        :param done: bool
        :param reward: scalar
        :return:
        """
        self.memories[index]['action'] = action
        self.memories[index]['reward'] = reward
        self.memories[index]['done'] = done

## 神经网络

In [3]:
class DQNet(nn.Module):
    def __init__(self, input_channels, out_channels):
        super(DQNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=input_channels, out_channels=32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1)

        self.fc1 = nn.Linear(in_features=7*7*64, out_features=512)
        self.fc2 = nn.Linear(in_features=512, out_features=out_channels)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
def change_to_tensor(data_np, dtype=torch.float32):
    """
    change numpy array to torch.tensor
    :param dtype:
    :param data_np:
    :return:
    """
    data_tensor = torch.from_numpy(data_np).type(dtype)
    if torch.cuda.is_available():
        data_tensor = data_tensor.cuda()
    return data_tensor

## Wrapper

In [5]:
class PreprocessFrame(gym.ObservationWrapper):
    """
    preprocess the observation of env
    """

    def __init__(self, env):
        super(PreprocessFrame, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1))

    def observation(self, observation):
        """
        raw data: [210, 160, 3]
        processed data: [84, 84, 1]
        :param observation:
        :return:
        """
        img = np.reshape(observation, [210, 160, 3]).astype(np.float32)
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_LINEAR)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)


# one life one episode
class EpisodicLifeEnv(gym.Wrapper):
    def __init__(self, env):
        """Make end-of-life == end-of-episode, but only reset on true game over.
        Done by DeepMind for the DQN and co. since it helps value estimation.
        """
        gym.Wrapper.__init__(self, env)
        self.lives = 0
        self.was_real_done = True

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.was_real_done = done
        # check current lives, make loss of life terminal,
        # then update lives to handle bonus lives
        lives = self.env.unwrapped.ale.lives()
        if self.lives > lives > 0:
            # for Qbert sometimes we stay in lives == 0 condition for a few frames
            # so it's important to keep lives > 0, so that we only reset once
            # the environment advertises done.
            done = True
        self.lives = lives
        return obs, reward, done, info

    def reset(self, **kwargs):
        """Reset only when lives are exhausted.
        This way all states are still reachable even though lives are episodic,
        and the learner need not know about any of this behind-the-scenes.
        """
        if self.was_real_done:
            obs = self.env.reset(**kwargs)
        else:
            # no-op step to advance from terminal/lost life state
            obs, _, _, _ = self.env.step(0)
        self.lives = self.env.unwrapped.ale.lives()
        return obs


class NoopResetEnv(gym.Wrapper):
    def __init__(self, env, noop_max=30):
        """Sample initial states by taking random number of no-ops on reset.
        No-op is assumed to be action 0.
        """
        gym.Wrapper.__init__(self, env)
        self.noop_max = noop_max
        self.override_num_noops = None
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

    def reset(self, **kwargs):
        """ Do no-op action for a number of steps in [1, noop_max]."""
        self.env.reset(**kwargs)
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
        assert noops > 0
        obs = None
        for _ in range(noops):
            obs, _, done, _ = self.env.step(self.noop_action)
            if done:
                obs = self.env.reset(**kwargs)
        return obs

    def step(self, ac):
        return self.env.step(ac)


class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        """Take action on reset for environments that are fixed until firing."""
        gym.Wrapper.__init__(self, env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(1)
        if done:
            self.env.reset(**kwargs)
        obs, _, done, _ = self.env.step(2)
        if done:
            self.env.reset(**kwargs)
        return obs

    def step(self, ac):
        return self.env.step(ac)


class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        """Return only every `skip`-th frame"""
        gym.Wrapper.__init__(self, env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
        self._skip = skip

    def step(self, action):
        """Repeat action, sum reward, and max over last observations."""
        total_reward = 0.0
        done = None
        for i in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if done:
                break
        # Note that the observation on the done=True frame
        # doesn't matter
        max_frame = np.max(self._obs_buffer, axis=0)
        # max_frame = self._obs_buffer.max(axis=0)

        return max_frame, total_reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)


class ClipRewardEnv(gym.RewardWrapper):
    def __init__(self, env):
        gym.RewardWrapper.__init__(self, env)

    def reward(self, reward):
        """Bin reward to {+1, 0, -1} by its sign."""
        return np.sign(reward)

class ScaledFloatFrame(gym.ObservationWrapper):
    def __init__(self, env):
        gym.ObservationWrapper.__init__(self, env)
        self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)

    def observation(self, observation):
        # careful! This undoes the memory optimization, use
        # with smaller replay buffers only.
        return np.array(observation).astype(np.float32) / 255.0


def wrap_deepmind(env, skip=4, no_op_max=30, episode_life=True, clip_rewards=True, scale=True):
    """Configure environment for DeepMind-style Atari.
    """

    # one life one episode to speed train
    if episode_life:
        env = EpisodicLifeEnv(env)

    # after reset taking random number of no-ops:
    env = NoopResetEnv(env, no_op_max)

    # use same action in k frames, compute all rewards, maximum last two frames
    env = MaxAndSkipEnv(env, skip)

    # return an env which would not over after using "fire"
    if 'FIRE' in env.unwrapped.get_action_meanings():
        env = FireResetEnv(env)

    env = PreprocessFrame(env)

    if scale:
        env = ScaledFloatFrame(env)

    if clip_rewards:
        env = ClipRewardEnv(env)

    return env

## trainer

In [6]:
class DQN_trainer:
    def __init__(self, config):
        # env info
        self.env = wrap_deepmind(gym.make('Breakout-v0'), skip=config.action_repeat, no_op_max=config.no_op_max)
        if config.is_monitor:
            self.env = gym.wrappers.Monitor(self.env, 'recording')
        self.action_num = self.env.action_space.n
        self.obs_shape = self.env.observation_space.shape  # [h, w, c]
        self.last_obs = self.env.reset()

        # reply buffer
        self.reply_buffer = ReplayBuffer(config.replay_memory_size, config.agent_history_length)

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        # initial model [batch_size, h, w, m*c]
        self.eval_model = DQNet(self.obs_shape[2] * config.agent_history_length, self.action_num).to(self.device)
        self.target_model = DQNet(self.obs_shape[2] * config.agent_history_length, self.action_num).to(self.device)
        
        if True:
            self.eval_model.load_state_dict(torch.load('./checkpoints/state_dict_step900000_ave_reward_2.5945.pth'))
            self.target_model.load_state_dict(torch.load('./checkpoints/state_dict_step900000_ave_reward_2.5945.pth'))
            
        # train param
        self.exploration = np.linspace(config.initial_exploration, config.final_exploration,
                                       config.final_exploration_frame)
        self.final_exploration_frame = config.final_exploration_frame
        self.discount_factor = config.discount_factor
        self.max_epoch = config.max_epoch
        self.learning_starts = config.learning_starts
        self.update_freq = config.update_freq
        self.target_update_freq = config.target_update_freq
        self.batch_size = config.batch_size

        self.model_path = config.model_path
        self.load_model_freq = config.load_model_freq

        self.criterion = torch.nn.MSELoss()
        self.optimizer = torch.optim.Adam(self.eval_model.parameters(), lr=config.learning_rate)

        # self.viz = visdom.Visdom(env="DQN_train", log_to_filename="./logs/viz_dqn_train.log")
        self.log_freq = config.log_freq

    def collect_memories(self):
        """
        before DQN begins to learn, collect adequate memories.
        :return:
        """
        print("-------------------collect memories------------------------")
        for step in tqdm(range(self.learning_starts)):
            # store observation
            cur_index = self.reply_buffer.store_memory_obs(self.last_obs)
            # choose action randomly
            action = self.env.action_space.sample()
            # interact with env
            obs, reward, done, info = self.env.step(action)
            # clip reward
            reward = np.clip(reward, -1.0, 1.0)
            # store other info
            self.reply_buffer.store_memory_effect(cur_index, action, reward, done)

            if done:
                obs = self.env.reset()

            self.last_obs = obs
        print("---------------------------end-----------------------------")

    def train(self):
        """
        train DQN agent
        :return:
        """
        total_reward = 0
        total_step = 0
        total_average_reward = 0
        total_average_step = 0
        episode = 0
        episode_1000 = 0
        logout = False

        train_ave_loss = 0
        log_step = 0

        self.last_obs = self.env.reset()
        print("-------------------train DQN agent------------------------")
        for step in tqdm(range(1, self.max_epoch)):
            cur_index = self.reply_buffer.store_memory_obs(self.last_obs)
            encoded_obs = self.reply_buffer.encoder_recent_observation()  # numpy: [m*c, h, w]

#             # visualize last k frames
#             image_num = int(encoded_obs.shape[0] / self.obs_shape[2])
#             images_numpy = np.array([[encoded_obs[i]] for i in range(image_num)])
#             self.viz.images(torch.from_numpy(images_numpy), win="observations")

            sample = np.random.random()
            # change from 1.0 to 0.1 linearly
#             epsilon = self.exploration[min([step, self.final_exploration_frame])]
            if sample > 0.1:
                # numpy: [m*c, h, w] => tensor: [1, m*c, h, w]
                encoded_obs = change_to_tensor(encoded_obs).unsqueeze(0)
                pred_action_values = self.eval_model(encoded_obs)  # [1, 4]
                _, action = pred_action_values.max(dim=1)
                action = action.item()
            else:
                action = self.env.action_space.sample()

            obs, reward, done, info = self.env.step(action)

            total_reward += reward
            total_step += 1

            # reward = np.clip(reward, -1.0, 1.0)

            self.reply_buffer.store_memory_effect(cur_index, action, reward, done)

            if done:
                obs = self.env.reset()
                episode += 1
                if (episode % 1000) == 0:
                    logout = True
                total_average_reward += total_reward
                total_average_step += total_step
                total_reward = 0
                total_step = 0

            self.last_obs = obs

            # train the model
            if step % self.update_freq == 0 and step > self.learning_starts:
                obs_batch, next_obs_batch, action_batch, reward_batch, done_batch = self.reply_buffer.sample_memories(
                    self.batch_size)
                # numpy to tensor
                obs_batch, next_obs_batch = change_to_tensor(obs_batch), change_to_tensor(next_obs_batch)
                action_batch, reward_batch = change_to_tensor(action_batch, torch.int64), change_to_tensor(reward_batch)

                # estimate Q values
                q_values = self.eval_model(obs_batch)  # [b, action_num]
                q_pred = q_values.gather(dim=1, index=action_batch)  # [b, 1]

                # target Q values
                q_next = self.target_model(next_obs_batch).detach()
                # Bellman equation
                q_target = reward_batch + self.discount_factor * q_next.max(dim=1)[0].view(self.batch_size, -1)

                loss = self.criterion(q_pred, q_target)
                train_ave_loss += loss.item()

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # update target net
            if step % self.target_update_freq == 0 and step > self.learning_starts:
                self.target_model.load_state_dict(OrderedDict(self.eval_model.state_dict()))

            # save model
            if (step / self.update_freq) % self.load_model_freq == 0 and step > self.learning_starts:
                torch.save(self.eval_model.state_dict(), self.model_path + '/state_dict_step%d-trainLoss%.4f.pth' % (step, loss.item()*1000))

#             # visualize data
#             if (step / self.update_freq) % self.log_freq == 0:
#                 log_step += 1
#                 train_ave_loss = train_ave_loss / self.log_freq
#                 self.viz.line([train_ave_loss], [log_step], win='train_average_loss', update='append', opts=dict(
#                                 title="train_average_loss",
#                                 xlabel="log_step",
#                                 ylabel="average_loss"
#                             ))
#                 train_ave_loss = 0

            if episode % 1000 == 0 and logout:
                logout = False
                episode_1000 += 1
                total_average_reward = total_average_reward / 1000
                total_average_step = total_average_step / 1000
                
                print("---------------------------end-----------------------------")
                print("average reward in recent 1000 times: {}".format(total_average_reward))
                print("average step in recent 1000 times: {}".format(total_average_step))
                print("---------------------------end-----------------------------")
#                 self.viz.line([total_ave100_reward], [episode_100], win='average100_reward', update='append', opts=dict(
#                     title="average100_reward",
#                     xlabel="episode_100",
#                     ylabel="average_reward"
#                 ))
#                 self.viz.line([total_ave100_step], [episode_100], win='average100_step', update='append', opts=dict(
#                     title="average100_step",
#                     xlabel="episode_100",
#                     ylabel="average_step"
#                 ))

        print("---------------------------end-----------------------------")

## load argument

In [7]:
parser = argparse.ArgumentParser()

# paper "Human-level control through deep reinforcement learning" argument
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--replay_memory_size', type=int, default=1000000)
parser.add_argument('--agent_history_length', type=int, default=4)
parser.add_argument('--target_update_freq', type=int, default=100000, help='target net update its param')
parser.add_argument('--discount_factor', type=float, default=0.99, help='discount factor')
parser.add_argument('--action_repeat', type=int, default=4, help='repeat same action in k frames')
parser.add_argument('--update_freq', type=int, default=4, help='DQN learn once per learning freq')
# RMSProp
parser.add_argument('--learning_rate', type=float, default=0.00025)
parser.add_argument('--gradient_momentum', type=float, default=0.95)
parser.add_argument('--squared_gradient_momentum', type=float, default=0.95)
parser.add_argument('--min_squared_gradient', type=float, default=0.01)
# epsilon
parser.add_argument('--initial_exploration', type=float, default=1.0)
parser.add_argument('--final_exploration', type=float, default=0.1)
parser.add_argument('--final_exploration_frame', type=int, default=1000000)

parser.add_argument('--learning_starts', type=int, default=5000, help='after learning starts DQN begin to learn')
parser.add_argument('--no_op_max', type=int, default=30, help='after reset taking random number of no-ops')

parser.add_argument('--load_model_freq', type=int, default=12500)
parser.add_argument('--model_path', type=str, default='./checkpoints', help='path for saving trained models')

# other argument
parser.add_argument('--max_epoch', type=int, default=8000000)
parser.add_argument('--is_monitor', type=bool, default=False, help='use monitor log the performance of the agent')
parser.add_argument('--log_freq', type=int, default=1000, help='step size for updating visdom')

args = parser.parse_known_args()[0]
print(args)

Namespace(action_repeat=4, agent_history_length=4, batch_size=32, discount_factor=0.99, final_exploration=0.1, final_exploration_frame=1000000, gradient_momentum=0.95, initial_exploration=1.0, is_monitor=False, learning_rate=0.00025, learning_starts=5000, load_model_freq=12500, log_freq=1000, max_epoch=8000000, min_squared_gradient=0.01, model_path='./checkpoints/state_dict_step900000_ave_reward_2.5945.pth', no_op_max=30, replay_memory_size=1000000, squared_gradient_momentum=0.95, target_update_freq=100000, update_freq=4)


In [8]:
dqn_agent_trainer = DQN_trainer(args)

In [9]:
print(dqn_agent_trainer)

<__main__.DQN_trainer object at 0x0000028401A748C8>


In [10]:
dqn_agent_trainer.collect_memories()
dqn_agent_trainer.train()

-------------------collect memories------------------------


100%|██████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:52<00:00, 95.22it/s]


---------------------------end-----------------------------
-------------------train DQN agent------------------------


  0%|                                                                        | 6478/7999999 [01:09<32:03:29, 69.26it/s]

KeyboardInterrupt: 