In [8]:
from pathlib import Path 
import numpy as np 
import matplotlib.pyplot as plt 
import pandas as pd
from collections import OrderedDict
import sys
import os
import seaborn as sns
import researchpy as rp
import statsmodels.formula.api as smf
import scipy.stats as stats
import warnings

from statsmodels.nonparametric.smoothers_lowess import lowess

#sys.path.append('/Users/alina/Desktop/MIT/code/ADHD/MTA/helper')
from helper import rr, prep, var_dict

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
save_path = Path('/Users/alina/Desktop/MIT/code/data/output/figures/mediator_regression_14months')

In [10]:
if Path('/Volumes/Samsung_T5/MIT/mta').exists():
    data_root =     '/Volumes/Samsung_T5/MIT/mta'
    data_derived = '/Volumes/Samsung_T5/MIT/mta/output/derived_data'
else: 
    data_root = '/Users/alina/Desktop/MIT/code/data'
    data_derived = '/Users/alina/Desktop/MIT/code/data/output/derived_data'

In [11]:
baseline_var = ['src_subject_id', 'interview_date', 'interview_age', 'sex', 'site', 'days_baseline']
baseline_var_short = ['src_subject_id', 'days_baseline']
dtypes_baseline = { 'src_subject_id' : 'str',
                    'interview_date': 'str' , 
                    'interview_age' : 'int64' ,
                    'sex' : 'str', 
                    'site' : 'int64' ,
                    'days_baseline':  'int64',
                    'version_form': 'str'}

version_form = ['version_form']

qsts = ['snap', 'ssrs',  'masc', 'pc']##, 'wechsler'] #masc to many missing data 

In [12]:
snap_file = 'snap01.txt'
ssrs_file = 'ssrs01.txt'
#masc_file = 'masc_p01.txt'
parent_child_file = 'pcrc01.txt'
wechsler_file = 'wiat_iiip201.txt'
treat_group_file = 'treatment_groups.csv'
#outcome variablles 
snap_vars = ['snainatx', "snahypax", 'snahix', 'snaoddx'] #inattention_mean, hyperactie mean
ssrs_vars = ['sspintx', 'ssptossx']# social skills mean, internalizing mean 
masc_vars = ['masc_masctotalt']
pc_vars = ['pcrcpax', 'pcrcprx'] # power assertion, personal closeness
#wechsler_vars = ['w1readb','w2math','w3spell' ]
outcomes_dict  = {'snap' : snap_vars, 'ssrs' : ssrs_vars,  'pc': pc_vars} #, 'wechsler': wechsler_vars}

interaction_predictors = ['days_baseline', 'site', 'trtname'] #time, site, treatment group

raters = ['Teacher', 'Parent']

treat_group = pd.read_csv(Path(data_derived, treat_group_file))

In [13]:
odd_cd_vars =[ 'cdorodd'] # CD or ODD 
anx_vars = ['pso', 'psoi', 'pag', 'pagi', 'pga', 'pgai' ,'psa', 'psai'] #poa : overanxious disorder: see if included in alaysis
comorb_vars = np.concatenate([odd_cd_vars, anx_vars])

assist_vars = ['demo61']# public assistance 

prev_med_vars = ['hi_24'] #prev medication 

accept_vars = ['d2dresp']# initial acceptance 
med_mod_list = np.concatenate([odd_cd_vars, ['anx'], assist_vars, prev_med_vars, accept_vars])


In [14]:
# load files, drop rows if missing date, drop duplicates 

snap_file = 'snap01.txt'
ssrs_file = 'ssrs01.txt'
masc_file = 'masc_p01.txt'
parent_child_file = 'pcrc01.txt'
wechsler_file = 'wiat_iiip201.txt'
treat_group_file = 'treatment_groups.csv'

treat_group = pd.read_csv(Path(data_derived, treat_group_file))


