In [3]:
%load_ext Cython

The Cython extension is already loaded. To reload it, use:
  %reload_ext Cython


In [16]:
%%cython

import sys
import time

import numpy as np
cimport numpy as np

import interface as bb
cimport interface as bb

from scipy.linalg.cython_blas cimport sgemm

cimport cython

cdef float alpha = 1.0, beta = 0.0
cdef float[::1,:] s, y, w0, wpos, wneg

f_sq = [1,2]
f_intr = [1,4,11,12,13]
cdef: 
    int n_sq = len(f_sq)
    int n_intr = len(f_intr)

s = np.empty((1,37 + n_sq + n_intr), np.float32, order="F")
w0 = np.empty((37 + n_sq + n_intr, 4), np.float32, order="F")
wpos = np.empty((37 + n_sq + n_intr, 4), np.float32, order="F")
wneg = np.empty((37 + n_sq + n_intr, 4), np.float32, order="F")

y = np.empty((1,4), np.float32, order="F")

cdef int NUM_CACHE = 51, NUM_ROLLOUT = 50
cdef int cache_i = 0, cache_n = 0
cdef float[::1,:] cache_s, cache_y

cache_s = np.empty((NUM_CACHE,36), np.float32, order="F")
cache_y = np.empty((NUM_CACHE,4), np.float32, order="F")


@cython.boundscheck(False)
cdef void fast_target(float *state, int use_cache = 0):
    global cache_i, cache_n
    cdef int i, c, m, n, k, lda, ldb, ldc
    cdef float s35_sign
    
    if use_cache == 1:
        c = 0
        while c < cache_n:
            i = 0
            while i < 36:
                if cache_s[c,i] != state[i]:
                    break
                i += 1
            if i == 36:
                for i in xrange(4):
                    y[0,i] = cache_y[c,i]
                return
            c += 1
        cache_i += 1
        if cache_i == NUM_CACHE:
            cache_i = 0
        if cache_n < NUM_CACHE:
            cache_n += 1
        for i in xrange(36):
            cache_s[cache_i,i] = state[i]
            s[0,i] = state[i]
        for i in xrange(n_sq):
            s[0,i+36] = state[f_sq[i]]**2
        for i in xrange(n_intr):
            s[0,i+36 + n_sq] = state[f_intr[i]]*state[35]
        s[0,36 + n_sq + n_intr] = 1.
    else:
        for i in xrange(36):
            s[0,i] = state[i]
        for i in xrange(n_sq):
            s[0,i+36] = state[f_sq[i]]**2
        for i in xrange(n_intr):
            s[0,i+36 + n_sq] = state[f_intr[i]]*state[35]
        s[0,36 + n_sq + n_intr] = 1.
    
    lda = 1
    ldb = 37 + len(f_sq) + len(f_intr)
    ldc = 1
    m = 1
    n = 4
    k = 37 + len(f_sq) + len(f_intr)
    
    s35_sign = round(state[35]*10)
    if s35_sign == 0:
        sgemm("N", "N", &m, &n, &k, &alpha, &s[0,0], &lda, &w0[0,0], &ldb, &beta, &y[0,0], &ldc)
    elif s35_sign > 0:
        sgemm("N", "N", &m, &n, &k, &alpha, &s[0,0], &lda, &wpos[0,0], &ldb, &beta, &y[0,0], &ldc)
    elif  s35_sign < 0:
        sgemm("N", "N", &m, &n, &k, &alpha, &s[0,0], &lda, &wneg[0,0], &ldb, &beta, &y[0,0], &ldc)
    
    if use_cache == 1:
        for i in xrange(4):
            cache_y[cache_i,i] = y[0,i]
    

@cython.boundscheck(False)
cdef int fast_action(float *state, int use_cache = 0):
    cdef int i, best_act = -1
    cdef best_val = -1e9
    fast_target(state, use_cache)
    for i in xrange(4):
        if y[0,i] > best_val:
            best_val = y[0,i]
            best_act = i
    return best_act


@cython.boundscheck(False)
cdef float fast_value(float *state):
    cdef int i
    cdef best_val = -1e9
    fast_target(state, 1)
    for i in xrange(4):
        if y[0,i] > best_val:
            best_val = y[0,i]
    return best_val

@cython.boundscheck(False)
def dump_weights(weights):
    cdef int i, j
    for k, v in weights.iteritems():
        for i in xrange(4):
            for j in xrange(37 + n_sq + n_intr):
                if k == 0:
                    w0[j,i] = v[j,i]
                elif k == -1:
                    wneg[j,i] = v[j,i]
                elif k == 1:
                    wpos[j,i] = v[j,i]


