In [None]:
import csv
import os
import warnings
import pickle
import pandas as pd
import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, MaxNLocator
import matplotlib.ticker as ticker
from matplotlib import colors
import matplotlib.ticker as plticker
import sklearn
from sklearn.neighbors import KernelDensity
import matplotlib as mpl
import matplotlib.gridspec as grid_spec
from wordcloud import WordCloud
import matplotlib.colors as mcolors
import seaborn as sns
from textwrap import wrap
from scipy.stats import pearsonr
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.cross_decomposition import PLSRegression
from sklearn.model_selection import GridSearchCV
from matplotlib.ticker import FormatStrFormatter
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

In [None]:
# Match loadings to correlations and filter, do not reindex
loadings_match_to_corr = pd.merge(
    corr_full_sorted,
    weighted_loadings_full_matched,
    left_index=True,
    right_index=True,
    how='left'
)

negative_corr = loadings_match_to_corr[loadings_match_to_corr['correlations'] < 0]
positive_corr = loadings_match_to_corr[loadings_match_to_corr['correlations'] >= 0]
negative_load = loadings_match_to_corr[loadings_match_to_corr['correlations'] < 0]
positive_load = loadings_match_to_corr[loadings_match_to_corr['correlations'] >= 0]

In [None]:
corr_full_sorted_renamed_df = pd.read_csv('/mental_health/single_split/corr_full_sorted_renamed_df.csv')
significant_corr_renamed_df = pd.read_csv('/mental_health/single_split/significant_corr_renamed_df.csv')
loadings_renamed_df = pd.read_csv('/mental_health/single_split/loadings_renamed_df.csv')
corr_full_sorted = pd.read_csv('/mental_health/single_split/mh_pls_model_16comp_corr_full_sorted.csv', index_col=0)
weighted_loadings_full_matched = pd.read_csv('/mental_health/single_split/mh_pls_model_16comp_weighted_loadings_full_matched.csv', index_col=0)

In [None]:
# Rename features for clarity 
negative_corr_renamed = negative_corr.rename(index={"Diagnoses 'G'": "Diseases of the nervous system",
                                                          "Diagnoses: Neurological problem, NS injury, epilepsy": "Neurological problems, nervous system injury, epilepsy",
                                                          "Diagnoses 'F'": "Mental and behavioural disorders",
                                                          "Diagnoses: Stress, insomnia, migraine, nervous/mental problems": "Stress, insomnia, migraine, nervous/mental problems",
                                                          'Diagnoses: Depression': 'Depression'})
negative_load_renamed = negative_load.rename(index={"Diagnoses 'G'": "Diseases of the nervous system",
                                                          "Diagnoses: Neurological problem, NS injury, epilepsy": "Neurological problems, nervous system injury, epilepsy",
                                                          "Diagnoses 'F'": "Mental and behavioural disorders",
                                                          "Diagnoses: Stress, insomnia, migraine, nervous/mental problems": "Stress, insomnia, migraine, nervous/mental problems",
                                                          'Diagnoses: Depression': 'Depression'})
positive_corr_renamed = positive_corr.rename(index={"Diagnoses 'G'": "Diseases of the nervous system",
                                                          "Diagnoses: Neurological problem, NS injury, epilepsy": "Neurological problems, nervous system injury, epilepsy",
                                                          "Diagnoses 'F'": "Mental and behavioural disorders",
                                                          "Diagnoses: Stress, insomnia, migraine, nervous/mental problems": "Stress, insomnia, migraine, nervous/mental problems",
                                                          'Diagnoses: Depression': 'Depression'})
positive_load_renamed = positive_load.rename(index={"Diagnoses 'G'": "Diseases of the nervous system",
                                                          "Diagnoses: Neurological problem, NS injury, epilepsy": "Neurological problems, nervous system injury, epilepsy",
                                                          "Diagnoses 'F'": "Mental and behavioural disorders",
                                                          "Diagnoses: Stress, insomnia, migraine, nervous/mental problems": "Stress, insomnia, migraine, nervous/mental problems",
                                                          'Diagnoses: Depression': 'Depression'})

In [None]:
# Select top 15 features with the largest absolute values
negative_corr_sorted = negative_corr_renamed.reindex(negative_corr_renamed['correlations'].abs().sort_values(ascending=False).head(20).index)
negative_load_sorted = negative_load_renamed.reindex(negative_corr_sorted.index)
positive_corr_sorted = positive_corr_renamed.reindex(positive_corr_renamed['correlations'].abs().sort_values(ascending=False).head(20).index)
positive_load_sorted = positive_load_renamed.reindex(positive_corr_sorted.index)


fig, ax = plt.subplots(1, 2, figsize=(10, 40), dpi=200, sharey='col')

# Define a function to choose the edge color based on the sign of the loading
def get_edge_color(value):
    return 'black' if value < 0 else 'black'

negative_edge_colors_sorted = negative_load_sorted['Loadings'].apply(get_edge_color).tolist()
positive_edge_colors_sorted = positive_load_sorted['Loadings'].apply(get_edge_color).tolist()

bar_height = 1  # Assuming height=1 for bars
#font_size = bar_height * 0.7 * 100
font_size = 55
height = 0.7

# Corr
ax[0].barh(negative_corr_sorted.index,
           negative_corr_sorted['correlations'],
           color='#79AF9799', alpha=0.4, height=height)

# Loadings
ax[0].barh(negative_load_sorted.index,
           negative_load_sorted['Loadings'],
           color='none', edgecolor=negative_edge_colors_sorted, linewidth=1, height=height)

# Corr
ax[1].barh(positive_corr_sorted.index,
           positive_corr_sorted['correlations'],
           color='#79AF9799', alpha=0.4, height=height)

# Loadings
ax[1].barh(positive_load_sorted.index,
           positive_load_sorted['Loadings'],
           color='none', edgecolor=positive_edge_colors_sorted, linewidth=1, height=height)

ax[0].tick_params(axis='y', labelsize=font_size)
ax[0].tick_params(axis='x', labelsize=font_size, rotation=50)
ax[0].spines['left'].set_visible(False)
ax[0].invert_yaxis()

