In [68]:
from pathlib import Path 
import numpy as np 
import matplotlib.pyplot as plt 
import pandas as pd
from collections import OrderedDict
import sys
import os
import seaborn as sns
import researchpy as rp
import statsmodels.formula.api as smf
import scipy.stats as stats
import warnings

from statsmodels.nonparametric.smoothers_lowess import lowess

#sys.path.append('/Users/alina/Desktop/MIT/code/ADHD/MTA/helper')
from helper import rr, prep, var_dict, plot

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
ylim_dict = {'snap' : [0,3], 'ssrs': [0,2], 'masc' : [0,3], 'pc': [0,4]}

In [36]:
trtnames = ['M', 'P', 'C', 'A']
timepoints = [46 , 168, 319, 500]
delta = 50
timepoints_range = [[time - delta, time + delta] for time in timepoints]
timepoints_range
y_lim = [0,3]
x_lim = [-10, 450]

In [37]:
save_path = Path("/Volumes/NO NAME/alina/MIT/mta/derived_data")

In [38]:
if Path('/Volumes/Samsung_T5/MIT/mta').exists():
    data_root =     '/Volumes/Samsung_T5/MIT/mta'
    data_derived = '/Volumes/Samsung_T5/MIT/mta/output/derived_data'
else: 
    data_root = '/Users/alina/Desktop/MIT/code/data'
    data_derived = '/Users/alina/Desktop/MIT/code/data/output/derived_data'

In [39]:

baseline_var = ['src_subject_id', 'interview_date', 'interview_age', 'sex', 'site', 'days_baseline']
dtypes_baseline = { 'src_subject_id' : 'str',
                    'interview_date': 'str' , 
                    'interview_age' : 'int64' ,
                    'sex' : 'str', 
                    'site' : 'int64' ,
                    'days_baseline':  'int64',
                    'version_form': 'str'}

version_form = ['version_form']

qsts = ['snap', 'ssrs', 'masc', 'pc','wechsler']#, 'masc']##, 'wechsler'] #masc to many missing data 

In [40]:
snap = pd.read_csv(Path(data_derived, 'snap_14_mediators.csv'))
ssrs = pd.read_csv(Path(data_derived, 'ssrs_14_mediators.csv'))
masc = pd.read_csv(Path(data_derived, 'masc_14_mediators.csv'))
pc = pd.read_csv(Path(data_derived, 'pc_14_mediators.csv'))
qsts =['snap', 'ssrs', 'masc', 'pc']
data_dict = dict(zip(['snap', 'ssrs', 'masc', 'pc'], [snap, ssrs, masc, pc]))

In [41]:
timepoints

[46, 168, 319, 500]

In [42]:
ssrs.loc[ssrs['version_form'].str.startswith('Teacher'), 'version_form'] = 'Teacher'
ssrs.loc[ssrs['version_form'].str.startswith('Parent'), 'version_form'] = 'Parent'

#wechsler = prep.get_data(Path(data_root, wechsler_file), columns= [baseline_var, wechsler_vars], treat_group= treat_group, set_dtypes= True, version_form= False, split_timepoints= None)
data_dict = dict(zip(qsts, [snap, ssrs, masc, pc]))#, wechsler]))
qsts =['snap', 'ssrs', 'masc', 'pc']

data_dict = dict(zip(qsts, [snap, ssrs, masc, pc]))
for qst in qsts:
    data_dict[qst] = data_dict[qst][data_dict[qst]['trtname'] != 'L']
    
    # Optionally, remove 'L' from the category list if 'trtname' is categorical
    if data_dict[qst]['trtname'].dtype.name == 'category':
        data_dict[qst]['trtname'] = data_dict[qst]['trtname'].cat.remove_categories(['L'])

# Verify the changes
for qst in qsts[:2]:
    print(data_dict[qst].trtname.unique())

['C' 'M' 'A' 'P']
['C' 'M' 'A' 'P']


