In [5]:
from sklearn.model_selection import KFold
from lightgbm import LGBMClassifier

from ssri_interactions.io import load_derived_generic
from ssri_interactions.config import ExperimentInfo, Config
from ssri_interactions.decoding.loaders import StateDecodeDataLoader
from ssri_interactions.decoding.preprocessors import StateDecodePreprocessor
from ssri_interactions.decoding.runners import DecodeRunner
from ssri_interactions.decoding.decoders import Decoder
from ssri_interactions.decoding.shuffle import shuffle_X


%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
# load_derived_generic("lfp_states.csv").query("quality == 'good'")

In [14]:
sessions =load_derived_generic("lfp_states.csv").query("quality == 'good'").session_name.unique().tolist()

loader = StateDecodeDataLoader(session_name=sessions[0], block="pre", t_stop=1800, states_path=Config.derived_data_dir / "lfp_states.csv")
preprocessor = StateDecodePreprocessor(
    thresh_empty=2,
)
spikes, states = loader()
spikes, states = preprocessor(spikes, states)

estimator = LGBMClassifier(n_estimators=50, max_depth=8, num_leaves=30)
cv = KFold(n_splits=5, shuffle=True)
decoder = Decoder(estimator=estimator, cv=cv, shuffler=shuffle_X)
runner = DecodeRunner(
    loader=loader,
    preprocessor=preprocessor,
    decoder=decoder,
    nboot=75,
)

In [15]:
pop, unit = runner.run_multiple(sessions)
limit = runner.run_multiple_limit(sessions=sessions, n_min=1, n_max=31)
dropout = runner.run_multiple_dropout(sessions=sessions, neuron_types=("SR", "SIR", "FF"))

100%|██████████| 75/75 [07:50<00:00,  6.28s/it]
100%|██████████| 75/75 [07:45<00:00,  6.21s/it]
100%|██████████| 75/75 [12:30<00:00, 10.01s/it]  
100%|██████████| 75/75 [23:59<00:00, 19.19s/it]   
100%|██████████| 75/75 [06:39<00:00,  5.33s/it]
100%|██████████| 75/75 [41:12<00:00, 32.97s/it] 
100%|██████████| 75/75 [06:39<00:00,  5.33s/it]
100%|██████████| 75/75 [06:30<00:00,  5.20s/it]
100%|██████████| 75/75 [06:07<00:00,  4.90s/it]
100%|██████████| 75/75 [05:45<00:00,  4.60s/it]
100%|██████████| 75/75 [06:25<00:00,  5.14s/it]
100%|██████████| 75/75 [08:10<00:00,  6.54s/it]
100%|██████████| 75/75 [09:31<00:00,  7.61s/it]


In [16]:
dd = Config.derived_data_dir / "decoding"
dd.mkdir(exist_ok=True)
pop.to_csv(dd / "brain state - pop.csv", index=False)
unit.to_csv(dd / "brain state - unit.csv", index=False)
limit.to_csv(dd / "brain state - limit.csv", index=False)
dropout.to_csv(dd / "brain state - dropout.csv", index=False)