# Boilerplate

In [1]:
from typing import Union
from allensdk.brain_observatory.ecephys.behavior_ecephys_session import BehaviorEcephysSession


from allensdk.brain_observatory.behavior.behavior_project_cache import (
    VisualBehaviorNeuropixelsProjectCache,
)
import brain_observatory_utilities.datasets.behavior.data_formatting as behavior_utils

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns

from hmmlearn import hmm, vhmm
from sklearn.decomposition import PCA, KernelPCA, FactorAnalysis
from sklearn.preprocessing import StandardScaler

import utils

import warnings
warnings.filterwarnings("ignore")

# Random seed
SEED = 42

%matplotlib inline

In [2]:
cache_dir = utils.get_data_root()
cache = VisualBehaviorNeuropixelsProjectCache.from_local_cache(
    cache_dir=cache_dir, use_static_cache=True
)

In [3]:
ephys_table = cache.get_ecephys_session_table()
ephys_table = ephys_table.loc[(~ephys_table["behavior_session_id"].isna()) & 
                              (ephys_table["experience_level"] == "Familiar")
                             ]

In [4]:
# session_id = 1124507277
# session_id = 1069461581
# session = cache.get_ecephys_session(session_id)
session_ids = np.random.choice(ephys_table.index, 20)

# Defining Behavioral Metrics

In [5]:
def align_to_stimulus(df: pd.DataFrame, session: BehaviorEcephysSession, active: bool = True) -> pd.DataFrame:
    if "timestamps" not in df.columns:
        raise ValueError("column timestamps must be present in df.")
    
    stim_presentations = session.stimulus_presentations
    if active:
        stim_presentations = stim_presentations.loc[stim_presentations["active"]]
    
    df = df.loc[(stim_presentations["start_time"].min() <= df["timestamps"]) 
                & (df["timestamps"] <= stim_presentations["end_time"].max())]
    bins = pd.concat([pd.Series([0]), stim_presentations["end_time"]])
    labels = stim_presentations.index
    stimulus_id_aligned = pd.cut(df["timestamps"], bins=bins, labels=labels, include_lowest=True, right=False)
    df = pd.concat([pd.Series(stimulus_id_aligned, name="stimulus_id"), df], axis=1)
    return df

In [6]:
def get_behavior_metrics(
    session: BehaviorEcephysSession, 
    center: bool = True,
) -> pd.DataFrame:
    eye = session.eye_tracking
    eye = eye.loc[(eye["likely_blink"] != True)]
    eye_metrics = eye[["timestamps", "pupil_area"]]
    eye_metrics = align_to_stimulus(eye_metrics, session)
    
    running_metrics = session.running_speed
    running_metrics = align_to_stimulus(running_metrics, session)


    rewards = (rewards := session.rewards).loc[~rewards["auto_rewarded"]]
    rewards_metric = align_to_stimulus(rewards, session)[["stimulus_id", "timestamps", "volume"]]
    rewards_metric["volume"] = rewards_metric["volume"].cumsum()
    
    metrics = (
        eye_metrics.
        merge(running_metrics, on="stimulus_id").
        groupby("stimulus_id").
        aggregate({"pupil_area": "mean", "speed": "mean"})
    )

    rolling_perf = session.get_rolling_performance_df()[["hit_rate"]]
    stimulus_presentations = session.stimulus_presentations
    metrics = (
        metrics.merge(stimulus_presentations["trials_id"], left_on="stimulus_id", right_index=True)
        .merge(rolling_perf, left_on="trials_id", right_index=True)
        .merge(rewards_metric[["stimulus_id", "volume"]], on="stimulus_id", how="left")
        .drop(columns=["trials_id", "stimulus_id"])
    )
    # Assign stimulus presentations that weren't rewarded to make volume a step function
    csum = metrics["volume"].notnull().cumsum()
    metrics["volume"] = metrics["volume"].fillna(0).groupby(csum).transform('sum')
    metrics = metrics.loc[(metrics.isna().sum(axis=1) == 0)]
    if center:
        metrics[["pupil_area", "speed"]] -= metrics[["pupil_area", "speed"]].mean(axis=0)
        metrics = metrics.loc[(metrics["pupil_area"] <= 3500)]
    
    metrics = metrics.loc[:, ~metrics.columns.str.startswith("timestamps")]
    return metrics

In [None]:
%%time
parallel = False
if parallel:
    metrics = utils.parallel_session_map(get_behavior_metrics, session_ids, "ephys")
    metrics = [session[1] for session in metrics["sessions"]]
else:
    sessions = [cache.get_ecephys_session(session_id) for session_id in session_ids]
    metrics = [get_behavior_metrics(session, center=True) for session in sessions]

# HMM Model Fitting and Plotting

In [None]:
X = pd.concat(metrics)
X_lens = [len(x) for x in metrics]

