In [20]:
from __future__ import division
from easydict import EasyDict as edict

import retro

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, Permute, Activation
from keras.layers import Conv2D, MaxPooling2D
from keras.optimizers import Adam
import keras.backend as K

from rl.agents.dqn import DQNAgent
from rl.policy import LinearAnnealedPolicy, BoltzmannQPolicy, EpsGreedyQPolicy
from rl.memory import SequentialMemory
from rl.core import Processor
from rl.callbacks import FileLogger, ModelIntervalCheckpoint

import skimage
from skimage import data,color, io
from skimage.color import rgb2gray
from skimage.transform import rescale, resize, downscale_local_mean

import numpy as np

INPUT_SHAPE = (84, 84)
WINDOW_LENGTH = 4
#Processor for keras-rl
class retro_processor(Processor):
    def process_observation(self,_obs):
        assert _obs.ndim == 3  # (height, width, channel)
        img = skimage.color.rgb2gray(_obs)
        img = skimage.transform.resize(img,INPUT_SHAPE)  # resize and convert to grayscale
        assert img.shape == INPUT_SHAPE
        return img  # saves storage in experience memory
    def process_state_batch(self, batch):
        processed_batch = batch.astype('float32') / 255.
        return processed_batch
    def process_reward(self, reward):
        return np.clip(reward, -1., 1.)
    def process_action(self, action):
        action_ = [0]*9
        action_[action] = 1
        return action_


args = edict({
    #'mode': 'train',
    'mode': 'test',
    '--env-name': 'PokemonPinball-Gbc',
    'weights': None
})


In [21]:
env = retro.make(game='PokemonPinball-Gbc', record='.')
np.random.seed(123)
nb_actions = env.action_space.n
print(nb_actions)

RuntimeError: Cannot create multiple emulator instances per process

In [22]:

input_shape =(WINDOW_LENGTH,)+INPUT_SHAPE
#creation of neural network
model = Sequential()
if K.image_dim_ordering() == 'tf':
    # (width, height, channels)
    model.add(Permute((2, 3, 1), input_shape=input_shape))
elif K.image_dim_ordering() == 'th':
    # (channels, width, height)
    model.add(Permute((1, 2, 3), input_shape=input_shape))
else:
    raise RuntimeError('Unknown image_dim_ordering.')
model.add(Conv2D(32, kernel_size=(8, 8), strides=(4, 4),activation='relu',input_shape=input_shape))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Conv2D(64, kernel_size=(4, 4),activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3),activation='relu'))
model.add(Flatten())
model.add(Dense(3136, activation='relu'))
model.add(Dense(512)) #'activation=softmax'
model.add(Dense(nb_actions))
model.add(Activation('linear'))
print(model.summary())
#include max pooling with convolution

memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)
processor = retro_processor()
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.05,nb_steps=1000000)
#create agent
dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, memory=memory,
               processor=processor, nb_steps_warmup=50000, gamma=.99, target_model_update=10000,
               train_interval=4, delta_clip=1.)
dqn.compile(Adam(lr=1), metrics=['mae'])

if args.mode == 'train':
    # Okay, now it's time to learn something! We capture the interrupt exception so that training
    # can be prematurely aborted. Notice that you can the built-in Keras callbacks!
    weights_filename = 'dqn_{}_weights.h5f'.format('PokemonPinball-Gbc')
    checkpoint_weights_filename = 'dqn_' + 'PokemonPinball-Gbc' + '_weights_{step}.h5f'
    log_filename = 'dqn_{}_log.json'.format('PokemonPinball-Gbc')
    callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250000)]
    callbacks += [FileLogger(log_filename, interval=100)]
    dqn.fit(env, callbacks=callbacks, nb_steps=175000, log_interval=10000)

    # After training is done, we save the final weights one more time.
    dqn.save_weights(weights_filename, overwrite=True)

    
    # Finally, evaluate our algorithm for 10 episodes.
    #dqn.test(env, nb_episodes=10, visualize=True)
    #dqn.save_weights(weights_filename, overwrite=True)

    
elif args.mode == 'test': #Broken, stalls on first epoch
    weights_filename = 'dqn_{}_weights.h5f'.format('PokemonPinball-Gbc')
    if args.weights:
        weights_filename = args.weights
    dqn.load_weights(weights_filename)
    dqn.test(env, nb_episodes=10, visualize=True)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
permute_8 (Permute)          (None, 84, 84, 4)         0         
_________________________________________________________________
conv2d_22 (Conv2D)           (None, 20, 20, 32)        8224      
_________________________________________________________________
max_pooling2d_15 (MaxPooling (None, 10, 10, 32)        0         
_________________________________________________________________
conv2d_23 (Conv2D)           (None, 7, 7, 64)          32832     
_________________________________________________________________
max_pooling2d_16 (MaxPooling (None, 3, 3, 64)          0         
_________________________________________________________________
conv2d_24 (Conv2D)           (None, 1, 1, 64)          36928     
_________________________________________________________________
flatten_8 (Flatten)          (None, 64)                0         
__________

KeyboardInterrupt: 

[<rl.callbacks.ModelIntervalCheckpoint object at 0x0000017A505BAE10>, <rl.callbacks.FileLogger object at 0x00000179B3590588>]
