In [None]:
from drn_interactions.io import load_spikes, load_events, load_derived_generic
from drn_interactions.decoding.offset_decoder import OffsetDecoder
from drn_interactions.transforms.nbox_transforms import align_to_data_by

from spiketimes.df.surrogates import shuffled_isi_spiketrains_by
from scipy.stats import zscore
from binit.bin import which_bin
import numpy as np
import pandas as pd
from tqdm import tqdm

from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.pipeline import make_pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.linear_model import LogisticRegressionCV
from sklearn.decomposition import NMF, PCA
from sklearn.model_selection import cross_val_score, KFold, StratifiedKFold
from sklearn.feature_extraction.text import TfidfTransformer
from sklearn.base import clone
from sklearn.metrics.pairwise import pairwise_distances



In [None]:
def get_aligned(df_spikes, df_events,):
    df = align_to_data_by(
        df_spikes, 
        df_events,
        time_before_event=0.5,
        time_after_event=1.5,
        df_data_cell_col="neuron_id",
        df_data_group_col="session_name",
        df_events_group_colname="session_name",
        df_events_timestamp_col="event_s",
        df_data_time_col="spiketimes"
    )
    bins = np.arange(-0.5, 1.5, 0.1)
    df["bin"] = np.round(which_bin(df["aligned"].values, bins), 2)
    df = df.groupby(["neuron_id", "event", "bin"]).apply(len).to_frame("counts").reset_index()
    return df.pivot(index=["event", "bin"], columns="neuron_id", values="counts").fillna(0)

def offset_decode(df_spikes, df_events, sessions, estimator, cv, scoring="f1_macro"):
    neurons_sub = df_spikes[["neuron_id", "session_name"]].drop_duplicates()
    out = []
    for session in tqdm(sessions):
        shuffled = shuffled_isi_spiketrains_by(df_spikes, by_col="neuron_id").merge(neurons_sub)
        spikes_true = get_aligned(
            df_spikes.query(f"session_name == '{session}'"),
            df_events
            )
        spikes_fake = get_aligned(
            shuffled.query(f"session_name == '{session}'"),
            df_events
        )
        decoder = OffsetDecoder(estimator=clone(estimator), cv=cv, scoring=scoring)
        res = decoder.fit_models([(spikes_true, spikes_fake)])
        out.append(res.assign(session_name = session))
    return out