snap = prep.get_data(Path(data_root, snap_file), columns= [baseline_var, snap_vars, version_form], treat_group= treat_group, set_dtypes= True, version_form= True, split_timepoints= True)
ssrs = prep.get_data(Path(data_root, ssrs_file), columns= [baseline_var, ssrs_vars, version_form], treat_group= treat_group, set_dtypes= True, version_form= True, split_timepoints= True)
pc = prep.get_data(Path(data_root, parent_child_file), columns= [baseline_var, pc_vars], treat_group= treat_group, set_dtypes= True, version_form= False, split_timepoints= True)
masc = prep.get_data(Path(data_root,masc_file), columns= [baseline_var, masc_vars], treat_group= treat_group, set_dtypes= True, version_form= False, split_timepoints= True)
#wechsler = prep.get_data(Path(data_root, wechsler_file), columns= [baseline_var, wechsler_vars], set_dtypes= True, version_form= False, split_timepoints= True)
data_dict = dict(zip(qsts, [snap, ssrs, masc, pc]))

Success
Success
Success
Success


In [15]:
diags1_file  = 'diagpsx01.txt' #comorbid anx and cd/odd 
demog_file = 'demgr01.txt' ##public assistance 
health_qst_file  = 'health01.txt' #prev medication 
initial_sat_file = 'debrief01.txt' #initial acceptance of treatment arm 

diags1 = pd.read_csv(Path(data_root, diags1_file), delimiter = '\t', skiprows=[1])
demog = pd.read_csv(Path(data_root, demog_file), delimiter= '\t', skiprows=[1])
health_qst = pd.read_csv(Path(data_root, health_qst_file), delimiter='\t', skiprows=[1])
init_sat = pd.read_csv(Path(data_root, initial_sat_file), delimiter='\t', skiprows=[1])

med_qsts = ['diags', 'demog', 'heath_qst', 'init_sat']
med_qsts_dict = dict(zip(med_qsts, [diags1, demog, health_qst, init_sat]))

  diags1 = pd.read_csv(Path(data_root, diags1_file), delimiter = '\t', skiprows=[1])
  health_qst = pd.read_csv(Path(data_root, health_qst_file), delimiter='\t', skiprows=[1])


In [16]:
diags1 = pd.read_csv(Path(data_root, diags1_file), delimiter = '\t')
demog = pd.read_csv(Path(data_root, demog_file), delimiter= '\t')
health_qst = pd.read_csv(Path(data_root, health_qst_file), delimiter='\t')
init_sat = pd.read_csv(Path(data_root, initial_sat_file), delimiter='\t')

  diags1 = pd.read_csv(Path(data_root, diags1_file), delimiter = '\t')
  health_qst = pd.read_csv(Path(data_root, health_qst_file), delimiter='\t')


In [17]:
comorb = diags1[np.concatenate([baseline_var_short,version_form, comorb_vars])].drop(0)
assist = demog[np.concatenate([baseline_var_short, assist_vars])].drop(0)
prev_med = health_qst[np.concatenate([baseline_var_short, prev_med_vars])].drop(0)
init_acc = init_sat[np.concatenate([baseline_var_short, accept_vars])].drop(0)

med_list = ['comorb', 'assist', 'prev_med', "init_sat"]
med_dict = dict(zip(med_list, [comorb, assist, prev_med, init_acc]))

In [18]:
comorb['anx'] = pd.NA
anx_vars = [ 'pso', 'psoi', 'pag', 'pagi', 'pga', 'pgai', 'psa', 'psai']
mask = (comorb[anx_vars].fillna(0) != 0).any(axis=1)  # Extract if any anxiety disdoers, comorb anx = True 
comorb.loc[mask, 'anx'] = 1
mask = (comorb[anx_vars] == 0).all(axis=1)  # If none of the disorder, Comorb anx = False, else NaN 
comorb.loc[mask, 'anx'] = 0
comorb = comorb.drop(columns=anx_vars)

