In [None]:
import os, sys, time, threading, collections
from io import BytesIO

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import keras
import KerasTools as KT
import numpy as np

import ipywidgets
import skimage
import matplotlib.pyplot as plt

import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

In [None]:
stopEvent = threading.Event()
currentGame = collections.deque([], 1)
class GamePlot(threading.Thread):
    def __init__(self, ratio=1.0):
        threading.Thread.__init__(self, name="GamePlot")
        self.ratio=ratio
        self.imbuf = BytesIO()
        self.img = ipywidgets.Image(width=int(ratio*256), height=256)
        display(self.img)
        self.plot_frame(np.zeros((10,10), np.uint8))
    def run(self):
        while not stopEvent.isSet():
            try: game = currentGame.pop()
            except IndexError: time.sleep(0.1)
            else:
                for frame in game:
                    self.plot_frame(frame)
                    time.sleep(0.1)
                time.sleep(0.5)
    def plot_frame(self, frame):
        f = skimage.transform.resize(frame,(256,int(self.ratio*256)), order=0, mode='constant', anti_aliasing=False)
        plt.imsave(self.imbuf, f, vmin=0.0, vmax=1.0, cmap='nipy_spectral')
        self.imbuf.seek(0)           
        self.img.value = self.imbuf.getvalue()

class Callback(object):
    def game_start(self, frame): pass
    def game_frame(self, frame): pass
    def game_over(self): pass
    def epoch_end(self, *args): pass

class History(Callback):
    def __init__(self, name):
        st = time.gmtime()
        self.timestamp = "{:04d}{:02d}{:02d}_{:02d}{:02d}{:02d}".format(st.tm_year, st.tm_mon, st.tm_mday, st.tm_hour, st.tm_min, st.tm_sec)
        self.filename = '{}-{}.log'.format(name, self.timestamp)
        with open(self.filename, 'w+') as fp:
            fp.write('Epoch, Epsilon,    Loss, Win Ratio, Avg Score, Max Score,   Memory\n')
    def epoch_end(self, *args):
        _model, name, epoch, epsilon, loss, win_ratio, avg_score, max_score, memory = args
        with open(self.filename, 'a') as fp:
            fp.write('{:> 5d}, {:>7.2f}, {:>7.4f}, {:>9.2%}, {:>9.2f}, {:>9.2f}, {:>8d}\n'.format(epoch, epsilon, loss, win_ratio, avg_score, max_score, memory))

class Checkpoint(Callback):
    def __init__(self, interval=1):
        self.interval = interval
    def epoch_end(self, *args):
        model, name, epoch, epsilon, loss, win_ratio, avg_score, max_score, memory = args
        if epoch % self.interval == 0:
            filename = '{}_{:03d}.h5'.format(name, epoch)
            model.save(filename)
            
class GameStore(Callback):
    def __init__(self, ratio=1.0):
        currentGame.append([])
        self.plotter = GamePlot(ratio)
        self.plotter.start()
    def game_start(self, frame):
        self.gamestore = [frame]
    def game_frame(self, frame):
        self.gamestore.append(frame)
    def game_over(self):
        currentGame.append(self.gamestore)

In [None]:
grid_size = 16
nb_frames = 2

game = KT.qlearn.catch.Catch(grid_size)

model = keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape=(nb_frames, grid_size, grid_size, 3)))
model.add(keras.layers.Dense(64, activation='relu'))
model.add(keras.layers.Dense(game.nb_actions, activation='linear'))
model.compile(keras.optimizers.rmsprop(), "logcosh")
model.summary()

agent = KT.qlearn.agent.Agent(model=model, memory_size=65536, nb_frames = nb_frames)
stopEvent.clear()
agent.train(game, batch_size=256, epochs=10, train_interval=32, gamma=0.99,
            epsilon=[0.5, 0.0], epsilon_rate=0.1, reset_memory=False,
            callbacks=[GameStore(), Checkpoint(1), History(game.name)])

stopEvent.set()

In [None]:
width = 5
height = 8
grid_size = 10
nb_frames = 16

ratio = width/float(height)
game = KT.qlearn.tromis.Tromis(width=width,height=height, max_turn=250)
#game = KT.qlearn.snake.Snake(grid_size, max_turn=64)
#game = KT.qlearn.catch.Catch(grid_size=grid_size)

inpc = keras.layers.Input(shape=(None, None, 3))
conv1 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(inpc)
conv2 = keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
gpool = keras.layers.GlobalMaxPooling2D()(conv2)
convm = keras.models.Model(inputs=inpc, outputs=gpool)
convm.summary()

inp = keras.layers.Input(shape=(None, None, None, 3))
x = keras.layers.TimeDistributed(convm)(inp)
x = keras.layers.SimpleRNN(32, return_sequences=False)(x)
act = keras.layers.Dense(game.nb_actions, activation='linear')(x)

model = keras.models.Model(inputs=inp, outputs=act)
model.compile(keras.optimizers.rmsprop(), 'logcosh')
model.summary()

agent = KT.qlearn.agent.Agent(model=model, memory_size=65536, nb_frames = nb_frames)
stopEvent.clear()

In [None]:
model.save('tromis_000.h5')
agent.train(game, batch_size=256, epochs=20, train_interval=128,
            epsilon=[0.5, 0.0], epsilon_rate=0.25,
            gamma=0.9, reset_memory=False, callbacks=[GameStore(ratio=ratio)])
#for i in range(10):
#    stopEvent.set()
#    stopEvent.clear()
#    model.save('tromis_{:03d}.h5'.format((i+1)*20))
#    agent.train(game, batch_size=256, epochs=20, train_interval=128,
#                epsilon=0.0, gamma=0.5, reset_memory=False, callbacks=[GameStore(ratio=ratio)])
stopEvent.set()

In [None]:
stopEvent.clear()
agent.train(game, batch_size=256, epochs=20, train_interval=128,
            epsilon=0.0, gamma=0.95, reset_memory=False, callbacks=[GameStore(ratio=ratio)])
stopEvent.set()