Среды, в которых мы обучали агентов до настоящего времени были относительно дружелюбными к ним. Пространство состояний, хотя и являлось порой непрерывным, всё же имело понятную структуру. Что делать, если на вход агента поступает, например, набор изображений?


Рассмотрим среду Pong, где наблюдением на каждом шаге является RGB-изображение экрана. Это матрица размером (210, 160, 3). Агент может двигаться вверх и вниз, чтобы отбивать мяч. 

In [None]:
!pip install gym pyvirtualdisplay atari-py > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg x11-utils > /dev/null 2>&1

In [None]:
!wget http://www.atarimania.com/roms/Roms.rar &> /dev/null

In [None]:
!unrar x Roms.rar &> /dev/null

In [None]:
!python -m atari_py.import_roms ROMS &> /dev/null

In [None]:
import gym
from gym import logger as gymlogger
from gym.wrappers import Monitor
gymlogger.set_level(40) #error only
import tensorflow as tf
import numpy as np
import random
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
import math
import glob
import io
import base64

from IPython.display import HTML
from IPython import display as ipythondisplay

In [None]:
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

<pyvirtualdisplay.display.Display at 0x7f8fec535dd0>

In [None]:
def wrap_env(env):
  env = Monitor(env, './video', force=True)
  return env

def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    mp4 = mp4list[0]
    video = io.open(mp4, 'r+b').read()
    encoded = base64.b64encode(video)
    ipythondisplay.display(HTML(data='''<video alt="test" autoplay 
                loop controls style="height: 400px;">
                <source src="data:video/mp4;base64,{0}" type="video/mp4" />
             </video>'''.format(encoded.decode('ascii'))))
  else: 
    print("Could not find video")

In [None]:
#env = gym.envs.make("PongDeterministic-v4")
env = wrap_env(gym.envs.make("PongDeterministic-v4"))

In [None]:
state_shape = env.observation_space.shape
n_action = env.action_space.n

print(state_shape)
print(n_action)
print(env.unwrapped.get_action_meanings())

(210, 160, 3)
6
['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']


In [None]:
ACTIONS = [0, 2, 3]
n_action = 3
env.reset()
env.render()
is_done = False

while not is_done:
    action = ACTIONS[random.randint(0, n_action - 1)]
    obs, reward, is_done, _ = env.step(action)
    env.render()

env.close()

In [None]:
show_video()

In [None]:
from collections import deque
import copy
import torch
from torch.autograd import Variable


class DQN():
    def __init__(self, n_state, n_action, n_hidden, lr=0.05):
        self.criterion = torch.nn.MSELoss()
        self.model = torch.nn.Sequential(
            torch.nn.Linear(n_state, n_hidden[0]),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden[0], n_hidden[1]),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden[1], n_action)
        )
        self.model_target = copy.deepcopy(self.model)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr)

    def update(self, s, y):
        """
        Update the weights of the DQN given a training sample
        @param s: state
        @param y: target value
        """
        y_pred = self.model(torch.Tensor(s))
        loss = self.criterion(y_pred, Variable(torch.Tensor(y)))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def predict(self, s):
        """
        Compute the Q values of the state for all actions using the learning model
        @param s: input state
        @return: Q values of the state for all actions
        """
        with torch.no_grad():
            return self.model(torch.Tensor(s))

    def target_predict(self, s):
        """
        Compute the Q values of the state for all actions using the target network
        @param s: input state
        @return: targeted Q values of the state for all actions
        """
        with torch.no_grad():
            return self.model_target(torch.Tensor(s))

    def replay(self, memory, replay_size, gamma):
        """
        Experience replay with target network
        @param memory: a list of experience
        @param replay_size: the number of samples we use to update the model each time
        @param gamma: the discount factor
        """
        if len(memory) >= replay_size:
            replay_data = random.sample(memory, replay_size)
            states = []
            td_targets = []
            for state, action, next_state, reward, is_done in replay_data:
                states.append(state.tolist())
                q_values = self.predict(state).tolist()
                if is_done:
                    q_values[action] = reward
                else:
                    q_values_next = self.target_predict(next_state).detach()

                    q_values[action] = reward + gamma * torch.max(q_values_next).item()

                td_targets.append(q_values)

            self.update(states, td_targets)

    def copy_target(self):
        self.model_target.load_state_dict(self.model.state_dict())

Пространство наблюдений размером (210, 160, 3) на самом деле довольно большое, чтобы процесс обучения шел быстрее, мы уменьшим его до размера 84 на 84 пикселя, а вместо 3х цветовых каналов оставим только оттенки серого.

In [None]:
from torchvision.transforms.functional import InterpolationMode
import torchvision.transforms as T
from PIL import Image

image_size = 84


transform = T.Compose([T.ToPILImage(),
                       T.Grayscale(num_output_channels=1),
                       T.Resize((image_size, image_size), interpolation=InterpolationMode.BOX),
                       T.ToTensor(),
                       ])


def get_state(obs):
    state = obs.transpose((2, 0, 1))
    state = torch.from_numpy(state)
    state = transform(state)
    return state


state = get_state(obs)


print(state.shape)

torch.Size([1, 84, 84])


In [None]:
def gen_epsilon_greedy_policy(estimator, epsilon, n_action):
    def policy_function(state):
        if random.random() < epsilon:
            return random.randint(0, n_action - 1)
        else:
            q_values = estimator.predict(state)
            return torch.argmax(q_values).item()

    return policy_function

In [None]:
def q_learning(env, estimator, n_episode, replay_size, target_update=10, gamma=1.0, epsilon=0.1, epsilon_decay=.99):
    """
    Deep Q-Learning using double DQN, with experience replay
    @param env: Gym environment
    @param estimator: DQN object
    @param replay_size: number of samples we use to update the model each time
    @param target_update: number of episodes before updating the target network
    @param n_episode: number of episodes
    @param gamma: the discount factor
    @param epsilon: parameter for epsilon_greedy
    @param epsilon_decay: epsilon decreasing factor
    """
    for episode in range(n_episode):
        if episode % target_update == 0:
            estimator.copy_target()
        policy = gen_epsilon_greedy_policy(estimator, epsilon, n_action)
        obs = env.reset()

        state = get_state(obs).view(image_size * image_size)

        is_done = False
        while not is_done:
            action = policy(state)
            next_obs, reward, is_done, _ = env.step(ACTIONS[action])

            total_reward_episode[episode] += reward

            next_state = get_state(obs).view(image_size * image_size)

            memory.append((state, action, next_state, reward, is_done))

            if is_done:
                break

            estimator.replay(memory, replay_size, gamma)

            state = next_state

        print('Episode: {}, total reward: {}, epsilon: {}'.format(episode, total_reward_episode[episode], epsilon))
        epsilon = max(epsilon * epsilon_decay, 0.01)


In [None]:
n_state = image_size * image_size

n_episode = 1000

n_hidden = [200, 50]
lr = 0.003
replay_size = 32
target_update = 10

dqn = DQN(n_state, n_action, n_hidden, lr)
memory = deque(maxlen=10000)
total_reward_episode = [0] * n_episode

q_learning(env, dqn, n_episode, replay_size, target_update, gamma=.9, epsilon=1)

In [None]:
import matplotlib.pyplot as plt
plt.plot(total_reward_episode)
plt.title('Episode reward over time')
plt.xlabel('Episode')
plt.ylabel('Total reward')
plt.show()