In [19]:
timepoints = [50, 213, 578, 912] 

In [20]:
def find_unique_subjects(df, lower_bound_time=None, select='min'):
    if lower_bound_time is not None:
        df = df.loc[df['days_baseline'] > lower_bound_time].copy()
    
    # Fix for the error - Using value_counts and filtering properly
    duplicates = df['src_subject_id'].value_counts()[df['src_subject_id'].value_counts() > 1].index
    duplicates_df = df[df['src_subject_id'].isin(duplicates)]

    if select == 'min':
        idx_to_keep = duplicates_df.groupby('src_subject_id')['days_baseline'].idxmin()
    elif select == 'max':
        idx_to_keep = duplicates_df.groupby('src_subject_id')['days_baseline'].idxmax()
    else:
        raise ValueError('Please specify how to select unique subjects')
    
    return df.loc[df.index.difference(duplicates_df.index).union(idx_to_keep)]

def get_unique_subjects_split(df, timepoints_unique = None, select='min', timepoints_split = None):
    df_split = prep.split_data_from_timepoints(df, timepoints_split)
    
    if timepoints_unique is not None:
        df_split_unique = [
            find_unique_subjects(df_timepoint, time, select) for df_timepoint, time in zip(df_split.values(), timepoints_unique)
        ]
    else:
        df_split_unique = [
            find_unique_subjects(df_timepoint, None, select) for df_timepoint in df_split.values()
        ]
    
    for df_ in df_split_unique:
        if (df_['src_subject_id'].value_counts() > 1).sum() == 0:
            print('Success')
        else:
            print('Found {} duplicates remaining.'.format((df_['src_subject_id'].value_counts() > 1).sum()))
    
    return df_split_unique


In [21]:
def set_baseline_dtypes_reduced(df, dtypes_baseline):
    for col in dtypes_baseline.keys():
        if col in df.columns:
            dtype = dtypes_baseline[col]
            df[col] = df[col].astype(dtype)
    print(df.dtypes)

In [22]:
dtypes_baseline
for col in dtypes_baseline.keys():
    if col in comorb.columns:
        dtype = dtypes_baseline[col]
        comorb[col] = comorb[col].astype(dtype)

In [23]:
for med in med_dict.values():
    set_baseline_dtypes_reduced(med, dtypes_baseline)

src_subject_id    object
days_baseline      int64
version_form      object
cdorodd           object
pso               object
psoi              object
pag               object
pagi              object
pga               object
pgai              object
psa               object
psai              object
anx               object
dtype: object
src_subject_id    object
days_baseline      int64
demo61            object
dtype: object
src_subject_id    object
days_baseline      int64
hi_24             object
dtype: object
src_subject_id    object
days_baseline      int64
d2dresp           object
dtype: object


In [24]:
timepoints

[50, 213, 578, 912]

In [25]:
comorb_split = prep.split_data_from_timepoints(comorb, timepoints)
for data in comorb_split.values():
    print(data['days_baseline'].max())

26
212
576
906


In [26]:

comorb_split_unique = [ find_unique_subjects(comorb_timepoint, None, select='max') for comorb_timepoint, time in zip(comorb_split.values(), timepoints) ]
for df in comorb_split_unique:
    print((df['src_subject_id'].value_counts() > 1).sum())

0
0
0
0


In [27]:
for data in comorb_split_unique:
    print(data['days_baseline'].max())

26
212
576
906


In [28]:
#[213, 578, 912, 1095]

In [29]:
comorb_split_unique = get_unique_subjects_split(comorb, None, select='max', timepoints_split=timepoints)
prev_med_unique = find_unique_subjects(prev_med)
assist_unique = find_unique_subjects(assist)

Success
Success
Success
Success


In [30]:
unique_med_dict = dict(zip(med_list, [comorb_split_unique, assist_unique, prev_med, init_sat]))

## Comorb is only one with different values per timepoint

