In [3]:
from sklearn.model_selection import KFold
from lightgbm import LGBMRegressor
from sklearn.compose import make_column_transformer, make_column_selector, TransformedTargetRegressor
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import KFold


from ssri_interactions.config import ExperimentInfo, Config
from ssri_interactions.io import load_derived_generic
from ssri_interactions.decoding.loaders import FSDecodeDataLoader, FSFastDecodeDataLoader
from ssri_interactions.decoding.preprocessors import DecodePreprocessor
from ssri_interactions.decoding.encoders import StateEncoder
from ssri_interactions.decoding.runners import EncodeRunner
from ssri_interactions.decoding.shuffle import shuffle_X

  return warn(


In [4]:
def make_slow_ts_runner(loader):
    preprocessor = DecodePreprocessor(
    thresh_empty=2,
    )

    ct = make_column_transformer(
        (
            StandardScaler(),
            make_column_selector(dtype_exclude=object)
        ),  
        (
            OneHotEncoder(drop="if_binary"),
            make_column_selector(dtype_include=object)
        ), 
    )

    estimator = make_pipeline(ct, 
        LGBMRegressor(
            n_estimators=20, 
            n_jobs=-1, 
            force_row_wise=True,
            reg_lambda=0.8,
            )
    )

    estimator = TransformedTargetRegressor(estimator, transformer=StandardScaler())
    cv = KFold(shuffle=True)
    encoder = StateEncoder(estimator=estimator, cv=cv, verbose=True)
    runner = EncodeRunner(
        loader=loader,
        preprocessor=preprocessor,
        encoder=encoder,
    )
    return runner

In [5]:
neuron_types = load_derived_generic("neuron_types.csv").query("group in ('CIT', 'SAL')")
sessions = neuron_types.query("experiment_name == 'HAMILTON'").session_name.unique().tolist()
session = sessions[0]
loader_shock_only = FSDecodeDataLoader(session_name=session, t_stop=600)
loader_post_included = FSFastDecodeDataLoader(session_name=session, t_stop=1200)

In [6]:
runner_shock_only = make_slow_ts_runner(loader_shock_only)

pop = runner_shock_only.run_multiple_pop(sessions=sessions)
pop_shuffle = runner_shock_only.run_multiple_pop(sessions=sessions, shuffle=True)
state = runner_shock_only.run_multiple_state(sessions=sessions)
state_shuffle = runner_shock_only.run_multiple_state(sessions=sessions, shuffle=True)
limit = runner_shock_only.run_multiple_limit(sessions=sessions, min_features=1, max_features=15)
dropout = runner_shock_only.run_multiple_dropout(sessions=sessions)

dd = Config.derived_data_dir / "encoding"
dd.mkdir(exist_ok=True)
pop.to_csv(dd / "fs_slow - shock_only - pop.csv", index=False)
pop_shuffle.to_csv(dd / "fs_slow - shock_only - pop shuffle.csv", index=False)
state.to_csv(dd / "fs_slow - shock_only - state.csv", index=False)
state_shuffle.to_csv(dd / "fs_slow - shock_only - state shuffle.csv", index=False)
limit.to_csv(dd / "fs_slow - shock_only - limit.csv", index=False)
dropout.to_csv(dd / "fs_slow - shock_only - dropout.csv", index=False)
    

100%|██████████| 17/17 [00:03<00:00,  5.37it/s]
100%|██████████| 19/19 [00:03<00:00,  6.10it/s]
100%|██████████| 20/20 [00:03<00:00,  5.97it/s]
100%|██████████| 10/10 [00:01<00:00,  6.39it/s]
100%|██████████| 84/84 [00:22<00:00,  3.70it/s]
100%|██████████| 56/56 [00:12<00:00,  4.60it/s]
100%|██████████| 61/61 [00:13<00:00,  4.50it/s]
100%|██████████| 55/55 [00:11<00:00,  4.84it/s]
100%|██████████| 52/52 [00:09<00:00,  5.54it/s]
100%|██████████| 17/17 [00:02<00:00,  7.62it/s]
100%|██████████| 19/19 [00:02<00:00,  7.66it/s]
100%|██████████| 20/20 [00:02<00:00,  7.31it/s]
100%|██████████| 10/10 [00:01<00:00,  7.88it/s]
100%|██████████| 84/84 [00:16<00:00,  4.97it/s]
100%|██████████| 56/56 [00:09<00:00,  5.61it/s]
100%|██████████| 61/61 [00:11<00:00,  5.49it/s]
100%|██████████| 55/55 [00:09<00:00,  5.57it/s]
100%|██████████| 52/52 [00:08<00:00,  5.92it/s]
100%|██████████| 17/17 [00:00<00:00, 17.94it/s]
100%|██████████| 19/19 [00:01<00:00, 17.66it/s]
100%|██████████| 20/20 [00:01<00:00, 17.

In [8]:
runner_post_included = make_slow_ts_runner(loader_post_included)

pop = runner_post_included.run_multiple_pop(sessions=sessions)
pop_shuffle = runner_post_included.run_multiple_pop(sessions=sessions, shuffle=True)
state = runner_post_included.run_multiple_state(sessions=sessions)
state_shuffle = runner_post_included.run_multiple_state(sessions=sessions, shuffle=True)
limit = runner_post_included.run_multiple_limit(sessions=sessions, min_features=1, max_features=15)
dropout = runner_post_included.run_multiple_dropout(sessions=sessions)

dd = Config.derived_data_dir / "encoding"
dd.mkdir(exist_ok=True)
pop.to_csv(dd / "fs_slow - post_included - pop.csv", index=False)
pop_shuffle.to_csv(dd / "fs_slow - post_included - pop shuffle.csv", index=False)
state.to_csv(dd / "fs_slow - post_included - state.csv", index=False)
state_shuffle.to_csv(dd / "fs_slow - post_included - state shuffle.csv", index=False)
limit.to_csv(dd / "fs_slow - post_included - limit.csv", index=False)
dropout.to_csv(dd / "fs_slow - post_included - dropout.csv", index=False)

100%|██████████| 17/17 [00:03<00:00,  5.44it/s]
100%|██████████| 19/19 [00:02<00:00,  6.79it/s]
100%|██████████| 20/20 [00:04<00:00,  4.84it/s]
100%|██████████| 10/10 [00:00<00:00, 12.90it/s]
100%|██████████| 84/84 [00:58<00:00,  1.44it/s]
100%|██████████| 56/56 [00:42<00:00,  1.33it/s]
100%|██████████| 61/61 [00:39<00:00,  1.56it/s]
100%|██████████| 55/55 [00:48<00:00,  1.14it/s]
100%|██████████| 52/52 [00:19<00:00,  2.60it/s]
100%|██████████| 17/17 [00:02<00:00,  5.69it/s]
100%|██████████| 19/19 [00:02<00:00,  6.89it/s]
100%|██████████| 20/20 [00:04<00:00,  4.81it/s]
100%|██████████| 10/10 [00:00<00:00, 12.97it/s]
100%|██████████| 84/84 [00:59<00:00,  1.41it/s]
100%|██████████| 56/56 [00:42<00:00,  1.31it/s]
100%|██████████| 61/61 [00:39<00:00,  1.55it/s]
100%|██████████| 55/55 [00:48<00:00,  1.13it/s]
100%|██████████| 52/52 [00:20<00:00,  2.55it/s]
100%|██████████| 17/17 [00:01<00:00, 12.36it/s]
100%|██████████| 19/19 [00:01<00:00, 15.08it/s]
100%|██████████| 20/20 [00:01<00:00, 11.

In [9]:
loader = FSFastDecodeDataLoader(session_name=session)
preprocessor = DecodePreprocessor(
    thresh_empty=2,
)
spikes, states = loader()
spikes, states = preprocessor(spikes, states)

ct = make_column_transformer(
    (
        StandardScaler(),
        make_column_selector(dtype_exclude=object)
    ),  
    (
        OneHotEncoder(drop="if_binary"),
        make_column_selector(dtype_include=object)
    ), 
)

estimator = make_pipeline(ct, 
    LGBMRegressor(
        n_estimators=20, 
        n_jobs=-1, 
        force_row_wise=True,
        reg_lambda=0.8,
        )
)

estimator = TransformedTargetRegressor(estimator, transformer=StandardScaler())
cv = KFold(shuffle=True)
encoder = StateEncoder(estimator=estimator, cv=cv, verbose=True)
runner = EncodeRunner(
    loader=loader,
    preprocessor=preprocessor,
    encoder=encoder,
)

In [10]:
pop = runner.run_multiple_pop(sessions=sessions)
pop_shuffle = runner.run_multiple_pop(sessions=sessions, shuffle=True)
state = runner.run_multiple_state(sessions=sessions)
state_shuffle = runner.run_multiple_state(sessions=sessions, shuffle=True)
limit = runner.run_multiple_limit(sessions=sessions, min_features=1, max_features=15)
dropout = runner.run_multiple_dropout(sessions=sessions)

100%|██████████| 17/17 [00:03<00:00,  5.53it/s]
100%|██████████| 19/19 [00:02<00:00,  6.68it/s]
100%|██████████| 20/20 [00:04<00:00,  4.87it/s]
100%|██████████| 10/10 [00:00<00:00, 12.61it/s]
100%|██████████| 84/84 [00:58<00:00,  1.45it/s]
100%|██████████| 56/56 [00:42<00:00,  1.31it/s]
100%|██████████| 61/61 [00:38<00:00,  1.57it/s]
100%|██████████| 55/55 [00:48<00:00,  1.14it/s]
100%|██████████| 52/52 [00:20<00:00,  2.56it/s]
100%|██████████| 17/17 [00:03<00:00,  5.56it/s]
100%|██████████| 19/19 [00:02<00:00,  6.71it/s]
100%|██████████| 20/20 [00:04<00:00,  4.81it/s]
100%|██████████| 10/10 [00:00<00:00, 12.63it/s]
100%|██████████| 84/84 [00:58<00:00,  1.42it/s]
100%|██████████| 56/56 [00:43<00:00,  1.30it/s]
100%|██████████| 61/61 [00:39<00:00,  1.55it/s]
100%|██████████| 55/55 [00:48<00:00,  1.13it/s]
100%|██████████| 52/52 [00:20<00:00,  2.50it/s]
100%|██████████| 17/17 [00:01<00:00, 11.30it/s]
100%|██████████| 19/19 [00:01<00:00, 14.14it/s]
100%|██████████| 20/20 [00:01<00:00, 10.

In [11]:
dd = Config.derived_data_dir / "encoding"
dd.mkdir(exist_ok=True)
pop.to_csv(dd / "fs_fast - pop.csv", index=False)
state.to_csv(dd / "fs_fast - state.csv", index=False)
limit.to_csv(dd / "fs_fast - limit.csv", index=False)
dropout.to_csv(dd / "fs_fast - dropout.csv", index=False)
    