def prepare_bbox(level='train', verbose=0):
    global cache_i, cache_n
    cache_i = 0
    cache_n = 0
    if bb.is_level_loaded():
        bb.reset_level()
    bb.load_level('../levels/'+level+'_level.data', verbose)


cdef float _rewards[4]
cdef float _mask[4]

@cython.boundscheck(False)
cdef crollout(int epoch=0, float curriculum=0.7):
    cdef:
        int i, a, action, has_next, checkpoint_id, has_change
        float r, prev_score, init_state35, next_state35, next_state35_abs, prev_state35_abs
        float *state
    
    init_state35 = bb.c_get_state()[35]
    checkpoint_id = bb.create_checkpoint()
   
    for a in xrange(4):
        
        _rewards[a] = 0
        _mask[a] = 0
        
        prev_score = bb.c_get_score()
        has_next = bb.c_do_action(a)
        state = bb.c_get_state()
        next_state35 = state[35]  
        
        if init_state35 != next_state35 or np.random.rand() < curriculum:
                        
            r = bb.c_get_score() - prev_score
            prev_score = bb.c_get_score()
            
            if has_next == 1:
                for i in xrange(NUM_ROLLOUT-1):
                    if epoch > 0:
                        action = fast_action(state, 1)
                    else:
                        action = 3

                    has_next = bb.c_do_action(action)
                    r += bb.c_get_score() - prev_score
                    state = bb.c_get_state()
                    prev_score = bb.c_get_score()
                    if has_next == 0:
                        break
                
                if has_next == 1 and epoch > 0:
                    r += fast_value(state)

            _rewards[a] = r
            _mask[a] = 1
        
        bb.load_from_checkpoint(checkpoint_id)
    bb.clear_all_checkpoints()


@cython.boundscheck(False)
def rollout(epoch=0, curriculum=0.7):
    cdef int i
    crollout(epoch, curriculum)
    rewards = np.empty(4, dtype=np.float32)
    mask = np.empty(4, dtype=np.float32)
    for i in xrange(4):
        rewards[i] = _rewards[i]
        mask[i] = _mask[i]
    return rewards, mask


def solve_lsq(X, y, lmd = 1):
    #regularization
    if lmd >0:
        Xsq = X.T.dot(X)
        I = np.diag([1]*Xsq.shape[0])
        I[-1,-1] = 0
        return np.linalg.inv(Xsq + lmd*I).dot(X.T.dot(y))
    else:
        return np.linalg.inv(X.T.dot(X)).dot(X.T.dot(y))

def train_epoch(X, Y, M):
    
    f0 = np.round(X[:, 35]*10) == 0 
    fpos = np.round(X[:, 35]*10) > 0
    fneg = np.round(X[:, 35]*10) < 0
    fs = [fneg, f0, fpos]
    
    weights = {}
    for  k in (-1, 0, 1):
        weights_tmp = []
        for i in xrange(4):
            m = M[:,i]
            y = Y[m & fs[k+1] ,i]
            x = X[m & fs[k+1]]
            weights_tmp.append(solve_lsq(x,y))

        weights[k] = np.array(weights_tmp).T.astype(np.float32)
    return weights

@cython.boundscheck(False)
def get_train_data(weights):
    cdef int action
    
    dump_weights(weights)
    d = {}

    for lvl in ('train', 'test'):
        
        prepare_bbox(lvl)
        Xlvl = []
        Ylvl = []
        while True:
            rewards, mask = rollout(100, 1)
            state = bb.get_state().copy()
            Xlvl.append(state)
            Ylvl.append(rewards)
            action = fast_action(bb.c_get_state(), 1)

            if bb.c_do_action(action) == 0:
                train_score = bb.finish(verbose=1)
                break

        d[lvl+'X'] = np.array(Xlvl).astype(np.float16)
        d[lvl+'Y'] = np.array(Ylvl).astype(np.float16).T
    
    return d

In [17]:
import cPickle
import numpy as np
with open('weights_reg1.pkl', 'rb') as f:
    W = cPickle.load(f)

In [18]:
d = get_train_data(W)

Level score= 3035.154053
Level score= 3550.278076


In [25]:
with open('regr_data16.pkl', 'wb') as f:
    cPickle.dump(d, f, -1)