# Test of bound and algorithm

In [3]:
import numpy as np
import gym

In [5]:
%load_ext autoreload
%autoreload 2

In [4]:
from src.algorithm.backward_feature_selection import BackwardFeatureSelector
from src.algorithm.info_theory.entropy import LeveOneOutEntropyEstimator, NNEntropyEstimator
from src.wenvs import WrapperEnv
from src.algorithm.utils import episodes_with_len



In [6]:
est = NNEntropyEstimator()

In [5]:
env = gym.make('CartPole-v1')
env = WrapperEnv(env, continuous_state=True)
wenv = WrapperEnv(env, n_fake_features=1, n_fake_actions=1, continuous_state=True)
wenv.seed(0)

k = 10
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [9]:
fs = BackwardFeatureSelector(est, trajectories)
fs.selectOnError(k, 0.9, 1)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

({1, 3, 4, 6}, 0.9224154806901057)

In [10]:
fs.computeError()

0.9224154806901057

## LQG n-dim

In [7]:
from src.envs import lqgNdim

In [12]:
env = lqgNdim.LQG_nD(0.9, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [13]:
hist = wenv.run_episode(render=True)
wenv.close()

In [14]:
K = env.computeOptimalK()
pi_opt = lambda x: np.clip(K@x, -env.max_action, env.max_action)

In [15]:
hist = wenv.run_episode(policy=pi_opt, render=True)
wenv.close()
hist[2].sum()

-85.7677713276748

## Infinite CartPole

In [16]:
from src.envs import cartpole

In [17]:
env = cartpole.CartPoleInfinite()
wenv = WrapperEnv(env, continuous_state=True)

In [27]:
wenv.run_episode(render=True)
wenv.close()

## Real test

In [6]:
env = lqgNdim.LQG_nD(0.9, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [7]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [8]:
fs = BackwardFeatureSelector(est, trajectories)

In [31]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 2} 279.3653845306717


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 577.8104948625996


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1095.7883801490254


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1436.7654514784522


In [8]:
Q = np.diag([0.9, 0.9, 0.1, 0.1])
R = Q.copy()
env = lqgNdim.LQG_nD(0.9, n_dim=4, Q=Q, R=R)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [10]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [11]:
fs = BackwardFeatureSelector(est, trajectories)

In [25]:
fs.selectNfeatures(1, k, 0.9)

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))




({0}, 979.8012837511585)

In [12]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 107.07231011244805


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 6} 174.00726516882884


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 5} 234.26185270077661


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 3, 4} 310.7818667932845


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 4} 414.8414850077089


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 590.07813177702


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 979.8012837511585


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 982.091241732819



In [13]:
for S, err in fs.try_remove_all(k, 0.5):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 46.344420660041685


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 6} 71.91600174565912


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 5} 100.17368530244394


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 4, 5} 132.55352412941536


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 4} 189.53036875325657


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 282.75846453516414


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 487.11503795271466


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 484.99859245775195



In [15]:
for S, err in fs.try_remove_all(k, 0.99):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 535.549686447864


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 5, 6} 822.0449981436032


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 6} 1093.6216906649445


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 2, 3} 1384.6471874523916


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 1771.2680755457216


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 2360.3173623655207


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{3} 3594.1663323472017


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 4140.0789347970485



In [14]:
for S, err in fs.try_remove_all(k, 0.95):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 171.434566338078


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 4, 5} 275.48126561022275


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 5} 369.649577275583


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 3, 4} 478.3536708809162


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 627.2045219614683


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 869.7079197538777


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1416.577200211781


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1471.6420622535484



In [29]:
for S, err in fs.try_remove_all(k-5, 0.9, sampling="decaying", freq=50):
    print(S, err)

[ 0  1  2  3  4  5  6  7  8 10 11 12 13 14 16]


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 113.05238035635055


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 6} 183.09369959963337


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 5} 245.27320714853684


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 3, 4} 323.2912973466417


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 4} 429.0575134144756


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 605.0533013209249


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 994.8054944514857


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1013.3399446617613



### Test with optimal policy

In [9]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k, policy=env.optimalPolicy())

In [10]:
fs = BackwardFeatureSelector(est, trajectories)

In [59]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 11.90474118124679


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 4, 5} 19.433641233189626


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 4} 36.55428189883335


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 2, 3} 48.98077883374398


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 117.06815802530032


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 235.51673126010547


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 625.9457548432281


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 221.59986083387997


In [11]:
for S, err in fs.try_remove_all(k, 0.9, sum_cmi=False):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 11.90474118124679


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 4, 5} 19.433641233189626


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 4} 36.554281898833366


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 2, 3} 48.980778833743976


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 117.06815802530029


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 235.5167312601055


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 625.9457548432281


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 221.59986083387986



### Subset enumeration

In [None]:
from itertools import chain, combinations
from tqdm import tqdm_notebook as tqdm
s = fs.idSet
powerset = set(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))

In [None]:
def compute_subset_error(S, k, gamma):
    fun_t, fun_k = fs._funOfBound(bound)
    w = fs._get_weights(k, gamma, bound)
    score = np.zeros(k+1)
    
    S = frozenset(S)
    no_S = fs.idSet.difference(S)
    
    for t in range(k):
        score[t] = fun_t(no_S, S, t)
    score[k] = fun_k(no_S, S)
    
    return fs.computeError(bound=bound, residual=w @ score)

In [None]:
rank = [(s,compute_subset_error(s, k, 0.9)) for s in tqdm(powerset)]

In [None]:
sorted(sorted(rank, key=lambda x: x[1]), key=lambda x: len(x[0]))

Except for one case the greedy algorithm makes always the best choice:

{0, 1, 2, 4, 5, 6, 7} 108.00216084944557 

{0, 1, 2, 4, 5, 7} 167.35674911495767

{0, 1, 2, 4, 5} 228.54114149996968

{0, 1, 4, 5} 300.6002932042114

{0, 1, 4} 419.9287663231109 => {0, 1, 3}, 416 better

{0, 1} 593.4443694638235

{0} 985.9146148635155

set() 992.1915099307594