## Prep

In [None]:
import pandas as pd
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
import os
from sklearn.preprocessing import StandardScaler
import warnings
import seaborn as sns


warnings.filterwarnings('ignore')
pd.options.display.max_columns = None
pd.options.display.max_rows = None

In [None]:
dat = pd.read_csv("data/processed/jhs_preprocess_0914.csv")

dat_v1 = dat[dat['visit'] == 1]
dat_v2 = dat.loc[dat['visit'] == 2, ['subjid','y']].rename(columns={'y': 'y2'})
dat_v3 = dat.loc[dat['visit'] == 3, ['subjid','y']].rename(columns={'y': 'y3'})

merged_df = pd.merge(dat_v1, dat_v2, on='subjid')
merged_df = pd.merge(merged_df, dat_v3, on='subjid')

merged_df['y_tot'] = (merged_df['y'] | merged_df['y2'] | merged_df['y3']).astype(int)

dat_plt = merged_df[['y_tot', 'nbSESpc2score', 'N_UNFAV_CT00', 'nbK3paFacilities', 'G_bla_rk', 'nutrition3cat', 'PA3cat']]
dat_plt = dat_plt.rename(columns = {'y_tot': 'Y  ', 
                                    'nbSESpc2score': 'Nb SES',
                                    'N_UNFAV_CT00': 'Nb unf food store', 
                                    'nbK3paFacilities': 'Nb phys act fac', 
                                    'G_bla_rk': 'Nb rac seg (Black)', 
                                    'nutrition3cat': 'Ind nut categ', 
                                    'PA3cat': 'Ind phys act categ'})


## Plot

In [None]:
from scipy.stats import pearsonr
df = dat_plt.copy()

correlation_matrix = df.corr()
p_values = pd.DataFrame(index=df.columns, columns=df.columns, dtype=float)

for col1 in df.columns:
    for col2 in df.columns:
        if col1 != col2:
            r, p = pearsonr(df[col1], df[col2])
            p_values.loc[col1, col2] = p

# Create a heatmap of the correlation matrix
plt.figure(figsize=(80, 60))
sns.set(font_scale=14)  # Adjust font size for better visibility
sns.heatmap(correlation_matrix, annot=False, fmt=".2f", cmap="Greens", cbar=True,
            xticklabels=correlation_matrix.columns, yticklabels=correlation_matrix.columns)

# Annotate the heatmap with p-values
for i in range(len(correlation_matrix.columns)):
    for j in range(len(correlation_matrix.columns)):
        if i != j:
            if p_values.iloc[i, j] < 0.05:
                text = f"p={p_values.iloc[i, j]:.2f}*"  
            else:
                text = f"p={p_values.iloc[i, j]:.2f}" #
                
            plt.text(j + 0.5, i + 0.5, text, ha="center", va="center", fontsize=115)

plt.show()

 