In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

data_dir = 'data'

In [None]:
df = pd.read_csv(os.path.join(data_dir, 'CSF_COG_PET_data_cleaned.csv'), low_memory=False)

In [None]:
df.info(verbose=True, show_counts=True)

In [None]:
correlation_matrix = df.corr()

In [None]:
# Mask the upper triangle of the correlation matrix
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))

# Set up the matplotlib figure
plt.figure(figsize=(100, 100))
cmap = sns.diverging_palette(230, 20, as_cmap=True)

# Draw the heatmap with the mask and correct aspect ratio
sns.heatmap(correlation_matrix, mask=mask, cmap=cmap, annot=True, fmt=".2f", square=True, linewidths=.5, cbar_kws={"shrink": .8})

plt.title('Correlation heatmap of features')
plt.show()

In [None]:
dx_labels = {0: 'AD', 1: 'MCI', 2: 'CN'}

def plot_pairwise_scatter_plot_with_highly_correlated_features(target, features, df):
  for col in features:
      plt.figure(figsize=(10, 6))
      ax = sns.scatterplot(data=df, x=col, y=target, hue='DX', palette='Set1')
      plt.title(f'Pairwise scatter plot b/w {col} and {target}')
      plt.xlabel(col)
      plt.ylabel(target)
      handles, labels = ax.get_legend_handles_labels()
      new_labels = [dx_labels.get(int(l), l) if l.isdigit() else l for l in labels]
      plt.legend(handles=handles, labels=new_labels, title='DX')
      plt.show()

In [None]:
# Select the features from the correlation matrix that are highly correlated with 'DX'
highly_correlated_features = correlation_matrix['DX'].abs().sort_values(ascending=False).drop('DX')
print(f"Top 10 features highly correlated with 'DX': {highly_correlated_features.index.tolist()[:10]}")
plot_pairwise_scatter_plot_with_highly_correlated_features('DX', highly_correlated_features.index.tolist()[:10], df)

In [None]:
# Select the features from the correlation matrix that are highly correlated with 'EcogPtTotal_bl'
highly_correlated_features = correlation_matrix['EcogPtTotal_bl'].abs().sort_values(ascending=False).drop('EcogPtTotal_bl')
print(f"Top 10 Features highly correlated with 'EcogPtTotal_bl': {highly_correlated_features.index.tolist()[:10]}")

In [None]:
# Select the features from the correlation matrix that are highly correlated with 'MMSE_bl'
highly_correlated_features = correlation_matrix['MMSE_bl'].abs().sort_values(ascending=False).drop('MMSE_bl')
print(f"Top 10 Features highly correlated with 'MMSE_bl': {highly_correlated_features.index.tolist()[:10]}")

In [None]:
# Select the features from the correlation matrix that are highly correlated with 'LDELTOTAL_BL'
highly_correlated_features = correlation_matrix['LDELTOTAL_BL'].abs().sort_values(ascending=False).drop('LDELTOTAL_BL')
print(f"Top 10 Features highly correlated with 'LDELTOTAL_BL': {highly_correlated_features.index.tolist()[:10]}")

In [None]:
# Select the features from the correlation matrix that are highly correlated with 'ABETA_bl'
highly_correlated_features = correlation_matrix['ABETA_bl'].abs().sort_values(ascending=False).drop('ABETA_bl')
print(f"Top 10 Features highly correlated with 'ABETA_bl': {highly_correlated_features.index.tolist()[:10]}")
plot_pairwise_scatter_plot_with_highly_correlated_features('ABETA_bl', highly_correlated_features.index.tolist()[:10], df)

In [None]:
# Select the features from the correlation matrix that are highly correlated with 'PTAU_bl'
highly_correlated_features = correlation_matrix['PTAU_bl'].abs().sort_values(ascending=False).drop('PTAU_bl')
print(f"Top 10 Features highly correlated with 'PTAU_bl': {highly_correlated_features.index.tolist()[:10]}")
plot_pairwise_scatter_plot_with_highly_correlated_features('PTAU_bl', highly_correlated_features.index.tolist()[:10], df)

In [None]:
# Select the features from the correlation matrix that are highly correlated with 'TAU_bl'
highly_correlated_features = correlation_matrix['TAU_bl'].abs().sort_values(ascending=False).drop('TAU_bl')
print(f"Top 10 Features highly correlated with 'TAU_bl': {highly_correlated_features.index.tolist()[:10]}")
plot_pairwise_scatter_plot_with_highly_correlated_features('TAU_bl', highly_correlated_features.index.tolist()[:10], df)