# Explore Data

### Authors: Calvin Howard.

#### Last updated: July 6, 2023

Use this to assess if a predictors relationship to the predictee is different between two groups. 

Notes:
- To best use this notebook, you should be familar with mixed effects models

# 00 - Import CSV with All Data
**The CSV is expected to be in this format**
- ID and absolute paths to niftis are critical
```
+-----+----------------------------+--------------+--------------+--------------+
| ID  | Nifti_File_Path            | Covariate_1  | Covariate_2  | Covariate_3  |
+-----+----------------------------+--------------+--------------+--------------+
| 1   | /path/to/file1.nii.gz      | 0.5          | 1.2          | 3.4          |
| 2   | /path/to/file2.nii.gz      | 0.7          | 1.4          | 3.1          |
| 3   | /path/to/file3.nii.gz      | 0.6          | 1.5          | 3.5          |
| 4   | /path/to/file4.nii.gz      | 0.9          | 1.1          | 3.2          |
| ... | ...                        | ...          | ...          | ...          |
+-----+----------------------------+--------------+--------------+--------------+
```

In [None]:
# Specify the path to your CSV file containing NIFTI paths
input_csv_path = '/Users/cu135/Library/CloudStorage/OneDrive-Personal/OneDrive_Documents/Work/KiTH_Solutions/Research/Clinical Trial/study_metadata/all_performances.xlsx'

In [None]:
# Specify where you want to save your results to
out_dir = '/Users/cu135/Library/CloudStorage/OneDrive-Personal/OneDrive_Documents/Research/2023/autonomous_cognitive_examination_rct/figures/covariates_and_acoe_socre/age_interact_status/adjustment'
sheet = 'study_results'

In [None]:
from calvin_utils.permutation_analysis_utils.statsmodels_palm import CalvinStatsmodelsPalm
# Instantiate the PalmPrepararation class
cal_palm = CalvinStatsmodelsPalm(input_csv_path=input_csv_path, output_dir=out_dir, sheet=sheet)
# Call the process_nifti_paths method
data_df = cal_palm.read_and_display_data()


# 01 - Preprocess Your Data

**Handle NANs**
- Set drop_nans=True is you would like to remove NaNs from data
- Provide a column name or a list of column names to remove NaNs from

In [None]:
data_df.columns

In [None]:
drop_list = ['Randomization_Group']

In [None]:
data_df = cal_palm.drop_nans_from_columns(columns_to_drop_from=drop_list)
display(data_df)

**Drop Row Based on Value of Column**

Define the column, condition, and value for dropping rows
- column = 'your_column_name'
- condition = 'above'  # Options: 'equal', 'above', 'below'

In [None]:
data_df.columns

Set the parameters for dropping rows

In [None]:
column = 'City'  # The column you'd like to evaluate
condition = 'not'  # The condition to check ('equal', 'above', 'below')
value = "Wurzburg"  # The value to compare against

In [None]:
data_df, other_df = cal_palm.drop_rows_based_on_value(column, condition, value)
display(data_df)

**Standardize Data**
- Enter Columns you Don't want to standardize into a list

In [None]:
# Remove anything you don't want to standardize
cols_not_to_standardize = ['Age']

In [None]:
data_df = cal_palm.standardize_columns(cols_not_to_standardize)
data_df

# 1.5 - Descriptive Stats

In [None]:
import pandas as pd
pd.set_option('display.max_columns', None)  # None means unlimited
data_df.describe()

Visualize Descriptive Stats

In [None]:
data_df.columns

In [None]:
import pandas as pd
df = pd.read_excel(input_csv_path, sheet)
data_df = df

select cols to plot

In [None]:
data_df = data_df.loc[:, ['Cohort', 'Educational Status',
       'Randomization Group', 'Age', 'Sex', 'Ethnicity', 'Cognitive Status', 'Total']]

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

# Create the pairplot
pairplot = sns.pairplot(data_df)

# Save the plot to a specific directory
plt.savefig('/Users/cu135/Library/CloudStorage/OneDrive-Personal/OneDrive_Documents/Research/2023/autonomous_cognitive_examination_rct/figures/covariate_relationships/pairplot.svg')  # Change this path to your desired folder

# Show the plot if needed (this line can be omitted if only saving is required)
plt.show()

# 02 - Visualize 3D Data

In [None]:
import plotly.express as px

def generate_scatterplot(dataframe, data_dict, correlation, palette, out_dir):
    """
    Generates a 3D scatter plot from the given DataFrame and saves it to the specified directory.

    Parameters:
    - dataframe: pandas DataFrame containing the data.
    - data_dict: Dictionary with one key-value pair, where the key is the dependent variable 
                 and the value is a list of independent variables (length should be 2 for 3D scatter).
    - x_label, y_label, z_label: Labels for the x, y, and z axes.
    - correlation: Whether to display correlation information on the plot (True/False).
    - palette: Color palette for the plot.
    - out_dir: Directory path where the plot image will be saved.
    """
    dependent_var = list(data_dict.keys())[0]
    independent_vars = data_dict[dependent_var]

    if len(independent_vars) != 2:
        raise ValueError("Independent variable list must contain exactly two elements for 3D scatter plot.")

    fig = px.scatter_3d(dataframe, x=independent_vars[0], y=independent_vars[1], z=dependent_var,
                        color=dependent_var, color_continuous_scale=palette)

    # Set the labels
    fig.update_layout(scene=dict(
        xaxis_title=independent_vars[0],
        yaxis_title=independent_vars[1],
        zaxis_title=dependent_var
    ))

    # Optionally, add correlation info as annotation
    if correlation:
        # Compute and display correlation (requires additional implementation)
        pass

    # Save the plot to the output directory
    fig.write_image(f"{out_dir}/3d_scatter_plot.png")

    return fig

