In [1]:
from IPython.display import display, clear_output
import numpy as np
import pyopencl as cl
from PIL import Image
import time
import signal
import re

In [35]:
from skvideo.io import FFmpegWriter
from IPython.display import HTML
import base64
import io

In [2]:
signal_done = False

def signal_handler(signal, frame):
    global signal_done
    signal_done = True

def stop_on_signal():
    global signal_done
    signal_done = False
    signal.signal(signal.SIGINT, signal_handler)

In [3]:
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
mf = cl.mem_flags

In [4]:
def exp_prob(one_step_prob, n_steps):
    return 1.0 - (1.0 - one_step_prob)**n_steps

In [5]:
class Buffer:
    def __init__(self, nparray, ro=False, dual=False):
        self.ro = ro
        self.dual = dual
        self.host = nparray
        flags = 0
        if ro:
            flags |= mf.READ_ONLY
        else:
            flags |= mf.READ_WRITE
        mkbuf = lambda buf: cl.Buffer(ctx, flags | mf.COPY_HOST_PTR, hostbuf=buf)
        self.buf = mkbuf(self.host)
        if dual:
            self.dbuf = mkbuf(self.host)
    
    def swap(self):
        if self.dual:
            self.buf, self.dbuf = self.dbuf, self.buf
    
    def load(self):
        cl.enqueue_copy(queue, self.host, self.buf)

