# Boilerplate

In [1]:
from typing import Union
from allensdk.brain_observatory.ecephys.behavior_ecephys_session import BehaviorEcephysSession


from allensdk.brain_observatory.behavior.behavior_project_cache import (
    VisualBehaviorNeuropixelsProjectCache,
)
import brain_observatory_utilities.datasets.behavior.data_formatting as behavior_utils

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns

from hmmlearn import hmm, vhmm
from sklearn.decomposition import PCA, KernelPCA, FactorAnalysis
from sklearn.preprocessing import StandardScaler

import utils

import warnings
warnings.filterwarnings("ignore")

# Random seed
SEED = 42

%matplotlib inline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cache_dir = utils.get_data_root()
cache = VisualBehaviorNeuropixelsProjectCache.from_local_cache(
    cache_dir=cache_dir, use_static_cache=True
)

In [3]:
ephys_table = cache.get_ecephys_session_table()
ephys_table = ephys_table.loc[(~ephys_table["behavior_session_id"].isna()) & 
                              (ephys_table["experience_level"] == "Familiar")
                             ]

In [None]:
# session_id = 1124507277
# session_id = 1069461581
# session_ids = [1124507277, 1069461581, ephys_table.index[10], 1069461581]
session_ids = np.random.choice(ephys_table.index, 10)
sessions = [cache.get_ecephys_session(session_id) for session_id in session_ids]

# Defining Behavioral Metrics

In [None]:
def align_to_stimulus(df: pd.DataFrame, session: BehaviorEcephysSession, active: bool = True) -> pd.DataFrame:
    if "timestamps" not in df.columns:
        raise ValueError("column timestamps must be present in df.")
    
    stim_presentations = session.stimulus_presentations
    if active:
        stim_presentations = stim_presentations.loc[stim_presentations["active"]]
    
    df = df.loc[(stim_presentations["start_time"].min() <= df["timestamps"]) 
                & (df["timestamps"] <= stim_presentations["end_time"].max())]
    bins = pd.concat([pd.Series([0]), stim_presentations["end_time"]])
    labels = stim_presentations.index
    stimulus_id_aligned = pd.cut(df["timestamps"], bins=bins, labels=labels, include_lowest=True, right=False)
    df = pd.concat([pd.Series(stimulus_id_aligned, name="stimulus_id"), df], axis=1)
    return df

In [None]:
def get_saccade_lengths(session_eye: pd.DataFrame) -> pd.DataFrame:
    x_y_dist = session_eye[["pupil_center_x", "pupil_center_y"]].diff()
    distance = pd.Series(
        np.sqrt(np.square(x_y_dist["pupil_center_x"]) + np.square(x_y_dist["pupil_center_y"])),
        name="saccade_length"
    )
    distance = pd.concat([session_eye["timestamps"], distance],axis=1)
    return distance

In [None]:
def get_behavior_metrics(
    session: BehaviorEcephysSession, 
    window: int = 20,
    center: bool = True,
) -> pd.DataFrame:
    eye = session.eye_tracking
    eye = eye.loc[eye["likely_blink"] != True]
    eye_metrics = eye[["timestamps", "pupil_area", "pupil_center_x", "pupil_center_y"]]
    saccade_lengths = get_saccade_lengths(eye_metrics)
    eye_metrics = eye_metrics.merge(saccade_lengths, on="timestamps").drop(columns=["pupil_center_x", "pupil_center_y"])
    eye_metrics = align_to_stimulus(eye_metrics, session)
    
    rewards = session.rewards
    rewards["volume"] = 1
    rewards = rewards.loc[~rewards["auto_rewarded"]].drop(columns="auto_rewarded")
    rewards_metric = align_to_stimulus(rewards, session)
    rewards_metric["volume"] = rewards_metric["volume"].cumsum()
    
    running_metrics = session.running_speed
    running_metrics = align_to_stimulus(running_metrics, session)

    lick_metrics = session.licks
    lick_metrics = align_to_stimulus(lick_metrics, session)
    lick_metrics = lick_metrics.groupby("stimulus_id").count().rename({"frame": "lick_count"}, axis=1)["lick_count"]

    
    metrics = eye_metrics.merge(running_metrics, on="stimulus_id")#.merge(rewards_metric, on="stimulus_id", how="left").rename({"volume": "running_reward"}, axis=1)
    #csum = metrics["running_reward"].notnull().cumsum()
    #metrics["running_reward"] = metrics["running_reward"].fillna(0).groupby(csum).transform('sum')
    metrics = metrics.loc[:, ~metrics.columns.str.startswith("timestamps")]
    metrics = metrics.groupby("stimulus_id").mean()
    # metrics = metrics.merge(lick_metrics, on="stimulus_id")
    
    rolling_perf = session.get_rolling_performance_df()[["rolling_dprime"]]
    stimulus_presentations = session.stimulus_presentations
    metrics = metrics.merge(stimulus_presentations["trials_id"], left_on="stimulus_id", right_index=True).merge(rolling_perf, left_on="trials_id", right_index=True)
    
    metrics = metrics.loc[metrics.isna().sum(axis=1) == 0]
    metrics = metrics.drop(columns="trials_id")
    
    if center:
        metrics[["pupil_area", "saccade_length", "speed"]] -= metrics[["pupil_area", "saccade_length", "speed"]].mean(axis=0)

    return metrics

In [None]:
metrics = [get_behavior_metrics(session) for session in sessions]

In [None]:
metric_idx = 5

In [None]:
sns.pairplot(metrics[metric_idx], corner=True);

In [None]:
model = hmm.GaussianHMM(n_components=2, n_iter=100000, random_state=SEED, init_params="mc", covariance_type="tied")
model.startprob_ = np.array([0.9, 0.1])
model.transmat_ = np.array([[0.99, 0.01],
                            [0.01, 0.99]])
# model.covars_prior = (eye := np.ones(metrics[0].shape[1]))/2

In [None]:
X = pd.concat(metrics)
X_lens = [len(x) for x in metrics]
model.fit(X, X_lens)
m_result = metrics[metric_idx].copy()
m_result["state"] = model.predict(metrics[metric_idx])

In [None]:
plt.plot(np.arange(m_result.shape[0]),m_result["state"]);

In [None]:
sns.pairplot(m_result, corner=True, hue="state");

In [None]:
sns.scatterplot(
    m_result,
    x=np.arange(m_result.shape[0]),
    y="rolling_dprime",
    hue=m_result["state"]);

In [None]:
pca = PCA(n_components=2)
embeddings = pca.fit_transform(metrics[metric_idx])
plt.scatter(embeddings[:, 0], embeddings[:, 1], c=m_result["state"], alpha=0.2);