In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import seaborn as sns
import pingouin as pg
import matplotlib.pyplot as plt

from typing import List, Tuple, Dict, Optional, Union

In [None]:
def get_only_matching_xlsx_files(dir_path: Path, paradigm_id: str, week_id: Optional[int]=None) -> List[Path]:
    filtered_filepaths = []
    for filepath in dir_path.iterdir():
        if filepath.name.endswith('.xlsx') and (paradigm_id in filepath.name):
            if week_id != None:
                if f'week-{week_id}.' in filepath.name:
                    filtered_filepaths.append(filepath)
            else:
                filtered_filepaths.append(filepath)
    return filtered_filepaths

In [None]:
def get_metadata_from_filename(filepath_session_results: Path, filepath_group_assignment: Path=Path('test_data/group_assignment.xlsx')) -> Dict:
    metadata = {}
    metadata['line_id'], mouse_id, metadata['paradigm_id'], week_string_with_file_extension = filepath_session_results.name.split('_')
    metadata['subject_id'] = f'{metadata["line_id"]}_{mouse_id}'
    metadata['week_id'] = week_string_with_file_extension[week_string_with_file_extension.find('-') + 1:week_string_with_file_extension.find('.')]
    df_group_assignment = pd.read_excel(filepath_group_assignment)
    if metadata['subject_id'] in df_group_assignment['subject_id'].unique():
        metadata['group_id'] = df_group_assignment.loc[df_group_assignment['subject_id'] == metadata['subject_id'], 'group_id'].iloc[0]
    elif metadata['subject_id'] in df_group_assignment['alternative_subject_id'].unique():
        metadata['group_id'] = df_group_assignment.loc[df_group_assignment['alternative_subject_id'] == metadata['subject_id'], 'group_id'].iloc[0]    
    else:
        metadata['group_id'] = 'unknown'
    return metadata

In [None]:
def check_data_availability(root_dir_path: Path, all_week_ids: List[int], all_paradigm_ids: List[str]) -> pd.DataFrame:
    df = pd.DataFrame({}, index=all_week_ids)
    root_dir_path = root_dir_path
    for week_id in all_week_ids:
        for paradigm_id in all_paradigm_ids:
            all_matching_filepaths = get_only_matching_xlsx_files(dir_path = root_dir_path, paradigm_id = paradigm_id, week_id = week_id)
            for filepath in all_matching_filepaths:
                subject_id = get_metadata_from_filename(filepath)['subject_id']
                if f'{subject_id}_{paradigm_id}' not in df.columns:
                    df[f'{subject_id}_{paradigm_id}'] = False
                df.loc[week_id, f'{subject_id}_{paradigm_id}'] = True
    return df

In [None]:
def collect_all_available_data(root_dir_path: Path, paradigm_ids: List[str], week_ids: List[str], sheet_name: str) -> pd.DataFrame:
    all_recording_results_dfs = []
    for week_id in week_ids:
        for paradigm_id in paradigm_ids:
            tmp_matching_filepaths = get_only_matching_xlsx_files(dir_path = root_dir_path, paradigm_id = paradigm_id, week_id = week_id)
            for filepath in tmp_matching_filepaths:
                metadata = get_metadata_from_filename(filepath_session_results = filepath)
                tmp_xlsx = pd.ExcelFile(filepath)
                tmp_df = pd.read_excel(tmp_xlsx, sheet_name = sheet_name, index_col = 0)
                for key, value in metadata.items():
                    tmp_df[key] = value
                all_recording_results_dfs.append(tmp_df)
    df = pd.concat(all_recording_results_dfs)
    df.reset_index(drop = True, inplace = True)
    return df

In [None]:
def filter_dataframe(df: pd.DataFrame, filter_criteria: List[Tuple]) -> pd.DataFrame:
    # assert all list have equal lenghts
    valid_idxs_per_criterion = []
    for column_name, comparison_method, reference_value in filter_criteria:
        # assert valid key in comparison methods
        # assert column name exists
        if comparison_method == 'greater':
            valid_idxs_per_criterion.append(df.loc[df[column_name] > reference_value].index.values)
        elif comparison_method == 'smaller':
            valid_idxs_per_criterion.append(df.loc[df[column_name] < reference_value].index.values)
        elif comparison_method == 'equal_to':
            valid_idxs_per_criterion.append(df.loc[df[column_name] == reference_value].index.values)
        elif comparison_method == 'is_in_list':
            valid_idxs_per_criterion.append(df.loc[df[column_name].isin(reference_value)].index.values)
        elif comparison_method == 'is_nan':
            valid_idxs_per_criterion.append(df.loc[df[column_name].isnull()].index.values)
    shared_valid_idxs_across_all_criteria = valid_idxs_per_criterion[0]
    if len(valid_idxs_per_criterion) > 1:
        for i in range(1, len(valid_idxs_per_criterion)):
            shared_valid_idxs_across_all_criteria = np.intersect1d(shared_valid_idxs_across_all_criteria, valid_idxs_per_criterion[i])
    df_filtered = df.loc[shared_valid_idxs_across_all_criteria, :].copy()
    return df_filtered