In [None]:
data_df.columns

In [None]:
data_dict = {'ventricleVisualCH': ['SubcortexSurface', 'SubcortexCSF']}

In [None]:
# Example usage
plot = generate_scatterplot(dataframe=data_df, 
                            data_dict=data_dict,
                            correlation=True,
                            palette='Reds',
                            out_dir=out_dir)
plot

# 03 - Visualize Distribution

In [None]:
data_df.columns

In [None]:
scatter_variable = 'Z_Scored_Cognitive_Baseline'
categorical_variable = 'Disease'

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import numpy as np
import os

def create_custom_stripplot(data, x_variable, y_variable, hue_change_value, out_dir=None):
    sns.set(style="white")

    # Get/set params for the color mapping
    vcenter = hue_change_value
    vmin, vmax = data[y_variable].min(), data[y_variable].max()
    normalize = mcolors.TwoSlopeNorm(vcenter=vcenter, vmin=vmin, vmax=vmax)
    colormap = cm.coolwarm

    # Create a scatter plot with custom color mapping
    plt.figure(figsize=(6, 6))
    ax = sns.scatterplot(
        x=x_variable,
        y=y_variable,
        data=data,
        c=data[y_variable],
        norm=normalize,
        cmap=colormap,
    )
    ax.axhline(y=vcenter, color='gray', linestyle='--')

    # Tweak the points to mimic `sns.stripplot`
    pts = ax.collections[0]
    pts.set_offsets(pts.get_offsets() + np.c_[np.random.uniform(-0.1, 0.1, len(data)), np.zeros(len(data))])
    ax.margins(x=0.15)

    scalarmappable = cm.ScalarMappable(norm=normalize, cmap=colormap)
    scalarmappable.set_array(data[y_variable])

    # Add a colorbar
    cbar = plt.colorbar(scalarmappable, ax=ax)
    cbar.set_label(y_variable)

    plt.title(f'Strip Plot of {y_variable} by {x_variable} with Hue Change at {hue_change_value}')
    plt.xticks(rotation=45)
    plt.tight_layout()
    sns.despine()
    
    # Save the figure if out_dir is provided
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
        plt.savefig(f"{out_dir}/strip_plot.png", bbox_inches='tight')
        plt.savefig(f"{out_dir}/strip_plot.svg", bbox_inches='tight')
        print(f'Saved to {out_dir}/strip_plot.svg')
    else:
        plt.show()

In [None]:
# Example usage:
create_custom_stripplot(data_df, categorical_variable, scatter_variable, 0, out_dir=out_dir)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import os

def create_distribution_violin_plot(data, x_variable, y_variable, hue_change_value, out_dir=None):
    sns.set(style="white")
    plt.figure(figsize=(5, 6))

    # Create the violin plot with specified parameters
    sns.violinplot(data=data, x=x_variable, y=y_variable, hue=x_variable, width=.9, cut=2, split=False, inner='box', palette='Greys', alpha=0.1)

    # Create the swarm plot on top of the violin plot
    sns.swarmplot(data=data, x=x_variable, y=y_variable, hue=x_variable, palette=['red', 'blue'], alpha=0.7)

    plt.title(f'Strip Plot with Hemi Violin of {y_variable} by {x_variable} with Hue Change at {hue_change_value}')
    plt.xticks(rotation=45)
    plt.tight_layout()
    
    # Save the figure if out_dir is provided
    if out_dir:
        os.makedirs(out_dir, exist_ok=True)
        plt.savefig(f"{out_dir}/distribution_violin.png", bbox_inches='tight')
        plt.savefig(f"{out_dir}/distribution_violin.svg", bbox_inches='tight')
        print(f'Saved to {out_dir}/distribution_violin.svg')
    else:
        plt.show()



In [None]:
# Example usage:
create_distribution_violin_plot(data_df, categorical_variable, scatter_variable, None, out_dir=out_dir)

# Create Pie Chart

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

def create_pie_chart(series, out_dir=None):
    """
    Takes a pandas Series and generates a pie chart using the Tab10 colormap.
    
    Parameters:
    series (pandas.Series): The input series where the index will be used as labels and the values as sizes.
    out_dir (str, optional): Directory to save the pie chart images. Saves in both PNG and SVG formats if provided.
    """
    # Check if the input is a pandas Series
    if not isinstance(series, pd.Series):
        raise ValueError("The input must be a pandas Series.")
    
    # Set the Tab10 colormap
    cmap = plt.get_cmap("tab10")
    colors = cmap(range(len(series)))
    
    # Generate the pie chart
    fig, ax = plt.subplots(figsize=(8, 8))  # Set the figure size for better visibility
    wedges, texts, autotexts = ax.pie(series, labels=series.index, autopct='%1.1f%%', startangle=140, colors=colors)
    
    # Add legend, which can be edited later by accessing the `legend` object
    ax.legend(wedges, series.index, title="Categories", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1))
    
    plt.setp(autotexts, size=8, weight="bold")
    ax.set_title("Pie Chart")
    plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.
    
    # Save figures if out_dir is provided
    if out_dir is not None:
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        plt.savefig(os.path.join(out_dir, 'pie.png'))
        plt.savefig(os.path.join(out_dir, 'pie.svg'))

    plt.show()

In [None]:
create_pie_chart(data_df['Percent'], out_dir=out_dir)