In [None]:
import os, io, threading, time

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import ipywidgets
import PIL, PIL.Image
import tensorflow as tf
import tensorflow.keras as keras
tf.get_logger().setLevel('ERROR')

import rl.games.catch
import rl.games.fruit
import rl.games.snake
import rl.games.tromis
import rl.tools.gui
import rl.callbacks.callback
import rl.callbacks.gamestore
import rl.agents.ddqn
import rl.memory.basicmemory
import rl.memory.uniqmemory

In [None]:
class GamePlot(threading.Thread):
    def __init__(self, width, height, gameQueue):
        threading.Thread.__init__(self, name="GamePlot")
        self.gameQueue = gameQueue
        self.imbuf = io.BytesIO()
        rx, ry = (256, int(height*256/width)) if (width > height) else (int(width*256/height), 256)
        self.canvas = ipywidgets.Image(width=rx, height=ry)
    def run(self):
        while not stopEvent.isSet():
            try: game = self.gameQueue.pop()
            except IndexError: time.sleep(0.1)
            else:
                for frame in game:
                    self.plot_frame(frame[0])
                    time.sleep(0.1)
                time.sleep(0.5)
    def plot_frame(self, frame):
        self.imbuf.seek(0)
        fx, fy = frame.shape[0], frame.shape[1]
        rx, ry = (256, int(fy*256/fx)) if (fx > fy) else (int(fx*256/fy), 256)
        PIL.Image.fromarray((frame*255).astype('uint8')).resize((ry, rx), resample=PIL.Image.NEAREST).save(self.imbuf, 'gif')
        self.canvas.value = self.imbuf.getvalue()

In [None]:
class HistoryPlot(rl.callbacks.Callback):
    def __init__(self):
        pass
    def epoch_end(self, *args):
        model, name, epoch, epsilon, win_ratio, avg_score, max_score, memory = args

In [None]:
nb_frames, grid_size = 1, 16
memory_size = 4096

game = rl.games.catch.Catch(grid_size)

inp = keras.layers.Input(shape=(nb_frames, grid_size, grid_size, 3))
x = keras.layers.Conv3D(16,5,padding='same',strides=1,activation='relu')(inp)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(32, activation='relu')(x)
act = keras.layers.Dense(game.nb_actions, activation='linear')(x)

model = keras.models.Model(inputs=inp, outputs=act)
model.compile(keras.optimizers.RMSprop(), keras.losses.LogCosh())
model.summary()

memory = rl.memory.uniqmemory.UniqMemory(memory_size=memory_size)
agent = rl.agents.ddqn.Agent(model, memory, with_target=True)

stopEvent = threading.Event()
gameStore = rl.callbacks.gamestore.GameStore()
plotter = GamePlot(grid_size, grid_size, gameStore.gameQueue)
display(plotter.canvas)

stopEvent.clear()
plotter.start()

agent.train(game, batch_size=64, epochs=50, episodes=32, train_freq=8, target_sync=64,
            epsilon_start=1.0, epsilon_decay=0.5, epsilon_final = 0.0,
            gamma=0.98, reset_memory=False, observe=128, verbose=1,
            callbacks = [gameStore])

stopEvent.set()