In [43]:
#outcome variablles 
snap_vars = ['snainatx', 'snahix', 'snaoddx'] #inattention_mean, hyperactie mean
ssrs_vars = ['sspintx', 'ssptossx']# social skills mean, internalizing mean 
masc_vars = ['masc_masctotalt']
pc_vars = ['pcrcpax', 'pcrcprx'] # power assertion, personal closeness
#wechsler_vars = ['w1readb','w2math','w3spell' ]
outcomes = np.concatenate([snap_vars, ssrs_vars, masc_vars, pc_vars])

outcomes_dict  = {'snap' : snap_vars, 'ssrs' : ssrs_vars, 'masc': masc_vars, 'pc': pc_vars}#, 'wechsler': wechsler_vars}

In [44]:
interaction_predictors = ['days_baseline', 'site', 'trtname'] #time, site, treatment group

# mediator variables
comorb_mediators  = ['cdorodd' , 'pso', 'psoi', 'pag', 'pagi', 'pga', 'pgai' ,'psa'] #ODD/CD or anx excluding specific phobia 
services_mediators =  ['demo61'] #reciept of public assistance 
prev_med_mediators = ['hi_24'] #medication intake prior to study 

#moderator variables #
#accept_moderator = ['d2dresp'] # initail acceptance of treatment 
accept_moderator = ['d2dresp'] # initail acceptance of treatment # binary
raters = ['Teacher', 'Parent']

med_mod_list = np.concatenate([[comorb_mediators[0]], ['anx'], services_mediators, prev_med_mediators, accept_moderator, ['sex']])
med_mod_list

array(['cdorodd', 'anx', 'demo61', 'hi_24', 'd2dresp', 'sex'], dtype='<U7')

In [45]:
#outcome variables 
outcomes_written = ['SNAP Inattention', 'SNAP Hyperactivity', 'SNAP Hyperactivity-Impulsivity','SNAP Aggressive','SSRS Internalizing', 'SSRS Social Skills', 'MASC total Score', 'Parent-Child Power Assertion', 'Parent-Child Personal Closeness' ]
outcomes_dict_fig = dict(zip(outcomes, outcomes_written))
outcomes_dict_fig

# treatment names 
trt_dict = {'M': 'Medication Management', 'P': 'Behavioral Treatment', 'C': 'Combined Treatment', "A": "Assement (control)"}

#mediators 
med_written = ['CD or ODD', 'Anxiety', 'Public Assistance', 'Prior Medication', 'Initial Acceptance of Treatment Arm', 'Sex']
med_options = [['No', ''],['No', ''] ,['No', ''],['No', ''], ['Low', 'High'], ['Male', 'Female']]

# mediator variable names spelled out 
med_dict_fig = {} #for figure titles 
options_dict = {}
for i, med in enumerate(med_mod_list):
    med_dict_fig[med] = [med_options[i][j] + ' ' + med_written[i] for j in range(2)]
med_dict_fig

# outcome values for mediators and moderators 
values_possible = [[0,1],[0,1],[1,2],[1,2],[1,2,3,4,5,6] ,['M', 'F']]
values_possible

med_values = dict(zip(med_mod_list, values_possible))
med_values

{'cdorodd': [0, 1],
 'anx': [0, 1],
 'demo61': [1, 2],
 'hi_24': [1, 2],
 'd2dresp': [1, 2, 3, 4, 5, 6],
 'sex': ['M', 'F']}

In [46]:
data_dict['snap']

