In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from tqdm import tqdm
import seaborn as sns

import pandas as pd
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", 100)
pd.set_option("display.min_rows", 50)
pd.set_option("display.precision", 8)
from pandas.api.types import is_numeric_dtype

import joblib 
import json
import time
from tableone import TableOne
from scipy import sparse
pd.options.mode.chained_assignment = None 

data_dir ='./data'
basedir = os.getcwd()

In [None]:
import importlib
import clinprediction.omop_fx
importlib.reload(clinprediction.omop_fx)

import clinprediction.patient_fx
importlib.reload(clinprediction.patient_fx)

import clinprediction.match_fx
importlib.reload(clinprediction.match_fx)

import clinprediction.model_fx
importlib.reload(clinprediction.match_fx)

## Parameters

In [None]:
load_options = None # ['testdate','AD','control'] 
in_dir = basedir + 'data/'
pdir = basedir + 'cohort_selection/'
output_dir = basedir + 'cohort_selection_out/'

if load_options is not None:
    options = load_in_options(*load_options)
    odir_main = options['odir_main']
    demo_cols = options['demo_cols']
    cols_match_num = options['match_params']['cols_match_num']
    cols_match_vis = options['match_params']['cols_match_vis']
    cols_match_cat = options['match_params']['cols_match_cat']
else: 
    demo_cols = ['Sex','date_age'] # demographic columns to use in model
    cols_match_num = ['year_of_birth', 'min_date_age', 'date_age', 'yrs_in_ehr'] # numerical columns to use in matching
    cols_match_vis = ['log_n_prev_visits','log_n_concepts','log_neg_earliestday'] # visit related columns to use in matching
    cols_match_cat = ['Sex', 'RaceEthnicity'] # categorical column to use in matching

    options = {
    'ndate': 'testdate',        # date identifier
    'input_dir': in_dir,      # input directory
    'output_dir': output_dir, # output directory
    'ratio': 8,               # ratio for control matching
    'timefilt_range': [-365*7, -365*5, -365*3, -365*1, -1],
    'dxgroup': 'AD',
    'comparison':'control',
    'match_cohorts': True,
    'match_params': {'cols_match_cat': cols_match_cat, 'cols_match_num': cols_match_num, 
                     'cols_match_vis':cols_match_vis,
                     'match_scale_features':False, 'match_with_replacement':False,
                     'match_with_pscore': True},
    'random_state': 500, 'random_seed': 500,
    'cv': 5,  
    }

    odir_main = output_dir + '{}_{}v{}/'.format(options['ndate'], options['dxgroup'], options['comparison'])
    options['odir_main'] = odir_main
    os.makedirs(odir_main, exist_ok = True)

    # Save options
    with open(odir_main + 'options.json', 'w') as fp:
        json.dump(options, fp)

## Load in Data

In [None]:
from clinspokeprediction.omop_fx import OMOPData

start = time.time()
omopdata = OMOPData(omopdir = 'data/omop/')

# Read in OMOP information for patients of interest.
try: 
    omopdata.load_compressed(pdir)
except: 
    omopdata.read_in_omop_csv(directory = pdir, read_in_controls = True)
    omopdata.save_compressed()
    
# Concept look-up dictionary
conceptdict = omopdata.concepts.set_index('concept_id')['concept_name'].to_dict()

print(omopdata.size_of_data_col('iscontrol'))
print('Finished reading in OMOP data. took {} minutes'.format((time.time() - start) / 60))

In [None]:
# load in patients, basic filtering
from clinspokeprediction.patient_fx import read_in_patients, age_visit_timefilt, filter_pts
from clinspokeprediction.match_fx import match_patients

timefilt_min = np.array(options['timefilt_range']).min()

if 'allpts_timefiltmin_file' in options:
    allpts = pd.read_csv(options['allpts_timefiltmin_file'])