In [None]:
def full_session_time(metric_df: pd.DataFrame) -> pd.DataFrame:
    t_steps = np.arange(metric_df.index.max() + 1)
    full_session = pd.DataFrame(index=t_steps, columns=metric_df.columns)
    full_session.loc[metric_df.index] = metric_df
    return full_session

In [None]:
startprob = np.array([.5, 0.5])
transmat = np.array([[0.99, 0.01],
                     [0.01, 0.99]
                    ])
model_2 = hmm.GaussianHMM(
    n_components=2, 
    n_iter=100000, 
    random_state=SEED, 
    init_params="m", 
    covariance_type="full", 
    startprob_prior=startprob,
    transmat_prior=transmat,
    covars_prior=X.cov().values,
    algorithm="map",
    implementation="scaling"
)

In [None]:
model_2.fit(X, X_lens);

In [None]:
metric_idx = 1

In [None]:
sns.pairplot(metrics[metric_idx], corner=True);

In [None]:
m_result = metrics[metric_idx].copy()
m_result["state_2"] = model_2.predict(metrics[metric_idx])
m_result = full_session_time(m_result)

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(10,6))
subfigs = fig.subfigures(nrows=1, ncols=1)
state = 2
subfigs.suptitle(f"State transition ({state} states)")
ax = subfigs.subplots()
ax.plot(np.arange(m_result.shape[0]),m_result[f"state_{state}"])
ax.set_yticks(np.arange(state), labels=np.arange(state).astype(str));

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(10,12))
subfig = fig.subfigures(nrows=1, ncols=1)
state = 2
subfig.suptitle(f"{state} State variable comparisons")
axes = subfig.subplots(4, 1, sharex=True)

sns.scatterplot(
    m_result,
    x=np.arange(m_result.shape[0]),
    y="pupil_area",
    hue=m_result[f"state_{state}"], ax=axes[0])
sns.scatterplot(
    m_result,
    x=np.arange(m_result.shape[0]),
    y="hit_rate",
    hue=m_result[f"state_{state}"], ax=axes[1])
sns.scatterplot(
    m_result,
    x=np.arange(m_result.shape[0]),
    y="speed",
    hue=m_result[f"state_{state}"], ax=axes[2])
sns.scatterplot(
    m_result,
    x=np.arange(m_result.shape[0]),
    y="volume",
    hue=m_result[f"state_{state}"], ax=axes[3]);

In [None]:
sns.pairplot(m_result, corner=True, hue="state_2");

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(10,6))
subfig = fig.subfigures(nrows=1, ncols=1)
state = 2
subfig.suptitle(f"Volume vs Hit Rate ({state} states)")
ax = subfig.subplots()
sns.scatterplot(m_result, x="volume", y="hit_rate", hue=m_result[f"state_{state}"], ax=ax);

In [None]:
pca = PCA(n_components=3)
pca_m = m_result.loc[m_result.isna().sum(axis=1) == 0]
embeddings = pca.fit_transform(pca_m.drop(columns=["state_2", "volume"]))

fig = plt.figure(constrained_layout=True, figsize=(10,6))
subfig = fig.subfigures(nrows=1, ncols=1)
axes = subfig.subplots(2, 1)

axes[0].scatter(embeddings[:, 0], embeddings[:, 2], alpha=0.2, c=pca_m.index, cmap="viridis")
axes[0].set_title("PCA projection colored by time")

axes[1].scatter(embeddings[:, 0], embeddings[:, 2], c=pca_m["state_2"], alpha=0.2)
axes[1].set_title("PCA projection colored by state");

axes[0].set_xlim(axes[1].get_xlim())
axes[0].set_ylim(axes[1].get_ylim());


# Getting Ephys Data

In [None]:
unit_table = cache.get_unit_table()

In [None]:
session = cache.get_ecephys_session(session_ids[metric_idx])

In [None]:
def get_spike_rates(session: BehaviorEcephysSession) -> tuple[np.array, pd.DataFrame]:
    """Get spike rates over stimulus presentations for a session
    
    Parameters
    ----------
    session
        The ecephys session to get spike rates for.
    
    Returns
    -------
    rates
        (n_units x n_stimuli) array of spiking rates.
    rates_df
        The ``rates`` array as a dataframe with an extra column
        indicating the region the unit is in.
    """
    
    spikes = session.spike_times
    units = session.get_units().join(unit_table["structure_acronym"])
    stimuli = session.stimulus_presentations
    stimuli = stimuli.loc[(stimuli["active"])]
    units = units[(units.isi_violations < 0.5) 
                    & (units.amplitude_cutoff < 0.1) 
                     & (units.presence_ratio > 0.9)
                ]
    rates = np.zeros((units.shape[0], stimuli.shape[0]))
    for i, unit_id in enumerate(units.index):
        unit_data = pd.DataFrame({"timestamps": spikes[unit_id], "spikes": np.ones(len(spikes[unit_id]))})
        counts = align_to_stimulus(unit_data, session).groupby("stimulus_id").sum()["spikes"]
        lengths = stimuli.end_time - stimuli.start_time
        rate = counts/lengths
        rates[i] = rate

    columns = [f"t_{timestep}" for timestep in stimuli.index]
    rates_df = pd.DataFrame(rates, index=units.index, columns=columns)
    rates_df = rates_df.join(units["structure_acronym"])
    
    return rates, rates_df