ax[1].tick_params(axis='y', labelsize=font_size, labelright=True, labelleft=False)
ax[1].tick_params(axis='x', labelsize=font_size, rotation=50)
ax[1].invert_yaxis()

# Remove the y-axis ticks on the right side while keeping the y-axis line
ax[1].yaxis.set_ticks_position('none')

# Hide the spines for a cleaner look
for axs in ax:
    for side in ["top", "bottom", "right"]:
        axs.spines[side].set_visible(False)

# Set a common x-axis label for both correlations and loadings
fig.text(0.55, 0.03, "Pearson's $r$ and Loadings", ha='center', fontsize=70)

min_negative = min(negative_corr_sorted['correlations'].min(), negative_load_sorted['Loadings'].min())
max_positive = max(positive_corr_sorted['correlations'].max(), positive_load_sorted['Loadings'].max())
ax[0].set_xlim(min_negative * 1.5, 0)
ax[1].set_xlim(-0.05, max_positive * 1.1)

# Format the x-axis labels to two decimal places
ax[0].xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax[1].xaxis.set_major_formatter(FormatStrFormatter('%.2f'))

# Add legends to both subplots
#ax[0].legend(loc='lower left')
#ax[1].legend(loc='lower right')

ax[0].axvline(x=0, color='grey', linewidth=1.5)
ax[1].axvline(x=0, color='grey', linewidth=1.5)

ax[0].spines['right'].set_visible(False)
ax[1].spines['left'].set_visible(False)

# Set x-tick intervals
ax[0].xaxis.set_major_locator(MultipleLocator(0.05))
ax[1].xaxis.set_major_locator(MultipleLocator(0.1))

# Adjust the spacing between the subplots
fig.subplots_adjust(wspace=0.1)

plt.show()

## Group features into mental health domains

In [None]:
# Group features into mental health domains
mh_domains = {
'Mental distress':  [
'Ever sought or received professional help for mental distress',
'Any distress',
'Ever suffered mental distress preventing usual activities',
'Seen doctor (GP) for nerves, anxiety, tension or depression',
'Seen a psychiatrist for nerves, anxiety, tension or depression'],

              
'Depression':[
'Depression single episode',
'RDS-4',
'PHQ-9', 
'Subthreshold depression',
'Depression triggered by loss',
'Current severe depression',
'Current depression',
'Ever unenthusiastic/disinterested for a whole week',
'Depression ever',
'Recurrent depression',
'Ever had prolonged feelings of sadness or depression',
'Recent poor appetite or overeating',
'PDS',
'Ever had prolonged loss of interest in normal activities',
'Recent changes in speed/amount of moving or speaking',
'Recent feelings of depression',
'Recent feelings of inadequacy',
'Recent feelings of tiredness or low energy',
'Recent lack of interest or pleasure in doing things',
'Recent thoughts of suicide or self-harm',
'Recent trouble concentrating on things',
'Trouble falling or staying asleep, or sleeping too much',
'Frequency of depressed mood in last 2 weeks',
'Frequency of unenthusiasm / disinterest in last 2 weeks',
'Frequency of tenseness / restlessness in last 2 weeks',
'Frequency of tiredness / lethargy in last 2 weeks',
'Ever depressed for a whole week'],

'Diagnoses': [
'Neurological problems, nervous system injury, epilepsy',
'Mental and behavioural disorders',
'Diseases of the nervous system',
'Anxiety/panic attacks',
'Stress, insomnia, migraine, nervous/mental problems',
'Depression'
],


'Mania': [
'Ever had period of extreme irritability',
'Ever had period of mania / excitability',
'Ever manic/hyper for 2 days',
'Ever highly irritable/argumentative for 2 days',
'Bipolar I', 'Bipolar II'
],

'Anxiety': [
'Ever felt worried, tense, or anxious for most of a month or longer',
'Ever worried more than most people would in similar situation',
'Recent easy annoyance or irritability',
'Recent feelings of foreboding',
'Recent feelings or nervousness or anxiety',
'Recent inability to stop or control worrying',
'Recent restlessness',
'Recent trouble relaxing',
'Recent worrying too much about different things',
'GAD ever',
'GAD-7',
'Current GAD',
'Current GAD moderate',
'Current GAD mild',
'Current GAD severe'],

'Neuroticism': [
'N-12',
'Mood swings',
'Miserableness',
'Irritability',
'Sensitivity / hurt feelings',
'Fed-up feelings',
'Nervous feelings',
'Worrier / anxious feelings',
"Tense / 'highly strung'",
'Worry too long after embarrassment',
"Suffer from 'nerves'",
'Loneliness, isolation',
'Guilty feelings',
'Risk taking'],


'Addictions': [
'Ever addicted to any substance or behaviour',
'Substance addiction',
'Current addiction',
'Physical alcohol dependence ever'],

'Alcohol and cannabis use': [
'Cannabis ever',
'Lifertime frequency of taking cannabis',
'(log)AUDIT',
'(log)AUDIT-C',
'(log)AUDIT-P',
'Alcohol dependence (AUDIT≥15)',
'Hazardous alcohol use (AUDIT≥8)',
'Frequency of consuming six or more units of alcohol',
'Frequency of drinking alcohol',
'Amount of alcohol drunk on a typical drinking day',
'Cannabis daily',
],

'Unusual/psychotic experiences': [
'Ever believed in an un-real conspiracy against self',
'Ever believed in un-real communications or signs',
'Ever heard an un-real voice',
'Ever seen an un-real vision',
'Recent unusual experience',
'Unusual experience'],

'Traumatic events': [
'Catastrophic trauma',
'Childhood adverse events',
'Adult adverse events',
'Felt loved as a child',
'Physically abused by family as a child',
'Sexually molested as a child',
'Someone to take to doctor when needed as a child',
'Been in a confiding relationship as an adult',
'Physical violence by partner or ex-partner as an adult',
'Belittlement by partner or ex-partner as an adult',
'Sexual interference by partner or ex-partner without consent as an adult',
'Able to pay rent/mortgage as an adult',
'Victim of sexual assault',
'Victim of physically violent crime',
'Been in serious accident believed to be life-threatening',
'Witnessed sudden violent death',
'Diagnosed with life-threatening illness',
'Been involved in combat or exposed to war-zone',
'Repeated disturbing thoughts of stressful experience in past month',
'Felt very upset when reminded of stressful experience in past month',
'PTSD',
'PCL-6',
'Avoided activities or situations because of previous stressful experience in past month',
'Felt hated by family member as a child'],


'Self-harm behaviours': [
'Lifetime frequency of contemplating self-harm',
'Self-harm: Ever thought life not worth living',
'Ever self-harmed',
'Ever self-harmed (non-suicidal)',
'Ever attempted suicide',
"Frequency of 'life not worth living' thoughts"],


'Happiness and subjective well-being': [
'Belief that own life is meaningful',
'Health satisfaction',
'General happiness',
'General happiness with own health',
'Happiness',
'Family relationship satisfaction',
'Friendships satisfaction',
'Financial situation satisfaction',
'Wellbeing']             
}

