In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
import os
sns.set()  # set the style

In [None]:
os.chdir('processedData\\')

In [None]:
sns.set_style('whitegrid')
sns.set_context('paper', font_scale= 2)

In [None]:
## read in Training data to match by age and scanner and equalize group sizes
menarcheDF = pd.read_csv('harmonizedTraindata_plusscannerdfONLYSMRI.csv')

## Age and Scanner matching

In [None]:
## choose features for propensity score calculation (only Age and Scanner in this case)
X = pd.get_dummies(menarcheDF['mri_info_deviceserialnumber'])
X['age'] = menarcheDF['interview_age']
y = menarcheDF['pds_f5_y_P']

In [None]:
## create pipeline for calculation of propensity scores
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipe = Pipeline([
    ('scaler', StandardScaler()),
    ('logistic_classifier', LogisticRegression())
])
pipe.fit(X, y)

In [None]:
## prediction
pred_binary = pipe.predict(X)  # binary 0 pre, 1 post
menarcheDF['PS'] = pipe.predict_proba(X)[:,1] # probabilities for classes
# calculate the logit of the propensity score for matching 
def logit(p):
    logit_value = math.log(p / (1-p))
    return logit_value

menarcheDF['PS_LOGIT'] = menarcheDF['PS'].apply(lambda x: logit(x))

menarcheDF.head()

In [None]:
## check the overlap of PS for pre and post using histogram
## if not much overlap, the matching won't work
sns.histplot(data=menarcheDF, x='PS_LOGIT', hue='pds_f5_y_P', palette=['red','blue']) 
plt.title('Propensity Scores of Pre and Post Menarche Group', size = 15)
plt.xlabel('Logit of Propensity Score')
plt.legend(['Post','Pre'],title='Menarche')
plt.savefig('..\\Plots\\overlapPrePostMen.png')

In [None]:
## use 30% of standard deviation of the propensity score as the caliper/radius
## get the k closest neighbors for each observations
## relax caliper and increase k can provide more matches

from sklearn.neighbors import NearestNeighbors

caliper = np.std(menarcheDF.PS_LOGIT) * 0.3
print(f'caliper (radius) is: {caliper:.4f}')

n_neighbors = 100

# setup knn
knn = NearestNeighbors(n_neighbors=n_neighbors, radius=caliper)

PS_LOGIT = menarcheDF[['PS_LOGIT']]  # double brackets as a dataframe
knn.fit(PS_LOGIT)

In [None]:
## distances and indexes
distances, neighbor_indexes = knn.kneighbors(PS_LOGIT)

print(neighbor_indexes.shape)

## the 10 closest points to the first point
print(distances[0])
print(neighbor_indexes[0])

In [None]:
## for each point in post men, we find a matching point in pre men without replacement
## note the 10 neighbors may include both points in post and pre

matched_control = []  # keep track of the matched observations in pre

for current_index, row in menarcheDF.iterrows():  # iterate over the dataframe
    if row.pds_f5_y_P == 1:  # the current row is in the pre group
        menarcheDF.loc[current_index, 'MATCHED'] = np.nan  # set matched to nan
    else: 
        for idx in neighbor_indexes[current_index, :]: # for each row in post, find the k neighbors
            # make sure the current row is not the idx - don't match to itself
            # and the neighbor is in the pre class 
            if (current_index != idx) and (menarcheDF.loc[idx].pds_f5_y_P == 1):
                if idx not in matched_control:  # this pre men subject has not been matched yet
                    menarcheDF.loc[current_index, 'MATCHED'] = idx  # record the matching
                    matched_control.append(idx)  # add the matched to the list
                    break

In [None]:
print('total observations in post menarche group:', len(menarcheDF[menarcheDF.pds_f5_y_P==4]))
print('total matched observations in pre menarche group:', len(matched_control))

In [None]:
## some pre men subs have no match
treatment_matchedBoth = menarcheDF.dropna(subset=['MATCHED'])  # drop not-matched

## matched pre men indexes
control_matched_idx = treatment_matchedBoth['MATCHED']
control_matched_idx = control_matched_idx.astype(int)  # change to int
control_matchedBoth = menarcheDF.loc[control_matched_idx, :]  # select matched control observations

## combine the matched pre and post subjects
df_matchedBoth = pd.concat([treatment_matchedBoth, control_matchedBoth])

df_matchedBoth['pds_f5_y_P'].value_counts()

In [None]:
df_matchedBoth.reset_index(inplace=True)
df_matchedBoth.groupby(['pds_f5_y_P'])['interview_age'].mean()

In [None]:
# distribution of age after matching
sns.histplot(data=df_matchedBoth, x='interview_age', hue='pds_f5_y_P', palette=['red','blue']) 
#fig.tight_layout()
plt.legend(['Post','Pre'],title='Menarche')
plt.xlabel('Age')
plt.title('Age Distribution in Training Data after Matching', size = 15)
plt.savefig('..\\Plots\\agedistAfterMatching.png')

In [None]:
df_mean = menarcheDF.groupby(['pds_f5_y_P','mri_info_deviceserialnumber']).size().reset_index(name='count')
df_mean = df_mean.sort_values('count',ascending=True)

In [None]:
df_mean2 = df_matchedBoth.groupby(['pds_f5_y_P','mri_info_deviceserialnumber']).size().reset_index(name='count')
df_mean2 = df_mean2.sort_values('count',ascending=True)

In [None]:
fig, (ax1,ax2) = plt.subplots(1,2, figsize = (15,5))
fig.tight_layout(pad = 1.8)
hueorder = [4,1]

ax1 = sns.barplot(data=df_mean, x='mri_info_deviceserialnumber', y='count', hue='pds_f5_y_P', hue_order = hueorder, palette='rocket', ax=ax1) 
ax1.set_xlabel('MRI Scanner')
ax1.set_xticklabels('')
labels = ['Post','Pre']
h, l = ax1.get_legend_handles_labels()
ax1.legend(h, labels, title="Menarche")

ax2 = sns.barplot(data=df_mean2, x='mri_info_deviceserialnumber', y='count', hue='pds_f5_y_P', hue_order = hueorder, palette='rocket', ax=ax2) 
ax2.set_xlabel('MRI Scanner')
ax2.set_xticklabels('')
labels = ['Post','Pre']
h, l = ax2.get_legend_handles_labels()
ax2.legend(h, labels, title="Menarche")

plt.savefig('..\\Plots\\scannerprepostmatching.png')

In [None]:
df_matchedBoth.to_csv('menarcheTrain_harm_red_matchedAgeScannerONLYSMRI.csv', index = False)