# Test of bound and algorithm

In [1]:
import numpy as np
import gym

In [2]:
%load_ext autoreload
%autoreload 2

In [2]:
from src.algorithm.backward_feature_selection import BackwardFeatureSelector
from src.algorithm.feature_selection import Bound
from src.algorithm.info_theory.entropy import LeveOneOutEntropyEstimator, NNEntropyEstimator
from src.wenvs import WrapperEnv
from src.algorithm.utils import episodes_with_len



In [3]:
est = NNEntropyEstimator()

In [5]:
env = gym.make('CartPole-v1')
env = WrapperEnv(env, continuous_state=True)
wenv = WrapperEnv(env, n_fake_features=1, n_fake_actions=1, continuous_state=True)
wenv.seed(0)

k = 10
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [6]:
fs = BackwardFeatureSelector(est, trajectories)
fs.selectOnError(k, 0.9, 1)

HBox(children=(IntProgress(value=0), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))




({1, 3, 4, 6}, 0.9100094568123525)

In [8]:
fs.computeError()

0.9100094568123525

## LQG n-dim

In [4]:
from src.envs import lqgNdim

In [14]:
env = lqgNdim.LQG_nD(0.99, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [15]:
hist = wenv.run_episode(render=True)
wenv.close()

In [16]:
K = env.computeOptimalK()
pi_opt = lambda x: np.clip(K@x, -env.max_action, env.max_action)

In [17]:
hist = wenv.run_episode(policy=pi_opt)
hist[2].sum()

-57.860930209418036

## Infinite CartPole

In [18]:
from src.envs import cartpole

In [19]:
env = cartpole.CartPoleInfinite()
wenv = WrapperEnv(env, continuous_state=True)

In [20]:
wenv.run_episode(render=True)
wenv.close()

## Real test

In [5]:
env = lqgNdim.LQG_nD(0.99, n_dim=2)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [6]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [7]:
fs = BackwardFeatureSelector(est, trajectories)

In [8]:
for S, err in fs.try_remove_all(k, 0.9, bound=Bound.cmi):
    print(S, err)

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 2} 275.1990224396083


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 2} 496.1075058088754


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 935.4788314555569


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 943.514515537091



In [9]:
for S, err in fs.try_remove_all(k, 0.9, bound=Bound.entropy):
    print(S, err)

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 2} 275.1990224396083


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 2} 496.1075058088752


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 935.4788314555569


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 943.5145155370909



In [10]:
for S, err in fs.try_remove_all(k, 0.9, bound=Bound.cmi_sqrt):
    print(S, err)

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 2} 1241.101718300192


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 2} 2234.4694059057388


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 4168.203698675913


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 4317.780379381366



In [11]:
Q = np.eye(4) * 0.9
Q[2:,2:] = 0
R = Q.copy()
env = lqgNdim.LQG_nD(0.99, n_dim=4, Q=Q, R=R)
wenv = WrapperEnv(env, continuous_state=True, continuous_actions=True)

In [12]:
np.random.seed(0)
wenv.seed(0)

k = 20
num_ep = 1000
trajectories = episodes_with_len(wenv, num_ep, k)

In [13]:
fs = BackwardFeatureSelector(est, trajectories)

In [14]:
fs.selectNfeatures(1, k, 0.9)

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))




({0}, 985.9146148635155)

In [15]:
for S, err in fs.try_remove_all(k, 0.9):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 4, 5, 6, 7} 108.00216084944557


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 4, 5, 7} 167.35674911495767


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 4, 5} 228.54114149996968


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 4, 5} 300.6002932042114


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 4} 419.9287663231109


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 593.4443694638235


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 985.9146148635155


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 992.1915099307594



In [18]:
for S, err in fs.try_remove_all(k, 0.5):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 7} 45.24535799669906


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 4, 5, 7} 71.18228218225686


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 4, 5, 7} 98.9496011756822


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 4, 5} 130.4360593389985


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 4} 194.17999425708885


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 285.24828747796295


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 490.6871943467559


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 491.31481546994667



