In [None]:
from sklearn.model_selection import StratifiedKFold
from lightgbm import LGBMClassifier
import numpy as np
import pandas as pd
from scipy.stats import wilcoxon
from drn_interactions.io import load_derived_generic
from drn_interactions.config import Config, ExperimentInfo
from drn_interactions.decoding.loaders import  FSFastDecodeDataLoaderTwoWindows
from drn_interactions.decoding.preprocessors import DecodePreprocessor
from drn_interactions.decoding.runners import DecodeRunner
from drn_interactions.decoding.decoders import Decoder
from drn_interactions.decoding.shuffle import shuffle_X

In [None]:
def make_fast_ts_runner(session, window_1, window_2, bin_width=0.05):
    loader = FSFastDecodeDataLoaderTwoWindows(
        session_name=session, 
        bin_width=bin_width, 
        window_1=window_1,
        window_2=window_2,
    )
    preprocessor = DecodePreprocessor(
        thresh_empty=1, 
    )
    estimator = LGBMClassifier(n_estimators=40, max_depth=4, num_leaves=30)
    cv = StratifiedKFold(n_splits=5, shuffle=False)
    decoder = Decoder(estimator=estimator, cv=cv, shuffler=shuffle_X)
    runner = DecodeRunner(
        loader=loader,
        preprocessor=preprocessor,
        decoder=decoder,
        nboot=1,
    )
    return runner

In [None]:
dd = Config.derived_data_dir / "decoding"
neuron_types = load_derived_generic("neuron_types.csv").dropna()
sessions = ExperimentInfo.foot_shock_sessions_10min
window_pre = (-0.3, -0.1)
window_post = (0.05, 0.2)

runner = make_fast_ts_runner(
    session=sessions[1],
    window_1=window_pre,
    window_2=window_post, 
    bin_width=0.1
)

pop, unit, unit_shuff = runner.run_multiple(sessions)
limit = runner.run_multiple_limit(sessions, n_min=1, n_max=30)
dropout = runner.run_multiple_dropout(sessions,)

In [None]:
unit.to_csv(dd / "fs_fast - unit.csv", index=False)
pop.to_csv(dd / "fs_fast - pop.csv", index=False)
limit.to_csv(dd / "fs_fast - limit.csv", index=False)
dropout.to_csv(dd / "fs_fast - dropout.csv", index=False)

In [None]:
W, p = pd.merge(
    unit_shuff.rename(columns={"F1 Score": "shuff_score"}),
    unit.rename(columns={"F1 Score": "real_score"}),
).pipe(lambda x: wilcoxon(x.shuff_score, x.real_score))
print(W, p)
print(unit_shuff["F1 Score"].median())
print(unit["F1 Score"].median())