In [295]:
class World:
    def __init__(self, size, n_agents, param, net=None):
        self.size = size
        self.n_agents = (n_agents,)
        self.param = param
        
        self.step_count = 0
        self.avg_score = 0.0
        self.avg_varexp = 0.0
        self.plants_count = 0
        
        self.w_shape = (size[1], size[0])
        self.a_shape = (n_agents,)
        
        self.buffers = {}
        
        # screen
        self.buffers["w_screen"] = Buffer(np.zeros((*self.w_shape, 3), dtype=np.uint8))
        
        # random
        self.buffers["a_random"] = Buffer(np.random.randint(1<<32, size=self.a_shape, dtype=np.uint32))
        self.buffers["w_random"] = Buffer(np.random.randint(1<<32, size=self.w_shape, dtype=np.uint32))
        
        # world
        self.init_world()
        self.init_agent_outer()
        
        # network
        net_param = {
            "rnn_sx": 11,
            "rnn_sh": 16,
            "rnn_sy": 5,
        }
        self.init_net(net_param, net=net)
        self.init_agent_inner()
        
        # parameters        
        self.constants = {
            "WORLD_SIZE_X": self.size[0],
            "WORLD_SIZE_Y": self.size[1],
            "SIZE_A_AGENT_I": self.buffers["a_agents_i"].host.shape[-1],
            "SIZE_A_AGENT_F": self.buffers["a_agents_f"].host.shape[-1],
            "SIZE_W_CACHE_I": self.buffers["w_cache_i"].host.shape[-1],
            "SIZE_W_CACHE_F": self.buffers["w_cache_f"].host.shape[-1],
            "SIZE_W_TRACE_F": self.buffers["w_trace_f"].host.shape[-1],
            "SIZE_W_OBJECT_I": self.buffers["w_object_i"].host.shape[-1],
            "SIZE_A_RNN_F": self.buffers["a_rnn_f"].host.shape[-1],
            "RNN_SIZE_X": self.net_param["rnn_sx"],
            "RNN_SIZE_H": self.net_param["rnn_sh"],
            "RNN_SIZE_Y": self.net_param["rnn_sy"],
            "AGENT_SELECT_N": self.param["agent_selection_size"],
        }
        
        self.buffers["PAR_I"] = Buffer(np.array([
            self.param["animal_sensor_length"],
        ], dtype=np.int32), ro=True)
        
        wperiod = self.param["world_step_period"];
        self.buffers["PAR_F"] = Buffer(np.array([
            exp_prob(self.param["trace_fade_factor"], wperiod),
            exp_prob(self.param["trace_diffusion_factor"], wperiod),
            self.param["trace_animal_factor"],
            exp_prob(self.param["trace_plant_factor"], wperiod),
            exp_prob(self.param["plant_appear_prob"], wperiod),
            self.param["weight_variation_factor"],
            self.param["selection_prob"],
            self.param["softmax_temperature"],
            self.param["selection_temperature"],
        ], dtype=np.float32), ro=True)
        
        self.build()
    
    def init_world(self):
        self.buffers["w_object_i"] = Buffer(np.zeros((*self.w_shape, 1), dtype=np.int32))
        self.buffers["w_trace_f"] = Buffer(np.zeros((*self.w_shape, 3), dtype=np.float32))
        self.buffers["w_cache_i"] = Buffer(np.zeros((*self.w_shape, 4), dtype=np.int32))
        self.buffers["w_cache_f"] = Buffer(np.zeros((*self.w_shape, 4), dtype=np.float32))
    
    def init_agent_outer(self):
        a_pos = np.stack((
            np.random.randint(0, self.size[0], size=self.a_shape),
            np.random.randint(0, self.size[1], size=self.a_shape),
        ), axis=1)
        a_dir = np.random.randint(0, 4, size=(*self.a_shape, 1))
        a_score = np.zeros((*self.a_shape, 2))
        self.buffers["a_agents_i"] = Buffer(np.concatenate((a_pos, a_dir, a_score), axis=1).astype(np.int32))
    
    def init_agent_inner(self):
        a_ve = np.zeros((*self.a_shape, 1))
        a_h = np.zeros((*self.a_shape, self.net_param["rnn_sh"]))
        self.buffers["a_agents_f"] = Buffer(np.concatenate((a_ve, a_h), axis=1).astype(np.float32))
        
    def init_net(self, param, net=None):
        self.net_param = param
        rnn_sx, rnn_sh, rnn_sy = param["rnn_sx"], param["rnn_sh"], param["rnn_sy"]
        if net is None:
            rnn_wim = 1e-1
            wxh = rnn_wim*np.random.randn(*self.a_shape, (rnn_sx+1)*rnn_sh)
            whh = rnn_wim*np.random.randn(*self.a_shape, rnn_sh*rnn_sh)
            why = rnn_wim*np.random.randn(*self.a_shape, (rnn_sh+1)*rnn_sy)
            self.buffers["a_rnn_f"] = Buffer(np.concatenate((wxh, whh, why), axis=1).astype(np.float32))
        else:
            assert net.shape[1] == (rnn_sx+1)*rnn_sh + rnn_sh*rnn_sh + (rnn_sh+1)*rnn_sy
            idxs = np.random.randint(net.shape[0], size=self.a_shape)
            self.buffers["a_rnn_f"] = Buffer(np.copy(net[idxs]).astype(np.float32))
    
    def build(self):
        with open("simple-rnn.cl", "r") as f:
            source = f.read()
            for k, v in self.constants.items():
                source = re.sub("(#define *%s)" % k, "\g<0> %s" % str(v), source)
            self.program = cl.Program(ctx, source).build()
        
    def step(self):
        if self.param["selection_period"] != 0 and (self.step_count % self.param["selection_period"]) == 0:
            self.program.a_select(
                queue,
                self.n_agents,
                None,

                self.buffers["PAR_I"].buf,
                self.buffers["PAR_F"].buf,

                self.buffers["a_random"].buf,

                self.buffers["a_agents_i"].buf,
                self.buffers["a_agents_f"].buf,
                self.buffers["a_rnn_f"].buf,
            )
            
        if self.param["disaster_period"] != 0 and (self.step_count % self.param["disaster_period"]) == 0:
            self.init_world()
            self.init_agent_outer()
        
        if (self.step_count % self.param["world_step_period"]) == 0:
            self.program.w_step_read(
                queue,
                self.size,
                None,

                self.buffers["PAR_I"].buf,
                self.buffers["PAR_F"].buf,
                
                self.buffers["w_random"].buf,
                
                self.buffers["w_cache_i"].buf,
                self.buffers["w_cache_f"].buf,

                self.buffers["w_object_i"].buf,
                self.buffers["w_trace_f"].buf,
            )
            
            self.program.w_step_write(
                queue,
                self.size,
                None,

                self.buffers["PAR_I"].buf,
                self.buffers["PAR_F"].buf,
                
                self.buffers["w_random"].buf,
                
                self.buffers["w_cache_i"].buf,
                self.buffers["w_cache_f"].buf,

                self.buffers["w_object_i"].buf,
                self.buffers["w_trace_f"].buf,
            )
        
        self.program.a_step(
            queue,
            self.n_agents,
            None,
            
            self.buffers["PAR_I"].buf,
            self.buffers["PAR_F"].buf,
            
            self.buffers["a_random"].buf,
            
            self.buffers["a_agents_i"].buf,
            self.buffers["a_agents_f"].buf,
            self.buffers["a_rnn_f"].buf,
            
            self.buffers["w_object_i"].buf,
            self.buffers["w_trace_f"].buf,
        )
        
        self.step_count += 1;
        
    def draw(self):
        self.program.w_draw(
            queue,
            self.size,
            None,
            
            self.buffers["PAR_I"].buf,
            self.buffers["PAR_F"].buf,
            
            self.buffers["w_object_i"].buf,
            self.buffers["w_trace_f"].buf,
            self.buffers["w_screen"].buf,
        )
        
        self.program.a_draw(
            queue,
            self.n_agents,
            None,
            
            self.buffers["PAR_I"].buf,
            self.buffers["PAR_F"].buf,
            
            self.buffers["a_agents_i"].buf,
            self.buffers["a_agents_f"].buf,
            
            self.buffers["w_screen"].buf,
        )
        
        self.buffers["w_screen"].load()
        return self.buffers["w_screen"].host
    
    def fetch_stats(self):
        self.buffers["a_agents_i"].load()
        self.buffers["a_agents_f"].load()
        self.avg_score = np.mean(self.buffers["a_agents_i"].host[:,3])
        self.avg_varexp = np.mean(self.buffers["a_agents_f"].host[:,0])
        
        self.buffers["w_object_i"].load()
        self.plants_count = np.sum(self.buffers["w_object_i"].host != 0)
    
    def dump_net(self, count):
        self.buffers["a_rnn_f"].load()
        net = self.buffers["a_rnn_f"].host
        idxs = np.random.randint(net.shape[0], size=(count,))
        return np.copy(net[idxs])

