In [5]:
import sys
import os
import logging
import logging.config
import json
import shutil
import matplotlib as mpl
import matplotlib.pyplot as plt
import pickle as pkl

# mpl.use('Agg')

import numpy as np

import copy
import sqlite3
import git
import torch
import torchvision
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from misc.database import Database

from prediction_model import utils
from prediction_model.models.lstm_l0 import L0SeqLSTM
from prediction_model.models.fully_connected_l0 import L0FCN
from prediction_model.models.ntm_aio_l0 import EncapsulatedNTML0
from prediction_model.models.ntm_aio import EncapsulatedNTM
from prediction_model.objectives.sse import SSE
from prediction_model.objectives.logistic_loss import LogisticLoss

# plotting tools
from bokeh.io import output_notebook, show
from bokeh.layouts import gridplot
from bokeh.plotting import figure
from bokeh.models import CustomJS, Slider, ColumnDataSource, Whisker

from datetime import datetime
from types import SimpleNamespace

output_notebook()

In [6]:
def model_forward_fcn(X_train, y_train, model_p, objective_p, optimizer_p, l0_reg, backward=True, enable_gpu=False):
    model_out = []
    total_log_loss = None
    total_penalty = None
    
    X_train = torch.FloatTensor(X_train)
    y_train = torch.FloatTensor(y_train)
    if enable_gpu:
        X_train = X_train.cuda()
        y_train = y_train.cuda()

    out, (hidden, penalty) = model_p.forward(X_train)
#         ht, ct = hidden
    # print(x_t1, out)
    out = torch.clamp(out, -8, 8)
    out = F.sigmoid(out)
    # print(out)

#     print(out.shape)
#     print(y_train.shape)
    log_loss = objective_p(out, y_train)

    if total_log_loss is None:
        total_log_loss = log_loss
        total_penalty = penalty
    else:
        total_log_loss += log_loss
        total_penalty += penalty

    model_out.append(out.cpu().data.numpy())

    total_loss = total_log_loss + l0_reg * total_penalty

    if backward:
        model.zero_grad()
        total_loss.backward()
        optimizer_p.step()

    if backward:
        total_penalty = np.sum(total_penalty.cpu().data.numpy())

    return model, \
           (np.sum(total_log_loss.cpu().data.numpy()), total_penalty), \
           model_out

In [7]:
def logistic_map(x, alpha):
    return alpha*x*(1-x)

def linear_map(x, alpha):
    return alpha*x

In [None]:
seed = 0
alpha = 1
enable_gpu = True
l0_reg = 0.001

# for seed in range(args.n_seeds):
utils.common.set_global_seeds(seed)

model = L0FCN(input_bins=1,
              output_bins=1,
              fc_config=[20]*2,
              enable_gpu=True,
              test_mode=False)

if enable_gpu:
    model = model.cuda()

optimizer = optim.Adam(model.parameters(),
                       lr=1e-4,
                       weight_decay=0)

objective = SSE()

prev_loss = 0
epoch = 0
best_model = None
best_acc = 100
best_nparams = 0

while True:
    X_train = np.random.rand(50)
    y_train = np.array(list(map(lambda x: logistic_map(x, alpha), X_train)))

    X_train = X_train.reshape((-1, 1))
    y_train = y_train.reshape((-1, 1))

    model.train()
    model, (log_loss, penalty), model_out = model_forward_fcn(X_train,
                                                              y_train,
                                                              model,
                                                              objective,
                                                              optimizer,
                                                              l0_reg,
                                                              enable_gpu=enable_gpu)
    total_loss = (log_loss + l0_reg * penalty)

    # Validation
    model.eval()

    _, _, model_out = model_forward_fcn(X_train,
                                        y_train,
                                        model,
                                        objective,
                                        optimizer,
                                        l0_reg,
                                        enable_gpu=enable_gpu,
                                        backward=False)
