In [None]:
import pandas as pd
import numpy as np
from numpy import mean
from numpy import std
from matplotlib import pyplot as plt
import seaborn as sns
import os

import sklearn
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import confusion_matrix
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import permutation_test_score

In [None]:
sns.set_style('whitegrid')
sns.set_context('paper', font_scale=1.5)

mypalettePink = ['#f78f8f','#a83246']
mypaletteBlue = sns.color_palette(['#C2F3F2','#1D9694'], as_cmap = True)

In [None]:
os.chdir('processedData\\')

In [None]:
plotdir = '..\\Plots\\'

In [None]:
sns.set_style({'axes.facecolor': 'white',
 'axes.edgecolor': '.15',
 'axes.grid': True,
 'axes.axisbelow': True,
 'axes.labelcolor': '.15',
 'figure.facecolor': 'white',
 'grid.color': '.8',
 'grid.linestyle': '-',
 'text.color': '.15',
 'xtick.color': '.15',
 'ytick.color': '.15',
 'xtick.direction': 'out',
 'ytick.direction': 'out',
 'lines.solid_capstyle': 'projecting',
 'patch.edgecolor': 'w',
 'patch.force_edgecolor': True,
 'image.cmap': 'rocket',
 'font.family': ['sans-serif'],
 'font.sans-serif': ['Arial',
  'DejaVu Sans',
  'Liberation Sans',
  'Bitstream Vera Sans',
  'sans-serif'],
 'xtick.bottom': False,
 'xtick.top': False,
 'ytick.left': False,
 'ytick.right': False,
 'axes.spines.left': True,
 'axes.spines.bottom': True,
 'axes.spines.right': True,
 'axes.spines.top': True})

In [None]:
sns.set_context({'font.size': 15.0,
 'axes.labelsize': 'medium',
 'axes.titlesize': 'large',
 'xtick.labelsize': 'medium',
 'ytick.labelsize': 'medium',
 'legend.fontsize': 'medium',
 'legend.title_fontsize': None,
 'axes.linewidth': 0.8,
 'grid.linewidth': 0.8,
 'lines.linewidth': 1.5,
 'lines.markersize': 6.0,
 'patch.linewidth': 1.0,
 'xtick.major.width': 0.8,
 'ytick.major.width': 0.8,
 'xtick.minor.width': 0.6,
 'ytick.minor.width': 0.6,
 'xtick.major.size': 3.5,
 'ytick.major.size': 3.5,
 'xtick.minor.size': 2.0,
 'ytick.minor.size': 2.0})

# 1. Model Training

In [None]:
## Load reduced dataset, i.e. redundant data such as means and totals are removed, 
## that has been harmonized for scan site and matched for Age and Scanner
Train = pd.read_csv('TrainAgePredFeaturesHarmonised.csv')
Train.sort_values(by = 'subjectkey', inplace=True)
## Dataframe containing demographic data
pds = pd.read_csv('menarcheTrain_harm_red_matchedAgeScannerONLYSMRI.csv')
pds.sort_values(by = 'subjectkey', inplace=True)
## extract menarche variable
pds = pds[['subjectkey','pds_f5_y_P']]
## add menarche variable to training feature dataframe
Train = Train.merge(pds, how = 'outer', on = 'subjectkey') 

In [None]:
## drop every column from X that isn't MRI data
X_TrainPred = Train.drop(columns=['subjectkey','pds_f5_y_P'])
## menarche data (1=pre, 4=post) as labels
y_Train = Train['pds_f5_y_P']

In [None]:
## Preprocess the dataset: Standardize by removing mean and scaling to unit variance
sc = StandardScaler()
X = sc.fit_transform(X_TrainPred)
## Encode target labels with value between 0 and n_classes-1
le = LabelEncoder()
Y = le.fit_transform(y_Train)

In [None]:
## Perform grid search to optimise the parameters for an LDA model
model = LinearDiscriminantAnalysis(n_components=1)

## perform nested stratified 10-fold crossvalidation
inner_cv = StratifiedKFold(n_splits=10)
outer_cv = StratifiedKFold(n_splits=10)

## do a search over all possible sovers and over the shrinkage options none, and fixed shrinkage parameters 
## within the possible range of 0 to 1 in increments of 0.1
grid_vals = {'solver': ['svd','lsqr','eigen'], 'shrinkage': [None,'auto',0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1]}

