In [None]:
import time, threading, collections, io

import ipywidgets
import PIL.Image

import numpy as np
import keras
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

import catch, snake

In [None]:
stopEvent = threading.Event()
currentGame = collections.deque([], 1)
class GamePlot(threading.Thread):
    def __init__(self, ratio=1.0):
        threading.Thread.__init__(self, name="GamePlot")
        self.ratio=ratio
        self.imbuf = io.BytesIO()
        self.img = ipywidgets.Image(width=int(ratio*256), height=256)
        display(self.img)
        self.plot_frame(np.zeros((2,2), np.uint8))
    def run(self):
        while not stopEvent.isSet():
            try: game = currentGame.pop()
            except IndexError: time.sleep(0.1)
            else:
                for frame in game:
                    self.plot_frame(frame)
                    time.sleep(0.1)
                time.sleep(0.5)
    def plot_frame(self, frame):
        self.imbuf.seek(0)
        frame *= 255
        PIL.Image.fromarray(frame.astype('uint8')).\
                  resize((int(self.ratio*256), 256)).\
                  save(self.imbuf, 'png')
        self.img.value = self.imbuf.getvalue()
        
class GameStore:
    def __init__(self, ratio=1.0):
        currentGame.append([])
        self.plotter = GamePlot(ratio)
        self.plotter.start()
    def game_start(self, frame):
        self.gamestore = [frame]
    def game_frame(self, frame):
        self.gamestore.append(frame)
    def game_over(self):
        currentGame.append(self.gamestore)

In [None]:
grid_size, width, height, nb_frames = 12, 12, 12, 1
game = catch.Catch(grid_size, movement_cost=-0.01)

#grid_size, width, height, nb_frames = 10, 10, 10, 1
#game = snake.Snake(grid_size)

learning_rate = 0.001
def loss_fn(y_true, y_pred):
    return -1.0 * keras.backend.sum(y_true * keras.backend.log(y_pred+1e-20))

def model_dnn():
    inp = keras.layers.Input(shape=(nb_frames, height, width, 3), name='state_input')
    flt = keras.layers.Flatten(name='flat')(inp)
    x = keras.layers.Dense(16, activation='relu', name='dense1')(flt)
    x = keras.layers.Dense(16, activation='relu', name='dense2')(x)
    act = keras.layers.Dense(game.nb_actions, activation='softmax', name='actions')(x)

    model = keras.models.Model(inputs=inp, outputs=act, name='PG_DNN')
    model.compile(loss=loss_fn, optimizer=keras.optimizers.RMSprop(learning_rate))
    
    model.summary()
    return model

def model_rcnn():
    cs = 32
    inpc = keras.layers.Input(shape=(height, width, 3), name='conv_input')
    conv1 = keras.layers.Conv2D(cs,3,padding='same',strides=2,activation='relu', name='conv1')(inpc)
    conv2 = keras.layers.Conv2D(cs*2,3,padding='same',strides=2,activation='relu', name='conv2')(conv1)
    flat = keras.layers.Flatten(name='flatten')(conv2)
    convm = keras.models.Model(inputs=inpc, outputs=flat, name='CONV_BASE')
    convm.summary()

    ls = cs
    inp = keras.layers.Input(shape=(nb_frames, height, width, 3), name='state_input')
    x = keras.layers.TimeDistributed(convm, name='conv_distributed')(inp)
    x = keras.layers.SimpleRNN(ls, return_sequences=False, name='rnn')(x)
    act = keras.layers.Dense(game.nb_actions, activation='softmax', name='actions')(x)

    model = keras.models.Model(inputs=inp, outputs=act, name='PG_RCNN')
    model.compile(loss=loss_fn, optimizer=keras.optimizers.RMSprop(learning_rate))

    model.summary()
    return model

def train(model, game, episodes=512, log_freq = 100, gamma=0.98, callbacks=[]):
    time_steps = []
    win_stats = []
    loss_stats = []

    for episode in range(episodes):
        game.reset()
        start_frame = game.get_frame()
        all(c.game_start(start_frame) for c in callbacks)
        curr_state = game.get_state()
        done = False
        transitions = [] # list of state, action, rewards
    
        for t in range(game.max_turn): #while in episode
            act_prob = model.predict(np.expand_dims(np.asarray([curr_state]), axis=0))
            action = np.random.choice(np.array([0,1,2]), p=act_prob[0])
            prev_state = curr_state
            curr_state, reward, done = game.play(action)
            curr_frame = game.get_frame()
            all(c.game_frame(curr_frame) for c in callbacks)
            transitions.append((prev_state, action, reward))
            if done:
                all(c.game_over() for c in callbacks)
                win_stats.append(1 if game.is_won() else 0)
                break

        # Optimize policy network with full episode
        ep_len = len(transitions) # episode length
        discounted_rewards = np.zeros((ep_len, game.nb_actions))
        train_states = []
        for i in range(ep_len): #for each step in episode
            discount = 1.0
            future_reward = 0.0
            # discount rewards
            for i2 in range(i, ep_len):
                future_reward += transitions[i2][2] * discount
                discount = discount * gamma
            discounted_rewards[i][transitions[i][1]] = future_reward
            train_states.append([transitions[i][0]])
        train_states = np.asarray(train_states)
        # Backpropagate model with preds & discounted_rewards here
        loss = model.train_on_batch(train_states, discounted_rewards)
        loss_stats.append(loss)
    
        if len(win_stats) >= log_freq:
            print("Episode {: 6d} Win perc {: 4.2%} Loss {: 2.8f}".format(
                episode+1, sum(win_stats)/float(log_freq), sum(loss_stats)/float(log_freq))
            )
            win_stats = []
            loss_stats = []
        

In [None]:
%%time
model = model_rcnn()
game.max_turn=16
stopEvent.clear()
train(model, game, episodes=10000, log_freq=1000, gamma=0.95, callbacks = [GameStore(ratio=width/height)])
stopEvent.set()
print()