# Count items
sum(len(features) for features in mh_domains.values())

In [None]:
# Correlation + Loadings: get data frames for each domain
domain_corr = []
for domain, features in mh_domains.items():
    for feature in features:
        if feature in corr_full_sorted_renamed_df['features'].values:
            correlation = corr_full_sorted_renamed_df[corr_full_sorted_renamed_df['features'] == feature]['correlations'].values[0]
            domain_corr.append({'feature': feature, 'correlation': correlation, 'domain': domain})
domain_corr_df = pd.DataFrame(domain_corr)

domain_loadings = []
for domain, features in mh_domains.items():
    for feature in features:
        if feature in loadings_renamed_df['features'].values:
            loading = loadings_renamed_df[loadings_renamed_df['features'] == feature]['Loadings'].values[0]
            domain_loadings.append({'feature': feature, 'loading': loading, 'domain': domain})
domain_loadings_df = pd.DataFrame(domain_loadings)

domain_corr_df.to_csv('/mental_health/single_split/domain_corr_df.csv', index=False)
domain_loadings_df.to_csv('/mental_health/single_split/domain_loadings_df.csv', index=False)

In [None]:
domain_corr_df = pd.read_csv('/mental_health/single_split/domain_corr_df.csv')
domain_loadings_df = pd.read_csv('/mental_health/single_split/domain_loadings_df.csv')

In [None]:
# Check if all features are in the domains
mh_features = set()
for features in mh_domains.values():
    mh_features.update(features)

# Get features from corr_full_sorted_renamed_df
corr_features = set(corr_full_sorted_renamed_df['features'])

# Find missing features
missing_features = mh_features - corr_features
print("Missing features:", missing_features)

In [None]:
# Get domain names
domain_corr_df['domain'].unique()

In [None]:
# Prepare dataframes for scatterplot
folds = ["0", "1", "2", "3", "4"]
y_pred_pooled = []
y_true_pooled = []

for i,fold in enumerate(folds):
    y_pred = pd.read_csv(f'/mental_health/folds/fold_{fold}/g_pred/g_pred_mh_fold_{fold}.csv')[['g_pred_mh']]
    y_pred_pooled.append(y_pred)
    y_pred_pooled_df = pd.DataFrame(pd.concat(y_pred_pooled, ignore_index=True)).rename(columns={'g_pred_mh': 'y_pred'})
    y_true = pd.read_csv(f'/mental_health/folds/fold_{fold}/suppl/g_test_matched_fold_{fold}.csv')['g']
    y_true_pooled.append(y_true)
    y_true_pooled_df = pd.DataFrame(pd.concat(y_true_pooled, ignore_index=True)).rename(columns={'g': 'y_true'})

y_true = y_true_pooled_df.copy()['y_true']
y_pred = y_pred_pooled_df.copy()['y_pred']

# Visualize loadings of the g-factor

In [None]:
g_loadings = pd.read_csv("/Plots_and_Tables/esem_loadings.csv")
g_loadings_wide = g_loadings.pivot(values='value', columns='latent', index='item').reset_index()

In [None]:
# Rename variables to save space on the plot
rename_dict = {
    '(log)Reaction time': "(log)RT", 
    'Fluid intelligence score': "Fluid intelligence",
    'Numeric memory: Maximum digits remembered correctly': "Numeric memory",
    '(log)Trail making test: Duration to complete numeric path': "(log)TMT duration, numeric",
    '(log)Trail making test: Duration to complete alphabetic path': "(log)TMT duration, alphabetic",
    'Symbol digit substitution: Number of correct matches': "SDS: N matches correct",
    'Paired associate learning: Number of correct pairs': "Paired associate learning: N correct", 
    'Tower rearranging: Number of puzzles correct': "Tower rearranging: N correct",
    'Matrix pattern completion: Number of puzzles correct': "Matrix completion: N correct",
    '(logx+1)Pairs matching: Incorrect matches': "(log$_{x+1}$)Pairs matching",
    'Picture vocabulary: Specific cognitive ability': "Picture vocabulary",
    'Prospective memory: Initial answer': "Prospective memory"
}
g_loadings_rename = g_loadings.set_index('item').rename(index=rename_dict)
g_loadings_rename = g_loadings_rename.reset_index()

# Add loadings for the four factors
g_loadings_rename_factor = g_loadings_rename.set_index('latent').rename(index={'Factor 1': 'Factor 1\n-0.75',
                                                                               'Factor 2': 'Factor 2\n0.7',
                                                                               'Factor 3': 'Factor 3\n0.37',
                                                                               'Factor 4': 'Factor 4\n0.6'})
g_loadings_rename_factor = g_loadings_rename_factor.reset_index()

In [None]:
# Define functions to add labels

# Grab the group values
GROUP = g_loadings_rename_factor["latent"].values
VALUES = g_loadings_rename["value"].values.astype(float)
LABELS = [f"{label} ({value:.2f})" for label, value in zip(g_loadings_rename['item'], g_loadings_rename['value'])]  #g_loadings["item"].values
OFFSET = np.pi / 2

# Add three empty bars to the end of each group
PAD = 3
ANGLES_N = len(VALUES) + PAD * len(np.unique(GROUP))
ANGLES = np.linspace(0, 2 * np.pi, num=ANGLES_N, endpoint=False)
WIDTH = (2 * np.pi) / len(ANGLES)

