# inspect

> Collection of functions that enable the visualization and inspection of obtained results in just a few lines of code

In [None]:
#| default_exp inspect

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from pathlib import Path, PosixPath
from typing import List, Tuple, Dict, Optional

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

## Example usage:

### Helper functions:

Before starting with the example usage, let´s quickly define two helper functions that we will need:

In [None]:
#| export

def get_only_matching_xlsx_files(dir_path: Path, paradigm_id: str, week_id: Optional[int]=None) -> List[Path]:
    filtered_filepaths = []
    for filepath in dir_path.iterdir():
        if filepath.name.endswith('.xlsx') and (paradigm_id in filepath.name):
            if week_id != None:
                if f'week-{week_id}.' in filepath.name:
                    filtered_filepaths.append(filepath)
            else:
                filtered_filepaths.append(filepath)
    return filtered_filepaths

In [None]:
#| export

def get_metadata_from_filename(filepath_session_results: Path, group_assignment_filepath: Path) -> Dict:
    metadata = {}
    metadata['line_id'], mouse_id, metadata['paradigm_id'], week_string_with_file_extension = filepath_session_results.name.split('_')
    metadata['subject_id'] = f'{metadata["line_id"]}_{mouse_id}'
    metadata['week_id'] = week_string_with_file_extension[week_string_with_file_extension.find('-') + 1:week_string_with_file_extension.find('.')]
    df_group_assignment = pd.read_excel(group_assignment_filepath)
    if metadata['subject_id'] in df_group_assignment['subject_id'].unique():
        metadata['group_id'] = df_group_assignment.loc[df_group_assignment['subject_id'] == metadata['subject_id'], 'group_id'].iloc[0]
    elif metadata['subject_id'] in df_group_assignment['alternative_subject_id'].unique():
        metadata['group_id'] = df_group_assignment.loc[df_group_assignment['alternative_subject_id'] == metadata['subject_id'], 'group_id'].iloc[0]    
    else:
        metadata['group_id'] = 'unknown'
    return metadata

### 1) Load all data:

This obviously assumes that you were using the `gait_analysis` package to analyze your 2D tracking data created the corresponding result exports. To get some quick insights into your data, feel free to use the following collection of functions.

First, you need to provide the filepath to the exported Excel files as `root_dir_path` as a Path object (see exmaple below). You can also specify if you´d like to right away filter only for a certain set of weeks or paradigms (you can also just pass a list with a single value for each). Please note that you have to provide the exact name of the respective Excel Tab you are interested in as:  `sheet_name`.

Note: The group_assignment Excel Sheet is still quite customized & should be replaced by a more generalizable configs file (e.g. .yaml) in future versions.

In [None]:
#| export

def collect_all_available_data(root_dir_path: Path, # Filepath to the directory that contains the exported results .xlsx files
                               group_assignment_filepath: Path, # Filepath to the group_assignments.xlsx file
                               paradigm_ids: List[str], # List of paradigms of which the results shall be loaded
                               week_ids: List[str],  # List of weeks from which the results shall be loaded
                               sheet_name: str # Tab name of the exported results sheet to load, e.g. "session_overview"
                              ) -> pd.DataFrame:
    all_recording_results_dfs = []
    for week_id in week_ids:
        for paradigm_id in paradigm_ids:
            tmp_matching_filepaths = get_only_matching_xlsx_files(dir_path = root_dir_path, paradigm_id = paradigm_id, week_id = week_id)
            for filepath in tmp_matching_filepaths:
                metadata = get_metadata_from_filename(filepath_session_results = filepath, group_assignment_filepath = group_assignment_filepath)
                tmp_xlsx = pd.ExcelFile(filepath)
                tmp_df = pd.read_excel(tmp_xlsx, sheet_name = sheet_name, index_col = 0)
                for key, value in metadata.items():
                    tmp_df[key] = value
                all_recording_results_dfs.append(tmp_df)
    df = pd.concat(all_recording_results_dfs)
    df.reset_index(drop = True, inplace = True)
    return df

For instance, if you´d like to inspect the overall session overview, use:

``` 
df = collect_all_available_data(root_dir_path = Path('/path/to/your/directory/with/the/exported/retults/'),
                                group_assignment_filepath = Path('/filepath/to/your/group_assignment.xlsx'),
                                paradigm_ids = ['paradigm_a', 'paradigm_b', 'paradigm_c'],
                                week_ids = [1, 4, 8, 12, 14],
                                sheet_name = 'session_overview')
```


### 2) Filter your data (optional):

You might want to filter your data to only inspect a specific proportion of it. You can do so quite easily by using the `filter_dataframe()` funtion (see example usage below):

In [None]:
#| export

