# Test of bound and algorithm

In [2]:
import numpy as np
import gym

In [23]:
%load_ext autoreload
%autoreload 2

In [36]:
from src.algorithm.backward_feature_selection import BackwardFeatureSelector
from src.algorithm.info_theory.entropy import LeveOneOutEntropyEstimator, NNEntropyEstimator
from src.wenvs import WrapperEnv
from src.algorithm.utils import episodes_with_len

In [4]:
est = NNEntropyEstimator()

In [5]:
env = gym.make('CartPole-v1')
env = WrapperEnv(env, continuous_state=True)
wenv = WrapperEnv(env, n_fake_features=1, n_fake_actions=1, continuous_state=True)
wenv.seed(0)

k = 10
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [9]:
fs = BackwardFeatureSelector(est, trajectories)
fs.selectOnError(k, 0.9, 1)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

({1, 3, 4, 6}, 0.9224154806901057)

In [10]:
fs.computeError()

0.9224154806901057

## LQG n-dim

In [5]:
from src.envs import lqgNdim

In [12]:
env = lqgNdim.LQG_nD(0.9, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [13]:
hist = wenv.run_episode(render=True)
wenv.close()

In [14]:
K = env.computeOptimalK()
pi_opt = lambda x: np.clip(K@x, -env.max_action, env.max_action)

In [15]:
hist = wenv.run_episode(policy=pi_opt, render=True)
wenv.close()
hist[2].sum()

-85.7677713276748

## Infinite CartPole

In [16]:
from src.envs import cartpole

In [17]:
env = cartpole.CartPoleInfinite()
wenv = WrapperEnv(env, continuous_state=True)

In [27]:
wenv.run_episode(render=True)
wenv.close()

## Real test

In [6]:
env = lqgNdim.LQG_nD(0.9, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [7]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [8]:
fs = BackwardFeatureSelector(est, trajectories)

In [31]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 2} 279.3653845306717


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 577.8104948625996


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1095.7883801490254


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1436.7654514784522


In [25]:
Q = np.diag([0.9, 0.9, 0.1, 0.1])
R = Q.copy()
env = lqgNdim.LQG_nD(0.9, n_dim=4, Q=Q, R=R)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [18]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [19]:
fs = BackwardFeatureSelector(est, trajectories)

In [35]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 117.30383130660906


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 6} 181.78509211359966


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 5, 6} 259.49585172711676


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 3, 6} 340.2697462808104


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 456.6661504300719


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 648.9755061740755


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1082.352247294386


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1089.3974225722643


In [34]:
fs.selectNfeatures(1, k, 0.9)

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

({0}, 1082.352247294386)

In [33]:
for S, err in fs.try_remove_all(k, 0.5):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 51.914792844858574


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 6} 79.75722079758164


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 4, 5, 6} 114.07728580689876


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 4, 5} 148.08680117143035


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 5} 211.1113889293652


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 313.7885023453356


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 538.1767018348105


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 538.8996027664483


In [32]:
for S, err in fs.try_remove_all(k, 0.99):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 583.9001904427909


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 4, 6} 896.0029472496045


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 6} 1174.0242139219858


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 2, 3} 1516.0411658471542


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 1946.0606606989038


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 2601.5063898443227


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{3} 3969.6318914849726


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 4574.7674982023555


In [31]:
for S, err in fs.try_remove_all(k, 0.95):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 187.6434622534076


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 6} 291.6962315382681


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 5, 6} 402.72555485857214


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 3, 6} 519.4343968275731


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 687.2046561792389


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 958.612824955195


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1564.821140483597


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1629.4139605823493


In [30]:
for S, err in fs.try_remove_all(k-5, 0.9, sampling="decaying", freq=50):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 163.4064735073718


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 6} 255.26201478317884


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 5, 6} 345.8307238397671


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 3, 6} 441.309175432398


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 569.9589212828294


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 775.4679063393768


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1212.0335004601436


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1345.921490384106


#### Subset enumeration

In [27]:
from itertools import chain, combinations
from tqdm import tqdm_notebook as tqdm
allSet = fs.idSet
powerset = set(chain.from_iterable(combinations(allSet, r) for r in range(len(s)+1)))

In [28]:
rank = [(S, fs.scoreSubset(k, 0.9, S)) for S in tqdm(powerset)]

