In [None]:
import os, io, threading, time

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

import ipywidgets
import PIL, PIL.Image
import bqplot, bqplot.pyplot
import numpy as np

import tensorflow as tf
import tensorflow.keras as keras
tf.get_logger().setLevel('ERROR')

import rl.games.catch
import rl.games.fruit
import rl.games.snake
import rl.games.tromis
import rl.tools.gui
import rl.callbacks.callbacks
import rl.callbacks.gamestore
import rl.agents.ddqn
import rl.memory.basicmemory
import rl.memory.uniqmemory

In [None]:
class GamePlot(threading.Thread):
    def __init__(self, width, height, gameQueue):
        threading.Thread.__init__(self, name="GamePlot")
        self.gameQueue = gameQueue
        self.imbuf = io.BytesIO()
        rx, ry = (256, int(height*256/width)) if (width > height) else (int(width*256/height), 256)
        self.canvas = ipywidgets.Image(width=rx, height=ry)
        initframe = np.zeros((width,height,3))
        self.plot_frame(initframe)
    def run(self):
        while not stopEvent.isSet():
            try: game = self.gameQueue.pop()
            except IndexError: time.sleep(0.1)
            else:
                for frame in game:
                    self.plot_frame(frame[0])
                    time.sleep(0.1)
                time.sleep(0.5)
    def plot_frame(self, frame):
        self.imbuf.seek(0)
        fx, fy = frame.shape[0], frame.shape[1]
        rx, ry = (256, int(fy*256/fx)) if (fx > fy) else (int(fx*256/fy), 256)
        PIL.Image.fromarray((frame*255).astype('uint8')).resize((ry, rx), resample=PIL.Image.NEAREST).save(self.imbuf, 'gif')
        self.canvas.value = self.imbuf.getvalue()

In [None]:
class HistoryPlot(rl.callbacks.callbacks.Callback):
    def __init__(self, epochs):
        self.win_ratio = []
        self.avg_score = []
        self.epochs = epochs
        self.axes_avg = {'x': {'label': 'Epochs'}, 
                    'y': {'label': 'Avg Score', 'label_offset': '50px', 'tick_style': {'font-size': 11} }
                   }
        self.avg_plt  = bqplot.pyplot.figure()
        self.avg_plt.layout = {"height": "256px", "width": "512px"}
        self.avg_plt.fig_margin = {"top":10, "bottom":30, "left":60, "right":0}
        self.avg_plt.min_aspect_ratio = 512.0/256.0
        self.avg_plt.max_aspect_ratio = 512.0/256.0
        bqplot.pyplot.scales(scales={'x': bqplot.scales.LinearScale(min=0,max=self.epochs)}) #, 'y': bqplot.scales.LinearScale(min=0.0,max=1.0)})
        #bqplot.pyplot.plot([0,self.epochs],[0.0,0.0], axes_options=self.axes_win)
        bqplot.pyplot.plot([0],[0.0], axes_options=self.axes_avg)
        
    def epoch_end(self, *args):
        model, name, epoch, epsilon, win_ratio, avg_score, max_score, memory = args
        self.win_ratio.append(win_ratio)
        self.avg_score.append(avg_score)
        #self.win_plt.marks[0].x = np.asarray(range(len(self.win_ratio)))
        #self.win_plt.marks[0].y = np.asarray(self.win_ratio)
        self.avg_plt.marks[0].x = np.asarray(range(len(self.avg_score)))
        self.avg_plt.marks[0].y = np.asarray(self.avg_score)

In [None]:
nb_frames, grid_size = 2, 12
memory_size = 65536
epochs = 100

#game = rl.games.catch.Catch(grid_size, split_reward=True)
#game = rl.games.fruit.Fruit(grid_size, with_poison=True)
game = rl.games.snake.Snake(grid_size, max_turn=96)

inp = keras.layers.Input(shape=(nb_frames, grid_size, grid_size, 3))
x = keras.layers.Conv3D(32,7,padding='same',strides=1,activation='relu')(inp)
x = keras.layers.Conv3D(64,3,padding='same',strides=1,activation='relu')(x)
x = keras.layers.GlobalMaxPooling3D()(x)
x = keras.layers.Dense(128, activation='relu')(x)
act = keras.layers.Dense(game.nb_actions, activation='linear')(x)

model = keras.models.Model(inputs=inp, outputs=act)
model.compile(keras.optimizers.RMSprop(), keras.losses.LogCosh())
model.summary()

memory = rl.memory.basicmemory.BasicMemory(memory_size=memory_size)
agent = rl.agents.ddqn.Agent(model, memory, with_target=True)

stopEvent = threading.Event()
gameStore = rl.callbacks.gamestore.GameStore()
plotter = GamePlot(grid_size, grid_size, gameStore.gameQueue)
histPlot = HistoryPlot(epochs)

display(ipywidgets.HBox([plotter.canvas, histPlot.avg_plt]))

stopEvent.clear()
plotter.start()

agent.train(game, batch_size=256, epochs=epochs, episodes=50, train_freq=32, target_sync=512,
            epsilon_start=0.5, epsilon_decay=0.5, epsilon_final = 0.0,
            gamma=0.95, reset_memory=False, observe=100, verbose=1,
            callbacks = [gameStore, histPlot])

stopEvent.set()