# Obtain size of each group
GROUPS_SIZE = [len(i[1]) for i in g_loadings_rename.groupby("latent")]

def get_label_rotation(angle, offset):
    # Rotation must be specified in degrees :(
    rotation = np.rad2deg(angle + offset)
    if angle <= np.pi:
        alignment = "right"
        rotation = rotation + 180
    else: 
        alignment = "left"
    return rotation, alignment

def add_labels(angles, values, labels, offset, ax):
    # This is the space between the end of the bar and the label
    padding = 0.005
    MIN_BAR_HEIGHT = 0.1  # Define a minimum height for the bars
    # Iterate over angles, values, and labels, to add all of them.
    for angle, value, label in zip(angles, values, labels):
        # Obtain text rotation and alignment
        rotation, alignment = get_label_rotation(angle, offset)

        # And finally add the text
        ax.text(
            x=angle, 
            y= max(value, MIN_BAR_HEIGHT) / 3 + 1.6, #value / 2 + 1.5 / value + padding,  # Position the label inside the bar 
            s=label, 
            ha='center',  # Center alignment / = alignment
            va='center', 
            rotation=rotation, 
            rotation_mode="anchor",
            fontsize=15,
            color="black")

def add_factor_names(groups, angles, offset, ax):
    offset = 2  # Reset offset for correct indexing
    for group, size in zip(groups, GROUPS_SIZE):
        start_idx = offset + PAD // 2
        end_idx = offset + size + PAD // 2
        mid_angle = np.mean(angles[start_idx:end_idx])
        ax.text(
            x=mid_angle, 
            y=0.9,  # Position the factor name inside the inner radius
            s=group, 
            ha='center', 
            va='center', 
            rotation=np.rad2deg(mid_angle + OFFSET) - 90,  # Adjust rotation based on position
            rotation_mode="anchor",
            fontsize=17,
            color="black"
        )
        offset += size + PAD


### Loading circle bar plot

In [None]:
fig, ax = plt.subplots(figsize=(20, 15), subplot_kw={"projection": "polar"})

offset = 0
IDXS = []
for size in GROUPS_SIZE:
    IDXS += list(range(offset + PAD, offset + size + PAD))
    offset += size + PAD

ax.set_theta_offset(OFFSET)
ax.set_ylim(0.7, max(abs(VALUES)) + 1) #0.5,  ... + 1 # Set a smaller inner radius for the hole
ax.set_frame_on(False)
ax.xaxis.grid(False)
ax.yaxis.grid(False)
ax.set_xticks([])
ax.set_yticks([])

# Define colors for positive and negative loadings
colors = ['#B2474599' if v > 0 else '#6A6599FF' for v in VALUES]

# Add the bars
MIN_BAR_HEIGHT = 0.1  # Define a minimum height for the bars
bars = ax.bar(
    ANGLES[IDXS], [max(abs(v), MIN_BAR_HEIGHT) for v in VALUES], width=WIDTH, #abs(VALUES)
    color=colors,
    edgecolor="white", linewidth=1, alpha=0.5,
    bottom=1 # Set the inner radius to create a smaller hole
)

# Add labels after adding the bars
add_labels(ANGLES[IDXS], abs(VALUES), LABELS, OFFSET, ax)
add_factor_names(np.unique(GROUP), ANGLES, OFFSET, ax)

# Add the italic 'g' in the middle of the circle
ax.text(
    x=0.7, 
    y=0.7, 
    s='g', 
    ha='center', 
    va='center', 
    fontsize=50, 
    fontstyle='italic', 
    color='black'
)

ax.set_title('', y=1.25, fontsize=40)

plt.savefig("/ESEM.png",
            bbox_inches="tight", 
            pad_inches=1, 
            transparent=False, 
            facecolor="w", 
            edgecolor='w', 
            orientation='landscape',
            format='png')

plt.show()


# Plot observed vs predicted g-factor

In [None]:
# Scatterplot: 5 folds
folds = ["0", "1", "2", "3", "4"]

fig, axes = plt.subplots(1, 5, figsize=(18.54,3.54), dpi=600) #17, 4
axes = axes.flatten() #when you create a grid of subplots with plt.subplots, the axes array is a 2D array if you specify more than one row and column. You need to flatten this array to use it in a single loop

for i,fold in enumerate(folds):
    y_pred = pd.read_csv(f'/mental_health/folds/fold_{fold}/g_pred/g_pred_mh_fold_{fold}.csv')['g_pred_mh']
    y_true = pd.read_csv(f'/mental_health/folds/fold_{fold}/suppl/g_test_matched_fold_{fold}.csv')['g']
    corr, p = pearsonr(y_true, y_pred)
    r2 = r2_score(y_true, y_pred)
        
    dist_i = np.sqrt((y_true - y_pred.mean())**2 + (y_pred - y_pred.mean())**2)
    sns.scatterplot(x=y_true, y=y_pred, c=dist_i, cmap='Greens', s=40, alpha=0.8, ax=axes[i]) #'flare'
    sns.regplot(x=y_true, y=y_pred, line_kws={"color": "red", "linewidth": 1}, scatter=False, ax=axes[i])
    sns.despine (top=True, right=True, ax=axes[i])
    axes[i].set_xlabel('Observed Cognitive Ability (Z)', fontsize=15)
    axes[i].set_ylabel('Predicted Cognitive Ability (Z)', fontsize=15)
    axes[i].tick_params(axis='x', labelsize=16)
    axes[i].tick_params(axis='y', labelsize=16)
    axes[i].set_title(f'Fold {fold}', fontsize=20, y=1.1)
    
    # Annotate the plot with Pearson correlation and R² score
    axes[i].text(0.05, 1.0, f'$r$ = {corr.round(2)}', transform=axes[i].transAxes, fontsize=15)
    axes[i].text(0.05, 0.91, f'$R$² = {r2.round(2)}', transform=axes[i].transAxes, fontsize=15)
    axes[i].xaxis.set_major_locator(MultipleLocator(1))
    
# Remove the empty subplot (if any)
for j in range(i + 1, len(axes)):
    fig.delaxes(axes[j])

plt.tight_layout()