HBox(children=(IntProgress(value=0, max=256), HTML(value='')))

In [29]:
sorted(sorted(rank, key=lambda x: x[1]), key=lambda x: len(x[0]))

[((), 1089.3974225722643),
 ((0,), 1082.352247294386),
 ((3,), 1083.8249920882338),
 ((2,), 1085.5915765804527),
 ((7,), 1089.2724531846934),
 ((1,), 1092.7190244502624),
 ((6,), 1094.5835200014633),
 ((4,), 1095.735421897282),
 ((5,), 1104.7245176531167),
 ((0, 3), 648.9755061740756),
 ((1, 2), 651.5971370858912),
 ((0, 1), 651.7327508168125),
 ((1, 3), 653.8745403702983),
 ((1, 5), 654.0341137895666),
 ((2, 3), 654.3105708128404),
 ((0, 4), 655.7946536333411),
 ((1, 6), 656.1583081825768),
 ((0, 5), 656.2106484500819),
 ((0, 2), 658.2633684707639),
 ((3, 7), 658.6541756291876),
 ((1, 7), 660.0239140728168),
 ((1, 4), 660.8979865329584),
 ((2, 4), 662.4814648081768),
 ((3, 5), 662.6147033599873),
 ((3, 4), 662.9638083511946),
 ((0, 6), 663.1039504917479),
 ((5, 7), 666.2494396448685),
 ((0, 7), 666.2947177060692),
 ((5, 6), 667.6999637670235),
 ((2, 7), 668.213540407104),
 ((2, 5), 668.2882546398574),
 ((6, 7), 668.5327440907607),
 ((4, 7), 668.6283363151284),
 ((3, 6), 670.4145358448

### Test with optimal policy

In [8]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k, policy=env.optimalPolicy())

In [9]:
fs = BackwardFeatureSelector(est, trajectories)

In [10]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 11.90474118124679


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 4, 5} 19.433641233189626


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 4} 36.55428189883335


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 2, 3} 48.98077883374398


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 117.06815802530032


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 235.51673126010547


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 625.9457548432281


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 221.59986083387997



In [11]:
for S, err in fs.try_remove_all(k, 0.9, sum_cmi=False):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 11.90474118124679


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 4, 5} 19.433641233189626


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 4} 36.554281898833366


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 2, 3} 48.980778833743976


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 117.06815802530029


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 235.5167312601055


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 625.9457548432281


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 221.59986083387986



#### Subset enumeration

In [21]:
allSet = fs.idSet
powerset = set(chain.from_iterable(combinations(allSet, r) for r in range(len(s)+1)))

In [15]:
rank = [(S, fs.scoreSubset(k, 0.9, S)) for S in tqdm(powerset)]

HBox(children=(IntProgress(value=0, max=256), HTML(value='')))




In [16]:
sorted(sorted(rank, key=lambda x: x[1]), key=lambda x: len(x[0]))

[((), 221.5998608338799),
 ((0,), 625.9457548432282),
 ((1,), 635.3424755068335),
 ((2,), 642.7158633576871),
 ((3,), 644.7962109483158),
 ((4,), 660.1351901059298),
 ((5,), 670.2079399991328),
 ((6,), 676.1227449043595),
 ((7,), 678.4537399010048),
 ((0, 1), 235.5167312601055),
 ((1, 4), 254.68858677241468),
 ((0, 5), 257.5352041964055),
 ((4, 5), 273.7993482696362),
 ((0, 3), 293.498390421585),
 ((1, 2), 296.63678222277406),
 ((1, 3), 301.13734417316135),
 ((0, 2), 301.82855485052335),
 ((3, 4), 308.4231809691274),
 ((0, 7), 309.6576878970556),
 ((1, 6), 311.24110453816087),
 ((2, 5), 313.67527508174584),
 ((1, 7), 315.12442064288945),
 ((0, 6), 316.6779060771711),
 ((2, 3), 317.32540282114155),
 ((2, 4), 318.23823323021696),
 ((3, 5), 319.1750682186439),
 ((4, 7), 324.41078195823815),
 ((5, 6), 327.67282576166275),
 ((2, 7), 332.02519706381776),
 ((5, 7), 332.2462813383383),
 ((3, 6), 332.37827167814856),
 ((4, 6), 332.89067907738803),
 ((6, 7), 346.6599897876572),
 ((0, 4), 616.258