# New Bound

In [1]:
import gym
import numpy as np

from src.algorithm.backward_feature_selection import BackwardFeatureSelector
from src.algorithm.info_theory.combo_estimators import FastNNEntropyEstimator, NpeetEstimator
from src.wenvs import WrapperEnv
from src.algorithm.utils import episodes_with_len



In [2]:
fest = FastNNEntropyEstimator()
nest = NpeetEstimator()

In [3]:
env = gym.make('LunarLander-v2')
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=False)

In [4]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, 100, policy=None)

## Continuous

In [37]:
nfs = BackwardFeatureSelector(nest, trajectories, nproc=None)

In [38]:
for S, err in nfs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=9), HTML(value='')))

HBox(children=(IntProgress(value=0, max=189), HTML(value='')))

{1, 2, 3, 4, 5, 6, 7, 8} 113.5091319997684


HBox(children=(IntProgress(value=0, max=168), HTML(value='')))

{2, 3, 4, 5, 6, 7, 8} 113.50997865456098


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

{2, 3, 5, 6, 7, 8} 113.51082530303853


HBox(children=(IntProgress(value=0, max=126), HTML(value='')))

{2, 3, 6, 7, 8} 113.51167194520119


HBox(children=(IntProgress(value=0, max=105), HTML(value='')))

{8, 2, 3, 7} 113.5125185810491


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))

{8, 2, 3} 113.51336521058239


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))

{8, 2} 113.78252632958164


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))

{2} 114.19975424121543


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

set() 114.71136970800558


In [7]:
ffs = BackwardFeatureSelector(fest, trajectories, nproc=None)

In [8]:
for S, err in ffs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=9), HTML(value='')))

HBox(children=(IntProgress(value=0, max=189), HTML(value='')))

{0, 1, 2, 3, 4, 5, 7, 8} 113.50828533866067


HBox(children=(IntProgress(value=0, max=168), HTML(value='')))

{0, 1, 2, 3, 4, 5, 8} 113.50828533866067


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

{1, 2, 3, 4, 5, 8} 113.50829948674408


HBox(children=(IntProgress(value=0, max=126), HTML(value='')))

{1, 2, 3, 5, 8} 113.50831848285793


HBox(children=(IntProgress(value=0, max=105), HTML(value='')))

{8, 2, 3, 5} 113.50838863229308


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))

{8, 2, 3} 113.51609151265514


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))

{2, 3} 113.99948395945647


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))

{2} 118.02769973612376


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

set() 114.24685651772253



## Discrete

In [9]:
dnfs = BackwardFeatureSelector(nest, trajectories, discrete=True, nproc=None)

In [11]:
for S, err in dnfs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=9), HTML(value='')))

HBox(children=(IntProgress(value=0, max=189), HTML(value='')))

{0, 1, 2, 3, 5, 6, 7, 8} nan


HBox(children=(IntProgress(value=0, max=168), HTML(value='')))

{1, 2, 3, 5, 6, 7, 8} nan


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

{2, 3, 5, 6, 7, 8} nan


HBox(children=(IntProgress(value=0, max=126), HTML(value='')))

{2, 3, 6, 7, 8} nan


HBox(children=(IntProgress(value=0, max=105), HTML(value='')))

{8, 2, 3, 7} nan


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))

{8, 2, 3} nan


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))

{2, 3} nan


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))

{2} nan


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

set() nan


In [14]:
dffs = BackwardFeatureSelector(fest, trajectories, discrete=True, nproc=None)

In [15]:
for S, err in dffs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=9), HTML(value='')))

HBox(children=(IntProgress(value=0, max=189), HTML(value='')))

{0, 1, 2, 3, 4, 5, 7, 8} 0.0


HBox(children=(IntProgress(value=0, max=168), HTML(value='')))

{0, 1, 2, 3, 4, 5, 8} 0.0


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

{1, 2, 3, 4, 5, 8} 0.281899676973021


HBox(children=(IntProgress(value=0, max=126), HTML(value='')))

{1, 2, 3, 5, 8} 0.43148365128717003