Unnamed: 0.1,Unnamed: 0,src_subject_id,interview_date,interview_age,sex,snainatx,snahypax,snaoddx,snahix,days_baseline,site,version_form,trtname,cdorodd,anx,hi_24,demo61,d2dresp
0,0,P1002,1997-05-02,155,M,2.00,1.17,1.63,1.56,-32,1,Parent,C,0.0,0.0,2.0,,
1,1,P1002,1997-06-03,120,M,1.00,0.50,1.25,0.78,0,1,Parent,C,0.0,0.0,2.0,,
2,2,P1002,1997-06-03,120,M,1.11,1.00,0.88,1.00,0,1,Parent,C,0.0,0.0,2.0,,
3,3,P1002,1997-09-19,124,M,0.33,0.33,0.63,0.44,108,1,Parent,C,0.0,0.0,2.0,,
4,4,P1002,1998-01-03,127,M,0.44,0.17,0.13,0.44,214,1,Parent,C,0.0,0.0,2.0,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5242,5242,P1842,1995-08-26,78,M,2.11,1.83,1.00,1.56,186,4,Teacher,M,,0.0,,2.0,3.0
5243,5243,P1844,1995-03-27,71,M,2.78,2.50,2.13,2.56,-28,4,Teacher,A,0.0,0.0,,1.0,3.0
5244,5244,P1844,1995-04-24,72,M,3.00,2.33,1.38,2.44,0,4,Teacher,A,0.0,0.0,,1.0,3.0
5245,5245,P1844,1995-10-04,77,M,2.11,1.00,0.13,1.00,163,4,Teacher,A,0.0,0.0,,1.0,3.0


In [47]:
gen_interact_formula = ' C(site) + C(trtname, Treatment(reference="A")) * days_baseline' #
re_formula = None
formulas, formulas_dict = rr.get_rr_formulas(gen_interact_formula, med_mod_list, outcomes_dict)


In [48]:
formulas

