## breakout

ブロック崩し

ボールを落とさずにブロックを崩せるようにエージェントを学習させる

https://gym.openai.com/envs/Breakout-ram-v0/

![breakout](https://thumbs.gfycat.com/AnchoredScornfulAustraliansilkyterrier-size_restricted.gif)

In [1]:
%matplotlib inline

import os
import io
import base64

from IPython.display import HTML
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import gym

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten, Convolution2D, Permute
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
from rl.memory import SequentialMemory
from rl.core import Processor
from rl.callbacks import Callback

Using TensorFlow backend.


### 完全にランダムに動作させた場合

In [2]:
env = gym.make('BreakoutDeterministic-v4')
env = gym.wrappers.Monitor(env, "./gym-results/breakout_random", force=True, video_callable=(lambda _: True))

for i in range(10):
    env.reset()
    done = False
    while not done:
        action = env.action_space.sample()
        ob, reward, done, _ = env.step(action)

In [3]:
video = io.open('./gym-results/breakout_random/openaigym.video.%s.video000000.mp4' % env.file_infix, 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='''
    <video width="360" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>'''
.format(encoded.decode('ascii')))

### 強化学習エージェントを実装して操作させる

In [2]:
INPUT_SHAPE = (84, 84)
WINDOW_LENGTH = 4

このゲームは画像を入力として学習するためProcessorの実装が必要。

学習すべき特徴量にするために白黒画像に変換したり、0~255のint値を0~1のfloatに変えている

In [3]:
class AtariProcessor(Processor):
    def process_observation(self, observation):
        assert observation.ndim == 3
        img = Image.fromarray(observation)
        img = img.resize(INPUT_SHAPE).convert('L')
        processed_observation = np.array(img)
        assert processed_observation.shape == INPUT_SHAPE
        return processed_observation.astype('uint8')

    def process_state_batch(self, batch):
        processed_batch = batch.astype('float32') / 255.
        return processed_batch

    def process_reward(self, reward):
        return np.clip(reward, -1., 1.)

In [4]:
weight_path = 'models/breakout/keras_weights.h5'

In [5]:
env = gym.make('BreakoutDeterministic-v4')

In [6]:
env = gym.wrappers.Monitor(env, "./gym-results/breakout", force=True, video_callable=(lambda ep: ep % 100 == 0))

In [7]:
input_shape = (WINDOW_LENGTH,) + INPUT_SHAPE
nb_actions = env.action_space.n

画像を利用するためネットワークも少し大きくなる。

画像の学習に有効な畳み込みを行い、3層の畳み込みと2層の全結合層を利用する。

In [8]:
model = Sequential()
model.add(Permute((2, 3, 1), input_shape=input_shape))
model.add(Convolution2D(32, (8, 8), strides=(4, 4)))
model.add(Activation('relu'))
model.add(Convolution2D(64, (4, 4), strides=(2, 2)))
model.add(Activation('relu'))
model.add(Convolution2D(64, (3, 3), strides=(1, 1)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dense(nb_actions))
model.add(Activation('linear'))

In [9]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
permute_1 (Permute)          (None, 84, 84, 4)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 20, 20, 32)        8224      
_________________________________________________________________
activation_1 (Activation)    (None, 20, 20, 32)        0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 9, 9, 64)          32832     
_________________________________________________________________
activation_2 (Activation)    (None, 9, 9, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 7, 7, 64)          36928     
_________________________________________________________________
activation_3 (Activation)    (None, 7, 7, 64)         

In [10]:
memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)
processor = AtariProcessor()

In [11]:
policy = LinearAnnealedPolicy(
    EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.05,
    nb_steps=1000000)

In [12]:
dqn = DQNAgent(
    model=model, nb_actions=nb_actions, policy=policy, memory=memory,
    processor=processor, nb_steps_warmup=50000, gamma=.99, target_model_update=10000,
    train_interval=4, delta_clip=1.)

In [13]:
dqn.compile(Adam(lr=.00025), metrics=['mae'])

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


In [14]:
if os.path.exists(weight_path):
    dqn.load_weights(weight_path)

5分程度で完全に学習できていたcartpoleと比べ、格段に学習に時間がかかるようになる。

8時間以上の学習が必要になる

今回は50000steps程度にとどめて結果を見てみます。

In [15]:
try:
    dqn.fit(
        env,
        nb_steps=50000,
#         nb_steps=1750000,  # 8h
        visualize=False,
    )
except KeyboardInterrupt:
    pass
finally:
    dqn.save_weights(weight_path, overwrite=True)

Training for 50000 steps ...
Interval 1 (0 steps performed)

56 episodes - episode_reward: 1.089 [0.000, 4.000] - ale.lives: 2.968

Interval 2 (10000 steps performed)
56 episodes - episode_reward: 1.125 [0.000, 4.000] - ale.lives: 2.930

Interval 3 (20000 steps performed)
61 episodes - episode_reward: 0.738 [0.000, 3.000] - ale.lives: 2.959

Interval 4 (30000 steps performed)
54 episodes - episode_reward: 1.241 [0.000, 6.000] - ale.lives: 2.892

Interval 5 (40000 steps performed)
done, took 151.915 seconds


In [16]:
video = io.open('./gym-results/breakout/openaigym.video.%s.video000000.mp4' % env.file_infix, 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='''
    <video width="360" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>'''
.format(encoded.decode('ascii')))

In [17]:
video = io.open('./gym-results/breakout/openaigym.video.%s.video000200.mp4' % env.file_infix, 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='''
    <video width="360" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>'''
.format(encoded.decode('ascii')))