# Test of bound and algorithm

In [1]:
import numpy as np
import gym

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from src.algorithm.backward_feature_selection import BackwardFeatureSelector
from src.algorithm.info_theory.entropy import LeveOneOutEntropyEstimator, NNEntropyEstimator
from src.wenvs import WrapperEnv
from src.algorithm.utils import episodes_with_len



In [4]:
est = NNEntropyEstimator()

In [5]:
env = gym.make('CartPole-v1')
env = WrapperEnv(env, continuous_state=True)
wenv = WrapperEnv(env, n_fake_features=1, n_fake_actions=1, continuous_state=True)
wenv.seed(0)

k = 10
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [9]:
fs = BackwardFeatureSelector(est, trajectories)
fs.selectOnError(k, 0.9, 1)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

({1, 3, 4, 6}, 0.9224154806901057)

In [10]:
fs.computeError()

0.9224154806901057

## LQG n-dim

In [11]:
from src.envs import lqgNdim

In [12]:
env = lqgNdim.LQG_nD(0.9, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [13]:
hist = wenv.run_episode(render=True)
wenv.close()

In [14]:
K = env.computeOptimalK()
pi_opt = lambda x: np.clip(K@x, -env.max_action, env.max_action)

In [15]:
hist = wenv.run_episode(policy=pi_opt, render=True)
wenv.close()
hist[2].sum()

-85.7677713276748

## Infinite CartPole

In [16]:
from src.envs import cartpole

In [17]:
env = cartpole.CartPoleInfinite()
wenv = WrapperEnv(env, continuous_state=True)

In [27]:
wenv.run_episode(render=True)
wenv.close()

## Real test

In [28]:
env = lqgNdim.LQG_nD(0.9, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [29]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [30]:
fs = BackwardFeatureSelector(est, trajectories)

In [31]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 2} 279.3653845306717


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 577.8104948625996


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1095.7883801490254


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1436.7654514784522


In [32]:
Q = np.eye(4) * 0.9
Q[2:,2:] = 0
R = Q.copy()
env = lqgNdim.LQG_nD(0.9, n_dim=4, Q=Q, R=R)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [33]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [49]:
fs = BackwardFeatureSelector(est, trajectories)

In [37]:
fs.selectNfeatures(1, k, 0.9)

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

({0}, 1287.8814699059678)

In [35]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 6, 7} 107.94313342376819


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 6, 7} 200.56008605917916


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 6} 306.94553030880456


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 3, 6} 430.94939527721243


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 594.7502954867308


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 834.9930574611521


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1287.8814699059678


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1617.071818706301


In [39]:
for S, err in fs.try_remove_all(k, 0.5):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 6} 47.650691490526135


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 4, 5} 86.0673720048022


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 5} 131.06136473027607


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 4, 5} 182.25310744363526


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 5} 265.05860800345243


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 388.5777943322571


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 624.79460244185


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 793.7746281815378


In [41]:
for S, err in fs.try_remove_all(k, 0.99):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 6, 7} 516.9057308827088


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 6, 7} 834.3704771110381


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 6} 1138.8994571670305


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 2, 3} 1489.687133110923


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 1941.6646200628427


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 2607.1735113560176


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{3} 3920.7153411643185


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 4767.027217642368


In [40]:
for S, err in fs.try_remove_all(k, 0.95):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 6, 7} 166.7304679252279


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 6, 7} 300.04627952882015


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 6} 446.6570174421815


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 3, 6} 615.6639997900515


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 840.6763637025495


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 1169.8488581222111


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1796.095847236575


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 2242.3448397507905


In [None]:
for S, err in fs.try_remove_all(k, 0.9, sampling="decaying", freq=5):
    print(S, err)

### Subset enumeration

In [None]:
from itertools import chain, combinations
from tqdm import tqdm_notebook as tqdm
s = fs.idSet
powerset = set(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))

In [None]:
def compute_subset_error(S, k, gamma):
    fun_t, fun_k = fs._funOfBound(bound)
    w = fs._get_weights(k, gamma, bound)
    score = np.zeros(k+1)
    
    S = frozenset(S)
    no_S = fs.idSet.difference(S)
    
    for t in range(k):
        score[t] = fun_t(no_S, S, t)
    score[k] = fun_k(no_S, S)
    
    return fs.computeError(bound=bound, residual=w @ score)

In [None]:
rank = [(s,compute_subset_error(s, k, 0.9)) for s in tqdm(powerset)]

In [None]:
sorted(sorted(rank, key=lambda x: x[1]), key=lambda x: len(x[0]))

Except for one case the greedy algorithm makes always the best choice:

{0, 1, 2, 4, 5, 6, 7} 108.00216084944557 

{0, 1, 2, 4, 5, 7} 167.35674911495767

{0, 1, 2, 4, 5} 228.54114149996968

{0, 1, 4, 5} 300.6002932042114

{0, 1, 4} 419.9287663231109 => {0, 1, 3}, 416 better

{0, 1} 593.4443694638235

{0} 985.9146148635155

set() 992.1915099307594