In [None]:
import os, sys, io, threading, collections, time
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
import ipywidgets
import PIL, PIL.Image
import tensorflow as tf
import tensorflow.keras as keras
tf.compat.v1.disable_eager_execution()
tf.get_logger().setLevel('ERROR')

import KerasTools.rl.games.catch
import KerasTools.rl.games.fruit
import KerasTools.rl.games.snake
import KerasTools.rl.games.tromis
import KerasTools.rl.tools.gui
import KerasTools.rl.callbacks.callbacks
import KerasTools.rl.agents.ddqn
import KerasTools.rl.memory.basicmemory
import KerasTools.rl.memory.uniqmemory

In [None]:
class GameStore(KerasTools.rl.callbacks.callbacks.Callback):
    def __init__(self, gameQueue):
        self.gameQueue = gameQueue
        self.gameQueue.clear()
    def game_start(self, frame):
        self.episode = [(frame, None, 0.0, False)]
    def game_step(self, frame, action, reward, isOver):
        self.episode.append((frame, action, reward, isOver))
    def game_over(self):
        self.gameQueue.clear()
        self.gameQueue.append(self.episode)
        
class GamePlot(threading.Thread):
    def __init__(self, width, height, gameQueue):
        threading.Thread.__init__(self, name="GamePlot")
        self.gameQueue = gameQueue
        self.imbuf = io.BytesIO()
        rx, ry = (256, int(height*256/width)) if (width > height) else (int(width*256/height), 256)
        self.canvas = ipywidgets.Image(width=rx, height=ry)
    def run(self):
        while not stopEvent.isSet():
            try: game = self.gameQueue.pop()
            except IndexError: time.sleep(0.1)
            else:
                for frame in game:
                    self.plot_frame(frame[0])
                    time.sleep(0.1)
                time.sleep(0.5)
    def plot_frame(self, frame):
        self.imbuf.seek(0)
        fx, fy = frame.shape[0], frame.shape[1]
        rx, ry = (256, int(fy*256/fx)) if (fx > fy) else (int(fx*256/fy), 256)
        PIL.Image.fromarray((frame*255).astype('uint8')).resize((ry, rx), resample=PIL.Image.NEAREST).save(self.imbuf, 'gif')
        self.canvas.value = self.imbuf.getvalue()

In [None]:
def build_model(nb_frames, width, height, nb_actions):
    inp = keras.layers.Input(shape=(nb_frames, height, width, 3))
    
    x = keras.layers.Conv3D(32,3,padding='same',strides=1,activation='relu')(inp)    
    x = keras.layers.Conv3D(64,3,padding='same',strides=1,activation='relu')(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dense(128, activation='relu')(x)
    act = keras.layers.Dense(nb_actions, activation='linear')(x)

    model = keras.models.Model(inputs=inp, outputs=act)
    model.compile(keras.optimizers.RMSprop(), keras.losses.LogCosh())
    return model

def build_drqn(nb_frames, width, height, nb_actions):
    inpc = keras.layers.Input(shape=(height, width, 3))
    conv1 = keras.layers.Conv2D(16, (3, 3), activation='relu', padding='same')(inpc)
    conv2 = keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(conv1)
    gpool = keras.layers.GlobalMaxPooling2D()(conv2)
    convm = keras.models.Model(inputs=inpc, outputs=gpool)
    
    inp = keras.layers.Input(shape=(nb_frames, height, width, 3))
    x = keras.layers.TimeDistributed(convm)(inp)
    x = keras.layers.GRU(32, return_sequences=False)(x)
    act = keras.layers.Dense(nb_actions, activation='linear')(x)

    model = keras.models.Model(inputs=inp, outputs=act)
    model.compile(keras.optimizers.RMSprop(), keras.losses.LogCosh())
    return model


In [None]:
nb_frames, grid_size = 1, 12
memory_size = 4096
game = KerasTools.rl.games.catch.Catch(grid_size, with_penalty=False, split_reward=False)
#game = KerasTools.rl.games.fruit.Fruit(grid_size, max_turn=grid_size*2, fixed=False, with_border=True, with_poison=True, with_penalty=False)
#game = KerasTools.rl.games.snake.Snake(grid_size, max_turn=64)
model = build_model(nb_frames, grid_size, grid_size, game.nb_actions)

# Tromis has asymmetric game board
#nb_frames, width, height = 4, 6, 9
#memory_size = 8196
#game = KerasTools.rl.games.tromis.Tromis(width, height, max_turn=128)
#model = build_drqn(nb_frames, width, height, game.nb_actions)

model.summary()
m = KerasTools.rl.memory.uniqmemory.UniqMemory(memory_size=memory_size)
a = KerasTools.rl.agents.ddqn.Agent(model=model, mem=m, with_target=True)

stopEvent = threading.Event()
gameQueue = collections.deque([], 1)
gameStore = GameStore(gameQueue)
plotter = GamePlot(grid_size, grid_size, gameQueue)
#plotter = GamePlot(width, height, gameQueue)
display(plotter.canvas)

stopEvent.clear()
plotter.start()

a.train(game, batch_size=64, epochs=50, episodes=32, target_sync=512,
            epsilon_start=1.0, epsilon_decay=0.5, epsilon_final = 0.0,
            gamma=0.98, reset_memory=False, observe=256, verbose=1,
            callbacks = [gameStore])

stopEvent.set()