# 1. Deeplabcut analysis (Fig. 5 E,F,H)

In [None]:
from gaitAnalysis import *

In [None]:
root_dir = r'data\openfield_Cntnap'
dataFolder = os.path.join(root_dir,'Data')

fps = 40

"""analysis
    1. moving distance
    2. running speed distribution/average
    3. angular speed distribution/average
    4. time spend in the middle (future). A gui defining boudaries of the field based on user input
     """

savemotionpath = os.path.join(root_dir, 'Summary', 'DLC')
groups = ['Ctrl', 'Exp']
behavior = 'openfield'
DLCSum = DLCSummary(root_dir, fps, groups,behavior)

    # basic motor-related analysis
    #
DLCSum.center_analysis(savemotionpath)

DLCSum.motion_analysis(savemotionpath)

# Keypoint_moseq analysis (Fig, 5 G,I,O)

## Project setup
Create a new project directory with a keypoint-MoSeq `config.yml` file.

In [None]:
import keypoint_moseq as kpms

project_dir = r'project/folder'
config = lambda: kpms.load_config(project_dir)

### Options 3: Manual setup


In [None]:

# open field
bodyparts=['nose','left ear', 'right ear','head','left hand', 'right hand','spine 1', 'spine 2',
    'spine 3','left foot','right foot',
    'tail 1', 'tail 2', 'tail 3']


# open field
skeleton=[
    ['tail 3', 'tail 2'],
    ['tail 2', 'tail 1'],
    ['tail 1', 'spine 3'],
    ['spine 3', 'left foot'],
    ['spine 3', 'right foot'],
    ['spine 3', 'spine 2'],
    ['spine 2', 'spine 1'],
    ['spine 1', 'left hand'],
    ['spine 1', 'right hand'],
    ['spine 1', 'head'],
    ['nose', 'head'],
    ['left ear', 'head'],
    ['right ear', 'head'],]


video_dir = r'/project/folder/videos/'
kpms.setup_project(
    project_dir,
    video_dir=video_dir,
    bodyparts=bodyparts,
    skeleton=skeleton, overwrite=True)

## Edit the config file

The config can be edited in a text editor or using the function `kpms.update_config`, as shown below. In general, the following parameters should be specified for each project:

- `bodyparts` (name of each keypoint; automatically imported from SLEAP/DeepLabCut)
- `use_bodyparts` (subset of bodyparts to use for modeling, set to all bodyparts by default; for mice we recommend excluding the tail)
- `anterior_bodyparts` and `posterior_bodyparts` (used for rotational alignment)
- `video_dir` (directory with videos of each experiment)