In [302]:
world_param = {
    "world_step_period": 100,
    "plant_appear_prob": 1e-7,
    "animal_sensor_length": 6,
    "trace_animal_factor": 0.5,
    "trace_plant_factor": 0.5,
    "trace_fade_factor": 0.0002,
    "trace_diffusion_factor": 0.002,
    "selection_prob": 0.5,
    "selection_period": 1000,
    "agent_selection_size": 16,
    "weight_variation_factor": 1e-1,
    "softmax_temperature": 1.0,
    "selection_temperature": 1.0,
    "disaster_period": 50000,
}

In [303]:
world = World((4000, 4000), 1024, param=world_param) #, net=world.dump_net(1024))

In [311]:
stop_on_signal()
last = time.time()
draw = False
while not signal_done:
    world.step()
    now = time.time()
    if now - last >= 10.0:
        clear_output(wait=True)
        if draw:
            img = Image.fromarray(world.draw())
            h = 600
            w = int((h/img.size[1])*img.size[0])
            img = img.resize((w, h), Image.ANTIALIAS)
            display(img)
        world.fetch_stats()
        print("steps elapsed: %s" % world.step_count)
        print("average score: %s" % world.avg_score)
        print("average varexp: %s" % world.avg_varexp)
        print("plants count: %s" % world.plants_count)
        last = now

steps elapsed: 3485435
average score: 0.748046875
average varexp: -1.638028
plants count: 3899


In [318]:
np.save("simple-rnn-net", world.dump_net(32))

In [313]:
tiny_world_param = dict(world_param)
tiny_world_param.update({
    #"selection_period": 0,
    "disaster_period": 0
})

In [314]:
tiny_world = World((800, 600), 32, param=tiny_world_param, net=world.dump_net(256))

In [315]:
for j in range(5000):
    tiny_world.step()

In [316]:
params = {
    "-vcodec": "libx264",
    "-pix_fmt": "yuv420p",
    "-profile:v": "baseline",
    "-level": "3"
}
video = FFmpegWriter("tmp.mp4", outputdict=params)
stride = 10
for i in range(20*24):
    for j in range(stride):
        tiny_world.step()
    img = tiny_world.draw()
    video.writeFrame(img)
video.close()

In [None]:
with open("tmp.mp4", "rb") as f:
    vdata = f.read()
vbase64 = base64.b64encode(vdata).decode("ascii")
HTML('<video controls src="data:video/mp4;base64,%s" type="video/mp4" >' % vbase64)