In [7]:
import os, sys
import time

import pandas as pd
import numpy as np
import tensorflow as tf

import seaborn as sns
import matplotlib.pyplot as plt

sys.path.append('../')
from det_rnn import *
import det_rnn.analysis as da
import det_rnn.train as dt

In [4]:
model_dir = 'D:\proj\det_rnn\output\model_200621_test'
os.makedirs(model_dir, exist_ok=True)

In [None]:
# boost_rnn
par = update_parameters(par)
stimulus = Stimulus()
ti_spec = dt.gen_ti_spec(stimulus.generate_trial())

model_performance = {'perf': [], 'loss': [], 'perf_loss': [], 'spike_loss': []}

# Boosting RNN
N_boost_max = 100000
perf_crit   = 0.95 # Human mean performance level
recency     = 50   # Number of 'recent' epochs to be assayed
boost_step  = 1.5  # How much step should we increase

extend_time = np.arange(boost_step,15.5,step=boost_step)
mileage_lim = len(extend_time)
milestones  = np.zeros((mileage_lim,), dtype=np.int64)
timestones  = np.zeros((mileage_lim,))

# Start boosting
model = dt.initialize_rnn(ti_spec) # initialize RNN to be boosted

mileage = -1
start_time = time.time()
print("RNN Booster started!")
for iter in range(N_boost_max):
    trial_info = dt.tensorize_trial(stimulus.generate_trial())
    Y, Loss = model(trial_info, dt.hp)
    model_performance = dt.append_model_performance(model_performance, trial_info, Y, Loss, par)

    if iter % 30 == 0:
        dt.print_results(model_performance, iter)

    if  dt.level_up_criterion(iter,perf_crit,recency,milestones[mileage],model_performance):
        check_time = time.time()
        mileage += 1
        if mileage >= mileage_lim:
            print("#"*80+"\nTraining criterion finally met!(Time Spent: {:0.2f}s)\t".format(check_time-start_time)+
                  "Now climb down the mountain!\n"+"#"*80)
            break
        milestones[mileage] = iter
        timestones[mileage] = check_time-start_time

        ## Attach to the model
        model.model_performance = dt.tensorize_model_performance(model_performance)
        model.milestones = tf.Variable(milestones, trainable=False)
        model.timestones = tf.Variable(timestones, trainable=False)

        ## save the model
        os.makedirs(model_dir + "/model_level" + str(mileage), exist_ok=True)
        tf.saved_model.save(model, model_dir  + "/model_level" + str(mileage))

        ## upgrade to higher level
        par['design'].update({'iti': (0, 1.5),
                              'stim': (1.5, 3.0),
                              'delay': (3.0, 4.5 + extend_time[mileage]),
                              'estim': (4.5 + extend_time[mileage], 6.0 + extend_time[mileage])})
        par = update_parameters(par)
        stimulus = Stimulus(par)

        ## modulate hyperparameters #######################################################
        # dt.hp['spike_cost'] /= 2.
        ###################################################################################

        ## Report an upgrade has been performed
        print("#"*80+"\nCriterion satisfied!(Time Spent: {:0.2f}s)\t".format(check_time-start_time)+
                     "Now extending: {:0.1f}\n".format(extend_time[mileage])+"#"*80)


RNN Booster started!
Iter.    0 | Performance -0.0081 | Loss 0.1286 | Spike loss 191.9973


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Iter.   30 | Performance 0.0058 | Loss 0.0724 | Spike loss 0.7578
Iter.   60 | Performance -0.0507 | Loss 0.0227 | Spike loss 0.7087
Iter.   90 | Performance 0.0024 | Loss 0.0034 | Spike loss 1.0294
Iter.  120 | Performance 0.0113 | Loss 0.0031 | Spike loss 1.0199
Iter.  150 | Performance -0.0777 | Loss 0.0031 | Spike loss 1.0035
Iter.  180 | Performance -0.0319 | Loss 0.0030 | Spike loss 0.9853
Iter.  210 | Performance -0.0700 | Loss 0.0029 | Spike loss 0.9601
Iter.  240 | Performance 0.0225 | Loss 0.0028 | Spike loss 0.9438
Iter.  270 | Performance 0.0884 | Loss 0.0027 | Spike loss 0.9153
Iter.  300 | Performance -0.0978 | Loss 0.0027 | Spike loss 0.8941
Iter.  330 | Performance 0.0757 | Loss 0.0025 | Spike loss 0.8825
Iter.  360 | Performance -0.0133 | Loss 0.0026 | Spike loss 0.8780
Iter.  390 | Performance 0.0413 | Loss 0.0025 | Spike loss 0.8649
Iter.  420 | Performance 0.0871 | Loss 0.0025 | Spike loss 0.8555
Iter.  450 | Performance -0.0217 | Loss 0.0025 | Spike loss 0.8486
Ite

Iter. 3750 | Performance 0.9145 | Loss 0.0007 | Spike loss 5.5845
Iter. 3780 | Performance 0.9195 | Loss 0.0007 | Spike loss 6.6774
Iter. 3810 | Performance 0.9041 | Loss 0.0007 | Spike loss 6.3911
Iter. 3840 | Performance 0.9325 | Loss 0.0006 | Spike loss 6.6461
Iter. 3870 | Performance 0.9379 | Loss 0.0007 | Spike loss 6.6841
Iter. 3900 | Performance 0.9345 | Loss 0.0006 | Spike loss 6.9670
Iter. 3930 | Performance 0.9374 | Loss 0.0006 | Spike loss 6.0139
Iter. 3960 | Performance 0.9270 | Loss 0.0007 | Spike loss 5.9826
Iter. 3990 | Performance 0.9301 | Loss 0.0007 | Spike loss 5.9967
Iter. 4020 | Performance 0.9290 | Loss 0.0007 | Spike loss 5.5465
Iter. 4050 | Performance 0.9229 | Loss 0.0007 | Spike loss 5.7685
Iter. 4080 | Performance 0.9360 | Loss 0.0006 | Spike loss 6.1586
Iter. 4110 | Performance 0.9274 | Loss 0.0007 | Spike loss 5.9451
Iter. 4140 | Performance 0.9364 | Loss 0.0006 | Spike loss 5.6158
Iter. 4170 | Performance 0.9444 | Loss 0.0005 | Spike loss 5.6624
Iter. 4200