grid_lr_balanced = GridSearchCV(estimator=model, param_grid=grid_vals, scoring='balanced_accuracy', 
                       cv=inner_cv, refit=True, return_train_score=True,verbose=3) 

## crossvalidate the performance of the model
nested_score_balanced = cross_val_score(grid_lr_balanced, X=X, y=Y, cv=outer_cv)

## fit model to complete training features
grid_lr_balanced.fit(X, Y)
mymodel_balanced = grid_lr_balanced.best_estimator_

In [None]:
print('CV scores: \n', nested_score_balanced)
print('CV mean: \n', mean(nested_score_balanced))

# 2. Model validation with holdout test data

In [None]:
## load test data
Test = pd.read_csv('TestAgePredFeaturesHarmonised.csv')
Test.sort_values(by='subjectkey',inplace=True)
## remeber subjectkeys of test subjects
holdout_subs = Test['subjectkey']

In [None]:
## Dataframe containing demographic data
pdstest = pd.read_csv('harmonizedTestdata_plusscannerdfONLYSMRI.csv')
pdstest.sort_values(by='subjectkey',inplace=True)
## extract menarche variable
pdstest = pdstest[['subjectkey','pds_f5_y_P']]
## add menarche variable to test feature dataframe
Test = Test.merge(pdstest, how = 'outer', on = 'subjectkey') 

In [None]:
## drop every column from X that isn't MRI data
X_TestPred = Test.drop(columns=['subjectkey','pds_f5_y_P'])
## menarche data (1=pre, 4=post) as labels
y_Test = Test['pds_f5_y_P']

In [None]:
## Preprocess the dataset: Standardize by removing mean and scaling to unit variance
sc_h = StandardScaler()
x_holdout = sc_h.fit_transform(X_TestPred)
## Encode target labels with value between 0 and n_classes-1
le_h = LabelEncoder()
y_holdout = le_h.fit_transform(y_Test)

In [None]:
## perform binary prediction of menarche status
preds = mymodel_balanced.predict(x_holdout)

# 3. Visualising Model Performance

In [None]:
## create confusion matrix for model performance on holdout data
cm1 = confusion_matrix(y_holdout, preds)
tn, fp, fn, tp = confusion_matrix(y_holdout, preds).ravel()
print(cm1)
cm1p = cm1.copy()

# turn int to float
cm1p = cm1p * 1.0

# calculate percentages per row
cm1p[0,] = cm1[0,] / (cm1[0,0] + cm1[0,1])
cm1p[1,] = cm1p[1,] / (cm1p[1,0] + cm1p[1,1])

# calculate accuracy score
ac_preds = balanced_accuracy_score(y_holdout,preds)
print(ac_preds)

In [None]:
## plot confusion matrix as heatmap
hm = sns.heatmap(data = cm1p, cmap=mypalettePink, center=0.5 ,annot = True, linewidths = 1, 
                 linecolor = "Black",cbar = False, xticklabels = ['pre','post'], yticklabels = ['pre','post'], fmt = '.2%')

plt.savefig(plotdir + 'confusionMatrix1102.png')

In [None]:
## extract class probabilities for further analysis
classprobsFinal = mymodel_balanced.predict_proba(x_holdout)

## create data frame of holdout subjects and their class probabilities
classprobsdfFinal = pd.DataFrame(classprobsFinal, columns = ['prob pre', 'prob post'])

## add subjectkeys of holdout subjects to df
classprobsdfFinal['subjectkey'] = holdout_subs

## add actual menacharche status to df
classprobsdfFinal['Actual'] = y_holdout

## add predicted menarche status to df
classprobsdfFinal['Predicted'] = preds

In [None]:
## create conditions for TP, FN, FP and TN
conditions = [
    (classprobsdfFinal['Actual'] == 1) & (classprobsdfFinal['Predicted'] == 1),
    (classprobsdfFinal['Actual'] == 1) & (classprobsdfFinal['Predicted'] == 0),
    (classprobsdfFinal['Actual'] == 0) & (classprobsdfFinal['Predicted'] == 1),
    (classprobsdfFinal['Actual'] == 0) & (classprobsdfFinal['Predicted'] == 0)
    ]

## save TP, FN, FP and TN as list in the same order as conditions above
values = ['True Positive', 'False Negative', 'False Positive', 'True Negative']

## add column to df that contains info on what type of classification(error) occured
classprobsdfFinal['tpfptnfn'] = np.select(conditions, values)