HBox(children=(IntProgress(value=0, max=105), HTML(value='')))

{8, 2, 3, 5} 0.7751435093205548


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))

{8, 2, 3} 6.606189309282113


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))

{2, 3} 49.02178188037771


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))

{2} 119.09087346357659


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

set() 122.64901173828765


### History

In [57]:
for S, err, scores in dffs.try_remove_all(k, 0.9, all_scores=True):
    print(S, err)
    print(scores[0])
    print(scores[1])
    print(scores[2])

HBox(children=(IntProgress(value=0, max=9), HTML(value='')))

HBox(children=(IntProgress(value=0, max=189), HTML(value='')))

{0, 1, 2, 3, 4, 5, 7, 8} 0.0
[6 7 0 4 1 5 8 2 3]
[0.00000000e+00 0.00000000e+00 1.60592479e-03 2.15640552e-03
 7.96252790e-03 8.74253163e-01 5.50381255e+01 5.97919577e+01
 2.29223276e+02]
[0.00000000e+00 0.00000000e+00 3.81277891e-02 5.11955471e-02
 1.99361128e-01 2.06421328e+01 1.12492749e+03 1.24054377e+03
 2.94412801e+03]


HBox(children=(IntProgress(value=0, max=168), HTML(value='')))

{0, 1, 2, 3, 4, 5, 8} 0.0
[7 0 4 1 5 8 2 3]
[0.00000000e+00 1.60592479e-03 2.15640552e-03 7.96252790e-03
 8.74253163e-01 5.50381255e+01 5.97919577e+01 2.29223276e+02]
[0.00000000e+00 3.81277891e-02 5.11955471e-02 1.99361128e-01
 2.06421328e+01 1.12492749e+03 1.24054377e+03 2.94412801e+03]


HBox(children=(IntProgress(value=0, max=147), HTML(value='')))

{1, 2, 3, 4, 5, 8} 0.281899676973021
[0 4 1 5 8 2 3]
[1.60592479e-03 2.15640552e-03 7.96252790e-03 8.74253163e-01
 5.50381255e+01 5.97919577e+01 2.29223276e+02]
[3.81277891e-02 5.11955471e-02 1.99361128e-01 2.06421328e+01
 1.12492749e+03 1.24054377e+03 2.94412801e+03]


HBox(children=(IntProgress(value=0, max=126), HTML(value='')))

{1, 2, 3, 5, 8} 0.43148365128717003
[4 1 5 8 2 3]
[2.15621676e-03 7.96259630e-03 8.74364931e-01 5.50380371e+01
 5.98325676e+01 2.29222368e+02]
[8.93269291e-02 2.37493606e-01 2.06824417e+01 1.12496528e+03
 1.24131530e+03 2.94416987e+03]


HBox(children=(IntProgress(value=0, max=105), HTML(value='')))

{8, 2, 3, 5} 0.7751435093205548
[1 5 8 2 3]
[7.96254689e-03 8.74515328e-01 5.50379183e+01 5.98872568e+01
 2.29221150e+02]
[2.88699042e-01 2.07365700e+01 1.12501602e+03 1.24235383e+03
 2.94422610e+03]


HBox(children=(IntProgress(value=0, max=84), HTML(value='')))

{8, 2, 3} 6.606189309282113
[5 8 2 3]
[  0.8743712   55.0376268   59.88563226 378.28873722]
[  20.9347727  1125.21723579 1242.55511168 4290.27709216]


HBox(children=(IntProgress(value=0, max=63), HTML(value='')))

{2, 3} 49.02178188037771
[8 3 2]
[ 54.98965536 379.73643273 482.75704009]
[1145.69179811 4330.54468939 5828.3910154 ]


HBox(children=(IntProgress(value=0, max=42), HTML(value='')))

{2} 119.09087346357659
[3 2]
[467.32778099 576.18230926]
[6568.11453892 8201.26032407]


HBox(children=(IntProgress(value=0, max=21), HTML(value='')))

set() 122.64901173828765
[2]
[-439.09684041]
[7437.28334836]
