In [1]:
from astro.load import Loader

from astro.preprocess import Preprocessor, GroupedEventPreprocessor
from astro.transforms import GroupSplitter
from astro.responders.rotated import RespondersConfig, RespondersSaver, run_responders
from astro.constants import SESSIONS


from trace_minder.preprocess import TracePreprocessor
from trace_minder.align import GroupedAligner
from trace_minder.trace_aggregation import PrePostAggregator

from copy import deepcopy
import os
from pathlib import Path
from astro.config import Config

In [2]:
notebook_path = Path(os.getcwd())
root_path = notebook_path.parent.parent
env_path = root_path / ".env"

PATHS = Config.from_env(env_path)

In [5]:
HYPERPARAMS = dict(FREQ=0.1, TPRE=5, TPOST=5, EVENTS="first-5", N_BOOT=100)

loader = Loader(data_dir=PATHS.data_dir)  # temp value

trace_preprocessor = TracePreprocessor(
    standardize=True,
    medfilt_kernel_size=None,
    resample_frequency=HYPERPARAMS["FREQ"],
    drop_na=True,
)
grouped_event_preprocessor = GroupedEventPreprocessor(
    df_events_group_col="mouse_name",
    df_events_event_time_col="start_time",
    first_x_events=5,
)
loader_preprocessor = Preprocessor(
    trace_preprocessor=trace_preprocessor,
    grouped_event_preprocessor=grouped_event_preprocessor,
)

group_splitter = GroupSplitter(
    df_mice=loader.load_mice(),
    df_neurons=loader.load_neurons(),
    df_traces_time_col="time",
    excluded_groups=["VEH-VEH"],
    df_neurons_mouse_col="mouse_name",
    df_mice_mouse_col="mouse_name",
    df_neurons_neuron_col="cell_id",
    df_mice_group_col="group",
)


round_precision = 1 if HYPERPARAMS["FREQ"] < 1 else 0
aligner = GroupedAligner(
    t_before=HYPERPARAMS["TPRE"],
    t_after=HYPERPARAMS["TPOST"],
    df_wide_group_mapper=group_splitter.neurons_by_mouse(),
    df_events_event_time_col="start_time",
    df_events_group_col="mouse_name",
    round_precision=round_precision,
)

average_trace_preprocessor = TracePreprocessor()

aggregator = PrePostAggregator(event_idx_col=None)


saver = RespondersSaver(
    root_data_dir=PATHS.derived_data_dir / "rotated_responders",
)

RUN_CONFIG = RespondersConfig(
    loader_preprocessor=loader_preprocessor,
    aligner=aligner,
    average_trace_preprocessor=average_trace_preprocessor,
    aggregator=aggregator,
    n_boot=HYPERPARAMS["N_BOOT"],
)

In [4]:
BLOCK = "CS"


for session in SESSIONS[1:]:
    name = f"{session} - {BLOCK}"
    print(name)

    fn_suffix = "".join([f"__{key}-{value}" for key, value in HYPERPARAMS.items()])
    run_saver = deepcopy(saver)
    run_saver.set_fn_suffix(fn_suffix)

    loader = Loader(
        data_dir=PATHS.data_dir,
        session_name=session,
        block_group=BLOCK,
        preprocessor=RUN_CONFIG.loader_preprocessor,
    )

    run_responders(
        name=name,
        loader=loader,
        responders_config=RUN_CONFIG,
        saver=run_saver,
    )

ret - CS


  if _pandas_api.is_sparse(col):


ext - CS
diff-ret - CS


  if _pandas_api.is_sparse(col):


late-ret - CS


  if _pandas_api.is_sparse(col):
  if _pandas_api.is_sparse(col):


renewal - CS


  if _pandas_api.is_sparse(col):