In [None]:
## post menarche class probabilites
classprobsFinalPost = classprobsFinal[:,1]

In [None]:
## for chance level 
random_probs = [0 for _ in range(len(Y))]
## ROC curves of performance on test data
ns_fpr, ns_tpr, _ = roc_curve(Y, random_probs)
ts_fpr, ts_tpr, _ = roc_curve(y_holdout, classprobsFinalPost)

In [None]:
## plot the roc curve for the model
plt.plot(ns_fpr, ns_tpr, linestyle='--')
plt.plot(ts_fpr, ts_tpr, marker = '.', color = 'Green')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.show()

In [None]:
## Permutation test: Lables in the training data get shuffled and a model gets trained on the random training data
## and is then applied to the holdout data. Resulting (random) accuracies are compared to the baseline accuracy 
## generated by training the model on non-shuffled training data and applying it to the holdout data

observed_accuracyTestData = balanced_accuracy_score(y_holdout,preds)

## 1000 permutations and a list to save the results
n_permutations = 1000
permuted_accuraciesTestData = []

## Use the same parameters as in the original model
random_model = LinearDiscriminantAnalysis(n_components=1, shrinkage=0.7, solver='lsqr')

for _ in range(n_permutations):
    ## shuffle the labels in the training data
    shuffled_labels = np.random.permutation(Y)
    ## train the model on random data
    random_model.fit(X,shuffled_labels) 
    ## use the model to classify menarche status in the test data
    predsRandom = random_model.predict(x_holdout)
    ## calculate and append accuracies to the list
    permuted_accuracy = balanced_accuracy_score(y_holdout, predsRandom)
    permuted_accuraciesTestData.append(permuted_accuracy)

# calculate the p-value by looking at number of models with learned randomness that performed as good or better than the
# original model
p_valuePermTest = (np.sum(permuted_accuraciesTestData >= observed_accuracyTestData) + 1) / (n_permutations + 1)
print(f"Observed Accuracy: {observed_accuracyTestData}")
print(mean(permuted_accuraciesTestData))
print(f"P-value: {p_valuePermTest}")

In [None]:
sns.set_style('whitegrid')
sns.set_context("paper", font_scale=1.5)
fig, ax = plt.subplots()

sns.histplot(permuted_accuraciesTestData, bins=20, color='#A4B7D6')

ax.axvline(observed_accuracyTestData, color="red")

ax.set_xlim([0.1,0.9])

ax.set_xlabel("Balanced accuracy")

#ax.set_title('Permutation Test Performed with Holdout Data', fontsize = 15)
plt.savefig(plotdir + 'permutationTest_holdout_new1102.png')

In [None]:
fig, (ax1, ax2) = plt.subplots(1,2, figsize = (11,5), constrained_layout = True)

hm = sns.heatmap(data = cm1p, cmap=mypalettePink, center=0.5 ,annot = True, linewidths = 1, 
                 linecolor = "Black",cbar = False, xticklabels = ['pre','post'], 
                 yticklabels = ['pre','post'], fmt = '.2%', ax = ax1)

ax1.text(0,-0.05,"a)")

p1 = sns.histplot(permuted_accuraciesTestData, bins=20, color='#A4B7D6', ax = ax2)

p1.axvline(observed_accuracyTestData, color="red")

p1.set_xlim([0.1,0.9])

p1.set_xlabel("Balanced accuracy")

ax2.text(0.1,116,"b)")


plt.savefig(plotdir + 'Figure1.png', dpi = 1000)

In [None]:
# save dataframe
classprobsdfFinal.to_csv('classprobabilitesDFfinal_independently_harmonizedScanMatchedAgePredFeaturesShrink071102.csv', index = False)

# 4. Validating that unbalanced test data is valid by repeating with balanced data

In [None]:
## drop 188 random pre-menarche subjects to balance the data
indices_to_drop = Test[Test['pds_f5_y_P'] == 1].sample(n=188, axis='rows', random_state = 1).index
TestBalanced = Test.drop(indices_to_drop)

## remeber subjectkeys of test subjects
holdout_subsBalanced = TestBalanced['subjectkey']

## drop subjectkey
X_TestAgePredBalanced = TestBalanced.drop(columns=['subjectkey','pds_f5_y_P'])

## menarche data (1=pre, 4=post) as labels
y_TestBalanced = TestBalanced['pds_f5_y_P']

