In [1]:
from astro.load import Loader
from astro.preprocess import (
    Preprocessor,
    TracePreprocessor,
    GroupedEventPreprocessor,
)
from astro.transforms.groups import GroupSplitter
from astro.constants import SESSIONS, SESSION_MAPPER
from astro.decoding_alltime.preprocess import ATDecodePreprocessor, latency_mask_factory

from trace_minder.align import GroupedAligner

from pathlib import Path

from dataclasses import dataclass, field
from copy import deepcopy
from typing import Optional, Dict, List, Tuple
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

In [2]:
from sklearn.preprocessing import (
    StandardScaler,
    RobustScaler,
    MinMaxScaler,
    PowerTransformer,
)
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.model_selection import KFold, cross_val_score, TimeSeriesSplit

from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier


from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA, KernelPCA, SparsePCA, TruncatedSVD, NMF
from sklearn.cross_decomposition import CCA, PLSCanonical, PLSRegression
from sklearn.manifold import TSNE, Isomap, LocallyLinearEmbedding

from sklearn.model_selection import BaseCrossValidator
from sklearn.base import BaseEstimator

In [3]:
DATA_DIR = Path("/Users/ruairiosullivan/Desktop/astro_data-02/dataset-02")
OUTPUT_DIR = Path("/Users/ruairiosullivan/Desktop/astro_data-02/Output-02")

In [4]:
loader = Loader(data_dir=DATA_DIR)

trace_preprocessor = TracePreprocessor(
    max_time=600,
    standardize=True,
    medfilt_kernel_size=None,
    resample_frequency=0.2,
)

event_preprocessor = GroupedEventPreprocessor(
    df_events_group_col="mouse_name",
    df_events_event_time_col="start_time",
    first_x_events=5,
)


loader_preprocessor = Preprocessor(
    trace_preprocessor=trace_preprocessor,
    grouped_event_preprocessor=event_preprocessor,
)

loader.set_preprocessor(loader_preprocessor)

In [5]:
group_info = GroupSplitter(
    df_mice=loader.load_mice(),
    df_neurons=loader.load_neurons(),
    df_traces_time_col="time",
    excluded_groups=["VEH-VEH"],
    df_neurons_mouse_col="mouse_name",
    df_mice_mouse_col="mouse_name",
    df_neurons_neuron_col="cell_id",
    df_mice_group_col="group",
)

In [6]:
def aligner_fac() -> GroupedAligner:
    aligner = GroupedAligner(
        t_before=30,
        t_after=30,
        df_wide_group_mapper=group_info.neurons_by_mouse(),
        df_events_event_time_col="start_time",
        df_events_group_col="mouse_name",
    )
    return aligner

In [7]:
def all_time_decode_pp_fac(
    window_1: Tuple[int, int], window_2: Tuple[int, int]
) -> ATDecodePreprocessor:
    min_window_checker = latency_mask_factory(t_min=window_1[0], t_max=window_1[1])
    max_window_checker = latency_mask_factory(t_min=window_2[0], t_max=window_2[1])

    preprocessor = ATDecodePreprocessor(
        aligner=aligner_fac(),
        latency_out_of_block=min_window_checker,
        latency_in_block=max_window_checker,
    )
    return preprocessor

In [8]:
SESSION_NAME = "ret"
GROUP = "CNO-VEH"


df_traces = loader.load_traces(session_name=SESSION_NAME)
df_events = loader.load_blockstarts(session_name=SESSION_NAME, block_group="CS")
trace_dict = group_info.traces_by_group(df_traces=df_traces)
df_traces_group = trace_dict[GROUP].copy()


preprocessor = all_time_decode_pp_fac(window_1=(-5, 0), window_2=(0, 5))
temporal_df, df, block_ts = preprocessor(
    df_traces=df_traces,
    block_starts=df_events,
)

In [9]:
from tsfresh import extract_features, extract_relevant_features, select_features
from tsfresh.utilities.dataframe_functions import impute
from tsfresh.utilities.dataframe_functions import roll_time_series

df["id"] = 1
df["time"] = temporal_df["time"]
df_roll = roll_time_series(
    df, column_id="id", column_sort="time", min_timeshift=10, max_timeshift=10
)

block_ts_aug = pd.concat([block_ts.to_frame("block"), temporal_df[["time"]]], axis=1)
block_ts_mapper = block_ts_aug.set_index("time").to_dict()
blocks_roll = df_roll["time"].map(block_ts_mapper["block"])

  df["id"] = 1
  df["time"] = temporal_df["time"]
Rolling: 100%|██████████| 20/20 [00:01<00:00, 14.96it/s]


In [11]:
from sklearn.base import BaseEstimator, TransformerMixin
from tsfresh import extract_features
from tsfresh.utilities.dataframe_functions import impute