plt.savefig("/MH/greal_gpred_r_MH_5Folds.png",
            bbox_inches="tight", 
            pad_inches=1, 
            transparent=False, 
            facecolor="w", 
            edgecolor='w', 
            orientation='landscape') 

plt.show()

# Final Plot

In [None]:
domain_corr_df = pd.read_csv('/mental_health/single_split/domain_corr_df.csv')
domain_loadings_df = pd.read_csv('/mental_health/single_split/domain_loadings_df.csv')
corr_full_sorted_renamed_df = pd.read_csv('/mental_health/single_split/corr_full_sorted_renamed_df.csv')
significant_corr_renamed_df = pd.read_csv('/mental_health/single_split/significant_corr_renamed_df.csv')
loadings_renamed_df = pd.read_csv('/mental_health/single_split/loadings_renamed_df.csv')
g_loadings = pd.read_csv("/Plots_and_Tables/esem_loadings.csv")
g_loadings_wide = g_loadings.pivot(values='value', columns='latent', index='item').reset_index()

In [None]:
# Rename variables to save space on the plot
rename_dict = {
    '(log)Reaction time': "(log)RT", 
    'Fluid intelligence score': "Fluid intelligence",
    'Numeric memory: Maximum digits remembered correctly': "Numeric memory",
    '(log)Trail making test: Duration to complete numeric path': "(log)TMT duration, numeric",
    '(log)Trail making test: Duration to complete alphabetic path': "(log)TMT duration, alphabetic",
    'Symbol digit substitution: Number of correct matches': "SDS: N matches correct",
    'Paired associate learning: Number of correct pairs': "Paired associate learning: N correct", 
    'Tower rearranging: Number of puzzles correct': "Tower rearranging: N correct",
    'Matrix pattern completion: Number of puzzles correct': "Matrix completion: N correct",
    '(logx+1)Pairs matching: Incorrect matches': "(log$_{x+1}$)Pairs matching",
    'Picture vocabulary: Specific cognitive ability': "Picture vocabulary",
    'Prospective memory: Initial answer': "Prospective memory"
}
g_loadings_rename = g_loadings.set_index('item').rename(index=rename_dict)
g_loadings_rename = g_loadings_rename.reset_index()

# Add loadings for the four factors
g_loadings_rename_factor = g_loadings_rename.set_index('latent').rename(index={'Factor 1': 'Factor 1\n-0.75',
                                                                               'Factor 2': 'Factor 2\n0.7',
                                                                               'Factor 3': 'Factor 3\n0.37',
                                                                               'Factor 4': 'Factor 4\n0.6'})
g_loadings_rename_factor = g_loadings_rename_factor.reset_index()

Define necessary functions and data frames

In [None]:
#Dot plot
def split_text(text):
    if text == "Diseases of the nervous system":
        return "Diseases of the\nnervous system"
    if text == "Worry too long after embarrassment":
        return "Worry too long\nafter embarrassment"
    if text == "Ever addicted to any substance or behaviour":
        return "Ever addicted to any\nsubstance or behaviour"
    if text == "Ever believed in un-real communications or signs":
        return "Ever believed in un-real\ncommunications or signs"
    if text == "Frequency of 'life not worth living' thoughts":
        return "Frequency of\n'life not worth living'\nthoughts"
    if text == "Recent unusual experience":
        return "Recent\nunusual experience"
    if text == "Ever manic/hyper for 2 days":
        return "Ever manic/hyper\nfor 2 days"
    if text == 'Lifetime frequency of contemplating self-harm':
        return 'Lifetime frequency\nof contemplating\nself-harm'
    if text == 'Ever self-harmed (non-suicidal)':
        return 'Ever self-harmed\n(non-suicidal)'
    if text == 'Ever had period of extreme irritability':
        return 'Ever had period\nof extreme irritability'
    if text == 'Friendships satisfaction':
        return 'Friendships\nsatisfaction'
    words = text.split()
    if len(words) > 3:
        third = len(words) // 3
        return ' '.join(words[:third]) + '\n' + ' '.join(words[third:2*third]) + '\n' + ' '.join(words[2*third:])
    return text

jama_palette = [
    "#374E55FF", "#DF8F44FF", "#00A1D5FF", "#B24745FF", 
    "#79AF97FF", "#6A6599FF", "#80796BFF", "#EFC000FF", 
    "#7AA6DCFF", "#003C67FF", "#8F7700FF", "#A20056FF"]
#palette_dict = {domain: color for domain, color in zip(domain_corr_df['domain'].unique(), palette)}
palette_dict = {domain: color for domain, color in zip(domain_corr_df['domain'].unique(), jama_palette)}

# Add a column to indicate significance
domain_corr_df['significant'] = domain_corr_df['feature'].isin(significant_corr_renamed_df['features'])
# Filter for significant correlations
significant_domain_corr_df = domain_corr_df[domain_corr_df['significant']]


###################### Scatterplot
folds = ["0", "1", "2", "3", "4"]
y_pred_pooled = []
y_true_pooled = []

for i,fold in enumerate(folds):
    y_pred = pd.read_csv(f'/mental_health/folds/fold_{fold}/g_pred/g_pred_mh_fold_{fold}.csv')[['g_pred_mh']]
    y_pred_pooled.append(y_pred)
    y_pred_pooled_df = pd.DataFrame(pd.concat(y_pred_pooled, ignore_index=True)).rename(columns={'g_pred_mh': 'y_pred'})
    y_true = pd.read_csv(f'/mental_health/folds/fold_{fold}/suppl/g_test_matched_fold_{fold}.csv')['g']
    y_true_pooled.append(y_true)
    y_true_pooled_df = pd.DataFrame(pd.concat(y_true_pooled, ignore_index=True)).rename(columns={'g': 'y_true'})

y_true = y_true_pooled_df.copy()['y_true']
y_pred = y_pred_pooled_df.copy()['y_pred']

###################### Circular plot
# Define functions to add labels

# Grab the group values
GROUP = g_loadings_rename_factor["latent"].values
VALUES = g_loadings_rename["value"].values.astype(float)
LABELS = [f"{label} ({value:.2f})" for label, value in zip(g_loadings_rename['item'], g_loadings_rename['value'])]
OFFSET = np.pi / 2