Edit the config as follows for the [example DeepLabCut dataset](https://drive.google.com/drive/folders/1UNHQ_XCQEKLPPSjGspRopWBj6-YNDV6G?usp=share_link):

In [None]:

kpms.update_config(
    project_dir, 
    video_dir = video_dir,
    anterior_bodyparts=['nose'],
    posterior_bodyparts=['tail 1'],
    use_bodyparts=['nose','left ear', 'right ear','head','left hand', 'right hand','spine 1', 'spine 2',
    'spine 3','left foot','right foot', 'tail 1'])

## Load data

Data can be loaded from DeepLabCut, SLEAP or from any another source as long as it has the following format, where K is the number of keypoints and D is 2 or 3
- `coordinates`: dict from session names to keypoint coordinate arrays of shape (T,K,D). Each key should start with its video name (e.g. `coordinates["experiment1_etc"]` would correspond to `experiment1.avi`. In general this will already be true if importing from SLEAP or DeepLabCut).
    
- `confidences`: dict from session names to **nonnegative** keypoint confidence arrays of shape (T,K). Confidences are optional (they are used to set the error prior for each observation).

In [None]:
keypoint_data_path= video_dir

# load data (e.g. from DeepLabCut)

coordinates, confidences, bodyparts = kpms.load_keypoints(keypoint_data_path, 'deeplabcut')

# format data for modeling
data, metadata = kpms.format_data(coordinates, confidences, **config())

In [None]:
#coordinates.keys()
#confidences
#bodyparts
#data.keys()
labels

## Calibration

The purpose of calibration is to learn the relationship between error and keypoint confidence scores. The resulting regression coefficients (`slope` and `intercept`) are used during modeling to set the noise prior on a per-frame, per-keypoint basis. One can also adjust the `confidence_threshold` parameter at this step, which is used to define outlier keypoints for PCA and model initialization. **This step can be skipped for the demo data** since the config already includes suitable regression coefficients.

- Run the cell below. A widget should appear with a video frame on the left.
    - *If the widget doesn't render, try using jupyter lab instead of jupyter notebook*
    
- Annotate each frame with the correct location of the labeled bodypart
    - Left click to specify the correct location - an "X" should appear.
    - Use the arrow buttons to annotate additional frames.
    - Each annotation adds a point to the right-hand scatter plot. 
    - Continue until the regression line stabilizes.
   
- At any point, adjust the confidence threshold by clicking on the scatter plot.

- Use the "save" button to update the config and store your annotations to disk.

In [None]:
kpms.noise_calibration(project_dir, coordinates, confidences, **config())

## Fit PCA

Run the cell below to fit a PCA model to aligned and centered keypoint coordinates. The model is saved to ``{project_dir}/pca.p`` and can be reloaded using ``kpms.load_pca``. Two plots are generated: a cumulative [scree plot](https://en.wikipedia.org/wiki/Scree_plot) and a depiction of each PC, where translucent nodes/edges represent the mean pose and opaque nodes/edges represent a perturbation in the direction of the PC. 

- After fitting, edit `latent_dimension` in the config. 
- A good heuristic is the number of dimensions needed to explain 90% of variance, or 10 dimensions - whichever is lower.  

In [None]:
pca = kpms.fit_pca(**data, **config(),parallel_message_passing=False)
kpms.save_pca(pca, project_dir)

kpms.print_dims_to_explain_variance(pca, 0.9)
kpms.plot_scree(pca, project_dir=project_dir)
kpms.plot_pcs(pca, project_dir=project_dir, **config())

# use the following to load an already 
# pca = kpms.load_pca(project_dir)

In [None]:
kpms.update_config(project_dir, latent_dim=10)

## Model fitting

Fitting a keypoint-MoSeq model involves:
1. **Initialization:** Auto-regressive (AR) parameters and syllable sequences are randomly initialized using pose trajectories from PCA.
2. **Fitting an AR-HMM:** The AR parameters, transition probabilities and syllable sequences are iteratively updated through Gibbs sampling. 
3. **Fitting the full model:** All parameters, including both the AR-HMM as well as centroid, heading, noise-estimates and continuous latent states (i.e. PCA trajectories) are iteratively updated through Gibbs sampling. This step is especially useful for noisy data.
4. **Apply the trained model:** The learned model parameters are used to infer a syllable sequence for each experiment. This step should always be applied at the end of model fitting, and it can also be used later on to infer syllable sequences for newly added data.

## Setting hyperparameters

There are two ways to change hyperparameters:
1. Update the config using `kpms.update_config` and then re-initialize the model
2. Change the model directly via `kpms.update_hypparams`

In general, the main hyperparam that needs to be adjusted is **kappa**, which sets the time-scale of syllables. Higher kappa leads to longer syllables. For this tutorial we chose kappa values that yielded a median syllable duration of 400ms (12 frames). In general, you will need to tune kappa for each new dataset based on the intended syllable time-scale. **You will need to pick two kappas: one for AR-HMM fitting and one for the full model.**
- We recommend iteratively updating kappa and refitting the model until the target syllable time-scale is attained.  
- Model fitting can be stopped at any time by interrupting the kernel, and kappa can be adjusted as described above.
- The full model will generally require a lower value of kappa to yield the same target syllable durations.


## Initialization

In [None]:
# optionally update kappa in the config before initializing 
# kpms.update_config(project_dir=project_dir, kappa=NUMBER)

# initialize the model
from jax_moseq.utils import set_mixed_map_iters
set_mixed_map_iters(6)

kpms.update_config(project_dir,kappa = 1e14)
model = kpms.init_model(data, pca=pca, **config())

## Fitting an AR-HMM

In addition to fitting an AR-HMM, the function below:
- generates a name for the model and a new directory in `project_dir`
- saves a checkpoint every 10 iterations from which fitting can be restarted
    - a single checkpoint file contains the full history of fitting, and can be used to restart fitting from any iteration
- plots the progress of fitting every 10 iterations, including
    - the distributions of syllable frequencies and durations for the most recent iteration
    - the change in median syllable duration across fitting iterations
    - the syllable sequence across iterations in a random window

In [None]:
num_ar_iters = 50

model, model_name = kpms.fit_model(
    model, data, metadata, project_dir,
    ar_only=True, num_iters=num_ar_iters)

## Fitting the full model

The following code fits a full keypoint-MoSeq model, using the results of AR-HMM fitting for initialization
- If using your own data, you may need to try a few values of kappa at this step. 
- Use `kpms.revert` to resume from the same starting point each time you restart fitting

In [None]:
# load model checkpoint
num_ar_iters = 50
from jax_moseq.utils import set_mixed_map_iters
set_mixed_map_iters(6)
#model_name = '2024_05_12-07_36_17'
model, data, metadata, current_iter = kpms.load_checkpoint(
    project_dir, model_name, iteration=num_ar_iters)

# modify kappa to maintain the desired syllable time-scale
model = kpms.update_hypparams(model, kappa=1e9)

# run fitting for an additional 200 iters
model = kpms.fit_model(
    model, data, metadata, project_dir, model_name, ar_only=False, 
    start_iter=current_iter, num_iters=current_iter+300,parallel_message_passing=False)[0]
     

## Apply the trained model

The code below infers a syllable sequence for each experiment. The results are saved in `{project_dir}/{name}/results.h5`. 
- The default assumption is that `coordinates` and `confidences` are the same data that were used for model fitting.

To infer syllable sequences for new data:
- Run ``kpms.apply_model`` with `use_saved_states=False` and pass in a pca model
- Results for the new experiments will be added to the existing `results.h5` file.

In [None]:
# load the most recent model checkpoint and pca object

model_name = '2024_05_12-07_36_17'
project_dir = r'project/folder'
#config = lambda: kpms.load_config(project_dir)

In [None]:
import jax
jax.config.read('jax_enable_x64')

## Visualize syllables

## Trajectory plots
Generate plots showing the average trajectory of poses associated with each given syllable. 

In [None]:
kpms.reindex_syllables_in_checkpoint(project_dir, model_name)


In [None]:
model_name = '2024_05_12-07_36_17'

In [None]:
# load the most recent model checkpoint
model, data, metadata, current_iter = kpms.load_checkpoint(project_dir, model_name)

# extract results
results = kpms.extract_results(model, metadata, project_dir, model_name)

In [None]:
kpms.save_results_as_csv(results, project_dir, model_name)
#keypoint_moseq.viz.plot_duration_distribution
#kpms.plot_syllable_frequencies?

In [None]:
# load results


In [None]:
import numpy as np
temp_syll_include = np.array([x for x in range(24)])

## Crowd & grid movies
Generate video clips showing examples of each syllable.

In [None]:
kpms.generate_trajectory_plots?
#results

In [None]:
kpms.generate_trajectory_plots(coordinates, results, project_dir, model_name,fps=40,**config())


In [None]:
kpms.generate_grid_movies(results, project_dir, model_name, pre=30, post=30,coordinates=coordinates,fps=40,overlay_keypoints=True,**config())
# rows=3, cols=5, 

In [None]:
# customized plot_similarity_dendrogram to save svg files
from keypoint_moseq.io import load_results, _get_path
from scipy.cluster.hierarchy import linkage, dendrogram, leaves_list
from scipy.spatial.distance import squareform, pdist
import matplotlib.pyplot as plt
import csv

def plot_similarity_dendrogram_customize(
    coordinates,
    results,
    project_dir=None,
    model_name=None,
    save_path=None,
    metric="cosine",
    pre=5,
    post=15,
    min_frequency=0.005,
    min_duration=3,
    bodyparts=None,
    use_bodyparts=None,
    density_sample=False,
    sampling_options={"n_neighbors": 50},
    figsize=(6, 3),
    **kwargs,
):
    """Plot a dendrogram showing the similarity between syllable trajectories.

    The dendrogram is saved to `{save_path}` if it is provided, or
    else to `{project_dir}/{model_name}/similarity_dendrogram.pdf`. Plot-
    related parameters are described below. For the remaining parameters
    see (:py:func:`keypoint_moseq.util.get_typical_trajectories`)

    Parameters
    ----------
    coordinates: dict
        Dictionary mapping recording names to keypoint coordinates as
        ndarrays of shape (n_frames, n_bodyparts, [2 or 3]).

    results: dict
        Dictionary containing modeling results for a dataset (see
        :py:func:`keypoint_moseq.fitting.extract_results`).

    project_dir: str, default=None
        Project directory. Required to save figure if `save_path` is None.

    model_name: str, default=None
        Model name. Required to save figure if `save_path` is None.

    save_path: str, default=None
        Path to save the dendrogram plot (do not include an extension).
        If None, the plot will be saved  to
        `{project_dir}/{name}/similarity_dendrogram.[pdf/png]`.

    metric: str, default='cosine'
        Distance metric to use. See :py:func:`scipy.spatial.pdist` for options.

    figsize: tuple of float, default=(10,5)
        Size of the dendrogram plot.
    """
    save_path = _get_path(project_dir, model_name, save_path, "similarity_dendrogram")

    distances, syllable_ixs = kpms.util.syllable_similarity(
        coordinates,
        results,
        metric,
        pre,
        post,
        min_frequency,
        min_duration,
        bodyparts,
        use_bodyparts,
        density_sample,
        sampling_options,
    )

    Z = linkage(squareform(distances), "complete")

    file_path = save_path + ".csv"
    with open(file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerows(Z)

    print("Matrix saved to", file_path)
    
    fig, ax = plt.subplots(1, 1)
    labels = [f"Syllable {s}" for s in syllable_ixs]
    dendrogram(Z, labels=labels, leaf_font_size=10, ax=ax, leaf_rotation=90)

    ax.set_yticks([])
    for spine in ax.spines.values():
        spine.set_color("lightgray")
    ax.set_title("Syllable similarity")
    fig.set_size_inches(figsize)
    plt.tight_layout()
    print(f"Saving dendrogram plot to {save_path}")
    for ext in ["pdf", "png", "svg"]:
        plt.savefig(save_path + "." + ext)

In [None]:
plot_similarity_dendrogram_customize(coordinates, results, project_dir, model_name, figsize=(6, 4),**config())


In [None]:
# syllables reordered in similarity

# edit the syllable sequence according to the similarity
syll_sim = [21, 20, 10,23,2,24,6, 19, 3, 15, 8, 27, 5, 13, 7, 17, 26, 4, 11, 25, 18, 22, 16, 12,0,1,9,14] 


In [None]:
# assign groups
index_file=kpms.interactive_group_setting(project_dir, model_name)

In [None]:
kpms.label_syllables(project_dir, model_name, moseq_df) 

In [None]:
index_file

In [None]:
moseq_df = kpms.compute_moseq_df(project_dir, model_name, smooth_heading=True) 
moseq_df

In [None]:
import os
save_dir = os.path.join(project_dir, model_name) # directory to save the moseq_df dataframe
moseq_df.to_csv(os.path.join(save_dir, 'moseq_df.csv'), index=False)
print('Saved `moseq_df` dataframe to', save_dir)

In [None]:
stats_df = kpms.compute_stats_df(
    project_dir,
    model_name,
    moseq_df, 
    min_frequency=0.005,       # threshold frequency for including a syllable in the dataframe
    groupby=['group', 'name'], # column(s) to group the dataframe by
    fps=40)                    # frame rate of the video from which keypoints were inferred

stats_df

In [None]:
import os
save_dir = os.path.join(project_dir, model_name)
stats_df.to_csv(os.path.join(save_dir, 'stats_df.csv'), index=False)
print('Saved `stats_df` dataframe to', save_dir)

In [None]:
from matplotlib.lines import Line2D
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
def plot_syll_stats_with_sem(
    stats_df,
    project_dir,
    model_name,
    save_dir=None,
    plot_sig=True,
    thresh=0.05,
    stat="frequency",
    order="stat",
    groups=None,
    ctrl_group=None,
    exp_group=None,
    colors=None,
    join=False,
    figsize=(8, 4),
):
    """Plot syllable statistics with standard error of the mean.

    Parameters
    ----------
    stats_df : pandas.DataFrame
        the dataframe that contains kinematic data and the syllable label
    project_dir : str
        the project directory
    model_name : str
        the model directory name
    save_dir : str
        the path to save the analysis plots
    plot_sig : bool, optional
        whether to plot the significant syllables, by default True
    thresh : float, optional
        the threshold for significance, by default 0.05
    stat : str, optional
        the statistic to plot, by default 'frequency'
    ordering : str, optional
        the ordering of the syllables, by default 'stat'
    groups : list, optional
        the list of groups to plot, by default None
    ctrl_group : str, optional
        the control group, by default None
    exp_group : str, optional
        the experimental group, by default None
    colors : list, optional
        the list of colors to use for each group, by default None
    join : bool, optional
        whether to join the points with a line, by default False
    figsize : tuple, optional
        the figure size, by default (8, 4)

    Returns
    -------
    fig : matplotlib.figure.Figure
        the figure object
    legend : matplotlib.legend.Legend
        the legend object
    """

    # get significant syllables
    sig_sylls = None

    if plot_sig and len(stats_df["group"].unique()) > 1:
        # run kruskal wallis and dunn's test
        df_k_real, dunn_results_df, sig_pairs = kpms.run_kruskal(stats_df, statistic=stat, thresh=thresh)
        # plot significant syllables for control and experimental group
        if ctrl_group is not None and exp_group is not None:
            # check if the group pair is in the sig pairs dict
            if (ctrl_group, exp_group) in sig_pairs.keys():
                sig_sylls = sig_pairs.get((ctrl_group, exp_group))
            # flip the order of the groups
            else:
                sig_sylls = sig_pairs.get((exp_group, ctrl_group))
        else:
            print(
                "No control or experimental group specified. Not plotting significant syllables."
            )
    #print(df_k_real.keys())
    #print(df_k_real["p_adj"])
    
    xlabel = f"Syllables sorted by {stat}"
    if order == "diff":
        xlabel += " difference"
    ordering, groups, colors, figsize = _validate_and_order_syll_stats_params(
        stats_df,
        stat=stat,
        order=order,
        groups=groups,
        ctrl_group=ctrl_group,
        exp_group=exp_group,
        colors=colors,
        figsize=figsize,
    )

    fig, ax = plt.subplots(1, 1, figsize=figsize)

    print(stat)
    # plot each group's stat data separately, computes groupwise SEM, and orders data based on the stat/ordering parameters
    hue = "group" if groups is not None else None
    ax = sns.pointplot(
        data=stats_df,
        x="syllable",
        y=stat,
        hue=hue,
        order=ordering,
        errorbar=("se"),
        ax=ax,
        hue_order=groups,
        palette=colors,
    )

    # extract data from plot
    lines = ax.get_lines()

    syllables = lines[0].get_xdata()
    KO_freq = lines[0].get_ydata()
    KO_high = []
    KO_low = []
    for k in range(len(syllables)):
        err=lines[k+1].get_ydata()
        KO_low.append(err[0])
        KO_high.append(err[1])

    WT_freq = lines[len(syllables)+1].get_ydata()
    
    WT_high = []
    WT_low = []
    for k in range(len(syllables)):
        err=lines[len(syllables)+k+2].get_ydata()
        WT_low.append(err[0])
        WT_high.append(err[1])
    if_sig = []
    for k in range(len(syllables)):
        if k in sig_sylls:
            if_sig.append(True)
        else:
            if_sig.append(False)
    #save to data frame
    
    df = {
        'WT_frequency':WT_freq,
        'WT_error_high':WT_high,
        'WT_error_low':WT_low,
        'KO_frequency':KO_freq,
        'KO_error_high':KO_high,
        'KO_error_low':KO_low,
        'significancy':if_sig,
        'p_value':df_k_real["p_adj"]
    }
    df=pd.DataFrame(df)
    save_name = f"{stat}_{order}_stats.csv"
    save_path = os.path.join(project_dir, model_name,save_name)
    df.to_csv(save_path)
    # where some data has already been plotted to ax
    handles, labels = ax.get_legend_handles_labels()

    # add syllable labels if they exist
    syll_names = kpms.get_syllable_names(project_dir, model_name, ordering)
    plt.xticks(range(len(syll_names)), syll_names, rotation=90)

    # if a list of significant syllables is given, mark the syllables above the x-axis
    if sig_sylls is not None:
        markings = []
        for s in sig_sylls:
            markings.append(ordering.index(s))
        plt.scatter(markings, [-0.005] * len(markings), color="r", marker="*")

        # manually define a new patch
        patch = Line2D(
            [],
            [],
            color="red",
            marker="*",
            markersize=9,
            label="Significant Syllable",
        )
        handles.append(patch)

    # add legend and axis labels
    legend = ax.legend(handles=handles, frameon=False, bbox_to_anchor=(1, 1))
    plt.xlabel(xlabel, fontsize=12)
    sns.despine()

    # save the figure
    plot_name = f"{stat}_{order}_stats"
    kpms.save_analysis_figure(fig, plot_name, project_dir, model_name, save_dir)
    return fig, legend

In [None]:
import numpy as np
def _validate_and_order_syll_stats_params(
    complete_df,
    stat="frequency",
    order="stat",
    groups=None,
    ctrl_group=None,
    exp_group=None,
    colors=None,
    figsize=(10, 5),
):
    if not isinstance(figsize, (tuple, list)):
        print(
            "Invalid figsize. Input a integer-tuple or list of len(figsize) = 2. Setting figsize to (10, 5)"
        )
        figsize = (10, 5)

    unique_groups = complete_df["group"].unique()

    if groups is None or len(groups) == 0:
        groups = unique_groups
    elif isinstance(groups, str):
        groups = [groups]

    if isinstance(groups, (list, tuple, np.ndarray)):
        diff = set(groups) - set(unique_groups)
        if len(diff) > 0:
            groups = unique_groups

    if stat.lower() not in complete_df.columns:
        raise ValueError(
            f"Invalid stat entered: {stat}. Must be a column in the supplied dataframe."
        )

    if order == "stat":
        ordering, _ = kpms.sort_syllables_by_stat(complete_df, stat=stat)
    elif order == "diff":
        if (
            ctrl_group is None
            or exp_group is None
            or not np.all(np.isin([ctrl_group, exp_group], groups))
        ):
            raise ValueError(
                f"Attempting to sort by {stat} differences, but {ctrl_group} or {exp_group} not in {groups}."
            )
        ordering = kpms.sort_syllables_by_stat_difference(
            complete_df, ctrl_group, exp_group, stat=stat
        )
    if colors is None:
        colors = []
    if len(colors) == 0 or len(colors) != len(groups):
        colors = sns.color_palette(n_colors=len(groups))

    return ordering, groups, colors, figsize

In [None]:
import os
fig,legend=plot_syll_stats_with_sem(
    stats_df, project_dir, model_name,
    plot_sig=True,    # whether to mark statistical significance with a star
    thresh=0.05,      # significance threshold
    stat='frequency', # statistic to be plotted (e.g. 'duration' or 'velocity_px_s_mean')
    order='stat',  # order syllables by overall frequency ("stat") or degree of difference ("diff")
    ctrl_group='Ctrl',   # name of the control group for statistical testing
    exp_group='cKO',
    colors = [(0,0,0), (237/255,28/255,36/255) ],# name of the experimental group for statistical testing
    figsize=(8, 6),   # figure size    
    groups=stats_df['group'].unique()
);
#ctrl_group='WT',   # name of the control group for statistical testing
    #xp_group='cntnap',    # name of the experimental group for statistical testing
#groups=stats_df['group'].unique(),
savefigpath = os.path.join(project_dir, model_name, 'figures')
fig.tight_layout()
fig.savefig(os.path.join(savefigpath,'freq_stat_stats.svg'))
fig.savefig(os.path.join(savefigpath,'freq_stat_stats.png'))

In [None]:
# sex difference
kpms.plot_syll_stats_with_sem(
    stats_df_male, project_dir, model_name,
    plot_sig=True,    # whether to mark statistical significance with a star
    thresh=0.05,      # significance threshold
    stat='frequency', # statistic to be plotted (e.g. 'duration' or 'velocity_px_s_mean')
    order='stat',     # order syllables by overall frequency ("stat") or degree of difference ("diff")
    ctrl_group='WT_F',   # name of the control group for statistical testing
    exp_group='Mut_F',    # name of the experimental group for statistical testing
    figsize=(8,8),   # figure size    
    groups=stats_df['group'].unique()
);

In [None]:
# use a temp stats_df, replace groups with different animal ID, then plot individually
sub_stats_df = stats_df.copy()
sub_stats_df.drop(columns=['group'], inplace=True)
subject = []
for ii in range(stats_df.shape[0]):
    subject.append(stats_df['name'][ii][27:32])
    
sub_stats_df['group'] = subject
sub_stats_df

In [None]:

def visualize_transition_bigram_svg(
    project_dir,
    model_name,
    group,
    trans_mats,
    syll_include,
    save_dir=None,
    normalize="bigram",
    figsize=(12, 6),
    show_syllable_names=True,
):
    """Visualize the transition matrices for each group.

    Parameters
    ----------
    group : list or np.ndarray
        the groups in the project
    trans_mats : list
        the list of transition matrices for each group
    normalize : str, optional
        the method to normalize the transition matrix, by default 'bigram'
    figsize : tuple, optional
        the figure size, by default (12,6)
    show_syllable_names : bool, optional
        whether to show just syllable indexes (False) or syllable indexes and
        names (True)
    """
    if show_syllable_names:
        syll_names = get_syllable_names(project_dir, model_name, syll_include)
    else:
        syll_names = [f"{ix}" for ix in syll_include]

    # infer max_syllables
    max_syllables = trans_mats[0].shape[0]

    fig, ax = plt.subplots(
        1, len(group), figsize=figsize, sharex=False, sharey=True
    )
    title_map = dict(bigram="Bigram", columns="Incoming", rows="Outgoing")
    color_lim = max([x.max() for x in trans_mats])
    if len(group) == 1:
        axs = [ax]
    else:
        axs = ax.flat
    for i, g in enumerate(group):
        h = axs[i].imshow(
            trans_mats[i][:max_syllables, :max_syllables],
            cmap="cubehelix",
            vmax=color_lim,
        )
        if i == 0:
            axs[i].set_ylabel("Incoming syllable")
            plt.yticks(np.arange(len(syll_include)), syll_names)
        cb = fig.colorbar(h, ax=axs[i], fraction=0.046, pad=0.04)
        cb.set_label(f"{title_map[normalize]} transition probability")
        axs[i].set_xlabel("Outgoing syllable")
        axs[i].set_title(g)
        axs[i].set_xticks(
            np.arange(len(syll_include)), syll_names, rotation=90
        )

    # saving the figures
    # saving the figure
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)
    else:
        save_dir = os.path.join(project_dir, model_name, "figures")
        os.makedirs(save_dir, exist_ok=True)

    fig.savefig(os.path.join(save_dir, "transition_matrices.pdf"))
    fig.savefig(os.path.join(save_dir, "transition_matrices.png"))
    fig.savefig(os.path.join(save_dir, "transition_matrices.svg"))

In [None]:
normalize='bigram' # normalization method ("bigram", "rows" or "columns")

trans_mats, usages, groups, syll_include=kpms.generate_transition_matrices(
    project_dir, model_name, normalize=normalize,
    min_frequency=0.005 # minimum syllable frequency to include
)    
print(groups)

# rearrange trans_mats in the similarity groups
new_syll_include = syll_sim
new_trans_mats = []
for ii in range(len(trans_mats)):
    new_trans_mats.append(trans_mats[ii][syll_sim][:,syll_sim])
    
visualize_transition_bigram_svg(
    project_dir, model_name, groups, new_trans_mats, new_syll_include, normalize=normalize, 
    show_syllable_names=False # label syllables by index (False) or index and name (True)
)

In [None]:
kpms.plot_transition_graph_group(
    project_dir, model_name, 
    groups, trans_mats, usages, new_syll_include, 
    layout='circular',        # transition graph layout ("circular" or "spring")
    show_syllable_names=False # label syllables by index (False) or index and name (True)
)


In [None]:
kpms.plot_transition_graph_difference(project_dir, model_name,
                                      groups, trans_mats, usages, new_syll_include, 
                                      layout='circular',show_syllable_names=False) # transition graph layout ("circular" or "spring")

In [None]:
kpms.plot_transition_graph_difference(project_dir, model_name,
                                      groups, trans_mats, usages, new_syll_include, 
                                      layout='circular',show_syllable_names=False) 

In [None]:
# make a new index_data containing subject ID 
import os
index_file = os.path.join(project_dir, "index.csv")
if not os.path.exists(index_file):
    generate_index(project_dir, model_name, index_file)

index_data = pd.read_csv(index_file, index_col=False)
sub_index_data = pd.DataFrame({
    'name': index_data['name'],
    'group': index_data['name']
})
sub_index_data
sub_index_file = os.path.join(project_dir, "sub_index.csv")
sub_index_data.to_csv(sub_index_file)

In [None]:
# function to calculate entropy
# also considering usage? (ref, Datta, 2015 paper)
def get_entropy(transition_matrix, usage, epsilon):
    # calculate entropy for a 2D matrix
    transition_matrix = np.where(transition_matrix == 0, epsilon, transition_matrix)
    transition_matrix = np.where(np.isnan(transition_matrix), epsilon, transition_matrix)
    # Calculate entropy using the formula
    #print(usage.shape)
    #print(transition_matrix.shape)
    entropy = -np.sum(usage *transition_matrix * np.log2(transition_matrix))

    # get individual contribution to the transitions
    size_x = transition_matrix.shape[0]
    size_y = transition_matrix.shape[1]
    entropy_ind = np.zeros((size_x, size_y))

    for ii in range(size_x):
        #print(ii)
        for jj in range(size_y):
            #print(usage[0][ii]*transition_matrix[ii,jj]* np.log2(transition_matrix[ii,jj]))
            entropy_ind[ii,jj] = -usage[ii]*transition_matrix[ii,jj]* np.log2(transition_matrix[ii,jj])
    return entropy, entropy_ind

# test_mat = np.full((25,25), 1/25)
# usage = np.full((1,25), 1/25)
# ent, ent_ind = get_entropy(test_mat, usage, 1e-12)
# print(ent)
# print(ent_ind)

In [None]:
# calculate entropy for the whole session
from keypoint_moseq.io import load_results
from jax_moseq.utils import get_durations, get_frequencies
import numpy as np

normalize='columns' # normalization method ("bigram", "rows" or "columns")
min_frequency = 0.005
epsilon=1e-12
entropyWhole = np.zeros((len(sub_index_data), 1))

trans_mats, usages = None, None
    # index file
    #index_data = pd.read_csv(index_file, index_col=False)

label_group = list(sub_index_data.group.values)
recordings = list(sub_index_data.name.values)
group = sorted(list(sub_index_data.group.unique()))
    #print("Group(s):", ", ".join(group)) 
save_results_path = os.path.join(project_dir, model_name,'results.h5')
    # load model reuslts
results_dict = load_results(path=save_results_path)
    #print(results_dict)
    
    # filter out syllables by freqency (using overall results!)
    
model_labels = [results_dict[recording]["syllable"] for recording in recordings]
    #print(model_labels)
frequencies = get_frequencies(model_labels)
syll_include = np.where(frequencies > min_frequency)[0]

    
trans_mats, usage = kpms.get_group_trans_mats(
    model_labels,
    label_group,
    group,
    syll_include=syll_include,
    normalize=normalize,
)
groups = group
    #print(trans_mats)
# rearrange trans_mats in the similarity groups
new_syll_include = syll_sim
new_trans_mats = []
#print(len(trans_mats))
x = len(trans_mats)
y = len(trans_mats[0])
entropy_ind = np.zeros((x,y,len(sub_index_data)))
for ii in range(len(trans_mats)):
        #rint(ii)
    #new_trans_mats.append(trans_mats[ii][syll_sim][:,syll_sim])
    entropyWhole[ii], entropy_ind[:,:,ii]= get_entropy(trans_mats[ii], usage[ii],epsilon)

In [None]:
# save entropy
data = {
    'animal':sub_index_data['name'],
    'entropy':entropyWhole[:,0]
}
dataDF = pd.DataFrame(data)
save_name = "entropy_wholesession.csv"
save_path = os.path.join(project_dir, model_name,save_name)
dataDF.to_csv(save_path)

# save the data frame
x = entropy_ind.shape[0]
y = entropy_ind.shape[1]
z = entropy_ind.shape[2]
data = {
    'animal':sub_index_data['name'],
    'entropy':entropy_ind.reshape(x*y,z)
}
dataDF = pd.DataFrame(entropy_ind.reshape(x*y,z), columns=sub_index_data['name'])
save_name = "entropy_pairs_wholesession.csv"
save_path = os.path.join(project_dir, model_name,save_name)
dataDF.to_csv(save_path)
# save entropy for individual syllable pairs

In [None]:
# MGRPR data
import matplotlib.pyplot as plt
import os
from scipy.stats import mannwhitneyu

savefigpath = os.path.join(project_dir, model_name,'figures')

fig,ax = plt.subplots()
fig.suptitle("Entropy of transitions - whole session")
# make the plot
WTColor = (255 / 255, 189 / 255, 53 / 255)
MutColor = (63 / 255, 167 / 255, 150 / 255)
colors = [WTColor, MutColor]


mat1 = entropyWhole[index_data['group']=='C'][:,0]
mat2= entropyWhole[index_data['group']=='K'][:,0]
statistic,p=mannwhitneyu(mat1,mat2)
print(p)
ax.violinplot([mat1, mat2], showmedians = True)
vp=ax.set_xticks([1, 2], ['Ctrl', 'cKO'])

fig.savefig(os.path.join(savefigpath, 'Entropy whole session.png'))
fig.savefig(os.path.join(savefigpath, 'Entropy whole session.svg'))

In [None]:
def get_group_trans_mats_custom(labels, label_group, group, syll_include, normalize="bigram"):
    from jax_moseq.utils import get_frequencies
    """Get the transition matrices for each group.

    Parameters
    ----------
    labels : list or np.ndarray
        recording state lists
    label_group : list or np.ndarray
        the group labels for each recording
    group : list or np.ndarray
        the groups in the project
    max_sylls : int
        the maximum number of syllables to include
    normalize : str, optional
        the method to normalize the transition matrix, by default 'bigram'

    Returns
    -------
    trans_mats : list
        the list of transition matrices for each group
    frequencies : list
        the list of syllable frequencies for each group
    """
    trans_mats = []
    frequencies = []

    # Computing transition matrices for each given group
    for plt_group in group:
        # list of syll labels in recordings in the group
        use_labels = [lbl for lbl, grp in zip(labels, label_group) if grp == plt_group]
        # find stack np array shape
        row_num = len(use_labels)
        max_len = max([len(lbl) for lbl in use_labels])
        # Get recordings to include in trans_mat
        # subset only syllable included
        trans_mats.append(
            kpms.get_transition_matrix(use_labels, normalize=normalize, combine=True)[
                syll_include, :
            ][:, syll_include]
        )

        # Getting frequency information for node scaling
        group_frequencies = get_frequencies(use_labels)[syll_include]

        frequencies.append(group_frequencies)
    return trans_mats, frequencies

In [None]:
def bootstrap(data, dim, dim0, n_sample=1000):
    """
    input:
    data: data matrix for bootstrap
    dim: the dimension for bootstrap, should be data.shape[1]
    dim0: the dimension untouched, shoud be data.shape[0]
    n_sample: number of samples for bootstrap. default: 1000
    output:
    bootRes={'bootAve','bootHigh','bootLow'}
    """
    # Resample the rows of the matrix with replacement
    if len(data)>0:  # if input data is not empty
        bootstrap_indices = np.random.choice(data.shape[dim], size=(n_sample, data.shape[dim]), replace=True)
        #print(data.shape[dim])
        # Bootstrap the matrix along the chosen dimension
        bootstrapped_matrix = np.take(data, bootstrap_indices, axis=dim)
        print(bootstrapped_matrix.shape)
        meanBoot = np.nanmedian(bootstrapped_matrix,2)
        print(meanBoot.shape)
        bootAve = np.nanmedian(bootstrapped_matrix, axis = (1,2))
        bootHigh = np.nanpercentile(meanBoot, 97.5, axis=1)
        bootLow = np.nanpercentile(meanBoot, 2.5, axis=1)

    else:  # return nans
        bootAve = np.full(dim0, np.nan)
        bootLow = np.full(dim0, np.nan)
        bootHigh = np.full(dim0, np.nan)
        # bootstrapped_matrix = np.array([np.nan])

    # bootstrapped_2d = bootstrapped_matrix.reshape(80,-1)
    # need to find a way to output raw bootstrap results
    tempData = {'bootAve': bootAve, 'bootHigh': bootHigh, 'bootLow': bootLow}
    index = np.arange(len(bootAve))
    bootRes = pd.DataFrame(tempData, index)

    return bootRes

In [None]:
# examine grooming bouts
frame_idx = np.unique(moseq_df['frame_index'])
fps=40
t_step = 5*60*fps
num_window = int(np.ceil(len(frame_idx)/t_step))

sessions = np.unique(moseq_df['name'])
sessions = sub_index_data['name']
fps = 40


grooming_syl = [4,11, 12,13, 18, 26]
total_grooming_time = np.zeros(len(sessions))
total_grooming_bouts = np.zeros(len(sessions))
grooming_bouts_duration = {}  # a dictionary to store the duration
grooming_time_5window = np.zeros((len(sessions), num_window))
grooming_bouts_5window = np.zeros((len(sessions), num_window))           

for idx,ses in enumerate(sessions):
    time  = 0
    bouts = 0
    grooming_bouts_duration[ses] = []
    prev_syl = np.nan
    duration = 0
    for ss in results[ses]['syllable']:
        if ss in grooming_syl:
            time += 1
            duration += 1
            if prev_syl not in grooming_syl:
                bouts += 1
                grooming_bouts_duration[ses].append(duration/fps)
                duration = 0    
        prev_syl = ss
    total_grooming_time[idx] = time/fps
    total_grooming_bouts[idx] = bouts
    if duration > 0:
        grooming_bouts_duration[ses].append(duration/fps)
    grooming_bouts_duration[ses] = grooming_bouts_duration[ses][1:]     
    
print(total_grooming_time)
print(total_grooming_bouts)
data = {
    'animal':sub_index_data['name'],
    'grooming_time':total_grooming_time,
    'grooming_bouts':total_grooming_bouts
}
dataDF = pd.DataFrame(data)
save_name = "grooming.csv"
save_path = os.path.join(project_dir, model_name,save_name)
dataDF.to_csv(save_path)
print(grooming_bouts_duration[ses])

In [None]:
## MGRPR

# make plots
from scipy.stats import mannwhitneyu
import matplotlib.pyplot as plt
fig,ax = plt.subplots()
fig.suptitle("Total grooming time - whole session")
# make the plot
colors = [WTColor, MutColor]

mat1 =total_grooming_time[index_data['group']=='C']
mat2= total_grooming_time[index_data['group']=='K']
statistic,p=mannwhitneyu(mat1,mat2)
print(p)
ax.violinplot([mat1, mat2], showmedians = True)
vp=ax.set_xticks([1, 2], ['Ctrl', 'cKO'])

savefigpath = os.path.join(project_dir, model_name,'figures')
fig.savefig(os.path.join(savefigpath, 'Total grooming time.png'))
fig.savefig(os.path.join(savefigpath, 'total grooming time.svg'))

# 3. Decoding analysis (Fig. 5P)

In [ ]:
# decode the label from moseq and movement data
# build the decoder on sessions? or different time frame of the sessions
import numpy as np
import pandas as pd
from keypoint_moseq.io import load_results
from jax_moseq.utils import get_durations, get_frequencies
normalize='columns' # normalization method ("bigram", "rows" or "columns")
min_frequency = 0.005

subject_geno_sex = index_data['group']
group = np.unique(subject_geno_sex)

trans_mats, usages = None, None
    # index file
#index_data = pd.read_csv(index_file, index_col=False)

label_group = list(sub_index_data.group.values)
recordings = list(sub_index_data.name.values)
group = sorted(list(sub_index_data.group.unique()))
    #print("Group(s):", ", ".join(group)) 
save_results_path = os.path.join(project_dir, model_name,'results.h5')
    # load model reuslts
results_dict = load_results(path=save_results_path)
    #print(results_dict)
    
    # filter out syllables by freqency (using overall results!)
    
model_labels = [results_dict[recording]["syllable"] for recording in recordings]
    #print(model_labels)
frequencies = get_frequencies(model_labels)
syll_include = np.where(frequencies > min_frequency)[0]

trans_mats, usage = kpms.get_group_trans_mats(
    model_labels,
    label_group,
    group,
    syll_include=syll_include,
    normalize=normalize,
)
trans_array = np.array(trans_mats)
n_subjects = len(usage)

#transition = np.full((n_subjects,  trans_array.shape[1]*trans_array.shape[2]),0)
#for i in range(n_subjects):
#    transition[i, :] = trans_mats[i].flatten()


animals = index_data['name']
#print(animals)
move_data = ['centroid_x', 'centroid_y','heading', 'angular_velocity', 'velocity_px_s']
max_length = 150000
#movement = {}


# for idx, move in enumerate(move_data):
#     #movement[move] = np.full((len(animals), max_length), 0)
#     for aidx, aa in enumerate(animals):
#         movement[move][aidx,0:len(moseq_df[move][moseq_df['name']==aa])] = moseq_df[move][moseq_df['name']==aa]
movement = np.full((len(animals), len(move_data)*max_length), 0)
for aidx, aa in enumerate(animals):
    for idx, move in enumerate(move_data):
        movement[aidx,idx*max_length:idx*max_length+len(moseq_df[move][moseq_df['name']==aa])] = moseq_df[move][moseq_df['name']==aa]

# decoding only 'WT' and 'Mut'
decoding_label = []
for gg in subject_geno_sex:
    print(gg)
    if 'C' in gg:
        decoding_label.append('WT')
    elif 'K' in gg:
    #elif 'KO' in gg:
        #decoding_label.append('Mut')
        decoding_label.append('KO')

#print(decoding_label)
#print(transition.shape)

In [ ]:
def find_nans(matrix):
    nan_indices = []
    for i, row in enumerate(matrix):
        for j, value in enumerate(row):
            if isinstance(value, float) and np.isnan(value):
                nan_indices.append((i, j))
    return nan_indices

In [ ]:
# remove animal 8 for too many nans
Index_to_keep = np.arange(len(decoding_label)) != 8
print(Index_to_keep)
decoding_label_new = decoding_label[0:8] + decoding_label[9:]
print(decoding_label_new)

In [ ]:

# run 20 times to get an average
import os
import pandas as pd

n_repeats = 20
classifier = 'RandomForest'
decoder_transition = run_decoder(transitions, decoding_label, classifier, 42, n_repeats)
print(decoder_transition)
print('Decode transition Done')

decoder_usage = run_decoder(usage, decoding_label, classifier, 42, n_repeats)
print(decoder_usage)
print('Decode usage Done')

decoder_movement = run_decoder(movement, decoding_label, classifier, 42, n_repeats)
print(decoder_movement)


In [ ]:
import random
n_repeats = 20
animalIndex = np.arange(len(decoding_label))
WTIdx = animalIndex[np.array(decoding_label)=='WT']
KOIdx = animalIndex[np.array(decoding_label)=='KO']
decode_transition = {}
decode_usage = {}
decode_movement = {}
decode_transition['accuracy'] = np.zeros((n_repeats))
decode_transition['ctrl_accuracy'] =np.zeros((n_repeats))
decode_usage['accuracy'] = np.zeros((n_repeats))
decode_usage['ctrl_accuracy'] =np.zeros((n_repeats))
decode_movement['accuracy'] = np.zeros((n_repeats))
decode_movement['ctrl_accuracy'] =np.zeros((n_repeats))

WTIdx_touse = random.sample(list(WTIdx), len(KOIdx))
idx_touse = sorted(WTIdx_touse+list(KOIdx))
#print(decoding_label[idx_touse])


In [ ]:
for rr in range(n_repeats):
    print(rr)
    random.seed(rr)
    WTIdx_touse = random.sample(list(WTIdx), len(KOIdx))
    idx_touse = sorted(WTIdx_touse+list(KOIdx))
    print(idx_touse)
    new_label = [decoding_label[ii] for ii in idx_touse]
    new_usage = [usage[ii] for ii in idx_touse]

    decoder_transition = run_decoder(transitions[idx_touse,:], new_label, classifier, rr, 1)
    decode_transition['accuracy'][rr] = decoder_transition['accuracy'][0]
    decode_transition['ctrl_accuracy'][rr] = decoder_transition['ctrl_accuracy'][0]
    decoder_usage = run_decoder(new_usage, new_label, classifier, rr,1)
    decode_usage['accuracy'][rr] = decoder_usage['accuracy'][0]
    decode_usage['ctrl_accuracy'][rr] = decoder_usage['ctrl_accuracy'][0]
    decoder_movement = run_decoder(movement[idx_touse,:], new_label, classifier, rr,1)
    decode_movement['accuracy'][rr] = decoder_movement['accuracy'][0]
    decode_movement['ctrl_accuracy'][rr] = decoder_movement['ctrl_accuracy'][0]


In [ ]:
# save the data
data = {
    'transition_decode_accuracy': decode_transition['accuracy'],
    'transition_decode_control': decode_transition['ctrl_accuracy'],
    'usage_decode_accuracy': decode_usage['accuracy'],
    'usage_decode_control': decode_usage['ctrl_accuracy'],
    'movement_decode_accuracy':decode_movement['accuracy'],
    'movement_decode_ctrl':decode_movement['ctrl_accuracy']
}
df = pd.DataFrame(data)
savefile = os.path.join(project_dir, model_name,'decode_results_RF.csv')
df.to_csv(savefile)

In [ ]:
import seaborn as sns
import matplotlib.pyplot as plt
data = [decode_transition['accuracy'],
        decode_usage['accuracy'],
        decode_movement['accuracy']]

ax=sns.boxplot(data=data)
sns.set(style="ticks")

ax.set_xticklabels(['Transition', 'Usage', 'Movement']) 
ax.set_ylabel('Decoding accuracy')
savefigpath = os.path.join(project_dir, model_name, 'figures')
plt.savefig(os.path.join(savefigpath,'Decoding accuracy.png'))
plt.savefig(os.path.join(savefigpath,'Decoding accuracy.svg'))