class TSFreshFeatureExtractor(BaseEstimator, TransformerMixin):
    def __init__(self, column_id, column_sort, tsfresh_params=None):
        self.column_id = column_id
        self.column_sort = column_sort
        self.tsfresh_params = tsfresh_params
        self.extracted_features_ = None

    def fit(self, X, y=None):
        # Use tsfresh to extract features from X
        extracted_features = extract_features(
            X,
            column_id=self.column_id,
            column_sort=self.column_sort,
            # default_fc_parameters=self.tsfresh_params,
            # n_jobs=8
        )

        # Impute to handle any NaNs in the extracted features
        impute(extracted_features)

        # Store the extracted features for use in transform
        self.extracted_features_ = extracted_features.columns
        return self

    def transform(self, X):
        # Extract the same features from new data
        new_features = extract_features(
            X,
            column_id=self.column_id,
            column_sort=self.column_sort,
            default_fc_parameters=self.tsfresh_params,
        )

        # Impute to handle any NaNs
        impute(new_features)

        # Ensure that only the features extracted during fit are returned
        return new_features[self.extracted_features_]

In [12]:
from sklearn.model_selection import train_test_split

y = LabelEncoder().fit_transform(blocks_roll)

X_train, X_test, y_train, y_test = train_test_split(
    df_roll, y, test_size=0.2, random_state=42, shuffle=False
)

In [14]:
from tsfresh.transformers import RelevantFeatureAugmenter
from sklearn.pipeline import Pipeline
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import accuracy_score, roc_auc_score, f1_score

pipe = Pipeline(
    [
        (
            "augmenter",
            TSFreshFeatureExtractor(
                column_id="id",
                column_sort="time",
            ),
        ),
        ("scaler", StandardScaler()),
        ("classifier", RandomForestClassifier()),
    ]
)

# X_fill = pd.DataFrame(index=X_train.index)

pipe.fit(X_train[:500], y_train[:500])
score = pipe.score(X_test, y_test)

Feature Extraction: 100%|██████████| 20/20 [01:53<00:00,  5.66s/it]
 '1050__partial_autocorrelation__lag_6'
 '1050__partial_autocorrelation__lag_7' ...
 '881__agg_linear_trend__attr_"stderr"__chunk_len_50__f_agg_"mean"'
 '881__agg_linear_trend__attr_"stderr"__chunk_len_50__f_agg_"var"'
 '881__query_similarity_count__query_None__threshold_0.0'] did not have any finite values. Filling with zeros.
Feature Extraction:   0%|          | 0/20 [00:22<?, ?it/s]


KeyboardInterrupt: 

In [None]:
X_train

In [95]:
len(df_roll) // 5

2178

In [42]:
X_fill

Unnamed: 0,1050,1051,1053,1054,1055,1056,1057,1059,106,1060,...,874,875,876,877,878,879,880,881,id,time
0,0.116642,-0.613115,-0.246569,-0.478063,-0.644784,-0.383528,-0.056925,-0.686820,-0.448807,-0.578487,...,-0.061082,0.138418,1.752519,-0.693676,1.085712,-0.679769,-0.414720,-0.312678,"(1, 161.0)",160.0
1,0.196386,-0.669786,0.264819,-0.465572,-0.556450,-0.379592,-0.186874,-0.656101,-0.304010,-0.244421,...,-0.087779,-0.787733,1.788122,-0.634733,1.071781,-0.544438,-0.345635,-0.192531,"(1, 161.0)",160.1
2,0.065824,-0.627307,0.729832,-0.150380,-0.579198,-0.367917,-0.021842,-0.732197,-0.350832,-0.024140,...,0.237805,-0.203025,1.779467,-0.315876,1.092088,-0.672063,-0.439554,-0.133469,"(1, 161.0)",160.2
3,-0.086976,-0.642197,-0.568010,-0.133038,-0.598893,-0.362637,-0.122581,-0.712941,-0.413132,-0.358385,...,0.344293,0.163872,1.947641,-0.489951,0.835358,-0.741365,-0.413940,-0.117640,"(1, 161.0)",160.3
4,-0.103283,-0.638236,-0.311149,-0.217314,-0.518474,-0.372620,-0.208191,-0.692670,-0.040957,-0.433270,...,0.413830,-0.060707,2.035557,-0.727974,0.852361,-0.454868,-0.539668,0.085718,"(1, 161.0)",160.4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
16412,0.264228,-0.590877,0.135353,-0.066490,2.528867,-0.314119,-0.545636,-0.791958,-0.129963,-0.196052,...,-0.685385,0.156685,-0.451156,-0.151528,0.031965,-0.956455,0.665049,2.173954,"(1, 370.2)",369.2
16413,0.014535,-0.447654,0.075967,-0.059001,2.504348,-0.316551,-0.453152,-0.563254,-0.360643,-0.082957,...,-0.845010,-0.405260,-0.461585,-0.220559,0.032617,-1.054943,0.685232,2.323323,"(1, 370.2)",369.3
16414,0.682235,-0.625413,0.329136,-0.154643,2.311841,-0.325180,-0.480219,-0.759900,-0.196956,-0.147195,...,-0.770554,-0.491761,-0.619808,-0.029801,-0.167523,-1.271149,0.589025,2.019487,"(1, 370.2)",369.4
16415,0.232931,-0.474859,0.656474,0.159828,2.266554,-0.231542,-0.522977,-0.653655,-0.365724,-0.107160,...,-0.732111,0.000079,-0.519478,-0.413701,-0.121853,-1.229845,0.572479,1.875447,"(1, 370.2)",369.5


In [43]:
y_train

0       out_of_block
1       out_of_block
2       out_of_block
3       out_of_block
4       out_of_block
            ...     
1495        in_block
1496        in_block
1497        in_block
1498        in_block
1499        in_block
Length: 1500, dtype: object