In [31]:
print([df.columns for df in comorb_split_unique] )

[Index(['src_subject_id', 'days_baseline', 'version_form', 'cdorodd', 'anx'], dtype='object'), Index(['src_subject_id', 'days_baseline', 'version_form', 'cdorodd', 'anx'], dtype='object'), Index(['src_subject_id', 'days_baseline', 'version_form', 'cdorodd', 'anx'], dtype='object'), Index(['src_subject_id', 'days_baseline', 'version_form', 'cdorodd', 'anx'], dtype='object')]


In [32]:
comorb_clean = [df.drop(columns= ['days_baseline', "version_form"]) for df in comorb_split_unique]

In [33]:
dfs_clean = [[df.drop(columns='days_baseline') for df in [ prev_med_unique, assist_unique, init_acc]]for comorb in comorb_clean]

In [34]:
for i, listt in enumerate(dfs_clean):
    listt.insert(0, comorb_clean[i])
    

In [35]:
def merge_data_mediators(data_list_1_timpoint, list_mediator_df):
    df = [pd.merge(data, dfs_clean[0], on= 'src_subject_id') for data in list_mediator_df]
    df2 = [pd.merge(data, dfs_clean[1], on= 'src_subject_id') for data in df]
    df3 = [pd.merge(data, dfs_clean[2], on= 'src_subject_id') for data in df2]
    df4 = [pd.merge(data, dfs_clean[3], on= 'src_subject_id') for data in df3]
    return df4

In [36]:
dfs_clean[0][0]

Unnamed: 0,src_subject_id,cdorodd,anx
6664,P1001,0.0,0
6666,P1002,,0
6669,P1003,1.0,1
6673,P1004,0.0,1
6677,P1005,1.0,0
...,...,...,...
9324,P1864,0.0,0
9326,P1865,0.0,0
9328,P1866,0.0,0
9330,P1867,0.0,0


In [37]:

qsts_dict_merged ={}
for qst in qsts : 
    time_dict_merged = {}
    for time_key in data_dict[qst].keys():

        for listt in dfs_clean:
            df_merged = data_dict[qst][time_key]
            for df in listt:
                df_merged = pd.merge(df_merged, df, on= 'src_subject_id')
        time_dict_merged[time_key] = df_merged
    qsts_dict_merged[qst]  =  time_dict_merged          
        

In [38]:
for qst in qsts:
    for time_key in data_dict[qst].keys():
        print(qst, time_key, qsts_dict_merged[qst][time_key].keys())
        qsts_dict_merged[qst][time_key].to_csv(Path(data_derived, str(qst) + "_" + time_key + "_mediators.csv"))

snap b Index(['src_subject_id', 'interview_date', 'interview_age', 'sex', 'snainatx',
       'snahypax', 'snaoddx', 'snahix', 'days_baseline', 'site',
       'version_form', 'trtname', 'cdorodd', 'anx', 'hi_24', 'demo61',
       'd2dresp'],
      dtype='object')
snap 14 Index(['src_subject_id', 'interview_date', 'interview_age', 'sex', 'snainatx',
       'snahypax', 'snaoddx', 'snahix', 'days_baseline', 'site',
       'version_form', 'trtname', 'cdorodd', 'anx', 'hi_24', 'demo61',
       'd2dresp'],
      dtype='object')
snap 24 Index(['src_subject_id', 'interview_date', 'interview_age', 'sex', 'snainatx',
       'snahypax', 'snaoddx', 'snahix', 'days_baseline', 'site',
       'version_form', 'trtname', 'cdorodd', 'anx', 'hi_24', 'demo61',
       'd2dresp'],
      dtype='object')
snap 36 Index(['src_subject_id', 'interview_date', 'interview_age', 'sex', 'snainatx',
       'snahypax', 'snaoddx', 'snahix', 'days_baseline', 'site',
       'version_form', 'trtname', 'cdorodd', 'anx', 'hi_2