# Add three empty bars to the end of each group
PAD = 3
ANGLES_N = len(VALUES) + PAD * len(np.unique(GROUP))
ANGLES = np.linspace(0, 2 * np.pi, num=ANGLES_N, endpoint=False)
WIDTH = (2 * np.pi) / len(ANGLES)

# Obtain size of each group
GROUPS_SIZE = [len(i[1]) for i in g_loadings_rename.groupby("latent")]

def get_label_rotation(angle, offset):
    # Rotation must be specified in degrees :(
    rotation = np.rad2deg(angle + offset)
    if angle <= np.pi:
        alignment = "right"
        rotation = rotation + 180
    else: 
        alignment = "left"
    return rotation, alignment

def add_labels(angles, values, labels, offset, ax):
    # This is the space between the end of the bar and the label
    padding = 0.005
    MIN_BAR_HEIGHT = 0.1  # Define a minimum height for the bars
    # Iterate over angles, values, and labels, to add all of them.
    for angle, value, label in zip(angles, values, labels):
        # Obtain text rotation and alignment
        rotation, alignment = get_label_rotation(angle, offset)

        # And finally add the text
        ax.text(
            x=angle, 
            y=max(value, MIN_BAR_HEIGHT) / 3 + 1.8, #value / 2 + 1.5 / value + padding,  # Position the label inside the bar 
            s=label, 
            ha='center',  # Center alignment / = alignment
            va='center', 
            rotation=rotation, 
            rotation_mode="anchor",
            fontsize=27,
            color="black")

def add_factor_names(groups, angles, offset, ax):
    offset = 2  # Reset offset for correct indexing
    for group, size in zip(groups, GROUPS_SIZE):
        start_idx = offset + PAD // 2
        end_idx = offset + size + PAD // 2
        mid_angle = np.mean(angles[start_idx:end_idx])
        ax.text(
            x=mid_angle, 
            y=0.87,  # Position the factor name inside the inner radius
            s=group, 
            ha='center', 
            va='center', 
            rotation=np.rad2deg(mid_angle + OFFSET) - 90,  # Adjust rotation based on position
            rotation_mode="anchor",
            fontsize=30,
            color="black"
        )
        offset += size + PAD


In [None]:
# All plots together
fig = plt.figure(figsize=(40, 40))
gs = fig.add_gridspec(2, 2, height_ratios=[1, 1], width_ratios=[1.5, 2]) # width_ratios - columns 1 vs 2; height_ratios - rows 1 vs 2
ax1 = fig.add_subplot(gs[0, 0], projection='polar')
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[0, 1])
ax4 = fig.add_subplot(gs[1, 1])

################################################## LOADING PLOT FOR G-FACTOR

offset = 0
IDXS = []
for size in GROUPS_SIZE:
    IDXS += list(range(offset + PAD, offset + size + PAD))
    offset += size + PAD

ax1.set_theta_offset(OFFSET)
ax1.set_ylim(0.7, max(abs(VALUES)) + 1) #0.5,  ... + 1 # Set a smaller inner radius for the hole
ax1.set_frame_on(False)
ax1.xaxis.grid(False)
ax1.yaxis.grid(False)
ax1.set_xticks([])
ax1.set_yticks([])

# Define colors for positive and negative loadings
colors = ['#B2474599' if v > 0 else '#6A6599FF' for v in VALUES]

# Add the bars
MIN_BAR_HEIGHT = 0.1  # Define a minimum height for the bars
bars = ax1.bar(
    ANGLES[IDXS], [max(abs(v), MIN_BAR_HEIGHT) for v in VALUES], width=WIDTH, #abs(VALUES)
    color=colors,
    edgecolor="white", linewidth=1, alpha=0.5,
    bottom=1 # Set the inner radius to create a smaller hole
)

# Add labels after adding the bars
add_labels(ANGLES[IDXS], abs(VALUES), LABELS, OFFSET, ax1)
add_factor_names(np.unique(GROUP), ANGLES, OFFSET, ax1)

# Add the italic 'g' in the middle of the circle
ax1.text(
    x=0.7, 
    y=0.7, 
    s='g', 
    ha='center', 
    va='center', 
    fontsize=50, 
    fontstyle='italic', 
    color='black'
)

ax1.set_title('ESEM Model Structure', y=1.4, fontsize=70)

################################################## SCATTERPLOT

cmap = sns.color_palette("ch:s=-.2,r=.6", as_cmap=True)
dist_i = np.sqrt((y_true - y_pred.mean())**2 + (y_pred - y_pred.mean())**2)
sns.scatterplot(x=y_pred, y=y_true, c=dist_i, cmap=cmap,  s=200, alpha=0.6, ax=ax2)

fontsize=25
r2_scores = []
pearson_corrs = []
for fold in folds:
    y_pred_fold = pd.read_csv(f'/mental_health/folds/fold_{fold}/g_pred/g_pred_mh_fold_{fold}.csv')[['g_pred_mh']]
    y_true_fold = pd.read_csv(f'/mental_health/folds/fold_{fold}/suppl/g_test_matched_fold_{fold}.csv')['g']
    sns.regplot(x=y_pred_fold, y=y_true_fold, scatter=False, ax=ax2, label=f'Fold {fold}', line_kws={"color": "red", "linewidth": 0.8})
    r2 = r2_score(y_true_fold, y_pred_fold)
    r2_scores.append(r2)
    corr, _ = pearsonr(y_true_fold.squeeze(), y_pred_fold.squeeze())
    pearson_corrs.append(corr)

sns.despine (top=True, right=True, ax=ax2)
ax2.set_xlabel('Predicted $g$-factor ($z$)', fontsize=60)
ax2.set_ylabel('$g$-factor derived from ESEM ($z$)', fontsize=60)
ax2.tick_params(axis='x', labelsize=40)
ax2.tick_params(axis='y', labelsize=40)
ax2.set_title(f'$g$-factor Predicted\nfrom Mental Health', fontsize=70, y=1.2)
    
# Annotate the plot with Pearson correlation and R² score
#corr, p = pearsonr(y_true, y_pred)
r_mean = np.mean(pearson_corrs)
r_std = np.std(pearson_corrs)
r2_mean = np.mean(r2_scores)
r2_std = np.std(r2_scores)

