In [None]:
import os, io, threading, time

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import ipywidgets
import PIL, PIL.Image
import bqplot, bqplot.pyplot
import numpy as np

import tensorflow as tf
import tensorflow.keras as keras
tf.get_logger().setLevel('ERROR')

import rl.games.catch
import rl.games.catch_basic
import rl.games.fruit
import rl.games.snake
import rl.games.tromis
import rl.tools.gui
import rl.callbacks.callbacks
import rl.callbacks.gamestore
import rl.agents.ddqn
import rl.memory.basicmemory
import rl.memory.uniqmemory

In [None]:
class GamePlot(threading.Thread):
    def __init__(self, width, height, gameQueue):
        threading.Thread.__init__(self, name="GamePlot")
        self.gameQueue = gameQueue
        self.imbuf = io.BytesIO()
        rx, ry = (256, int(height*256/width)) if (width > height) else (int(width*256/height), 256)
        self.canvas = ipywidgets.Image(width=rx, height=ry)
        initframe = np.zeros((width,height,3))
        self.plot_frame(initframe)
    def run(self):
        while not stopEvent.isSet():
            try: game = self.gameQueue.pop()
            except IndexError: time.sleep(0.1)
            else:
                for frame in game:
                    self.plot_frame(frame[0])
                    time.sleep(0.1)
                time.sleep(0.5)
    def plot_frame(self, frame):
        self.imbuf.seek(0)
        fx, fy = frame.shape[0], frame.shape[1]
        rx, ry = (256, int(fy*256/fx)) if (fx > fy) else (int(fx*256/fy), 256)
        PIL.Image.fromarray((frame*255).astype('uint8')).resize((ry, rx), resample=PIL.Image.NEAREST).save(self.imbuf, 'gif')
        self.canvas.value = self.imbuf.getvalue()

In [None]:
class HistoryPlot(rl.callbacks.callbacks.Callback):
    def __init__(self, epochs, stat):
        self.plot_value = []
        self.epochs = epochs
        self.stat = stat
        self.axes = {'x': {'label': 'Epochs'}, 
                    'y': {'label': self.stat, 'label_offset': '50px', 'tick_style': {'font-size': 11} }
                   }
        self.hist_plt  = bqplot.pyplot.figure()
        self.hist_plt.layout = {"height": "256px", "width": "512px"}
        self.hist_plt.fig_margin = {"top":10, "bottom":30, "left":60, "right":0}
        self.hist_plt.min_aspect_ratio = 512.0/256.0
        self.hist_plt.max_aspect_ratio = 512.0/256.0
        bqplot.pyplot.scales(scales={'x': bqplot.scales.LinearScale(min=0,max=self.epochs)})
        bqplot.pyplot.plot([0],[0.0], axes_options=self.axes)
        
    def epoch_end(self, stats):
        self.plot_value.append(stats[self.stat])
        self.hist_plt.marks[0].x = np.asarray(range(len(self.plot_value)))
        self.hist_plt.marks[0].y = np.asarray(self.plot_value)

In [None]:
nb_frames, grid_size = 2, 16
memory_size = 65536
epochs = 30

#game = rl.games.catch_basic.Catch(grid_size)
game = rl.games.catch.Catch(grid_size, split_reward=True, with_penalty=True, hop=0.2)
#game = rl.games.fruit.Fruit(grid_size, with_poison=True)
#game = rl.games.snake.Snake(grid_size, max_turn=96)

inp = keras.layers.Input(shape=(nb_frames, grid_size, grid_size, 3))
x = keras.layers.Conv3D(16,3,padding='same',strides=1,activation='relu')(inp)
x = keras.layers.AveragePooling3D(padding='same')(x)
x = keras.layers.Flatten()(x)
x = keras.layers.Dense(32, activation='relu')(x)
act = keras.layers.Dense(game.nb_actions, activation='linear')(x)

model = keras.models.Model(inputs=inp, outputs=act)
model.compile(keras.optimizers.RMSprop(), keras.losses.LogCosh())
model.summary()

memory = rl.memory.uniqmemory.UniqMemory(memory_size=memory_size)
agent = rl.agents.ddqn.Agent(model, memory, with_target=True)

stopEvent = threading.Event()
gameStore = rl.callbacks.gamestore.GameStore()
plotter = GamePlot(grid_size, grid_size, gameStore.gameQueue)
histPlot = HistoryPlot(epochs, 'win_ratio')

display(ipywidgets.HBox([plotter.canvas, histPlot.hist_plt]))

stopEvent.clear()
plotter.start()

agent.train(game, batch_size=256, epochs=epochs, episodes=100, train_freq=8, target_sync=512,
            epsilon_start=1.0, epsilon_decay=0.5, epsilon_final = 0.01,
            gamma=0.99, reset_memory=False, observe=100, verbose=1,
            callbacks = [gameStore, histPlot])

stopEvent.set()

In [None]:
model.save("catch_dqn.h5")