# Test of bound and algorithm

In [1]:
import numpy as np
import gym

In [2]:
%load_ext autoreload
%autoreload 2

In [220]:
from src.algorithm.backward_feature_selection import BackwardFeatureSelector
from src.algorithm.feature_selection import Bound
from src.algorithm.info_theory.entropy import LeveOneOutEntropyEstimator, NNEntropyEstimator
from src.wenvs import WrapperEnv
from src.algorithm.utils import episodes_with_len

In [4]:
est = NNEntropyEstimator()

In [5]:
env = gym.make('CartPole-v1')
env = WrapperEnv(env, continuous_state=True)
wenv = WrapperEnv(env, n_fake_features=1, n_fake_actions=1, continuous_state=True)
wenv.seed(0)

k = 10
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [6]:
fs = BackwardFeatureSelector(est, trajectories)
fs.selectOnError(k, 0.9, 1)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))




{0, 1, 2, 4, 6}

In [7]:
fs.computeError()

1.0113339562391774

## LQG n-dim

In [6]:
from src.envs import lqgNdim

In [9]:
env = lqgNdim.LQG_nD(0.99, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [10]:
hist = wenv.run_episode(render=True)
wenv.close()

In [11]:
K = env.computeOptimalK()
pi_opt = lambda x: np.clip(K@x, -env.max_action, env.max_action)

In [12]:
hist = wenv.run_episode(policy=pi_opt)
hist[2].sum()

-34.846069838517174

## Infinite CartPole

In [6]:
from src.envs import cartpole

In [14]:
env = cartpole.CartPoleInfinite()
wenv = WrapperEnv(env, continuous_state=True)

In [15]:
wenv.run_episode(render=True)
wenv.close()

## Real test

In [7]:
env = lqgNdim.LQG_nD(0.99, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [8]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [9]:
fs = BackwardFeatureSelector(est, trajectories)

In [59]:
fs.selectOnError(k, 0.9, 5, bound=Bound.cmi)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))




({0}, 3.42736738989731)

In [60]:
fs.selectOnError(k, 0.9, 5, bound=Bound.cmi_sqrt)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))




({0, 1, 2}, 3.8911497790336282)

In [61]:
Q = np.eye(4) * 0.9
Q[2:,2:] = 0
R = Q.copy()
env = lqgNdim.LQG_nD(0.99, n_dim=4, Q=Q, R=R)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [62]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [63]:
fs = BackwardFeatureSelector(est, trajectories)

In [222]:
fs.selectOnError(k, 0.9, 1000, bound=Bound.cmi)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

({0, 4}, 844.7188428300764)

In [221]:
fs.selectOnError(k, 0.9, 2000, bound=Bound.cmi)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

({0}, 1298.8681463365601)

In [205]:
fs.selectNfeatures(1, k, 0.9)

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

({0}, 1298.8681463365601)

In [225]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 105.35954766257254


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 6} 201.7501110374587


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 5} 310.10104006706024


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 4, 5} 437.9462397505193


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 4} 604.6844375850555


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 4} 844.7188428300764


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1298.8681463365601


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1635.1879086351903


In [226]:
for S, err in fs.try_remove_all(k, 0.99):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 6, 7} 529.6550180980164


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 6, 7} 965.4759410373883


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 6} 1449.258311445125


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 2, 3} 2007.425372532614


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 2682.378186197259


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 3583.244539951975


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{3} 5096.0169336446625


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 6589.159475066558


In [227]:
for S, err in fs.try_remove_all(k, 0.5):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 46.27689813001862


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 6} 86.0593830208704


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 5} 132.2614956542015


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 4, 5} 186.2737385249489


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 4} 267.4219074530581


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 390.94921012380325


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 627.5320537036407


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 797.0168505174531


In [228]:
for S, err in fs.try_remove_all(k, 0.95):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 169.47267791680198


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 4, 6} 319.6438412413225


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 6} 486.397888348011


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 3, 6} 683.7539597499112


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 927.1116140291637


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 1273.3137235456109


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1912.5353152400728


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 2422.414979429696
