# Run A Mixed Effects Model

### Authors: Calvin Howard.

#### Last updated: July 6, 2023

Use this to assess if a predictors relationship to the predictee is different between two groups. 

Notes:
- To best use this notebook, you should be familar with mixed effects models

# 00 - Import CSV with All Data
**The CSV is expected to be in this format**
- ID and absolute paths to niftis are critical
```
+-----+----------------------------+--------------+--------------+--------------+
| ID  | Nifti_File_Path            | Covariate_1  | Covariate_2  | Covariate_3  |
+-----+----------------------------+--------------+--------------+--------------+
| 1   | /path/to/file1.nii.gz      | 0.5          | 1.2          | 3.4          |
| 2   | /path/to/file2.nii.gz      | 0.7          | 1.4          | 3.1          |
| 3   | /path/to/file3.nii.gz      | 0.6          | 1.5          | 3.5          |
| 4   | /path/to/file4.nii.gz      | 0.9          | 1.1          | 3.2          |
| ... | ...                        | ...          | ...          | ...          |
+-----+----------------------------+--------------+--------------+--------------+
```

In [None]:
# Specify the path to your CSV file containing NIFTI paths
input_csv_path = '/Users/cu135/Dropbox (Partners HealthCare)/studies/cognition_2023/metadata/master_list_proper_subjects.xlsx'

In [None]:
# Specify where you want to save your results to
out_dir = '/Users/cu135/Library/CloudStorage/OneDrive-Personal/OneDrive_Documents/Research/2023/subiculum_cognition_and_age/figures/Figures/supplementary_patient_binarization'

In [None]:
from calvin_utils.permutation_analysis_utils.statsmodels_palm import CalvinStatsmodelsPalm
# Instantiate the PalmPrepararation class
cal_palm = CalvinStatsmodelsPalm(input_csv_path=input_csv_path, output_dir=out_dir, sheet='master_list_proper_subjects')
# Call the process_nifti_paths method
data_df = cal_palm.read_and_display_data()


# 01 - Preprocess Your Data

**Handle NANs**
- Set drop_nans=True is you would like to remove NaNs from data
- Provide a column name or a list of column names to remove NaNs from

In [None]:
data_df.columns

In [None]:
drop_list = ['Age', 'Subiculum_Connectivity_T_Redone']

In [None]:
data_df = cal_palm.drop_nans_from_columns(columns_to_drop_from=drop_list)
display(data_df)

**Drop Row Based on Value of Column**

Define the column, condition, and value for dropping rows
- column = 'your_column_name'
- condition = 'above'  # Options: 'equal', 'above', 'below'

In [None]:
data_df.columns

Set the parameters for dropping rows

In [None]:
column = 'City'  # The column you'd like to evaluate
condition = 'not'  # The condition to check ('equal', 'above', 'below')
value = 'Toronto'  # The value to compare against

In [None]:
data_df, other_df = cal_palm.drop_rows_based_on_value(column, condition, value)
display(data_df)

**Standardize Data**
- Enter Columns you Don't want to standardize into a list

In [None]:
# Remove anything you don't want to standardize
cols_not_to_standardize = ['Age']

In [None]:
data_df = cal_palm.standardize_columns(cols_not_to_standardize)
data_df

Descriptive Stats

In [None]:
data_df.describe()

# 02 Plot

**Grouped Barplot**
- Expects a Dataframe with a category the grouping variable that sets colour. 
- variable represents each thig to be plotted, like 'neuroimaging, bias, etc'. 
- metric is the value of the variable to be plotted.

In [None]:
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import spearmanr, pearsonr, kendalltau


# def plot_scatter_with_kde(self, hline_value=None, vline_value=None):
#     fig, axes = plt.subplots(self.rows_per_fig, self.cols_per_fig, figsize=(self.cols_per_fig * 5, self.rows_per_fig * 5))
#     axes = axes.flatten()
#     current_ax = 0

#     for dependent_var, independent_vars in self.data_dict.items():
#         self.dependent_var = dependent_var
#         for independent_var in independent_vars:
#             if current_ax >= len(axes):
#                 self.figures.append(fig)
#                 fig, axes = plt.subplots(self.rows_per_fig, self.cols_per_fig, figsize=(self.cols_per_fig * 5, self.rows_per_fig * 5))
#                 axes = axes.flatten()
#                 current_ax = 0

#             sns.scatterplot(x=independent_var, y=dependent_var, data=self.dataframe, ax=axes[current_ax])

#             if hline_value is not None:
#                 axes[current_ax].axhline(y=hline_value, color='black', linestyle='--')

#             if vline_value is not None:
#                 axes[current_ax].axvline(x=vline_value, color='black', linestyle='--')

#             if self.ylim is not None:
#                 axes[current_ax].set_ylim(self.ylim[0], self.ylim[1])

#             axes[current_ax].set_title(independent_var)
#             axes[current_ax].set_xlabel(self.x_label)
#             axes[current_ax].set_ylabel(self.y_label)

#             current_ax += 1

#     self.figures.append(fig)  # Append the last figure


In [None]:
import os
import seaborn as sns
import matplotlib.pyplot as plt

def plot_scatter_with_shaded_sectors(dataframe, x_column, y_column, hline_value, vline_value, x_label='X Label', y_label='Y Label', title='Scatterplot with Shaded Sectors', out_dir=None):
    plt.figure(figsize=(10, 6))
    sns.scatterplot(x=x_column, y=y_column, data=dataframe)
    
    # Draw horizontal and vertical lines
    plt.axhline(y=hline_value, color='black', linestyle='--')
    plt.axvline(x=vline_value, color='black', linestyle='--')
    
    # Get current plot limits
    xlim = plt.xlim()
    ylim = plt.ylim()
    
    # Shade areas
    # Left of the vertical line and above the horizontal line
    plt.fill_betweenx(y=[hline_value, ylim[1]], x1=xlim[1], x2=vline_value, color='green', alpha=0.2)
    # Right of the vertical line and below the horizontal line
    plt.fill_betweenx(y=[ylim[0], hline_value], x1=vline_value, x2=xlim[0], color='green', alpha=0.2)
    
    
    # Set plot limits back to original, as fill_betweenx might change them
    plt.xlim(xlim)
    plt.ylim(ylim)
    
    plt.title(title)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    
    # Ensure the output directory exists
    if out_dir is not None:
        os.makedirs(out_dir, exist_ok=True)
        plt.savefig(os.path.join(out_dir, 'shaded_sector_plot.svg'), format='svg')


    plt.show()


In [None]:
plot_scatter_with_shaded_sectors(data_df, 'Subiculum_Connectivity_T_Redone', 'Age', hline_value=65.3, vline_value=23.8,
                        x_label='Subiculum Connectivity', y_label='Age', title='Age-Connectivity Match Plot',
                        out_dir=out_dir)