In [None]:
def plot(df: pd.DataFrame, x_column: str, y_column: str, plot_type: str='violinplot', hue_column: Optional[str]=None, hide_legend: bool=True):
    fig = plt.figure(figsize = (8, 5), facecolor = 'white')
    ax = fig.add_subplot(111)
    if plot_type == 'violinplot':
        sns.violinplot(data = df, x = x_column, y = y_column, hue = hue_column)
        if df.shape[0] < 2_000:
            sns.stripplot(data = df, x = x_column, y = y_column, hue = hue_column, dodge = True, color = 'black', alpha = 0.3)
    elif plot_type == 'stripplot':
        sns.stripplot(data = df, x = x_column, y = y_column, hue = hue_column, dodge = True)
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    if hide_legend:
        ax.get_legend().remove()
    plt.show()

## Example usage:

### 1) Load all data:

This obviously assumes that you were using the "run_2d_gait_analysis_with_top_cam.ipynb" notebook and created the corresponding result exports. You will now have to provide the filepath to these Excel files as `root_dir_path` as a Path object (see exmaple below). You can also specify if you´d like to right away filter only for a certain set of weeks or paradigms (you can also just pass a list with a single value for each). Please note that you have to provide the exact name of the respective Excel Tab you are interested in as:  `sheet_name` <br>
For instance, if you´d like to inspect the overall session overview, use:

In [None]:
df = collect_all_available_data(root_dir_path = Path('/mnt/c/Users/dsege/Downloads/DLC_data/22_11_14/analyses/'),
                                paradigm_ids = ['OTR', 'OTT', 'OTE'],
                                week_ids = [1, 8, 12, 14],
                                sheet_name = 'session_overview')
df.head()

### 2) Filter your data (optional):

You might want to filter your data to only inspect a specific proportion of it. You can do so quite easily by using the `filter_dataframe()` funtion. <br>
For this, you need to specify your filter criteria in a list of tuples, for which each tuple follows this schema:
> (`column_name`, `comparison_method`, `reference_value`)

So for instance, if you´d like to filter your dataframe by selecting only those rows in which the paradigm_id is "OTE" and in which the mouse line is one of a list (e.g. "206" and "209"), your filter criteria would look like this:

In [None]:
filter_criteria = [('paradigm_id', 'equal_to', 'OTE'),
                   ('line_id', 'is_in_list', ['206', '209'])]

You can add as many tuples (= criteria) you´d like. Currently implemented comparison methods include:
- "greater": selects only rows in which the values of the column are greater than the reference value
- "smaller": selects only rows in which the values of the column are smaller than the reference value
- "equal_to": selects only rows in which the values of the column are equal to the reference value
- "is_in_list": selects only rows in which the values of the column are matching to an element in the reference value (which has to be a list, in this case)
- "is_nan": selects only rows in which the values of the column are NaN

You can then pass the `filter_criteria` along your dataframe to the `filter_dataframe()` function:

In [None]:
df_filtered = filter_dataframe(df = df, filter_criteria = filter_criteria)

### 3) Plotting:

Eventually, you can also plot your data, filtering it even more, if you´d like to. When you´re using the `plot()` function, you can specify the following parameters:

- df: the datafram you´d like to plot (e.g. your filtered dataframe)
- x_column: the column name of the data that should be visualized on the x-axis
- y_column: the column name of the data that should be visualized on the y-axis
- plot_type: currently only "violinplot" and "stripplot" are implemented
- hue_column (optional): if you´d like to use the data of a column to color-code the plotted, you can specify it here (see example below)
- hide_legend (optional): pass along as `False` if you´d like the legend of the plot to be displayed

In [None]:
plot(df = df_filtered, x_column = 'group_id', y_column = 'CenterOfGravity_x_at_bout_start', plot_type = 'violinplot', hue_column='week_id', hide_legend = False)