In [2]:
from sklearn.model_selection import KFold
from lightgbm import LGBMRegressor
from sklearn.compose import make_column_transformer, make_column_selector, TransformedTargetRegressor
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import KFold


from ssri_interactions.config import ExperimentInfo, Config
from ssri_interactions.io import load_derived_generic
from ssri_interactions.decoding.loaders import FSDecodeDataLoader, FSFastDecodeDataLoader
from ssri_interactions.decoding.preprocessors import DecodePreprocessor
from ssri_interactions.decoding.encoders import StateEncoder
from ssri_interactions.decoding.runners import EncodeRunner
from ssri_interactions.decoding.shuffle import shuffle_X

In [3]:
neuron_types = load_derived_generic("neuron_types.csv").query("group in ('CIT', 'SAL')")
sessions = neuron_types.query("experiment_name == 'HAMILTON'").session_name.unique().tolist()
session = sessions[0]
loader = FSDecodeDataLoader(session_name=session)
preprocessor = DecodePreprocessor(
    thresh_empty=2,
)
spikes, states = loader()
spikes, states = preprocessor(spikes, states)

ct = make_column_transformer(
    (
        StandardScaler(),
        make_column_selector(dtype_exclude=object)
    ),  
    (
        OneHotEncoder(drop="if_binary"),
        make_column_selector(dtype_include=object)
    ), 
)

estimator = make_pipeline(ct, 
    LGBMRegressor(
        n_estimators=20, 
        n_jobs=-1, 
        force_row_wise=True,
        reg_lambda=0.8,
        )
)

estimator = TransformedTargetRegressor(estimator, transformer=StandardScaler())
cv = KFold(shuffle=True)
encoder = StateEncoder(estimator=estimator, cv=cv, verbose=True)
runner = EncodeRunner(
    loader=loader,
    preprocessor=preprocessor,
    encoder=encoder,
)

In [6]:
pop = runner.run_multiple_pop(sessions=sessions)
pop_shuffle = runner.run_multiple_pop(sessions=sessions, shuffle=True)
state = runner.run_multiple_state(sessions=sessions)
state_shuffle = runner.run_multiple_state(sessions=sessions, shuffle=True)
limit = runner.run_multiple_limit(sessions=sessions, min_features=1, max_features=15)
dropout = runner.run_multiple_dropout(sessions=sessions)

100%|██████████| 17/17 [03:19<00:00, 11.75s/it]
100%|██████████| 19/19 [04:17<00:00, 13.53s/it]
100%|██████████| 20/20 [04:31<00:00, 13.59s/it]
100%|██████████| 10/10 [00:56<00:00,  5.61s/it]
100%|██████████| 84/84 [1:57:44<00:00, 84.10s/it]   
100%|██████████| 56/56 [33:23<00:00, 35.77s/it]
100%|██████████| 61/61 [1:14:57<00:00, 73.73s/it] 
100%|██████████| 55/55 [30:21<00:00, 33.11s/it]
100%|██████████| 52/52 [24:55<00:00, 28.76s/it]
100%|██████████| 3/3 [00:49<00:00, 16.61s/it]
100%|██████████| 3/3 [00:52<00:00, 17.39s/it]
100%|██████████| 3/3 [00:53<00:00, 17.68s/it]
100%|██████████| 3/3 [00:25<00:00,  8.57s/it]
100%|██████████| 3/3 [04:05<00:00, 81.73s/it]
100%|██████████| 3/3 [02:14<00:00, 44.86s/it]
100%|██████████| 3/3 [02:45<00:00, 55.30s/it]
100%|██████████| 3/3 [05:29<00:00, 109.88s/it]
100%|██████████| 3/3 [02:32<00:00, 50.92s/it]


In [7]:
dd = Config.derived_data_dir / "encoding"
dd.mkdir(exist_ok=True)
# pop.to_csv(dd / "fs_slow - pop.csv", index=False)
# pop_shuffle.to_csv(dd / "fs_slow - pop shuffle.csv", index=False)
# state.to_csv(dd / "fs_slow - state.csv", index=False)
# state_shuffle.to_csv(dd / "fs_slow - state shuffle.csv", index=False)
limit.to_csv(dd / "fs_slow - limit.csv", index=False)
dropout.to_csv(dd / "fs_slow - dropout.csv", index=False)
    

In [8]:
loader = FSFastDecodeDataLoader(session_name=session)
preprocessor = DecodePreprocessor(
    thresh_empty=2,
)
spikes, states = loader()
spikes, states = preprocessor(spikes, states)

ct = make_column_transformer(
    (
        StandardScaler(),
        make_column_selector(dtype_exclude=object)
    ),  
    (
        OneHotEncoder(drop="if_binary"),
        make_column_selector(dtype_include=object)
    ), 
)

estimator = make_pipeline(ct, 
    LGBMRegressor(
        n_estimators=20, 
        n_jobs=-1, 
        force_row_wise=True,
        reg_lambda=0.8,
        )
)

estimator = TransformedTargetRegressor(estimator, transformer=StandardScaler())
cv = KFold(shuffle=True)
encoder = StateEncoder(estimator=estimator, cv=cv, verbose=True)
runner = EncodeRunner(
    loader=loader,
    preprocessor=preprocessor,
    encoder=encoder,
)

In [9]:
pop = runner.run_multiple_pop(sessions=sessions)
pop_shuffle = runner.run_multiple_pop(sessions=sessions, shuffle=True)
state = runner.run_multiple_state(sessions=sessions)
state_shuffle = runner.run_multiple_state(sessions=sessions, shuffle=True)
limit = runner.run_multiple_limit(sessions=sessions, min_features=1, max_features=15)
dropout = runner.run_multiple_dropout(sessions=sessions)

100%|██████████| 17/17 [00:17<00:00,  1.05s/it]
100%|██████████| 19/19 [00:16<00:00,  1.17it/s]
100%|██████████| 20/20 [00:20<00:00,  1.00s/it]
100%|██████████| 10/10 [00:01<00:00,  5.90it/s]
100%|██████████| 84/84 [02:05<00:00,  1.50s/it]
100%|██████████| 56/56 [01:25<00:00,  1.53s/it]
100%|██████████| 61/61 [01:19<00:00,  1.31s/it]
100%|██████████| 55/55 [01:18<00:00,  1.42s/it]
100%|██████████| 52/52 [00:59<00:00,  1.14s/it]
100%|██████████| 17/17 [00:16<00:00,  1.02it/s]
100%|██████████| 19/19 [00:16<00:00,  1.15it/s]
100%|██████████| 20/20 [00:20<00:00,  1.03s/it]
100%|██████████| 10/10 [00:01<00:00,  5.84it/s]
100%|██████████| 84/84 [01:50<00:00,  1.32s/it]
100%|██████████| 56/56 [01:17<00:00,  1.38s/it]
100%|██████████| 61/61 [01:23<00:00,  1.37s/it]
100%|██████████| 55/55 [01:18<00:00,  1.43s/it]
100%|██████████| 52/52 [00:59<00:00,  1.15s/it]
100%|██████████| 17/17 [00:01<00:00,  8.53it/s]
100%|██████████| 19/19 [00:02<00:00,  8.95it/s]
100%|██████████| 20/20 [00:04<00:00,  4.

In [10]:
dd = Config.derived_data_dir / "encoding"
dd.mkdir(exist_ok=True)
pop.to_csv(dd / "fs_fast - pop.csv", index=False)
state.to_csv(dd / "fs_fast - state.csv", index=False)
limit.to_csv(dd / "fs_fast - limit.csv", index=False)
dropout.to_csv(dd / "fs_fast - dropout.csv", index=False)
    