In [21]:
for S, err in fs.try_remove_all(k, 0.99):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 7} 545.0460961674122


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 2, 3, 4, 7} 825.5590212798049


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 2, 3, 4} 1098.460291161602


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 2, 3} 1388.9583969799105


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 3} 1779.6226254877067


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 3} 2374.356759976569


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{3} 3619.5974031744663


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 4171.510703275914



In [22]:
for S, err in fs.try_remove_all(k, 0.95):
    print(S, err)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

{0, 1, 2, 3, 4, 5, 7} 174.5338703513713


HBox(children=(IntProgress(value=0, max=7), HTML(value='')))

{0, 1, 3, 4, 5, 7} 268.7735621112553


HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

{0, 1, 3, 4, 5} 362.65883506823303


HBox(children=(IntProgress(value=0, max=5), HTML(value='')))

{0, 1, 4, 5} 473.14190819180556


HBox(children=(IntProgress(value=0, max=4), HTML(value='')))

{0, 1, 4} 637.6746542816132


HBox(children=(IntProgress(value=0, max=3), HTML(value='')))

{0, 1} 880.0673789375687


HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

{0} 1425.6404762549394


HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

set() 1484.7182088144207



### Subset enumeration

In [71]:
from itertools import chain, combinations
from tqdm import tqdm_notebook as tqdm
s = fs.idSet
powerset = set(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))

In [68]:
def compute_subset_error(S, k, gamma, bound=Bound.cmi):
    fun_t, fun_k = fs._funOfBound(bound)
    w = fs._get_weights(k, gamma, bound)
    score = np.zeros(k+1)
    
    S = frozenset(S)
    no_S = fs.idSet.difference(S)
    
    for t in range(k):
        score[t] = fun_t(no_S, S, t)
    score[k] = fun_k(no_S, S)
    
    return fs.computeError(bound=bound, residual=w @ score)

In [72]:
rank = [(s,compute_subset_error(s, k, 0.9)) for s in tqdm(powerset)]

HBox(children=(IntProgress(value=0, max=256), HTML(value='')))




In [78]:
sorted(sorted(rank, key=lambda x: x[1]), key=lambda x: len(x[0]))

[((), 992.1915099307596),
 ((0,), 985.9146148635153),
 ((2,), 987.5724091817738),
 ((3,), 989.04505046072),
 ((7,), 993.9485417887956),
 ((1,), 994.0373838851236),
 ((6,), 1007.7422011724366),
 ((5,), 1008.9713305048182),
 ((4,), 1014.2775229392186),
 ((0, 3), 591.393300570646),
 ((0, 1), 593.4443694638235),
 ((1, 2), 594.1638474200697),
 ((1, 3), 595.969590012388),
 ((2, 3), 597.0552717297661),
 ((1, 4), 597.4656849203376),
 ((3, 5), 598.3334570372426),
 ((0, 4), 599.3984488739156),
 ((0, 2), 599.5740180354807),
 ((3, 4), 600.3784244412752),
 ((0, 7), 600.4009206302219),
 ((2, 5), 602.274627256905),
 ((0, 6), 602.9615788358916),
 ((0, 5), 603.4226749249827),
 ((2, 7), 603.7287848627649),
 ((3, 6), 604.3716442943826),
 ((1, 5), 604.5542463182067),
 ((1, 6), 604.75872845718),
 ((5, 7), 607.2705831731931),
 ((2, 4), 608.2540276766148),
 ((1, 7), 608.5966335480879),
 ((4, 6), 610.5044991105497),
 ((6, 7), 612.7894767651916),
 ((3, 7), 613.2038586659663),
 ((2, 6), 614.0209086630333),
 ((5

Except for one case the greedy algorithm makes always the best choice:

{0, 1, 2, 4, 5, 6, 7} 108.00216084944557 

{0, 1, 2, 4, 5, 7} 167.35674911495767

{0, 1, 2, 4, 5} 228.54114149996968

{0, 1, 4, 5} 300.6002932042114

{0, 1, 4} 419.9287663231109 => {0, 1, 3}, 416 better

{0, 1} 593.4443694638235

{0} 985.9146148635155

set() 992.1915099307594