In [1]:
import observational_fear.load as load
from observational_fear.events import get_freeze_starts
from neurobox.wide_transforms import resample
from neurobox.compose import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.base import clone
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
import sklearn.metrics as metrics
from observational_fear.stats import auc
from pathlib import Path
import pandas as pd
import numpy as np
from observational_fear.events import get_freeze_stops



DATA_DIR = Path(r"D:\OFL\one-p experiment\data")
FIG_DIR = Path(r"C:\Users\roryl\repos\observational-fear\figs")

In [58]:
# resample
ROLE = "obs"

def resample_both(df_traces, df_freeze, df_cells, role, interval="300ms"):
    df_traces = resample(df_traces.set_index("time").copy(), interval)
    idx = df_traces.index
    df_traces = df_traces.reset_index().melt(id_vars=["time"]).merge(df_cells)
    df_freeze["time"] = df_freeze["time"].round(2)
    df_freeze = df_freeze.pivot(columns="mouse_name", values=f"was_freezing_{role}", index="time")
    df_freeze = resample(df_freeze, interval).round()
    df_freeze = df_freeze.reindex(idx, method="bfill")
    return df_traces, df_freeze


def preprocess(df_traces, df_freeze, mouse):
    dft = df_traces.loc[lambda x: x.mouse == mouse].pivot(index="time", columns="new_id", values="value").dropna()
    dff = df_freeze[mouse].loc[lambda x: x.index <= dft.index.max()]
    return dft, dff


def fit_model(dft, dff, pipe):
    X_train, X_test, y_train, y_test = train_test_split(dft, dff, shuffle=False)
    p = clone(pipe)
    p.fit(X_train, y_train)
    y_hat = p.predict(X_test)
    return metrics.f1_score(y_test, y_hat)


def model_boot_reps(dft, dff, n_boot, pipe):
    return np.array([fit_model(dft, dff.sample(frac=1), pipe) for _ in range(n_boot)])

def decode(df_traces, df_freeze, df_cells, role, pipe, n_boot, interval="300ms",):
    df_traces, df_freeze = resample_both(df_traces, df_freeze, df_cells, role=role, interval=interval)
    data = []
    for mouse in df_freeze.columns.to_list():
        dft, dff = preprocess(df_traces, df_freeze, mouse)

        score = fit_model(dft, dff, clone(pipe))
        reps = model_boot_reps(dft, dff, n_boot, pipe)
        p = np.mean(reps >= score)
        res = (mouse, score, p)
        data.append(res)
    return pd.DataFrame(data=data, columns=["mouse", "f1_score", "p"])



In [22]:
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier
from sklearn.naive_bayes import GaussianNB

# clf = LogisticRegression(C=1, penalty="elasticnet", solver='saga', l1_ratio=0.1)
# clf = GradientBoostingClassifier(n_estimators=500)
clf = GaussianNB()
pipe = Pipeline([
    ("impute", SimpleImputer()),
    ("scale", StandardScaler()),
    ("clf", clf)
])


In [23]:
# load data
SESSION = "day2"
df_freeze = load.load_freeze(DATA_DIR, session=SESSION)
df_cells = load.load_cells(DATA_DIR)
df_cells["new_id"] = df_cells["new_id"].astype("str") 
df_traces = load.load_traces(DATA_DIR, session=SESSION)

In [59]:
res = decode(df_traces, df_freeze, df_cells, role="obs", pipe=pipe, n_boot=50, interval="300ms")

KeyboardInterrupt: 

In [34]:
res

Unnamed: 0,mouse,f1_score,p
0,B17274,0.65202,0.0
1,B17276,0.877496,0.0
2,B43396,0.745247,0.0
3,PL-OFL-2,0.517437,0.0
4,PL-OFL-4,0.745853,0.0
5,PL-OFL-5,0.820669,0.0
6,PL-OFL-6,0.787282,0.0
7,PL-OFL-7,0.771311,0.0


In [60]:
res = decode(df_traces, df_freeze, df_cells, role="dem", pipe=pipe, n_boot=50, interval="300ms", lag=1)

ValueError: Input contains NaN, infinity or a value too large for dtype('float64').

In [None]:
res

Unnamed: 0,mouse,f1_score,p
0,B17274,0.927835,0.0
1,B17276,0.766267,0.36
2,B43396,0.730271,0.18
3,PL-OFL-2,0.644628,0.96
4,PL-OFL-4,0.739535,0.22
5,PL-OFL-5,0.873073,0.0
6,PL-OFL-6,0.714504,0.0
7,PL-OFL-7,0.760388,0.0


In [44]:
res.shift(-3).ffill()

Unnamed: 0,mouse,f1_score,p
0,PL-OFL-2,0.644628,0.912
1,PL-OFL-4,0.739535,0.182
2,PL-OFL-5,0.873073,0.0
3,PL-OFL-6,0.714504,0.016
4,PL-OFL-7,0.760388,0.0
5,PL-OFL-7,0.760388,0.0
6,PL-OFL-7,0.760388,0.0
7,PL-OFL-7,0.760388,0.0