#     predictions = (np.array(model_out) > 0.5).astype(int)
    predictions = np.array(model_out).flatten()
    actual = y_train.flatten()
    # print(actual == predictions)
    accuracy = np.sum(actual == predictions) / len(actual)
    model_nparams = model.get_l0_norm()

    # if accuracy > best_acc or (accuracy == best_acc and best_nparams > model_nparams):
    if log_loss <= best_acc and not (np.isnan(log_loss) or np.isnan(penalty)):
        best_model = copy.deepcopy(model.state_dict())
        best_acc = log_loss
        best_nparams = model_nparams
    div = X_train.shape[0]
    # print diagnostics
    if epoch % 100 == 0:
        model_out = np.array(model_out).flatten()
        print("")
        print("epoch: {} SSE loss: {}, l0 penalty: {} total loss: {}".format(epoch,
                                                                              log_loss / div,
                                                                              penalty / div,
                                                                              total_loss))
        print("accuracy: {}, actual: {}, predicted: {}, n_params: {} model_out: {}".format(accuracy,
                                                                                                 X_train.flatten(),
                                                                                                 predictions,
                                                                                                 model.get_l0_norm(),
                                                                                                 model_out))

    if np.abs(total_loss - prev_loss) <= 1e-9:
        print("Training converged: {}".format(np.abs(total_loss - prev_loss)))
        break

    prev_loss = total_loss
    epoch += 1

    if np.isnan(log_loss) or np.isnan(penalty):
        print("NaN detected")
        break
#     break

fc0
Gate mode: unique
fc1
Gate mode: unique
Gate mode: unique

