# Propensity Score Matching for ABIDE

In [1]:
### Load python libraries
import os
import numpy as np
import pandas as pd
import rpy2.robjects as robjects
from matplotlib import pyplot as plt

In [2]:
%matplotlib inline

In [17]:
pheno_path = '/data1/abide/Pheno/full_merged_pheno.csv'
proj_name = 'abide_site'

In [18]:
pheno = pd.read_csv(pheno_path)

In [19]:
# Prepare the pheno file
include = ['USM', 'PITT', 'NYU', 'UCLA_1', 'UCLA_2']
pheno = pheno[pheno['SITE_ID'].isin(include)]

In [20]:
pheno.to_csv('/data1/abide/Pheno/temp_pheno.csv')

In [30]:
### Variables to be parsed to R
# Working directory
work_dir = "/data1/abide/Pheno/"
# CSV file name
CSV_file =  "full_merged_pheno.csv"
out_name = "{}_prop.csv".format(proj_name)
# Declare a list of variables that need to be categorical (using the names from CSV; assuming NOSPACE)
categories = robjects.StrVector(['SITE_ID', 'SEX', 'DX_GROUP'])
# Declare Formula for Mahalanobis distance matching (using the names from CSV)
## Format: Disease variable ~  What to Match by 1 + What to Match by 2 + ...
## All variables matched by must be NUMERIC
Mah_formula = 'DX_GROUP ~ AGE_AT_SCAN + FD_scrubbed' 
# Declare Caliper Width (as a fraction of the SD of the Propensity Scores)
cal_width = 0.5
# Declare Formula for PSM distance matching (using the names from CSV)
## Format: Disease variable ~  What to Match by 1 + What to Match by 2 + ...
PSM_formula = 'DX_GROUP ~ AGE_AT_SCAN + FD_scrubbed + SITE_ID + SEX'

In [31]:
### Parsing Stuff
robjects.globalenv["work_dir"] = work_dir
robjects.globalenv["CSV_file"] = CSV_file
robjects.globalenv["categories"] = categories
robjects.globalenv["Mah_formula"] = Mah_formula
robjects.globalenv["cal_width"] = cal_width
robjects.globalenv["PSM_formula"] = PSM_formula
robjects.globalenv["out_name"] = out_name

In [32]:
# Run R script
## R scripts write a CSV file of the form
## matching_*CSV_file*
## containing two extra columns;
## keep column is 1 to keep and 0 to leave out.
robjects.r('''
  # Load R Librarires
  library(optmatch)
  
  # Set Working directory
  setwd(work_dir)
  
  # Read CSV
  data <- read.csv(CSV_file)
  
  # Make all categories factors
  for (ff in 1:length(categories)) {
    data[[categories[ff]]] <- as.factor(data[[categories[ff]]])
  }
  
  # Performs Matching
  Matching <- fullmatch(
  match_on( as.formula(Mah_formula),
           data = data ) + 
    caliper( match_on( as.formula(PSM_formula), 
                     data = data ),
            width = cal_width ),
  data = data )
  
  # Make a data frame with a column with 1 to keep an observation and 0 to leave out
  save_data <- cbind(data,match=Matching)
  save_data$keep <- as.numeric(!is.na(save_data$match))
  write.csv(save_data,file=paste('matching_',CSV_file,sep=''))
''')

rpy2.rinterface.NULL

In [None]:
# Load the matched pheno file
matched = pd.read_csv(os.path.join(work_dir, out_name))
# Only keep the matched subjects
keep = matched.dropna(subset=['match'])
# Save that out
keep.to_csv(os.path.join(work_dir, '{}_matched.csv'.format(proj_name)))

In [None]:
# Explore the matched sample
patient_idx = keep.DX_GROUP.values == 1
control_idx = keep.DX_GROUP.values == 2
patients = keep[patient_idx]
controls = keep[control_idx]

In [None]:
def plot_sample(ados, ctrl):
    ados_grouped = ados.groupby('SITE_ID')
    ctrl_grouped = ctrl.groupby('SITE_ID')
    site_names = ados_grouped.groups.keys()

    ados_ages = [ados_grouped.get_group(st)['AGE_AT_SCAN'].values for st in site_names]
    ados_dx = [ados_grouped.get_group(st)['DX_GROUP'].values for st in site_names]

    ctrl_ages = [ctrl_grouped.get_group(st)['AGE_AT_SCAN'].values for st in site_names]
    ctrl_dx = [ctrl_grouped.get_group(st)['DX_GROUP'].values for st in site_names]


    f = plt.figure(figsize=(10,4))
    ax = f.add_subplot(111)
    y_ados = np.array([])
    x_ados = np.array([])
    y_ctrl = np.array([])
    x_ctrl = np.array([])

    for idx, ados_ag in enumerate(ados_ages):
        # ados first
        l_ages = len(ados_ag)
        ind = np.ones((l_ages,))*idx+1
        jitter = (np.random.random(l_ages,)*2-1)*0.2
        ind += jitter

        x_ados = np.append(x_ados, ind)
        y_ados = np.append(y_ados, ados_ag)

        # now ctrl
        ctrl_ag = ctrl_ages[idx]
        l_ages = len(ctrl_ag)
        ind = np.ones((l_ages,))*idx+1
        jitter = (np.random.random(l_ages,)*2-1)*0.2
        ind += jitter

        x_ctrl = np.append(x_ctrl, ind)
        y_ctrl = np.append(y_ctrl, ctrl_ag)

    ad = ax.scatter(x_ados, y_ados, c='r', alpha=0.5)
    ct = ax.scatter(x_ctrl, y_ctrl, c='y', alpha=0.5)
    ax.legend((ad, ct),
               ('Patient', 'Control'),
               scatterpoints=1,
               loc='upper right',
               ncol=1,
               fontsize=10)
    tmp = ax.set_xticks(np.arange(len(site_names))+1)
    tmp = ax.set_xticklabels(site_names, rotation=70)
    tmp = ax.set_ylabel('Age')
    tmp = ax.set_title('Simple Matching ADOS sample')

In [None]:
plot_sample(patients, controls)