In [None]:
import time, threading, collections, io

import ipywidgets
import PIL.Image

import numpy as np
import keras
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

import agent, memory, callbacks, catch, snake, tromis

In [None]:
stopEvent = threading.Event()
currentGame = collections.deque([], 1)
class GamePlot(threading.Thread):
    def __init__(self, ratio=1.0):
        threading.Thread.__init__(self, name="GamePlot")
        self.ratio=ratio
        self.imbuf = io.BytesIO()
        self.img = ipywidgets.Image(width=int(ratio*256), height=256)
        display(self.img)
        self.plot_frame(np.zeros((2,2), np.uint8))
    def run(self):
        while not stopEvent.isSet():
            try: game = currentGame.pop()
            except IndexError: time.sleep(0.1)
            else:
                for frame in game:
                    self.plot_frame(frame)
                    time.sleep(0.1)
                time.sleep(0.5)
    def plot_frame(self, frame):
        self.imbuf.seek(0)
        frame *= 255
        PIL.Image.fromarray(frame.astype('uint8')).\
                  resize((int(self.ratio*256), 256)).\
                  save(self.imbuf, 'png')
        self.img.value = self.imbuf.getvalue()
        
class GameStore(callbacks.Callback):
    def __init__(self, ratio=1.0):
        currentGame.append([])
        self.plotter = GamePlot(ratio)
        self.plotter.start()
    def game_start(self, frame):
        self.gamestore = [frame]
    def game_frame(self, frame):
        self.gamestore.append(frame)
    def game_over(self):
        currentGame.append(self.gamestore)

In [None]:
grid_size, width, height, nb_frames = 12, 12, 12, 2
game = catch.Catch(grid_size)

#grid_size, width, height, nb_frames = 8, 8, 8, 2
#game = snake.Snake(grid_size, max_turn=32)

#width, height, nb_frames = 6, 9, 1
#game = tromis.Tromis(width=width, height=height, max_turn=128)

cs = 4
inpc = keras.layers.Input(shape=(height, width, 3), name='conv_input')
conv1 = keras.layers.Conv2D(cs,3,padding='same',strides=2,activation='relu', name='conv1')(inpc)
conv2 = keras.layers.Conv2D(cs*2,3,padding='same',strides=2,activation='relu', name='conv2')(conv1)
conv3 = keras.layers.Conv2D(cs*4,3,padding='same',strides=2,activation='relu', name='conv3')(conv2)
conv4 = keras.layers.Conv2D(cs*8,3,padding='same',strides=2,activation='relu', name='conv4')(conv3)
flat = keras.layers.Flatten(name='flatten')(conv4)
convm = keras.models.Model(inputs=inpc, outputs=flat, name='CONV_BASE')
convm.summary()

ls = cs*16
inp = keras.layers.Input(shape=(nb_frames, height, width, 3), name='state_input')
x = keras.layers.TimeDistributed(convm, name='conv_distributed')(inp)
x = keras.layers.SimpleRNN(ls, return_sequences=False, name='rnn')(x)
x = keras.layers.Dense(ls, activation='relu', name='dense1')(x)
act = keras.layers.Dense(game.nb_actions, activation='linear', name='actions')(x)

model = keras.models.Model(inputs=inp, outputs=act, name='DQN')
model.compile(keras.optimizers.rmsprop(), 'logcosh')
model.summary()

m = memory.UniqMemory(memory_size=65536)
a = agent.Agent(model=model, mem=m, num_frames = nb_frames)

In [None]:
game.max_turn=32
stopEvent.clear()
a.train(game, batch_size=64, epochs=10, train_interval=8, episodes=256,
            epsilon=0.01, gamma=0.98, reset_memory=False,
            callbacks = [GameStore(ratio=width/height)])
stopEvent.set()