In [None]:
def plot_areas(
    session_rates: pd.DataFrame, 
    behavior_metric: np.array, 
    areas: Union[str, list[str]] = None, 
):
    """
    Plot the mean activity of a brain region along with individual unit activity along with
    behavioral state.
    
    If no areas are provided, all areas will be plotted.
    
    Parameters
    ----------
    session_rates
        The firing rates to plot.
    behavior_metric
        The behavior metric vector to plot.
    areas
        Single or list of brain regions to plot.
        Default behavior is to plot all regions
    
    Returns
    -------
    fig, ax
        Matplotlib plot
    """
    if areas is None:
        areas = session_rates["structure_acronym"].unique()
    elif isinstance(areas, str):
        areas = [areas]
    
    
    fig, axes = plt.subplots(len(areas), 2, figsize=(10, 4 * len(areas)));   
    axes = np.expand_dims(axes, 0) if len(areas) == 1 else axes
    
    session_rates_t = session_rates.iloc[:, :-1].T
    
    behavior_metric /= behavior_metric.max()
    behavior_state = np.empty(session_rates_t.shape[0])
    behavior_state[:] = np.nan
    behavior_state[behavior_metric.index] = behavior_metric
    
    
    for i, area in enumerate(areas):
        area_activity_idx = session_rates.loc[session_rates["structure_acronym"] == area].index
        (area_rates := session_rates_t.loc[:, area_activity_idx]).mean(axis=1).plot(ax=axes[i,0]);
        
    
        axes[i, 0].set_title(f"(mean) activity for {area}");
        axes[i, 0].plot(behavior_state * area_rates.to_numpy().mean(axis=(0,1)) , linewidth=3);
        axes[i, 0].set_ylabel("Firing rate (over stimulus presentation")
        
        area_rates.plot(ax=axes[i,1], alpha=0.4);
        axes[i, 1].plot(behavior_state * area_rates.to_numpy().max(axis=(0,1))/5, color="orange", linewidth=3);
        axes[i, 1].set_title(f"activity for {area}");
        axes[i, 1].legend().remove();
        
    fig.tight_layout();
    return fig, axes

# Plotting data

In [None]:
rates, rates_df = get_spike_rates(session)

In [None]:
plot_areas(rates_df, m_result["state_2"]);

In [None]:
def plot_area_units(
    session_rates: pd.DataFrame, 
    behavior_metric: np.array, 
    area: str, 
) -> None:
    """
    Plot the individual unit activity for a single brain region.
        
    Parameters
    ----------
    session_rates
        The firing rates to plot.
    behavior_state
        The behavior metric vector to plot.
    area
        Region to plot units for.
    
    Returns
    -------
    fig, ax
        Matplotlib plot
    """
    
    area_rates = session_rates.loc[session_rates["structure_acronym"] == area]
    
    area_rates_t = area_rates.iloc[:, :-1].T
    
    behavior_metric /= behavior_metric.max()
    behavior_state = np.empty(area_rates_t.shape[0])
    behavior_state[:] = np.nan
    behavior_state[behavior_metric.index] = behavior_metric
    
    fig, axes = plt.subplots((n_units := area_rates_t.shape[1]), 1, figsize=(8, 6 * n_units), sharex=True);
    axes = np.reshape(axes, -1) if n_units == 1 else axes
    for i, ax in enumerate(axes):
        area_rates_t.iloc[:, i].plot(ax=ax)
        ax.plot(behavior_state * area_rates_t.iloc[:, i].to_numpy().max(axis=(0))/5, color="orange", linewidth=3);
        ax.set_title(f"activity for {area_rates_t.columns[i]}");
        ax.set_ylabel("Firing rate (over stimulus presentation")

    return fig, axes

In [73]:
plot_area_units(rates_df, m_result["state_2"], "TH");

ValueError: Number of rows must be a positive integer, not 0

SystemError: tile cannot extend outside image

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous view', 'arrow-left', 'back'), ('Forward', 'Forward to next view', 'arrow-right', 'forward'), ('Pan', 'Left button pans, Right button zooms\nx/y fixes axis, CTRL fixes aspect', 'arrows', 'pan'), ('Zoom', 'Zoom to rectangle\nx/y fixes axis', 'square-o', 'zoom'), ('Download', 'Download plot', 'floppy-o', 'save_figure')]))