def filter_dataframe(df: pd.DataFrame, filter_criteria: List[Tuple]) -> pd.DataFrame:
    # assert all list have equal lenghts
    valid_idxs_per_criterion = []
    for column_name, comparison_method, reference_value in filter_criteria:
        # assert valid key in comparison methods
        # assert column name exists
        if comparison_method == 'greater':
            valid_idxs_per_criterion.append(df.loc[df[column_name] > reference_value].index.values)
        elif comparison_method == 'smaller':
            valid_idxs_per_criterion.append(df.loc[df[column_name] < reference_value].index.values)
        elif comparison_method == 'equal_to':
            valid_idxs_per_criterion.append(df.loc[df[column_name] == reference_value].index.values)
        elif comparison_method == 'is_in_list':
            valid_idxs_per_criterion.append(df.loc[df[column_name].isin(reference_value)].index.values)
        elif comparison_method == 'is_nan':
            valid_idxs_per_criterion.append(df.loc[df[column_name].isnull()].index.values)
    shared_valid_idxs_across_all_criteria = valid_idxs_per_criterion[0]
    if len(valid_idxs_per_criterion) > 1:
        for i in range(1, len(valid_idxs_per_criterion)):
            shared_valid_idxs_across_all_criteria = np.intersect1d(shared_valid_idxs_across_all_criteria, valid_idxs_per_criterion[i])
    df_filtered = df.loc[shared_valid_idxs_across_all_criteria, :].copy()
    return df_filtered

You need to specify your filter criteria in a list of tuples, for which each tuple follows this schema:
> (column_name, comparison_method, reference_value)

So for instance, if you´d like to filter your dataframe by selecting only the data of all freezing bouts, you would add a tuple that specifies that you want all rows from the column "bout_type" that have the value "all_freezing_bouts": 
> `('bout_type', 'equal_to', 'all_freezing_bouts')`

You can also add more criteria with additional tuples, for instance to filter for specific mouse lines your filter criteria would look like this:
> `('line_id', 'is_in_list', ['206', '209'])` 

Bringing it together, you´d define your filter_criteria in a list of these tuples:

```
filter_criteria = [('line_id', 'is_in_list', ['206', '209']),
                   ('bout_type', 'equal_to', 'all_freezing_bouts')]
```

You can add as many tuples (= criteria) you´d like. Currently implemented comparison methods include:
- "greater": selects only rows in which the values of the column are greater than the reference value
- "smaller": selects only rows in which the values of the column are smaller than the reference value
- "equal_to": selects only rows in which the values of the column are equal to the reference value
- "is_in_list": selects only rows in which the values of the column are matching to an element in the reference value (which has to be a list, in this case)
- "is_nan": selects only rows in which the values of the column are NaN

You can then pass the `filter_criteria` along your dataframe to the `filter_dataframe()` function:

```
df_filtered = filter_dataframe(df = df, filter_criteria = filter_criteria)
```

### 3) Plotting:

Eventually, you can also plot your data, filtering it even more, if you´d like to. When you´re using the `plot()` function, you can specify the following parameters (see function documentation and example usage below):

In [None]:
#| export

def plot(df: pd.DataFrame, # the DataFrame that contains the data you´d like to plot
         x_column: str, # the column name of the data that should be visualized on the x-axis
         y_column: str, # the column name of the data that should be visualized on the y-axis
         plot_type: str='violinplot', # currently only "violinplot" and "stripplot" are implemented
         hue_column: Optional[str]=None, # if you´d like to use the data of a column to color-code the plotted, you can specify it here (see example below)
         hide_legend: bool=True # pass along as `False` if you´d like the legend of the plot to be displayed
        ) -> None:
    fig = plt.figure(figsize = (8, 5), facecolor = 'white')
    ax = fig.add_subplot(111)
    if plot_type == 'violinplot':
        sns.violinplot(data = df, x = x_column, y = y_column, hue = hue_column)
        if df.shape[0] < 10_000:
            sns.stripplot(data = df, x = x_column, y = y_column, hue = hue_column, dodge = True, color = 'black', alpha = 0.3)
    elif plot_type == 'stripplot':
        sns.stripplot(data = df, x = x_column, y = y_column, hue = hue_column, dodge = True)
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    if hide_legend:
        ax.get_legend().remove()
    plt.show()

```
plot(df = df_filtered, x_column = 'group_id', y_column = 'total_duration', hue_column = 'week_id')
```

You can also use the following function to create a DataFrame that gives you an overview of your data availability:

In [None]:
#| export

def check_data_availability(root_dir_path: Path, all_week_ids: List[int], all_paradigm_ids: List[str]) -> pd.DataFrame:
    df = pd.DataFrame({}, index=all_week_ids)
    root_dir_path = root_dir_path
    for week_id in all_week_ids:
        for paradigm_id in all_paradigm_ids:
            all_matching_filepaths = get_only_matching_xlsx_files(dir_path = root_dir_path, paradigm_id = paradigm_id, week_id = week_id)
            for filepath in all_matching_filepaths:
                subject_id = get_metadata_from_filename(filepath)['subject_id']
                if f'{subject_id}_{paradigm_id}' not in df.columns:
                    df[f'{subject_id}_{paradigm_id}'] = False
                df.loc[week_id, f'{subject_id}_{paradigm_id}'] = True
    return df

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()