## Preprocess the dataset: Standardize by removing mean and scaling to unit variance
sc_h = StandardScaler()
x_holdoutBalanced = sc_h.fit_transform(X_TestAgePredBalanced)

## Encode target labels with value between 0 and n_classes-1
le_h = LabelEncoder()
y_holdoutBalanced = le_h.fit_transform(y_TestBalanced)

## perform binary prediction of menarche status
predsBalanced = mymodel_balanced.predict(x_holdoutBalanced)

In [None]:
## create confusion matrix for model applied to balanced test data
cm2 = confusion_matrix(y_holdoutBalanced, predsBalanced)
print(cm2)
cm2p = cm2.copy()

## turn int to float
cm2p = cm2p * 1.0

## calculate percentages per row
cm2p[0,] = cm2[0,] / (cm2[0,0] + cm2[0,1])
cm2p[1,] = cm2p[1,] / (cm2p[1,0] + cm2p[1,1])

## calculate accuracy score
tn, fp, fn, tp = confusion_matrix(y_holdoutBalanced, predsBalanced).ravel()
ac_preds2 = balanced_accuracy_score(y_holdoutBalanced,predsBalanced)
print(ac_preds2)
print(tn, fp, fn, tp)

In [None]:
## plot confusion matrix of model applied to balanced test data as heatmap
hm = sns.heatmap(data = cm2p, cmap=mypaletteBlue, center=0.5 ,annot = True, linewidths = 1, 
                 linecolor = "Black",cbar = False, xticklabels = ['pre','post'], yticklabels = ['pre','post'], fmt = '.2%')
plt.title('confusion matrix of model applied to balanced test data', size = 14)

In [None]:
## To avoid accidental bias when randomly dropping 188 participants, the process of balancing, training and applying
## the model is repeated 1000 times
n_rounds = 1000
balanced_accuracies = []

for i in range(n_rounds):
    ## balance the sample by dropping 188 random pre-menarche subjects
    indices_to_drop = Test[Test['pds_f5_y_P'] == 1].sample(n=188,axis='rows', random_state = i).index
    TestBalanced = Test.drop(indices_to_drop)

    ## perform the preprocessing, training and prediction
    X_TestPredBalanced = TestBalanced.drop(columns=['subjectkey','pds_f5_y_P'])
    y_TestBalanced = TestBalanced['pds_f5_y_P']

    sc_h = StandardScaler()
    x_holdoutBalanced = sc_h.fit_transform(X_TestPredBalanced)

    le_h = LabelEncoder()
    y_holdoutBalanced = le_h.fit_transform(y_TestBalanced)

    predsBalanced = mymodel_balanced.predict(x_holdoutBalanced)
    accs = balanced_accuracy_score(y_holdoutBalanced,predsBalanced)
    balanced_accuracies.append(accs)
    
print(mean(balanced_accuracies))

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1,3, figsize = (21,6))

hm = sns.heatmap(data = cm1p, cmap=mypalettePink, center=0.5 ,annot = True, linewidths = 1, 
                 linecolor = "Black",cbar = False, xticklabels = ['pre','post'], 
                 yticklabels = ['pre','post'], fmt = '.2%', ax = ax1)

p1 = sns.histplot(permuted_accuraciesTestData, bins=20, color='#A4B7D6', ax = ax2)
p1.axvline(observed_accuracyTestData, color="red")
p1.set_xlim([0.1,0.9])
p1.set_xlabel("Balanced accuracy")

p2 = sns.histplot(balanced_accuracies, bins=10, color='#A4B7D6', ax = ax3)
p2.set_xlim(0, 1)
p2.set_xlabel("Balanced accuracy score - balanced test sets")
p2.axvline(0.5, color='red', linewidth = 0.9)

#plt.savefig(plotdir + 'plot2ConfMatPerm.png', dpi = 1000)

In [None]:
max(balanced_accuracies)

In [None]:
min(balanced_accuracies)

In [None]:
# Plot a histogram of all accuracies on balanced holdout data
sns.histplot(balanced_accuracies, bins=20, color='blue')

plt.xlim(0.2, 0.8)
plt.xlabel("Balanced accuracy score")

plt.axvline(0.5, color='#AC2123', linewidth = 0.5, ls = '--')

plt.savefig(plotdir + '1000balanced_accuracies.pdf', dpi = 1000)