In [1]:
from protosc import filter_model
from protosc.simulation import create_simulation_data, create_correlated_data, slightly_correlated_data
from protosc.filter_model import train_xvalidate, select_features
from protosc.final_selection import final_selection
from protosc.wrapper import Wrapper
from collections import defaultdict
import numpy as np
from sklearn.model_selection import train_test_split
import pandas as pd
from protosc.parallel import execute_parallel

In [2]:
# Create data: features (X) and categories (y)
np.random.seed(1928374)
X, y, ground_truth = create_correlated_data()
print(f'features: {X[:5]}')
print(f'categories: {y[:5]}')

features: [[ 1.35946304 -0.05495788 -1.25547915 ... -0.29250157  0.71268609
  -1.15811878]
 [ 1.16041638  0.11066386  0.96355358 ... -0.74329004  0.1515906
  -0.07321478]
 [ 0.30436521  0.61188844 -0.84293664 ... -0.04545197  1.07970547
  -1.35302247]
 [-0.44400432 -0.55848696  0.62296112 ...  0.09136625 -0.86405888
  -1.28561021]
 [-1.91110816 -2.27302481 -1.0157041  ... -0.01774011  0.13587279
  -0.88981648]]
categories: [0 0 0 1 1]


In [3]:
# Slow method (add immediately = False)
slow = Wrapper(X, y, fold_seed=1)
out_slow = slow.wrapper(n_jobs=-1)

# Print outcome in dataframe
df = pd.DataFrame([value for key, value in out_slow.items() if key != 'recurring']).T
df = df.rename(columns={0: 'Model', 1: 'Features', 2: 'Accuracy'})
try:
    df['Recurring features'] = [out_slow['recurring']] * len(df)
except KeyError:
    pass
df

100%|██████████| 8/8 [00:04<00:00,  1.62it/s]


Unnamed: 0,Model,Features,Accuracy,Recurring features
0,"[[862, 485, 497, 788, 473], [663, 106, 27, 371...","[862, 485, 497, 788, 473, 663, 106, 27, 371, 5...",0.758621,[]
1,"[[485, 862, 788, 473, 497], [175, 81, 486, 344...","[485, 862, 788, 473, 497, 175, 81, 486, 344, 5...",0.839286,[]
2,"[[175, 486, 81, 344, 591], [424, 381, 563, 465...","[175, 486, 81, 344, 591, 424, 381, 563, 465, 2...",0.741379,[]
3,"[[862, 485, 497, 788, 473], [895, 555, 19, 768...","[862, 485, 497, 788, 473, 895, 555, 19, 768, 8...",0.75,[]
4,"[[485, 862, 497, 788, 473], [293, 379, 80, 836...","[485, 862, 497, 788, 473, 293, 379, 80, 836, 216]",0.759259,[]
5,"[[371, 27, 663, 106, 528], [496, 955, 193, 903...","[371, 27, 663, 106, 528, 496, 955, 193, 903, 934]",0.683333,[]
6,"[[635, 769, 682, 693, 863], [175, 81, 344, 486...","[635, 769, 682, 693, 863, 175, 81, 344, 486, 5...",0.827586,[]
7,"[[302, 899, 732, 124, 867]]","[302, 899, 732, 124, 867]",0.68,[]


In [4]:
# Fast method (add immediately = True)
fast = Wrapper(X, y, add_im=True, fold_seed=1)
out_fast = fast.wrapper(n_jobs=-1)

# Print outcome in dataframe
df = pd.DataFrame([value for key, value in out_fast.items() if key != 'recurring']).T
df = df.rename(columns={0: 'Model', 1: 'Features', 2: 'Clusters', 3: 'Accuracy'})
try:
    df['Recurring features'] = [out_fast['recurring']] * len(df)
except KeyError:
    pass
df

100%|██████████| 8/8 [00:04<00:00,  1.61it/s]


Unnamed: 0,Model,Features,Clusters,Recurring features
0,"[[175, 81, 486, 344, 591], [862, 485, 497, 788...","[175, 81, 486, 344, 591, 862, 485, 497, 788, 4...",0.741379,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
1,"[[485, 862, 788, 473, 497], [175, 81, 486, 344...","[485, 862, 788, 473, 497, 175, 81, 486, 344, 5...",0.910714,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
2,"[[485, 862, 497, 473, 788], [635, 682, 769, 69...","[485, 862, 497, 473, 788, 635, 682, 769, 693, ...",0.758621,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
3,"[[486, 175, 81, 344, 591], [862, 485, 497, 788...","[486, 175, 81, 344, 591, 862, 485, 497, 788, 4...",0.821429,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
4,"[[81, 175, 486, 344, 591], [485, 862, 497, 788...","[81, 175, 486, 344, 591, 485, 862, 497, 788, 4...",0.796296,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
5,"[[485, 862, 788, 497, 473], [81, 486, 344, 175...","[485, 862, 788, 497, 473, 81, 486, 344, 175, 5...",0.8,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
6,"[[485, 862, 788, 473, 497], [175, 81, 344, 486...","[485, 862, 788, 473, 497, 175, 81, 344, 486, 5...",0.862069,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
7,"[[485, 862, 788, 497, 473], [451, 251, 468, 76...","[485, 862, 788, 497, 473, 451, 251, 468, 766, ...",0.7,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"


In [5]:
# Fast method (add immediately = True) with exclusion (excl=True)
fast_excl = Wrapper(X, y, add_im=True, excl=True, fold_seed=1)
out_fast_excl = fast_excl.wrapper(n_jobs=-1)

# Print outcome in dataframe
df = pd.DataFrame([value for key, value in out_fast_excl.items() if key != 'recurring']).T
df = df.rename(columns={0: 'Model', 1: 'Features', 2: 'Clusters', 3: 'Accuracy'})
try:
    df['Recurring features'] = [out_fast_excl['recurring']] * len(df)
except KeyError:
    pass
df

100%|██████████| 8/8 [00:08<00:00,  1.09s/it]


Unnamed: 0,Model,Features,Clusters,Recurring features
0,"[[175, 81, 486, 344, 591], [862, 485, 497, 788...","[175, 81, 486, 344, 591, 862, 485, 497, 788, 4...",0.741379,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
1,"[[485, 862, 788, 473, 497], [175, 81, 486, 344...","[485, 862, 788, 473, 497, 175, 81, 486, 344, 5...",0.910714,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
2,"[[485, 862, 497, 473, 788], [635, 682, 769, 69...","[485, 862, 497, 473, 788, 635, 682, 769, 693, ...",0.758621,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
3,"[[486, 175, 81, 344, 591], [862, 485, 497, 788...","[486, 175, 81, 344, 591, 862, 485, 497, 788, 4...",0.821429,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
4,"[[81, 175, 486, 344, 591], [485, 862, 497, 788...","[81, 175, 486, 344, 591, 485, 862, 497, 788, 4...",0.796296,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
5,"[[485, 862, 788, 497, 473], [81, 486, 344, 175...","[485, 862, 788, 497, 473, 81, 486, 344, 175, 5...",0.8,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
6,"[[485, 862, 788, 473, 497], [175, 81, 344, 486...","[485, 862, 788, 473, 497, 175, 81, 344, 486, 5...",0.862069,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
7,"[[485, 862, 788, 497, 473], [451, 251, 468, 76...","[485, 862, 788, 497, 473, 451, 251, 468, 766, ...",0.7,"[591, 81, 175, 788, 344, 862, 473, 485, 486, 497]"