[[['snainatx ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * cdorodd',
   'snahix ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * cdorodd',
   'snaoddx ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * cdorodd'],
  ['snainatx ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * anx',
   'snahix ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * anx',
   'snaoddx ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * anx'],
  ['snainatx ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * demo61',
   'snahix ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * demo61',
   'snaoddx ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * demo61'],
  ['snainatx ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * hi_24',
   'snahix ~  C(site) + C(trtname, Treatment(reference="A")) * days_baseline * hi_24',
   'snaoddx ~  C(site) + C

In [49]:
def perform_rr_analysis(data_dict, interaction_predictors, formulas,raters, outcomes_dict , re_formula = None):

    cols = np.concatenate([['src_subject_id'], interaction_predictors])
    print(cols)
    results_s = {}
    for i, qst in enumerate(qsts) :
        if qst == 'snap' or qst == 'ssrs':
            results_rater = {}
            for rater in raters: 
                results_ = {}
                
                for j, med in enumerate(med_mod_list):
                    
                    for k, var in enumerate(outcomes_dict[qst]):
                        cols_total = np.concatenate([cols, [med, var]])
                        print(cols_total)
                        formula = formulas[i][j][k]
                        data = data_dict[qst][(data_dict[qst]['version_form'] == rater) & (data_dict[qst][cols_total].notna().all(axis=1))]
                        print(data.shape)
                        groups = data['src_subject_id']
                        re_formula = re_formula#"~days_baseline" # alow for random slope per subject 
                        result = smf.mixedlm(formula, data, groups = groups, re_formula= re_formula).fit()
                        results_[str(med) + '_' + str(var) ] = result
                        
                results_rater[rater] = results_
                
            results_s[qst] = results_rater
        elif qst == 'masc' or qst == 'pc':
            results_m = {}
            for i, qst in enumerate(['masc', 'pc']) :
                results_ = {}
                for j, med in enumerate(med_mod_list):

                    for k, var in enumerate(outcomes_dict[qst]):
                        cols_total = np.concatenate([cols, [med, var]])
                        print(cols_total)
                        formula = formulas[i+2][j][k]
                        data = data_dict[qst][cols_total].dropna()
                        groups = data['src_subject_id']
                        re_formula = re_formula #"~days_baseline"
                        result = smf.mixedlm(formula, data, groups = groups, re_formula= re_formula).fit()
                        results_[str(med) + '_' + str(var) ] = result
                results_m[qst] = results_
    return  {'snap': results_s['snap'], 'ssrs' : results_s['ssrs'], 'masc': results_m['masc'], 'pc': results_m['pc'] }

In [50]:
results = perform_rr_analysis(data_dict, interaction_predictors, formulas, raters, outcomes_dict)

['src_subject_id' 'days_baseline' 'site' 'trtname']
['src_subject_id' 'days_baseline' 'site' 'trtname' 'cdorodd' 'snainatx']
(1987, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'cdorodd' 'snahix']
(1987, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'cdorodd' 'snaoddx']
(1987, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'anx' 'snainatx']
(1900, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'anx' 'snahix']
(1900, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'anx' 'snaoddx']
(1900, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'demo61' 'snainatx']
(1837, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'demo61' 'snahix']
(1837, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'demo61' 'snaoddx']
(1837, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'hi_24' 'snainatx']
(279, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'hi_24' 'snahix']
(279, 18)
['src_subject_id' 'days_baseline' 'site' 'trtname' 'hi_2

In [51]:
cols_replication_pred = ['predicted_anx_snahix_P',
                    "predicted_hi_24_ssptossx_P",
                     'predicted_cdorodd_ssptossx_T'
                    'predicted_sex_snainatx_P',
                    'predicted_anx_sspintx_P',
                    "predicted_anx_masc_masctotalt",
                    'predicted_demo61_pcrcprx',
                    'predicted_d2dresp_snainatx_P']


cols_replication=['anx_snahix',
                    "hi_24_ssptoss",
                     'cdorodd_ssptossx'
                    'sex_snainatx',
                    'anx_sspintx',
                    "anx_masc_masctotal",
                    'demo61_pcrcprx',
                    'd2dresp_snainatx']


In [52]:
for qst in qsts:
    rater = raters[0] if (qst == 'snap' or qst == 'ssrs') else None
    result_ = results[qst][rater] if (qst == 'snap' or qst == 'ssrs') else results[qst]

    for key in result_.keys():

        if key in cols_replication: 
            splitt = key.split("_")
            if key.startswith('hi'):
                var_mod = prep.split_on_occurrence(key, '_', 2)[0]
                var_out = prep.split_on_occurrence(key, '_', 2)[1]
            else:
                var_mod = splitt[0]
                var_out = splitt[1]
            var_mod = "sex[T.M]" if var_mod == "sex" else var_mod

            print(key)
            result = result_[key]
      
            coeffs = rr.get_slopes(result, var_mod) #* result.params['days_baseline']
            var_mod = "sex" if var_mod == "sex[T.M]" else var_mod
            print(coeffs)
        

anx_snahix
           no_mod   yes_mod
trtname                    
P       -0.002441 -0.281845
M       -0.003245 -0.249815
C       -0.002749 -0.100784
A       -0.002199  0.222227
d2dresp_snainatx
           no_mod   yes_mod
trtname                    
P       -0.002034  0.069447
M       -0.002050  0.019125
C       -0.003444 -0.068737
A       -0.001878 -0.019833
anx_sspintx
           no_mod   yes_mod
trtname                    
P       -0.000614  0.375008
M       -0.000455  0.103049
C       -0.000272 -0.006755
A       -0.000250 -0.052124
demo61_pcrcprx
           no_mod   yes_mod
trtname                    
P        0.000697 -0.071556
M       -0.002082 -0.081107
C        0.000871  0.132020
A        0.000809 -0.039256


In [58]:
raters = ['Teacher', 'Parent']

In [59]:
rr.extract_prediction(results, data_dict, qsts, raters)


SNAP
SSRS
MASC
PC


In [62]:
timepoints_range = rr.get_timepoints_range(timepoints=timepoints, delta= 20)

In [75]:
data = data_dict["snap"]
var_mod = 'anx'
var_out = 'snahix'
plot_bool = False
pred_col_name = 'predicted_anx_snahix_P'
rater = "Parent"
window  = 50 
ylim = [0,3]
xlim = [-10,450]
type_plot = 'exp' 
# options type plot : ploy_fit, mov_av, smooth 

In [76]:
pnt_av = rr.get_point_av(data, pred_col_name, timepoints_range)

In [77]:
   
plot.plot_RR_results(data, var_mod, var_out, med_values, type_plot=type_plot, window= window,  trt_dict=trt_dict, rater=rater, show = plot_bool, save_path=save_path)

TypeError: plot_RR_results() missing 4 required positional arguments: 'timepoints', 'delta', 'med_dict_fig', and 'outcomes_dict_fig'