epoch: 0 SSE loss: 0.10524124145507813, l0 penalty: 8.002297973632812 total loss: 5.662176971435547
accuracy: 0.0, actual: [0.5488135  0.71518937 0.60276338 0.54488318 0.4236548  0.64589411
 0.43758721 0.891773   0.96366276 0.38344152 0.79172504 0.52889492
 0.56804456 0.92559664 0.07103606 0.0871293  0.0202184  0.83261985
 0.77815675 0.87001215 0.97861834 0.79915856 0.46147936 0.78052918
 0.11827443 0.63992102 0.14335329 0.94466892 0.52184832 0.41466194
 0.26455561 0.77423369 0.45615033 0.56843395 0.0187898  0.6176355
 0.61209572 0.616934   0.94374808 0.6818203  0.3595079  0.43703195
 0.6976312  0.06022547 0.66676672 0.67063787 0.21038256 0.1289263
 0.31542835 0.36371077], predicted: [0.4910428  0.49057236 0.4908599  0.49105614 0.49139714 0.49071363
 0.4913607  0.49023575 0.49009866 0.49150643 0.49042645 0.4911103
 0.49097762 0.49017122 0.49235564 0.49231192 0.49249384 0.49034846
 0.49045232 0.49027723 0.49007136 0.4904123 


epoch: 500 SSE loss: 0.08815521240234375, l0 penalty: 7.97047607421875 total loss: 4.806284423828125
accuracy: 0.0, actual: [0.8812033  0.53122103 0.8680575  0.77755907 0.96932112 0.92571412
 0.16927561 0.57996402 0.02856338 0.68828411 0.54150551 0.14278145
 0.66640611 0.36708941 0.09459162 0.00299151 0.60631975 0.79011012
 0.47399445 0.36391856 0.75603037 0.0063267  0.80624946 0.14631303
 0.2934198  0.61383733 0.8059898  0.14708033 0.66029159 0.40847581
 0.25638414 0.39374154 0.38463804 0.83570583 0.53699062 0.84165466
 0.16601442 0.83759583 0.3109018  0.71662891 0.49894272 0.59874476
 0.64255778 0.64013713 0.93336891 0.68057536 0.57611589 0.27045168
 0.83392349 0.2913378 ], predicted: [0.47267935 0.4732891  0.4726977  0.4728248  0.47256207 0.47261703
 0.47463748 0.47313565 0.4750991  0.4729502  0.47325674 0.47474164
 0.47298092 0.4738727  0.47491825 0.47516906 0.47306532 0.4728071
 0.47346926 0.47388488 0.47285497 0.4751599  0.47278446 0.47472772
 0.4741549  0.47305474 0.47278482 0.


epoch: 1000 SSE loss: 0.08775465965270995, l0 penalty: 7.94111083984375 total loss: 4.784788524627686
accuracy: 0.0, actual: [0.30764535 0.06200521 0.98878508 0.7190903  0.75756964 0.26375811
 0.81423212 0.56902059 0.45944772 0.13070624 0.31822656 0.3471398
 0.04941394 0.96384525 0.96475062 0.1106671  0.33928814 0.56340746
 0.48681865 0.56492329 0.86105383 0.81537125 0.41010138 0.52310581
 0.18856046 0.45604781 0.76245352 0.25340822 0.24580262 0.99679214
 0.22599421 0.962532   0.24860247 0.71298981 0.04488238 0.67372927
 0.41643382 0.51367789 0.51061752 0.37990148 0.01794764 0.80771941
 0.41246825 0.28481926 0.89542746 0.51598704 0.98335094 0.4560161
 0.71131126 0.09003222], predicted: [0.44920164 0.44989774 0.44760108 0.4481951  0.4481208  0.4493266
 0.44801128 0.44848505 0.44876957 0.4497045  0.44917148 0.44908923
 0.44993317 0.4477223  0.44772053 0.44976082 0.44911155 0.4484959
 0.44869164 0.448493   0.44792086 0.44800907 0.44891    0.44858837
 0.44954064 0.44877923 0.44811133 0.44


epoch: 1500 SSE loss: 0.052677912712097166, l0 penalty: 7.913941650390625 total loss: 3.02959271812439
accuracy: 0.0, actual: [0.24392333 0.65495435 0.35150695 0.62498055 0.03401694 0.24540688
 0.22322093 0.0760436  0.83643934 0.88957214 0.18126889 0.78185092
 0.59343407 0.1847256  0.37908656 0.9475536  0.01381148 0.47050336
 0.78318807 0.62917569 0.36378382 0.02195515 0.8975815  0.50285664
 0.29962934 0.04863477 0.9932151  0.81687624 0.16885527 0.88573366
 0.06075629 0.8205812  0.25389633 0.99687541 0.04379909 0.15060614
 0.78009182 0.02238677 0.13359511 0.23338146 0.22705116 0.39122358
 0.77979902 0.18917251 0.32549157 0.44852835 0.3491879  0.64779936
 0.98710881 0.02330982], predicted: [0.41055182 0.40802258 0.40989172 0.4082081  0.4118406  0.4105427
 0.41067883 0.41158244 0.4068905  0.40642864 0.41093636 0.40723565
 0.40840343 0.41091514 0.4097226  0.40556988 0.41196468 0.409162
 0.40722728 0.4081821  0.40981644 0.41191468 0.40630996 0.4089637
 0.41020995 0.41175076 0.40488216 0.4


epoch: 2000 SSE loss: 0.039977831840515135, l0 penalty: 7.888450317382812 total loss: 2.3933141078948976
accuracy: 0.0, actual: [0.53525707 0.90404425 0.50239657 0.10087001 0.52758198 0.71122893
 0.31295428 0.05032535 0.12328206 0.77969075 0.94802187 0.71818603
 0.75389358 0.10964897 0.63814993 0.28945067 0.361702   0.00315978
 0.04352426 0.79419409 0.77746827 0.56411938 0.38736774 0.34514564
 0.41004102 0.06705749 0.97356847 0.06405922 0.10844153 0.47859636
 0.87148378 0.98030973 0.67780738 0.87716066 0.26757641 0.61586974
 0.44980164 0.61761576 0.80281463 0.03813337 0.29641567 0.33406014
 0.50907489 0.74504766 0.59307464 0.34360118 0.30266013 0.27967976
 0.92107048 0.87532113], predicted: [0.3529567  0.34606403 0.35338718 0.35866645 0.35305724 0.3504048
 0.35587364 0.35933346 0.35837087 0.3493985  0.34487417 0.3503037
 0.34978524 0.35855064 0.35146713 0.35618263 0.35523304 0.35995638
 0.35942325 0.34901768 0.34944317 0.35254478 0.354896   0.35545057
 0.3545984  0.35911262 0.3441839 


epoch: 2500 SSE loss: 0.018952257633209228, l0 penalty: 7.861978759765625 total loss: 1.3407118196487426
accuracy: 0.0, actual: [0.70234877 0.24126003 0.3909774  0.08272685 0.0125667  0.09097745
 0.88338182 0.73286425 0.88013022 0.92964604 0.4807014  0.40341409
 0.3506564  0.54586824 0.79552717 0.07982766 0.7188167  0.62327477
 0.28242463 0.76059583 0.50328887 0.33464501 0.76103674 0.06868315
 0.71545598 0.56427928 0.10325397 0.29758537 0.73224651 0.33388648
 0.910522   0.26858995 0.43681056 0.24196784 0.04470311 0.49783801
 0.40223873 0.9800457  0.47167284 0.66310306 0.4971505  0.80088507
 0.41198731 0.12023148 0.7414398  0.5854993  0.44173437 0.23205948
 0.10164505 0.48847288], predicted: [0.28598195 0.2964123  0.2931229  0.2999185  0.30147773 0.29973546
 0.27873567 0.28480873 0.27886602 0.27688506 0.2911619  0.2928506
 0.29400668 0.28974247 0.28227076 0.29998285 0.28537956 0.2880049
 0.29550576 0.28368384 0.29066944 0.29435804 0.283666   0.30023023
 0.28551623 0.2893422  0.29946318


epoch: 3000 SSE loss: 0.00557354986667633, l0 penalty: 7.83185302734375 total loss: 0.670270144701004
accuracy: 0.0, actual: [0.6226517  0.82421794 0.28428344 0.87839639 0.64116389 0.88815645
 0.76152797 0.37503038 0.92185572 0.29182224 0.41247687 0.4641012
 0.25166247 0.81322924 0.12431812 0.43852124 0.16661531 0.25279386
 0.25963365 0.6609033  0.49987259 0.6382574  0.46343688 0.22698828
 0.97664225 0.97426234 0.20306161 0.32619469 0.69963976 0.23739341
 0.39503501 0.74172975 0.94369034 0.47140019 0.0632019  0.61826634
 0.08297844 0.93694758 0.28686746 0.31118546 0.08903719 0.66314109
 0.37911571 0.92719766 0.84107261 0.86530349 0.19918857 0.81320992
 0.3319823  0.93641826], predicted: [0.23216783 0.22280088 0.24224412 0.22014028 0.23159674 0.21966341
 0.22590801 0.23956539 0.2180226  0.24202085 0.23846592 0.23695588
 0.24321207 0.2233433  0.2470156  0.23770326 0.24574783 0.24317846
 0.24297531 0.23085603 0.23591344 0.23168634 0.23697527 0.24394588
 0.21537398 0.21548857 0.24465896 0


epoch: 3500 SSE loss: 0.029899399280548095, l0 penalty: 7.797201538085938 total loss: 1.8848300409317016
accuracy: 0.0, actual: [0.93656379 0.29330629 0.71696113 0.97770735 0.46248868 0.95647445
 0.6050101  0.22903965 0.94625056 0.85270814 0.11594969 0.57132807
 0.03979628 0.49921246 0.69562301 0.17441155 0.23401466 0.01122773
 0.70659574 0.78306599 0.93497557 0.5777703  0.28601512 0.31964301
 0.31330001 0.99867422 0.50271547 0.69948505 0.19834866 0.67062492
 0.06411956 0.32203181 0.24092778 0.25941079 0.25054223 0.44311414
 0.58788874 0.54324192 0.51296574 0.23324817 0.77078055 0.28826467
 0.00248217 0.69585547 0.23473468 0.87028814 0.90305784 0.22837777
 0.75358128 0.01285757], predicted: [0.17650412 0.2042409  0.18799186 0.17437913 0.1985196  0.17547321
 0.19347425 0.20642999 0.17600197 0.18089803 0.21032421 0.19465776
 0.21297677 0.19721018 0.18906192 0.20830442 0.2062599  0.21397811
 0.18851107 0.18461159 0.17658655 0.19443098 0.2044884  0.2033488
 0.20356339 0.17330006 0.1970855


epoch: 4000 SSE loss: 0.006549602150917053, l0 penalty: 7.758446044921875 total loss: 0.7154024097919465
accuracy: 0.0, actual: [0.21776986 0.58724972 0.9876133  0.95086705 0.34297149 0.39389324
 0.58573127 0.31384096 0.8532363  0.53031352 0.73691596 0.23933497
 0.81544938 0.41063763 0.37707666 0.72059189 0.69629163 0.37351208
 0.8857305  0.76544058 0.43635873 0.96868552 0.92976244 0.15911103
 0.72552475 0.5989313  0.14331981 0.06272644 0.35662631 0.71137245
 0.3667841  0.23897153 0.33319889 0.53158815 0.56552545 0.25363734
 0.33481974 0.57202368 0.33461279 0.73643939 0.28074505 0.57774971
 0.68363098 0.78815434 0.86252649 0.68589606 0.40991839 0.0750447
 0.57870191 0.33163897], predicted: [0.18179825 0.16838315 0.14760847 0.14959009 0.17734429 0.17555745
 0.16843858 0.17837289 0.15474801 0.17047133 0.16075277 0.18102491
 0.15678003 0.17492743 0.17614596 0.16157904 0.16281539 0.17627093
 0.15301783 0.15931714 0.17396203 0.14862649 0.15069874 0.1839147
 0.16132899 0.16784543 0.18448775

In [None]:
model_test = L0FCN(input_bins=1,
                  output_bins=1,
                  fc_config=[20]*2,
                  enable_gpu=True,
                  test_mode=False)
model_test.load_state_dict(best_model)

In [None]:
model_test.get_l0_norm()

In [None]:
model_test = model_test.cuda()

x = np.arange(0, 1, 0.01)
y = list(map(lambda val: logistic_map(val, alpha), x))
y_pred = None
k = 1
for i in range(k):
    y_pred_k = []
    for x_val in x:
        x_val = torch.FloatTensor([x_val]).cuda()
        out, _ = model_test.forward(x_val)
        out = torch.clamp(out, -8, 8)
        out = F.sigmoid(out)
        y_pred_k.append(out.data.cpu().numpy()[0])
    
    if y_pred is None:
        y_pred = np.array(y_pred_k)
    else:
        y_pred += np.array(y_pred_k)

y_pred = y_pred/k
plot_options = dict(width=900,
                    plot_height=450,
                    tools='pan,wheel_zoom,reset,save')

plot = figure(**plot_options)
plot.line(x,
          y)
plot.line(x,
          y_pred,
         line_color="orange")
plot.x(X_train.flatten(),
       y_train.flatten())

plot.xaxis.axis_label = 'x'
plot.yaxis.axis_label = 'f(x)'

show(plot)

In [None]:
series_pred = []
series_pred_in = []
series_actual = []
lambda_series_pred = []
lambda_series_actual = []
x_t_m = 0.7
x_t_a = x_t_m
x_t_i = x_t_m
n_points = 500

step = 0
tot_steps = 5

for i in range(n_points):
    x_val = torch.FloatTensor([x_t_m]).cuda()
    out, _ = model_test.forward(x_val)
    out = torch.clamp(out, -8, 8)
    out = F.sigmoid(out)
    lambda_series_pred.append(out.data.cpu().numpy()[0]/(x_t_m*(1-x_t_m)))
    x_t_m = out.data.cpu().numpy()[0]
    series_pred.append(x_t_m)
    
    if step >= tot_steps:
        x_t_i = x_t_a
        step = 0
    
    x_val = torch.FloatTensor([x_t_i]).cuda()
    out, _ = model_test.forward(x_val)
    out = torch.clamp(out, -8, 8)
    out = F.sigmoid(out)
    x_t_i = out.data.cpu().numpy()[0]
    series_pred_in.append(x_t_i)
    step += 1
    
    lambda_series_actual.append(logistic_map(x_t_a, alpha=alpha)/(x_t_a*(1-x_t_a)))
    x_t_a = logistic_map(x_t_a, alpha=alpha)
    series_actual.append(x_t_a)

plot_options = dict(width=900,
                    plot_height=450,
                    tools='pan,xwheel_zoom,reset,save')

plot = figure(**plot_options)
plot.line(range(n_points),
          series_actual)
plot.line(range(n_points),
          series_pred,
          line_color="orange")
plot.line(range(n_points),
          series_pred_in,
          line_color="red")

plot.xaxis.axis_label = 'step'
plot.yaxis.axis_label = 'f(x)'

show(plot)

mean_lambda_pred = [ np.mean(lambda_series_pred[:i]) 
                    for i in range(1, len(lambda_series_pred)) ]

plot = figure(**plot_options)
plot.line(range(n_points),
          lambda_series_actual)
plot.line(range(n_points),
          lambda_series_pred,
          line_color="orange")
plot.line(range(n_points),
          mean_lambda_pred,
          line_color="red")

plot.xaxis.axis_label = 'step'
plot.yaxis.axis_label = 'lambda'

show(plot)

In [None]:
n_points = 500

step = 0
tot_steps = 10
step_stderr = []

for step in range(1, tot_steps):
    print(step)
    step_err = []
    for i in range(n_points):  
        x_t_m = series_actual[i]
        err = []
        for pred_step in range(step):
            x_val = torch.FloatTensor([x_t_m]).cuda()
            out, _ = model_test.forward(x_val)
            out = torch.clamp(out, -8, 8)
            out = F.sigmoid(out)
            lambda_series_pred.append(out.data.cpu().numpy()[0]/(x_t_m*(1-x_t_m)))
            x_t_m = out.data.cpu().numpy()[0]
            
            try:
                err.append(np.abs(series_actual[i+pred_step+1] - x_t_m))
            except IndexError:
                break
        step_err.append(np.mean(err))
    step_err = np.array(step_err)
    step_err = step_err[~np.isnan(step_err)]
    step_stderr.append(np.mean(np.array(step_err)))
#     print(step_err)
#     print(step_stderr)
            
plot_options = dict(width=900,
                    plot_height=450,
                    tools='pan,xwheel_zoom,reset,save')

plot = figure(**plot_options)
plot.line(range(tot_steps),
          step_stderr,
         line_width=2)

plot.xaxis.axis_label = 'prediction steps'
plot.yaxis.axis_label = 'standard_error'

show(plot)

In [None]:
step_stderr