#r2 = r2_score(y_true, y_pred)
ax2.text(0.05, 1.08, f'$r_{{mean}}$ = {r_mean:.2f} (SD={r_std:.2f})', transform=ax2.transAxes, fontsize=60)
ax2.text(0.05, 0.98, f'$R^2_{{mean}}$ = {r2_mean:.2f} (SD={r2_std:.2f})', transform=ax2.transAxes, fontsize=60)
ax2.xaxis.set_major_locator(MultipleLocator(1))
ax2.yaxis.set_major_locator(MultipleLocator(1))

################################################## CORR + LOADINGS PLSR
ax3.axhline(y=0, color='grey', linewidth=1.5, linestyle='--')

alpha_values = [0.9 if sig else 0.3 for sig in domain_corr_df['significant']]
sns.scatterplot(data=domain_corr_df, x='feature', y='correlation', hue='domain', ax=ax3, s=900,
                alpha=alpha_values, palette=jama_palette, zorder=0)

ax4.axhline(y=0, color='grey', linewidth=1.5, linestyle='--')
sns.scatterplot(data=domain_loadings_df, x='feature', y='loading', hue='domain', ax=ax4, s=900,
                alpha=0.6, palette=jama_palette, zorder=0)


# Set the limits for the y-axis
ax3.set_ylim(-0.15, 0.25)
ax4.set_ylim(-0.2, 0.4)

# Remove x-axis labels and ticks for both subplots
ax3.set_xticks([])
ax3.set_xticklabels([])
ax4.set_xticks([])
ax4.set_xticklabels([])

# Increase font size for tick labels
ax3.tick_params(axis='y', labelsize=55)
ax4.tick_params(axis='y', labelsize=55)

# Set y-tick intervals
ax3.yaxis.set_major_locator(MultipleLocator(0.05))
ax4.yaxis.set_major_locator(MultipleLocator(0.05))

# Set labels
ax3.set_ylabel("Pearson's $r$", fontsize=60)
ax4.set_ylabel("Loadings", fontsize=60)

ax3.set_xlabel('')
ax4.set_xlabel('')

# Add annotations for the top significant correlation within each domain with unique xytext values
for domain in significant_domain_corr_df['domain'].unique():
    domain_subset = significant_domain_corr_df[significant_domain_corr_df['domain'] == domain]
    top_corr = domain_subset.loc[domain_subset['correlation'].abs().idxmax()]
    if domain == 'Mental distress':
        xytext = (-30, 40)
    elif domain == 'Depression':
        xytext = (-60, 150)
    elif domain == 'Neuroticism':
        xytext = (-34, 70)
    elif domain == 'Diagnoses':
        xytext = (-25, -80)
    elif domain == 'Mania':
        xytext = (-200, 10)
    elif domain == 'Anxiety':
        xytext = (30, -0)
    elif domain == 'Addictions':
        xytext = (-30, 20)
    elif domain == 'Unusual/psychotic experiences':
        xytext = (-50, 50)
    elif domain == 'Alcohol and cannabis use':
        xytext = (40, -10)
    elif domain == 'Happiness and subjective well-being':
        xytext = (30, 30)
    else:
        xytext = (20, 20)
    ax3.annotate(split_text(top_corr['feature']), (top_corr['feature'], top_corr['correlation']),
                 textcoords="offset points", xytext=xytext, ha='left', fontsize=35, color=palette_dict[domain], zorder=10)
    
# Add annotations for the top loadings within each domain with unique xytext values
for domain in domain_loadings_df['domain'].unique():
    domain_subset = domain_loadings_df[domain_loadings_df['domain'] == domain]
    top_loading = domain_subset.loc[domain_subset['loading'].abs().idxmax()]
    if domain == 'Mental distress':
        xytext = (60, 20)
    elif domain == 'Depression':
        xytext = (50, 10)
    elif domain == 'Diagnoses':
        xytext = (-150, -100)
    elif domain == 'Anxiety':
        xytext = (120, -45)
    elif domain == 'Addictions':
        xytext = (-30, 20)
    elif domain == 'Mania':
        xytext = (40, 130)
    elif domain == 'Unusual/psychotic experiences':
        xytext = (-60, 120)
    elif domain == 'Alcohol and cannabis use':
        xytext = (150, -10)
    elif domain == 'Neuroticism':
        xytext = (135, -40)
    elif domain ==  'Traumatic events':
        xytext = (60, 25)
    elif domain == 'Happiness and subjective well-being':
        xytext = (130, -60)
    elif domain == 'Self-harm behaviours':
        xytext = (120, 20)
    else:
        xytext = (10, 20)
    ax4.annotate(split_text(top_loading['feature']), (top_loading['feature'], top_loading['loading']),
                 textcoords="offset points", xytext=xytext, ha='center', fontsize=35, color=palette_dict[domain], zorder=10)

# Remove the frame (spines) around the subplots
for spine in ax3.spines.values():
    spine.set_visible(False)
for spine in ax4.spines.values():
    spine.set_visible(False)
    
ax3.spines['left'].set_visible(True)
ax4.spines['left'].set_visible(True)

# Legend along the x-axis
ax3.legend('')
ax4.legend('')

ax3.set_title('$g$-factor and Mental Health:\nCorrelations', fontsize=70, pad=20, y=1.22) #fontweight='bold',
ax4.set_title('$g$-factor and Mental Health:\nPLSR Loadings', fontsize=70, pad=20, y=1.18)

handles, labels = ax3.get_legend_handles_labels()
legend = fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.02), ncol=3, fontsize=50, frameon=False)
for handle in legend.legend_handles:
    handle._sizes = [1000]

fig.subplots_adjust(top=0.85)
plt.tight_layout(pad=3.0)

axes=[ax1, ax2, ax3, ax4]
labels = ['a', 'b', 'c', 'd']
positions = [
    (-0.15, 1.49), #a
    (-0.1, 1.33), #b
    (-0.05, 1.4), #c
    (-0.05, 1.35)] #d

for i in range(len(axes)):
    axes[i].text(positions[i][0], positions[i][1], labels[i], transform=axes[i].transAxes,
                 fontsize=85, fontweight='bold', va='top', ha='right')

