# Run A Mixed Effects Model

### Authors: Calvin Howard.

#### Last updated: July 6, 2023

Use this to assess if a predictors relationship to the predictee is different between two groups. 

Notes:
- To best use this notebook, you should be familar with mixed effects models

# 00 - Import CSV with All Data
**The CSV is expected to be in this format**
- ID and absolute paths to niftis are critical
```
+-----+----------------------------+--------------+--------------+--------------+
| ID  | Nifti_File_Path            | Covariate_1  | Covariate_2  | Covariate_3  |
+-----+----------------------------+--------------+--------------+--------------+
| 1   | /path/to/file1.nii.gz      | 0.5          | 1.2          | 3.4          |
| 2   | /path/to/file2.nii.gz      | 0.7          | 1.4          | 3.1          |
| 3   | /path/to/file3.nii.gz      | 0.6          | 1.5          | 3.5          |
| 4   | /path/to/file4.nii.gz      | 0.9          | 1.1          | 3.2          |
| ... | ...                        | ...          | ...          | ...          |
+-----+----------------------------+--------------+--------------+--------------+
```

In [None]:
# Specify the path to your CSV file containing NIFTI paths
input_csv_path = '/Users/cu135/Dropbox (Partners HealthCare)/studies/cognition_2023/metadata/master_list_proper_subjects.xlsx'

In [None]:
# Specify where you want to save your results to
out_dir = '/Users/cu135/Partners HealthCare Dropbox/Calvin Howard/studies/cognition_2023/revisions/stripplot_of_outcomes'

In [None]:
from calvin_utils.permutation_analysis_utils.statsmodels_palm import CalvinStatsmodelsPalm
# Instantiate the PalmPrepararation class
cal_palm = CalvinStatsmodelsPalm(input_csv_path=input_csv_path, output_dir=out_dir, sheet='master_list_proper_subjects')
# Call the process_nifti_paths method
data_df = cal_palm.read_and_display_data()


# 01 - Preprocess Your Data

**Handle NANs**
- Set drop_nans=True is you would like to remove NaNs from data
- Provide a column name or a list of column names to remove NaNs from

In [None]:
data_df.columns

In [None]:
drop_list = ['Percent_Cognitive_Improvement']

In [None]:
data_df = cal_palm.drop_nans_from_columns(columns_to_drop_from=drop_list)
display(data_df)

**Drop Row Based on Value of Column**

Define the column, condition, and value for dropping rows
- column = 'your_column_name'
- condition = 'above'  # Options: 'equal', 'above', 'below'

In [None]:
data_df.columns

Set the parameters for dropping rows

In [None]:
column = 'City'  # The column you'd like to evaluate
condition = 'not'  # The condition to check ('equal', 'above', 'below')
value = 'Toronto'  # The value to compare against

In [None]:
data_df, other_df = cal_palm.drop_rows_based_on_value(column, condition, value)
display(data_df)

**Standardize Data**
- Enter Columns you Don't want to standardize into a list

In [None]:
# Remove anything you don't want to standardize
# cols_not_to_standardize = ['Age']

In [None]:
# data_df = cal_palm.standardize_columns(cols_not_to_standardize)
# data_df

Descriptive Stats

In [None]:
data_df.describe()

# 02 Plot

In [None]:
import os
import seaborn as sns
def plot_group_values(df, group_col='Group', value_col='Value', jitter_range=(-0.01, 0.01), title="Value by Group", out_dir=None):
    import seaborn as sns
    import matplotlib.pyplot as plt
    import numpy as np

    # Map groups to numeric values and add jitter for strip effect
    group_mapping = {group: i for i, group in enumerate(sorted(df[group_col].unique()))}
    df['Group_num'] = df[group_col].map(group_mapping)
    df['Group_jitter'] = df['Group_num'] + np.random.uniform(jitter_range[0], jitter_range[1], size=len(df))

    # Create jointplot: main panel shows jittered points (acting as a strip plot),
    # while the margins display KDE distributions.
    g = sns.jointplot(
        data=df, 
        x="Group_jitter", 
        y=value_col, 
        hue=group_col, 
        kind="scatter", 
        height=6, 
        ratio=6
    )
    # Update x-axis: place ticks at the group centers and label them
    ticks = list(group_mapping.values())
    labels = list(group_mapping.keys())
    g.ax_joint.set_xticks(ticks)
    g.ax_joint.set_xticklabels(labels)

    # Set xlim if there is only one group
    if len(group_mapping) == 1:
        g.ax_joint.set_xlim(-0.2, 0.2)

    g.set_axis_labels(group_col, value_col)
    g.ax_joint.set_title(title, pad=80)
    if out_dir is not None: 
        os.makedirs(out_dir, exist_ok=True)
        plt.savefig(os.path.join(out_dir, f"{title}.svg"))
    plt.show()

In [None]:
data_df.columns

In [None]:
plot_group_values(df=data_df, group_col='City', value_col='Percent_Cognitive_Improvement', title="Alzheimer", out_dir=out_dir)
