In [1]:
from sklearn.model_selection import KFold
from lightgbm import LGBMRegressor
from sklearn.compose import make_column_transformer, make_column_selector, TransformedTargetRegressor
from sklearn.preprocessing import OneHotEncoder
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import KFold


from drn_interactions.config import ExperimentInfo

from drn_interactions.decoding.loaders import FSDecodeDataLoader, FSFastDecodeDataLoader
from drn_interactions.decoding.preprocessors import DecodePreprocessor
from drn_interactions.decoding.encoders import StateEncoder
from drn_interactions.decoding.runners import EncodeRunner
from drn_interactions.decoding.shuffle import shuffle_X

  return warn(


In [2]:
sessions = ExperimentInfo.foot_shock_sessions_10min
session = sessions[0]
loader = FSDecodeDataLoader(session_name=session)
preprocessor = DecodePreprocessor(
    thresh_empty=2,
)
spikes, states = loader()
spikes, states = preprocessor(spikes, states)

ct = make_column_transformer(
    (
        StandardScaler(),
        make_column_selector(dtype_exclude=object)
    ),  
    (
        OneHotEncoder(drop="if_binary"),
        make_column_selector(dtype_include=object)
    ), 
)

estimator = make_pipeline(ct, 
    LGBMRegressor(
        n_estimators=20, 
        n_jobs=-1, 
        force_row_wise=True,
        reg_lambda=0.8,
        )
)

estimator = TransformedTargetRegressor(estimator, transformer=StandardScaler())
cv = KFold(shuffle=True)
encoder = StateEncoder(estimator=estimator, cv=cv, verbose=True)
runner = EncodeRunner(
    loader=loader,
    preprocessor=preprocessor,
    encoder=encoder,
)

In [3]:
pop = runner.run_multiple_pop(sessions=sessions)
state = runner.run_multiple_state(sessions=sessions)
limit = runner.run_multiple_limit(sessions=sessions, min_features=1, max_features=30)
dropout = runner.run_multiple_dropout(sessions=sessions)

100%|██████████| 16/16 [00:02<00:00,  6.28it/s]
100%|██████████| 15/15 [00:02<00:00,  6.24it/s]
100%|██████████| 52/52 [00:12<00:00,  4.30it/s]
100%|██████████| 33/33 [00:06<00:00,  4.97it/s]
100%|██████████| 37/37 [00:07<00:00,  4.93it/s]
100%|██████████| 51/51 [00:12<00:00,  4.23it/s]
100%|██████████| 24/24 [00:04<00:00,  5.20it/s]
100%|██████████| 31/31 [00:06<00:00,  4.94it/s]
100%|██████████| 16/16 [00:01<00:00, 12.84it/s]
100%|██████████| 15/15 [00:01<00:00, 13.04it/s]
100%|██████████| 52/52 [00:04<00:00, 12.87it/s]
100%|██████████| 33/33 [00:02<00:00, 12.41it/s]
100%|██████████| 37/37 [00:02<00:00, 12.73it/s]
100%|██████████| 51/51 [00:03<00:00, 12.89it/s]
100%|██████████| 24/24 [00:01<00:00, 13.00it/s]
100%|██████████| 31/31 [00:02<00:00, 13.05it/s]
100%|██████████| 16/16 [00:32<00:00,  2.04s/it]
100%|██████████| 15/15 [00:30<00:00,  2.04s/it]
100%|██████████| 52/52 [04:34<00:00,  5.28s/it]
100%|██████████| 33/33 [02:40<00:00,  4.85s/it]
100%|██████████| 37/37 [03:05<00:00,  5.

In [4]:

from drn_interactions.config import Config


dd = Config.derived_data_dir / "encoding"

pop.to_csv(dd / "fs_slow - pop.csv", index=False)
state.to_csv(dd / "fs_slow - state.csv", index=False)
limit.to_csv(dd / "fs_slow - limit.csv", index=False)
dropout.to_csv(dd / "fs_slow - dropout.csv", index=False)
    

In [5]:
sessions = ExperimentInfo.foot_shock_sessions_10min
session = sessions[0]
loader = FSFastDecodeDataLoader(session_name=session)
preprocessor = DecodePreprocessor(
    thresh_empty=2,
)
spikes, states = loader()
spikes, states = preprocessor(spikes, states)

ct = make_column_transformer(
    (
        StandardScaler(),
        make_column_selector(dtype_exclude=object)
    ),  
    (
        OneHotEncoder(drop="if_binary"),
        make_column_selector(dtype_include=object)
    ), 
)

estimator = make_pipeline(ct, 
    LGBMRegressor(
        n_estimators=20, 
        n_jobs=-1, 
        force_row_wise=True,
        reg_lambda=0.8,
        )
)

estimator = TransformedTargetRegressor(estimator, transformer=StandardScaler())
cv = KFold(shuffle=True)
encoder = StateEncoder(estimator=estimator, cv=cv, verbose=True)
runner = EncodeRunner(
    loader=loader,
    preprocessor=preprocessor,
    encoder=encoder,
)

In [6]:
dd = Config.derived_data_dir / "encoding"

pop.to_csv(dd / "fs_fast - pop.csv", index=False)
state.to_csv(dd / "fs_fast - state.csv", index=False)
limit.to_csv(dd / "fs_fast - limit.csv", index=False)
dropout.to_csv(dd / "fs_fast - dropout.csv", index=False)
    