plt.savefig("/final/pdf/Fig2.pdf",
            bbox_inches="tight", 
            pad_inches=1, 
            transparent=False, 
            facecolor="w", 
            edgecolor='w', 
            orientation='landscape',
            format='pdf') 

plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import MultipleLocator
import numpy as np
import pandas as pd
from scipy.stats import pearsonr
from sklearn.metrics import r2_score

# Create figure for just the scatterplot
plt.figure(figsize=(12, 10))

# Create single axis
ax = plt.gca()

# Plot the scatterplot (same as original ax2)
cmap = sns.color_palette("ch:s=-.2,r=.6", as_cmap=True)
dist_i = np.sqrt((y_true - y_pred.mean())**2 + (y_pred - y_pred.mean())**2)
sns.scatterplot(x=y_pred, y=y_true, c=dist_i, cmap=cmap, s=100, alpha=0.6, ax=ax)

# Plot regression lines for each fold
fontsize=18
r2_scores = []
pearson_corrs = []
for fold in folds:
    y_pred_fold = pd.read_csv(f'/mental_health/folds/fold_{fold}/g_pred/g_pred_mh_fold_{fold}.csv')[['g_pred_mh']]
    y_true_fold = pd.read_csv(f'/mental_health/folds/fold_{fold}/suppl/g_test_matched_fold_{fold}.csv')['g']
    sns.regplot(x=y_pred_fold, y=y_true_fold, scatter=False, ax=ax, 
                line_kws={"color": "red", "linewidth": 0.8, "alpha": 0.6})
    r2 = r2_score(y_true_fold, y_pred_fold)
    r2_scores.append(r2)
    corr, _ = pearsonr(y_true_fold.squeeze(), y_pred_fold.squeeze())
    pearson_corrs.append(corr)

# Style adjustments
sns.despine(top=True, right=True, ax=ax)
ax.set_xlabel('$g$-factor predicted from MH ($z$)', fontsize=40)
ax.set_ylabel('$g$-factor derived from ESEM ($z$)', fontsize=40)
ax.tick_params(axis='both', labelsize=14)

# Calculate and display statistics
r_mean = np.mean(pearson_corrs)
r_std = np.std(pearson_corrs)
r2_mean = np.mean(r2_scores)
r2_std = np.std(r2_scores)

ax.text(0.05, 1.08, f'$r_{{mean}}$ = {r_mean:.2f} (SD={r_std:.2f})', 
        transform=ax.transAxes, fontsize=35)
ax.text(0.05, 0.98, f'$R^2_{{mean}}$ = {r2_mean:.2f} (SD={r2_std:.2f})', 
        transform=ax.transAxes, fontsize=35)

# Set tick intervals
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(1))

plt.tight_layout()

# Save the figure
plt.savefig("/g_factor_scatterplot.png",
            bbox_inches="tight", 
            pad_inches=0.3, 
            transparent=False, 
            facecolor="w", 
            edgecolor='w',
            dpi=300)

plt.show()

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.ticker import MultipleLocator

# Create figure for just the PLSR loadings plot
plt.figure(figsize=(20, 18))

# Create single axis (equivalent to original ax4)
ax = plt.gca()

# Plot the loadings data
ax.axhline(y=0, color='grey', linewidth=1.5, linestyle='--')
sns.scatterplot(data=domain_loadings_df, x='feature', y='loading', hue='domain', 
                ax=ax, s=900, alpha=0.6, palette=jama_palette, zorder=0)

# Set plot limits and style
ax.set_ylim(-0.2, 0.4)
ax.set_xticks([])
ax.set_xticklabels([])
ax.tick_params(axis='y', labelsize=40)
ax.yaxis.set_major_locator(MultipleLocator(0.05))
ax.set_ylabel("PLSR Loadings", fontsize=45)
ax.set_xlabel('')

# Add annotations (same as original)
for domain in domain_loadings_df['domain'].unique():
    domain_subset = domain_loadings_df[domain_loadings_df['domain'] == domain]
    top_loading = domain_subset.loc[domain_subset['loading'].abs().idxmax()]
    if domain == 'Mental distress':
        xytext = (60, 20)
    elif domain == 'Depression':
        xytext = (50, 10)
    elif domain == 'Diagnoses':
        xytext = (-150, -100)
    elif domain == 'Anxiety':
        xytext = (120, -45)
    elif domain == 'Addictions':
        xytext = (-30, 20)
    elif domain == 'Mania':
        xytext = (40, 130)
    elif domain == 'Unusual/psychotic experiences':
        xytext = (-60, 120)
    elif domain == 'Alcohol and cannabis use':
        xytext = (150, -10)
    elif domain == 'Neuroticism':
        xytext = (135, -40)
    elif domain ==  'Traumatic events':
        xytext = (60, 25)
    elif domain == 'Happiness and subjective well-being':
        xytext = (130, -60)
    elif domain == 'Self-harm behaviours':
        xytext = (120, 20)
    else:
        xytext = (10, 20)
    ax.annotate(split_text(top_loading['feature']), 
               (top_loading['feature'], top_loading['loading']),
               textcoords="offset points", 
               xytext=xytext, 
               ha='center', 
               fontsize=25, 
               color=palette_dict[domain], 
               zorder=10)

# Adjust spines
for spine in ax.spines.values():
    spine.set_visible(False)
ax.spines['left'].set_visible(True)

# Add title
ax.set_title('$g$-factor and Mental Health:\nPLSR Loadings', fontsize=55, pad=30)

# Add legend (if needed)
handles, labels = ax.get_legend_handles_labels()
if handles:  # Only add legend if there are items to show
    legend = plt.legend(handles, labels, loc='upper center', 
                       bbox_to_anchor=(0.5, -0.05), 
                       ncol=3, fontsize=30, frameon=False)
    for handle in legend.legend_handles:
        handle._sizes = [800]

plt.tight_layout()

# Save the figure
plt.savefig("/PLSR_loadings_plot.png",
            bbox_inches="tight", 
            pad_inches=0.5, 
            transparent=False, 
            facecolor="w", 
            edgecolor='w',
            dpi=300)

plt.show()