In [2]:
import keras
import gym
import gym_piranhas

env = gym.make('piranhas-v0')

In [5]:
from keras.models import Sequential
from keras.layers import Convolution2D, Flatten, Dense, Activation, Permute
import keras.backend as K

INPUT_SHAPE = 10
DIMENSIONS  = 3
input_shape = (INPUT_SHAPE, INPUT_SHAPE, DIMENSIONS)
output_length = 2 + 1  # two coordinates + 1 direction
# TODO test with 100 or more outputs (two-hot encoding)

model = Sequential()
if K.image_dim_ordering() == 'tf':
    # tensorflow ordering: (width, height, channels)
    model.add(Permute((1, 2, 3), input_shape=input_shape))
elif K.image_dim_ordering() == 'th':
    # theano ordering: (channels, width, height)
    model.add(Permute((3, 1, 2), input_shape=input_shape))
else:
    raise RuntimeError('Unknown image_dim_ordering')
    
# The actual network structure
# results in a (6 x 6 x 32) output volume
model.add(Convolution2D(32, (5, 5), activation='relu', data_format='channels_last'))
# results in a (4 x 4 x 64) output volume
model.add(Convolution2D(64, (3, 3), activation='relu', data_format='channels_last'))
# results in a (2 x 2 x 64) output volume
model.add(Convolution2D(64, (3, 3), activation='relu', data_format='channels_last'))
# flattens the result (vector of size 256)
model.add(Flatten())
# add fully-connected layer
model.add(Dense(128, activation='relu'))
# map to output: coordinates and move
model.add(Dense(output_length, activation='linear'))
print(model.summary())

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
permute_3 (Permute)          (None, 10, 10, 3)         0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 6, 6, 32)          2432      
_________________________________________________________________
conv2d_6 (Conv2D)            (None, 4, 4, 64)          18496     
_________________________________________________________________
conv2d_7 (Conv2D)            (None, 2, 2, 64)          36928     
_________________________________________________________________
flatten_2 (Flatten)          (None, 256)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 128)               32896     
_________________________________________________________________
dense_3 (Dense)              (None, 3)                 387       
Total para