else: 
    # read in patients
    cohortpts, controlpts, all_visits = read_in_patients(options)
    print('cohortpts shape: {}\ncontrolpts shape: {}'.format(cohortpts.shape, controlpts.shape))

    # cohort mindatept: first AD, dementia, or cognitive drug. 
    cohortpts, controlpts = filter_pts(cohortpts, controlpts)
    print('cohortpts shape: {}\ncontrolpts shape: {}'.format(cohortpts.shape, controlpts.shape))
    controlpts[options['dxgroup']]=0 # 0 for the label

    # filter by index 0 age
    index0age_thresh = 55
    print('remove if index 0 age is less than {} years'.format(index0age_thresh))
    print('remove {} cases'.format((cohortpts.mindatept_age >= index0age_thresh).sum()))
    cohortpts = cohortpts[cohortpts.mindatept_age >= index0age_thresh]
    print('remove {} controls'.format((controlpts.mindatept_age >= index0age_thresh).sum()))
    controlpts = controlpts[controlpts.mindatept_age >= index0age_thresh]

    print('cohortpts shape: {}\ncontrolpts shape: {}'.format(cohortpts.shape, controlpts.shape))

    cohort_dxf = 1-cohortpts.set_index('person_id')['dxf']\
        .apply(lambda x: pd.Series(x.split(','))).unstack().dropna()\
        .reset_index().pivot(index = 'person_id', columns = 0, values = 0)\
        .isna().astype(int)

    # look at cohort
    TableOne(cohortpts[cohortpts[dxgroup]==1]\
                 .merge(cohort_dxf[\
                                   np.setdiff1d(cohort_dxf.columns,cohortpts.columns)].reset_index(), 
                                     on = 'person_id', how = 'left', validate = '1:1'),
             columns = ['Sex','RaceEthnicity','n_visits','e_age','mindatept_age',
                        '90old','lds_birthyear'] + list(cohort_dxf.columns))

    # filter patient based upon minimum time point for model
    timefilt_min = np.array(options['timefilt_range']).min()

    # get information for patients using earliest timepoint of interest 
    print('earliest timepoint: ', timefilt_min)
    cohortpts, controlpts = age_visit_timefilt(cohortpts, controlpts, all_visits, timefilt_min)

    # Use patients with 'timefilt_min' visits for prediction model
    cohortpts = cohortpts[cohortpts['date_age'] >= cohortpts['min_date_age']]
    controlpts = controlpts[controlpts['date_age'] >= controlpts['min_date_age']]
    print('cohortpts shape: {}\ncontrolpts shape: {}'.format(cohortpts.shape, controlpts.shape))
    allpts = cohortpts.append(controlpts)

    allpts.to_csv(odir_main + 'allpts_timefiltmin.csv', index = False)
    options['allpts_timefiltmin_file'] = odir_main + 'allpts_timefiltmin.csv'
    save_updated_options(options)

display_table(allpts.reset_index(), groupby = 'AD', options = options)

In [None]:
# Train-test split
tosplit_person_id = allpts['person_id'].to_numpy() # all person_id
tosplit_dxgroup = allpts[options['dxgroup']].to_numpy() # all labels

if 'person_id_train_all_file' not in options:
    person_id_train, person_id_test, _, _ = \
    train_test_split(tosplit_person_id, tosplit_dxgroup, stratify = tosplit_dxgroup, 
                     test_size = .3, random_state = options['random_state']) 

    person_id_train = np.sort(person_id_train)
    person_id_test = np.sort(person_id_test)

    options['person_id_train_all_file'] = odir_main + 'train_personid_all.npy'
    np.save(options['person_id_train_all_file'], person_id_train)
    
    options['person_id_test_all_file'] = odir_main + 'test_personid_all.npy'
    np.save(options['person_id_test_all_file'], person_id_test)
else:
    person_id_train = np.load(options['person_id_train_all_file'])
    person_id_test = np.load(options['person_id_test_all_file'])
    
print('train number pid:', len(person_id_train))
print('test number pid:', len(person_id_test))

## Models for each time prior to AD onset

For each time point prior to onset, we preprocess the data and run a random forest model.\
For interpretation purposes, we then match our training cohort and run another model.

In [None]:
from clinspokeprediction.omop_fx import filter_omopdata_by_time
save_long = True

for timefilt in options['timefilt_range']:
    odir_tf = options['odir_main'] + str(timefilt) + '/'
    os.makedirs(odir_tf, exist_ok = True)

    if os.path.isfile(odir_tf + 'omop_count_demo_visit.joblib'):
        print(odir_tf + 'omop_count_demo_visit.joblib exists.')
        print('getting information at timefilt:{}'.format(timefilt))
        omop_pt_tf_input = joblib.load(odir_tf + 'omop_count_demo_visit.joblib')
        allpt_tf = omop_pt_tf_input['allpt_tf']
        allpt_tf_train = omop_pt_tf_input['allpt_tf_train'] 
        allpt_tf_test = omop_pt_tf_input['allpt_tf_test']
        pts_train_tf = omop_pt_tf_input['train_pts']
        pts_test_tf = omop_pt_tf_input['test_pts']
        feat_concepts = omop_pt_tf_input['feat_concepts']

        allptomop = pd.read_csv(odir_tf + 'patient_sentence_long.csv')
        del omop_pt_tf_input
    else: 
        print('getting information at timefilt:{}'.format(timefilt))
        cohortpts_tf, controlpts_tf = age_visit_timefilt(cohortpts, controlpts, all_visits, timefilt)
        pts_tf = cohortpts_tf.append(controlpts_tf)

        if options['demo_cols'] is not None:
            demo_data = pts_tf.reset_index()[['person_id']+ options['demo_cols']]
            demo_data = pd.get_dummies(demo_data).set_index('person_id')
        else: demo_data = None

        pts_train_tf = np.intersect1d(pts_tf.person_id, person_id_train)
        pts_test_tf = np.intersect1d(pts_tf.person_id, person_id_test)

        print('train number pid: ',len(pts_train_tf))
        print('test number pid: ', len(pts_test_tf))

        print('extracting OMOP information: conditions, drugs, measures... ')
        (conditions, drugs, measures) = filter_omopdata_by_time(omopdata, pts_tf, timefilt = timefilt)
        measures_ab = process_abnormal_measures(measures)

        # turn into long format
        allptomop = drugs.rename({'drug_concept_id':'concept_id'},axis=1)\
            .append(conditions.rename({'condition_concept_id':'concept_id'},axis=1))\
            .append(measures_ab.rename({'measurement_concept_id':'concept_id'},axis=1))\
            [['person_id','concept_id','datediff']]

        # add domain
        allptomop = allptomop.merge(omopdata.concepts[['concept_id','domain_id']], 
                                    on = 'concept_id', how = 'left')

        # feature space
        feat_concepts = np.sort(allptomop.concept_id.unique())

        # now that the concepts are filtered, count number of concepts, number of concepts per domain, and earliest "date" for an entry
        allpt_omopcount = allptomop.groupby('person_id')['concept_id'].nunique().to_frame('n_concepts')\
            .merge(allptomop.groupby('person_id')['datediff'].min().dt.days.rename('earliest_day'), 
               left_index = True, right_index = True, how = 'outer')\
            .merge(allptomop.groupby('person_id')['domain_id'].value_counts().unstack().fillna(0), 
               left_index = True, right_index = True, how = 'outer')

        # combine patient info with omop and demographics
        allpt_tf = pts_tf[['person_id','min_date_age','date_age', 
                      'yrs_in_ehr','n_prev_visits', 'Sex',
                      'RaceEthnicity', 'year_of_birth', 
                      dxgroup]].set_index('person_id').sort_index()\
                    .merge(allpt_omopcount, left_index = True, # combine with counts
                           right_index = True, how = 'left')\
                    .merge(demo_data, left_index = True, # combine with demographics
                           right_index = True, suffixes = ('','_'), how = 'left')

        # preprocess again
        allpt_tf.loc[allpt_tf.yrs_in_ehr > 50, 'yrs_in_ehr'] = 50
        allpt_tf.loc[:,'log_n_prev_visits'] = allpt_tf.n_prev_visits.apply(lambda x: np.log(x+.1))
        allpt_tf.loc[:,'log_n_concepts'] = allpt_tf.n_concepts.apply(lambda x: 0 if pd.isna(x) else np.log(x+.1))
        allpt_tf.loc[:,'log_neg_earliestday'] = allpt_tf.earliest_day.fillna(.1)\
                        .apply(lambda x: 0 if pd.isna(x) else np.log(-x))

        print('AD after computing timepoint info...{}'.format(timefilt))
        display(pts_tf.AD.value_counts(dropna = False)) # after computing timepoint info

        print()
        print('AD after merging with omop info...{}'.format(timefilt))
        display(allpt_tf.AD.value_counts(dropna = False)) # after merging with omop info

        print('get train/test')
        # get info on all patients 
        allpt_tf_train = allpt_tf.loc[pts_train_tf] # TAKE TRAINING PATIENTS
        allpt_tf_test = allpt_tf.loc[pts_test_tf] # TAKE TESTING PATIENTS

        print()
        print('AD in train{}'.format(timefilt))
        display(allpt_tf_train.AD.value_counts(dropna = False))

        print()
        print('AD in test{}'.format(timefilt))
        display(allpt_tf_test.AD.value_counts(dropna = False))

        if save_long:
            allptomop.merge(allpt_tf[['min_date_age','date_age','yrs_in_ehr',
                    'n_prev_visits','Sex_Female','RaceEthnicity','AD']],
                    left_on = 'person_id', right_index = True, how = 'left')\
                    .to_csv(odir_tf + 'patient_sentence_long.csv', index = False)

        joblib.dump({'allpt_tf':allpt_tf, 'allpt_tf_train': allpt_tf_train, 
                 'allpt_tf_test':allpt_tf_test,
                 'train_pts':pts_train_tf, 'test_pts':pts_test_tf,
                  'feat_concepts':feat_concepts},
                  odir_tf + 'omop_count_demo_visit.joblib')
        
            #### INITIAL MODELS WITHOUT MATCHING

    if timefilt != timefilt_min:
        if 'train_personid_mintimefilt_file' not in options:
            options['train_personid_mintimefilt_file'] = odir_main + 'train_personid_mintimefilt.npy'
            options['test_personid_mintimefilt_file'] = odir_main + 'test_personid_mintimefilt.npy'
            save_updated_options(options)
        train_personid = np.load(options['train_personid_mintimefilt_file'])
        test_personid = np.load(options['test_personid_mintimefilt_file'])

        print('train:',pts_train_tf.shape, train_personid.shape)
        print('test:',pts_test_tf.shape, test_personid.shape)
    else:
        train_personid = np.load(options['train_personid_mintimefilt_file'])
        test_personid = np.load(options['test_personid_mintimefilt_file'])


    # Training data preparation
    if os.path.isfile(odir_tf + 'model_unmatched_input_data.joblib'):
        print('loading in saved model inputs... ')
        unmatched_model_inputs = joblib.load(odir_tf + 'model_unmatched_input_data.joblib')

        X_train = unmatched_model_inputs['X_train']
        X_test = unmatched_model_inputs['X_test']
        feature_names = unmatched_model_inputs['feature_names']
        varthresh = unmatched_model_inputs['varthresh']
        y_train = unmatched_model_inputs['y_train']
        y_test = unmatched_model_inputs['y_test']

        print('variance thresholded n features:', varthresh.get_support(1).shape)
        feature_names_var = feature_names[varthresh.get_support(1)]
        feature_name_info = feature_names.to_frame('concept_id').rename_axis('')\
            .merge(omopdata.concepts.groupby('concept_id').head(1), 
                   on = 'concept_id', how = 'left').set_index('concept_id')
        print('X_train shape: {}, X_test shape: {}'.format(X_train.shape, X_test.shape))
        print('length of y_train: {}. \n\tsum of y_train: {}. \n\tmean of y_train: {:0.07f}'.format(\
                        len(y_train), y_train.sum(), y_train.mean()))
        print('length of y_test: {}. \n\tsum of y_test: {}. \n\tmean of y_test: {:0.07f}'.format(\
                    len(y_test), y_test.sum(), y_test.mean()))
    else: 
        print('preparing X_train')
        allptomop_pivot_train = pivot_omop(allptomop, pts_train_tf)
        if 'train_personid_mintimefilt_file' in options:
            X_train = allptomop_pivot_train.loc[train_personid]
            print(X_train.shape)
        else: 
            X_train = allptomop_pivot_train
            #X_train = X_train[X_train.sum(axis=1)>=4]
            print('X_train: ', X_train.shape)
            train_personid = X_train.index

        feature_names = X_train.columns

        # remove features with 0 variance
        varthresh = VarianceThreshold().fit(X_train)
        print('variance thresholded n features:', varthresh.get_support(1).shape)

        feature_names_var = feature_names[varthresh.get_support(1)]
        feature_name_info = feature_names.to_frame('concept_id').rename_axis('')\
            .merge(omopdata.concepts.groupby('concept_id').head(1), 
                   on = 'concept_id', how = 'left').set_index('concept_id')
    
        # Test data prep
        print('preparing X_test')
        allptomop_pivot_test = pivot_omop(allptomop, pts_test_tf, feature_names_var)
        if 'test_personid_mintimefilt_file' in options:
            X_test = allptomop_pivot_test.loc[test_personid, feature_names_var]
            print(X_test.shape)
        else:
            X_test = allptomop_pivot_test 
            #X_test = X_test[X_test.sum(axis=1)>=4]
            test_personid = X_test.index
            print('X_test: ', X_test.shape)

        if timefilt == options['timefilt_range'][0]:
            print('Save patients with concepts if timefilt_min')
            options['train_personid_mintimefilt_file'] = odir_main + 'train_personid_mintimefilt.npy'
            options['test_personid_mintimefilt_file'] = odir_main + 'test_personid_mintimefilt.npy'
            np.save(options['train_personid_mintimefilt_file'], train_personid)
            np.save(options['test_personid_mintimefilt_file'], test_personid)
            save_updated_options(options)

        print('load y_train and y_test')
        y_train = allpt_tf.loc[train_personid][dxgroup].to_numpy()
        print('length of y_train: {}. \n\tsum of y_train: {}. \n\tmean of y_train: {:0.07f}'.format(\
                        len(y_train), y_train.sum(), y_train.mean()))
        y_test = allpt_tf.loc[test_personid][dxgroup].to_numpy()
        print('length of y_test: {}. \n\tsum of y_test: {}. \n\tmean of y_test: {:0.07f}'.format(\
                    len(y_test), y_test.sum(), y_test.mean()))

        print('save.')
        joblib.dump({'X_train':X_train, 'X_test':X_test, 'y_train':y_train, 'y_test':y_test, 
                 'feature_names':feature_names,'varthresh':varthresh},
                  odir_tf + 'model_unmatched_input_data.joblib')
    
    allpt_tf['visits_per_yr'] = allpt_tf['n_prev_visits']/(allpt_tf['yrs_in_ehr']+.1)
    # demographics of updated pts? 
    print('train:',pts_train_tf.shape, train_personid.shape)
    print('test:',pts_test_tf.shape, test_personid.shape)
    display_table(allpt_tf.loc[np.concatenate((train_personid,test_personid))].reset_index(),  groupby = 'AD', options = options)
    N_FEATURES = varthresh.get_support().sum()
    X_train2 = varthresh.transform(X_train)
    X_train_sparse = sparse.csr_matrix(X_train2)
    X_test2 = X_test.to_numpy()

    ### RANDOM FOREST MODEL
    fname = odir_tf+'rf_unmatched_model.joblib'
    if os.path.isfile(fname):
        print(fname, 'exists, loading in... ')
        rf_unmatched_dict = joblib.load(fname)
        rf_feat_import = feature_context(rf_unmatched_dict['feat_import'], feature_name_info, 
                         import_col = 'rf_import', modelkind = 'rf_unmatched')
    else: 
        np.random.seed(1100)
        pt_choice = np.concatenate((np.where(y_train)[0], 
                        np.random.choice(np.where(1-y_train)[0], int(y_train.sum())*ratio)))

        rf_unmatched_dict = rf_model(X_train2[pt_choice], y_train[pt_choice],
                     X_test2, y_test, feature_names_var, options, odir_tf = odir_tf, modelsuffix = '_unmatched',
                     n_its = 300)

        rf_feat_import_unmatched = feature_context(rf_unmatched_dict['feat_import'], feature_name_info, 
                            import_col = 'rf_import', modelkind = 'rf_unmatched', odir_tf = odir_tf)
        
    joblib.dump(rf_unmatched_dict, fname)
    
    
    ## PREPROCESS MATCHED PATIENTS
    cols_match_num = options['match_params']['cols_match_num'] + options['match_params']['cols_match_vis']
    cols_match_cat = options['match_params']['cols_match_cat']
    
    # Training data prep
    if os.path.isfile(odir_tf + 'cohort_control_pt_train_tf.joblib'):
        print('loading in saved model inputs... ')
        loaded_matched_pts = joblib.load(odir_tf + 'cohort_control_pt_train_tf.joblib')

        cohortpts_tf_train = loaded_matched_pts['cohortpts_tf_train']
        controlpts_tf_train = loaded_matched_pts['controlpts_tf_train']
        allpt_tf_matched = loaded_matched_pts['allpt_tf_matched']

        del loaded_matched_pts
    else: 
        cohortpts_tf_train = allpt_tf_train[allpt_tf_train[dxgroup]==1].reset_index()
        controlpts_tf_train = allpt_tf_train[allpt_tf_train[dxgroup]==0].reset_index()
        print('cohortpts shape (prior): {}\ncontrolpts shape: {}'.format(cohortpts_tf_train.shape, 
                                                                 controlpts_tf_train.shape))

        cohortpts_tf_train = cohortpts_tf_train[cohortpts_tf_train.person_id.isin(train_personid)]
        controlpts_tf_train = controlpts_tf_train[controlpts_tf_train.person_id.isin(train_personid)]
        print('cohortpts shape: {}\ncontrolpts shape: {}'.format(cohortpts_tf_train.shape, 
                                                                 controlpts_tf_train.shape))

        # match patients
        cohortpts_tf_train, controlpts_tf_train, _ = \
                match_patients(cohortpts_tf_train, controlpts_tf_train, dxgroup, 
                            cols_match_cat = cols_match_cat, 
                            cols_match_num = cols_match_num, 
                           ratio = options['ratio'], return_split = True)

        allpt_tf_matched = cohortpts_tf_train.append(controlpts_tf_train).set_index('person_id')
        joblib.dump({'cohortpts_tf_train': cohortpts_tf_train, 'controlpts_tf_train': controlpts_tf_train, 
                      'allpt_tf_matched': allpt_tf_matched}, odir_tf + 'cohort_control_pt_train_tf.joblib')

        # plot table
        mytable = TableOne(allpt_tf_matched, columns = cols_match_num + cols_match_cat, 
                    groupby=dxgroup, categorical = cols_match_cat, smd = True, 
                       pval = True);
        display(mytable)
        mytable.to_csv(odir_tf + 'allpt_train_tf_matched.csv')
            
    # Test data prep
    if os.path.isfile(odir_tf + 'cohort_control_pt_test_tf.joblib'):
        print('loading in saved model inputs... ')
        loaded_matched_pts = joblib.load(odir_tf + 'cohort_control_pt_test_tf.joblib')

        cohortpts_tf_test = loaded_matched_pts['cohortpts_tf_test']
        controlpts_tf_test = loaded_matched_pts['controlpts_tf_test']
        allpt_tf_test_matched = loaded_matched_pts['allpt_tf_test_matched']

        del loaded_matched_pts
    else: 
        # match test patients
        cohortpts_tf_test = allpt_tf_test[allpt_tf_test[dxgroup]==1].reset_index()
        controlpts_tf_test = allpt_tf_test[allpt_tf_test[dxgroup]==0].reset_index()
        print('cohortpts shape (prior): {}\ncontrolpts shape: {}'.format(cohortpts_tf_test.shape, controlpts_tf_test.shape))
        cohortpts_tf_test = cohortpts_tf_test[cohortpts_tf_test.person_id.isin(test_personid)]
        controlpts_tf_test = controlpts_tf_test[controlpts_tf_test.person_id.isin(test_personid)]
        print('cohortpts shape: {}\ncontrolpts shape: {}'.format(cohortpts_tf_test.shape, controlpts_tf_test.shape))

        # match patients
        cohortpts_tf_test, controlpts_tf_test, _ = match_patients(cohortpts_tf_test, controlpts_tf_test, dxgroup, 
                            cols_match_cat = cols_match_cat, cols_match_num = cols_match_num, 
                           ratio = options['ratio'], return_split = True)
        allpt_tf_test_matched = cohortpts_tf_test.append(controlpts_tf_test).set_index('person_id')

        joblib.dump({'cohortpts_tf_test': cohortpts_tf_test, 'controlpts_tf_test': controlpts_tf_test, 
                 'allpt_tf_test_matched': allpt_tf_test_matched},
                 odir_tf + 'cohort_control_pt_test_tf.joblib')

        # plot table
        mytable = TableOne(allpt_tf_test_matched, columns = cols_match_num + cols_match_cat, 
                    groupby=dxgroup, categorical = cols_match_cat, smd = True, 
                       pval = True);
        display(mytable)
        mytable.to_csv(odir_tf + 'allpt_test_tf_matched.csv')
        
    try: allptomop_pivot_train
    except: allptomop_pivot_train = pivot_omop(allptomop, pts_train_tf)
    try: allptomop_pivot_test
    except: allptomop_pivot_test = pivot_omop(allptomop, pts_test_tf, feature_names_var)
    X_train = allptomop_pivot_train.loc[allpt_tf_matched.index]
    y_train = allpt_tf_matched.loc[allpt_tf_matched.index][dxgroup].to_numpy()

    print('matched X_train.shape ',X_train.shape)
    print('matched y_train.shape ', len(y_train))
    print('matched train:',pts_train_tf.shape, train_personid.shape, allpt_tf_matched.shape, X_train.shape)
    feature_names = X_train.columns
    train_personid_matched = allpt_tf_matched.index
    feature_name_info = feature_names.to_frame('concept_id').rename_axis('')\
        .merge(omopdata.concepts.groupby('concept_id').head(1), on = 'concept_id', how = 'left')\
        .set_index('concept_id')
    
    varthresh = VarianceThreshold().fit(X_train)
    N_FEATURES = varthresh.get_support().sum()
    print('varthresh n features: ', varthresh.get_support(1).shape)

    feature_names_var = feature_names[varthresh.get_support(1)]
    feature_name_info = feature_names.to_frame('concept_id').rename_axis('')\
        .merge(omopdata.concepts.groupby('concept_id').head(1), on = 'concept_id', how = 'left').set_index('concept_id')
    X_test = allptomop_pivot_test.loc[test_personid, # options['test_personid_mintimefilt'], 
                                  feature_names_var].fillna(0)
    y_test = allpt_tf.loc[test_personid][dxgroup].to_numpy()
    print('X_test.shape ', X_test.shape)
    X_train2 = varthresh.transform(X_train)
    X_test2 = X_test.to_numpy()
    X_train_sparse = sparse.csr_matrix(X_train2)

    ## MATCHED COHORT RANDOM FOREST MODELS
    fname = odir_tf+'rf_matched_model.joblib'
    if os.path.isfile(fname):
        print(fname, 'exists, loading in... ')
        rf_matched_dict = joblib.load(fname)
        rf_feat_import = feature_context(rf_matched_dict['feat_import'], feature_name_info, 
                         import_col = 'rf_import', modelkind = 'rf_matched')
    else: 
        rf_matched_dict = rf_model(X_train2, y_train,
                 X_test2, y_test, feature_names_var, options, 
                 odir_tf = odir_tf, modelsuffix = '_matched', 
                 n_its = 300) 

        rf_feat_import = feature_context(rf_matched_dict['feat_import'], feature_name_info, 
                import_col = 'rf_import', modelkind = 'rf_matched', odir_tf = odir_tf)
        
        ## updating features
        feature_context(rf_matched_dict['feat_import'], feature_name_info, 
                             import_col = 'rf_import', modelkind = 'rf_matched', odir_tf = odir_tf)