In [1]:
import sys 
sys.path.insert(0, '/home/yikuan/project/git/CPRD')

from utils.yaml_act import yaml_load
from utils.arg_parse import arg_paser
from CPRD.config.spark import spark_init, read_parquet, read_txt
import pyspark.sql.functions as F
from CPRD.functions import tables, merge, cohort_select,risk_prediction, modalities, MedicalDictionary, risk_prediction, predictor_extractor
from CPRD.functions import merge
from utils.utils import save_obj, load_obj
import pandas as pd
from pyspark.sql import Window
import seaborn as sns
import matplotlib.pyplot as plt

from CPRD.config.utils import cvt_str2time
from CPRD.config.utils import check_time
from CPRD.config.utils import RangeExtract

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [2]:
args = dotdict({'params': '/home/yikuan/project/git/CPRD/config/config.yaml'})
params = yaml_load(args.params)
spark_params = params['pyspark']
spark = spark_init(spark_params)
file = params['file_path']
# data_params = params['params']

  cfg = yaml.load(ymlfile)


# Get the cohort file

In [None]:
demographics = read_parquet(spark.sqlContext, '/home/shared/yikuan/HF_Valid/data/demographics.parquet/').drop("study_entry")

# preprocess diagnosis

In [3]:
# process diagnoses for exclusion criteria
diagnoses = read_parquet(spark.sqlContext, '/home/shared/yikuan/HF_Valid/data/diagnosis.parquet')
diag_cprd = diagnoses.filter(F.col('source')=='CPRD').select(['patid', 'eventdate', 'medcode']).withColumnRenamed('medcode', 'code')
diag_hes = diagnoses.filter(F.col('source')=='HES').select(['patid', 'eventdate', 'ICD']).withColumnRenamed('ICD', 'code')
diagnoses = diag_cprd.union(diag_hes)

In [4]:
medications = read_parquet(spark.sqlContext, '/home/shared/yikuan/HF_Valid/data/medications.parquet').select(['patid', 'prodcode', 'eventdate', 'bnfcode', 'code'])
medications = medications.withColumn('bnf6', F.col('bnfcode').substr(1, 6))

# Find patients who have hear failure in the records

In [None]:
condition_query = MedicalDictionary.MedicalDictionaryRiskPrediction(file, spark)

hf = condition_query.queryDisease(['Heart failure'], merge=True)
hf_code = {'hf': []}
for k,v in hf['heart failure'].items():
    hf_code['hf'].extend(v)
condition = hf_code['hf']

hf = diagnoses.filter(F.col('code').isin(*condition))
w = Window.partitionBy('patid').orderBy('eventdate')
hf = hf.withColumn('code', F.first('code').over(w)).groupBy('patid').agg(
    F.min('eventdate').alias('eventdate'),
    F.first('code').alias('code')
)

demographics = demographics.join(hf, 'patid', 'left').dropna()

# Selection criteria
1. 2 year before the incident HF
2. at least 2 years registration with GP at the baseline
3. at least 35 years old

In [None]:
demographics = demographics.withColumn('baseline', demographics.eventdate - F.expr('INTERVAL 2 YEARS'))
demographics = demographics.filter(F.col('baseline') > F.col('start'))

demographics.write.parquet('/home/shared/yikuan/Mo/demographics.parquet')

# Patient predictor selection
1. age
2. renal function
3. blood pressure
4. sodium
5. ejection fraction
6. gender
7. BNP
8. NYHA class
9. diabetes
10. BMI
11. excercise tolerance

In [None]:
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/demographics.parquet').select(['patid', 'gender', 'dob', 'baseline'])

def code_merge(code_dict, name):
    code = []
    for k,v in code_dict[name].items():
        code.extend(v)
    return code

predictor = predictor_extractor.PredictorExtractorBase()

# age calculation
age_cal = F.unix_timestamp('baseline', "yyyy-MM-dd") - F.unix_timestamp('dob', "yyyy-MM-dd")
cohort = cohort.withColumn('age', age_cal).withColumn('age', (F.col('age') / 3600 / 24 / 30/ 12).cast('integer'))

# blood pressure
bp = modalities.retrieve_systolic_bp_measurement(file, spark, duration=(1985, 2015), usable_range=(60, 200)).select(['patid', 'eventdate', 'systolic'])
bp = predictor.predictor_extract(bp, cohort, 'systolic', col_baseline='baseline', span_before_baseline_month=24, type='mean')
cohort = cohort.join(bp, 'patid', 'left')

# sodium
sodium = tables.retrieve_additional(dir=file['test'], spark=spark).filter(F.col('enttype') == 196)
sodium = RangeExtract(sodium, 'data2', (70, 200))
sodium = sodium.withColumn('eventdate', cvt_str2time(sodium, 'eventdate', year_first=True))
sodium = check_time(sodium, 'eventdate', time_a=1985, time_b=2015).select(['patid','eventdate','data2']).withColumnRenamed('data2', 'sodium').withColumn('sodium', F.col('sodium').astype("float"))
sodium = predictor.predictor_extract(sodium, cohort, 'sodium', col_baseline='baseline', span_before_baseline_month=24,type='mean')
cohort = cohort.join(sodium, 'patid', 'left')

# diabetes
diabetes = condition_query.queryDisease(['Diabetes'], merge=True)
diabetes = code_merge(diabetes, 'diabetes')
diabetes = predictor.predictor_check_exist(diabetes, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', 'diabetes')
cohort = cohort.join(diabetes, 'patid', 'left')

cohort.write.parquet('/home/shared/yikuan/Mo/predictor.parquet')

In [None]:
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/predictor.parquet')

# bmi
bmi = modalities.retrieve_bmi(file, spark, duration=(1985, 2015), usable_range=(16, 50))
bmi = predictor.predictor_extract(bmi, cohort, 'BMI', col_baseline='baseline', span_before_baseline_month=24, type='mean')
cohort = cohort.join(bmi, 'patid', 'left')

# hdl ratio
hdl_r = tables.retrieve_additional(dir=file['test'], spark=spark).filter(F.col('enttype') == 338)
hdl_r = RangeExtract(hdl_r, 'data2', (0, 50))

hdl_r = hdl_r.withColumn('eventdate', cvt_str2time(hdl_r, 'eventdate', year_first=True))
hdl_r = check_time(hdl_r, 'eventdate', time_a=1985, time_b=2015).select(['patid','eventdate','data2']).withColumnRenamed('data2', 'hdl_r').withColumn('hdl_r', F.col('hdl_r').astype("float"))
hdl_r = predictor.predictor_extract(hdl_r, cohort, 'hdl_r', col_baseline='baseline', span_before_baseline_month=24,type='mean')
cohort = cohort.join(hdl_r, 'patid', 'left')

# Chronic kidney disease (stage 4 or 5) and major chronic renal disease (including nephrotic syndrome, chronic glomerulonephritis, chronic pyelonephritis, renal dialysis, and renal transplant)
ckd = ['T86', 'T861', 'N18', 'N189', 'N18', 'N185', 'N18', 'N184', 'N18', 'N183', 'N18', 'N182', 'N18', 'N181', 'N11', 'N110', 'N07', 'N074', 'N07', 'N073', 'N07', 'N072', 'N03', 'N030', 'N00', 'N000', '99644', '99631', '97980', '97979', '97978', '97758', '97734', '97683', '97587', '97388', '95572', '95571', '95546', '95508', '95408', '95406', '95405', '95188', '95180', '95179', '95178', '95177', '95176', '95175', '95146', '95145', '95123', '95122', '95121', '94965', '94793', '94789', '94373', '94350', '93922', '91738', '89924', '88597', '85991', '73026', '72877', '72303', '72004', '71709', '71314', '71174', '70874', '68659', '68364', '68114', '68112', '67995', '67486', '67232', '67197', '67193', '66872', '66714', '66705', '66505', '66062', '65400', '65064', '64828', '64622', '64571', '63786', '63615', '63466', '63000', '62980', '62520', '62320', '61811', '61494', '61344', '61317', '61145', '60960', '60857', '60856', '60796', '60484', '60198', '60128', '59365', '59031', '59018', '58750', '58671', '58164', '58060', '57621', '57568', '57278', '57168', '57072', '56987', '56893', '56852', '55151', '54990', '53940', '53852', '52969', '52303', '51113', '50728', '50472', '50331', '50305', '50225', '50200', '49642', '48855', '48111', '47922', '47672', '47582', '47342', '46963', '46438', '46145', '45867', '45499', '45160', '44270', '43935', '41881', '41676', '41285', '41239', '41148', '41013', '40413', '40349', '39840', '39649', '38572', '36342', '36205', '36125', '35360', '35107', '35105', '34998', '34648', '34637', '32423', '30323', '30301', '30294', '29638', '29013', '28684', '26054', '25394', '25055', '24836', '24384', '24361', '23913', '22852', '22252', '22205', '21989', '21983', '21947', '21837', '21297', '21158', '20629', '20516', '20073', '19316', '18777', '18774', '18390', '18209', '17365', '16929', '16502', '16008', '15917', '15106', '15097', '13279', '12720', '12640', '12586', '12585', '12566', '12479', '12465', '11875', '11773', '11745', '11553', '10809', '10647', '10418', '9840', '9240', '8330', '8037', '7804', '7190', '6712', '5911', '5504', '4669', '4668', '4654', '4503', '2997', '2996', '2475', '2471', '1803', '512']
ckd = predictor.predictor_check_exist(ckd, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', 'ckd')
cohort = cohort.join(ckd, 'patid', 'left')

# Family history of coronary heart disease in a first degree relative aged less than 60 years
history = modalities.retrieve_by_enttype(file, spark, 87, (1985, 2015)).withColumnRenamed('data1', 'history').select(['patid', 'history'])
chd = condition_query.queryDisease(['coronary heart disease not otherwise specified'], merge=True)
chd = code_merge(chd, 'coronary heart disease not otherwise specified')
history = history.filter(F.col('history').isin(*chd)).groupby(['patid']).agg(F.first('history').alias('chd_history'))
cohort = cohort.join(history, 'patid', 'left')

cohort.write.parquet('/home/shared/yikuan/Mo/predictor_1.parquet')

In [None]:
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/predictor_1.parquet')

# Treated hypertension (diagnosis of hypertension and treatment with at least one antihypertensive drug)
antihyp = condition_query.queryMedication(['antihypertensives'])
antihyp = antihyp.get('antihypertensives').get('prod')
antihyp = predictor.predictor_check_exist(antihyp, medications, cohort, col='prodcode', col_baseline='baseline').withColumnRenamed('prodcode', 'antihtn')
cohort = cohort.join(antihyp, 'patid', 'left')

# smoking status
## smoking status 0: not recorded 1: non smoker 2: ex 3: light (1-9) 4:moderate (10-19) 5: heavy (>=20) 
def categorise_smoke(x):
    if int(x)>=1 and int(x)<=9:
        return 3
    elif int(x)>=10 and int(x)<=19:
        return 4
    elif int(x)>=20:
        return 5
    else:
        return None
    
def smoke_dict(x):
    map_dict = {
        '0': 0,
        '2': 1,
        '3': 2
    }
    return map_dict[str(x)]

smoke_cat = F.udf(lambda x: categorise_smoke(x))
smoke_map = F.udf(lambda x: smoke_dict(x))

smoke = modalities.retrieve_smoking_status(file, spark, duration=(1985, 2015))
smoke_current = smoke.filter(F.col('smoke')==1).filter(F.col('cig_per_day')!='').withColumn('cig_per_day', smoke_cat('cig_per_day')).drop('smoke')
smoke_current = smoke_current.withColumnRenamed('cig_per_day', 'smoke')

smoke_other = smoke.filter(F.col('smoke')!=1).withColumn('smoke', smoke_map('smoke')).select(['patid', 'eventdate', 'smoke'])
smoke = smoke_current.union(smoke_other)
smoke = smoke.filter(F.col('smoke')!=0)
smoke = predictor.predictor_extract(smoke, cohort, 'smoke', col_baseline='baseline', span_before_baseline_month=24,type='last')
cohort = cohort.join(smoke, 'patid', 'left')

# Deprivation
deprivation = modalities.retrieve_imd(file, spark)
cohort = cohort.join(deprivation, 'patid', 'left')

# AF
AF = condition_query.queryDisease(['Atrial fibrillation'], merge=True)
AF = code_merge(AF, 'atrial fibrillation')
AF = predictor.predictor_check_exist(AF, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', 'atrial_fibrillation')
cohort = cohort.join(AF, 'patid', 'left')

cohort.write.parquet('/home/shared/yikuan/Mo/predictor_2.parquet')

# Clinical outcome interested in 
1. mortality (5-year)

In [None]:
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/demographics.parquet/')
death = tables.retrieve_death(dir=file['death'], spark=spark).select(['patid', 'dod'])
cohort = cohort.join(death, 'patid', 'left')

# set up the 5 year date theshold for each patient (baseline + 7 years)
cohort = cohort.withColumn('endfollowupdate', cohort.baseline + F.expr('INTERVAL 7 YEARS')).cache()

# do filtering
cohort_no_event = cohort.filter(F.col('dod').isNull()).withColumn('event', F.lit(0)).withColumn('time', F.least(F.col('enddate'), F.col('endfollowupdate')))
cohort_with_event = cohort.filter(F.col('dod').isNotNull()).filter(F.col('dod') > F.col('baseline')).cache()

cohort_with_event_a = cohort_with_event.filter(F.col('dod') > F.col('endfollowupdate')).withColumn('event', F.lit(0)).withColumn('time', F.col('endfollowupdate'))
cohort_with_event_b = cohort_with_event.filter((F.col('dod') < F.col('endfollowupdate')) & (F.col('dod') > F.col('baseline'))).withColumn('event', F.lit(1)).withColumn('time', F.col('dod'))

cohort = cohort_no_event.union(cohort_with_event_a).union(cohort_with_event_b).drop('endfollowupdate')

time2eventdiff = F.unix_timestamp('time', "yyyy-MM-dd") - F.unix_timestamp('baseline', "yyyy-MM-dd")
cohort = cohort.withColumn('time', time2eventdiff).withColumn('time', (F.col('time') / 3600 / 24 / 30).cast('integer'))
cohort = cohort.filter(F.col('time')>0)
cohort.write.parquet('/home/shared/yikuan/Mo/cohort.parquet')

In [None]:
# data imputation
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/cohort.parquet')\
    .select(['patid', 'event', 'time'])

predictor = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/predictor.parquet')\
    .select(['patid', 'gender', 'age', 'systolic', 'sodium', 'diabetes', 'hdl_r', 'ckd', 'chd_history', 'antihtn',
            'smoke', 'imd2015_5', 'atrial_fibrillation'])

cohort = cohort.join(predictor, 'patid', 'left').filter(F.col('imd2015_5').isNotNull())
cohort.write.parquet('/home/shared/yikuan/Mo/mortality.parquet')

In [None]:
# MICE imputation
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
import pandas as pd
import numpy as np

# eth_map = {
#     'Unknown': "0",
#     'White': '1',
#     'Oth_Asian': '2', 
#     'Pakistani': "3",
#     'Indian': "4",
#     'Other': "5",
#     'Bl_Carib': "6",
#     'Mixed': "7", 
#     'Bangladesi': "8", 
#     'Chinese': "9", 
#     'Bl_Other': "10",
#     'Bl_Afric': "11"
# }
    
b = pd.read_parquet('/home/shared/yikuan/Mo/mortality.parquet')
# b['gen_ethnicity'] = b.gen_ethnicity.apply(lambda x: eth_map.get(x))

feature = ['gender', 'age', 'systolic', 'sodium', 'diabetes', 'hdl_r', 'ckd', 'chd_history', 'antihtn','smoke', 'imd2015_5', 'atrial_fibrillation']

rest = [each for each in b.columns if each not in feature]

imp = IterativeImputer(max_iter=5, random_state=0)
imp.fit(b[feature])

imput = imp.transform(b[feature])

b_rest = b[rest]
imput = pd.DataFrame(imput, columns=feature)

b_rest = b_rest.join(imput)

b_rest.to_parquet('/home/shared/yikuan/Mo/mortality_imp.parquet')

# BEHRT

In [None]:
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/cohort.parquet')
cohort = cohort.select(['patid', 'dob', 'baseline', 'event', 'time'])

In [6]:
# retrive medications
medications = read_parquet(spark.sqlContext, '/home/shared/yikuan/HF_Valid/data/medications.parquet')
medications = medications.select(['patid', 'eventdate', 'code']).withColumn('code', F.concat(F.lit('MED'), F.col('code'))).dropna()

In [7]:
# retrieve procedure
procedure = read_parquet(spark.sqlContext, '/home/shared/yikuan/HF_Valid/data/procedure.parquet')
procedure = procedure.select(['patid', 'eventdate', 'OPCS']).withColumnRenamed('OPCS', 'code').withColumn('code', F.concat(F.lit('PRO'), F.col('code'))).dropna()

In [8]:
lab_test = read_parquet(spark.sqlContext, '/home/shared/yikuan/HF_Valid/data/lab_test.parquet')
lab_test = lab_test.select(['patid', 'eventdate', 'medcode']).withColumnRenamed('medcode', 'code').withColumn('code', F.concat(F.lit('LAB'), F.col('code'))).dropna()

In [None]:
# retrieve diagnosis
diagnoses = read_parquet(spark.sqlContext, '/home/shared/yikuan/HF_Valid/data/diagnosis.parquet')
diagnoses = diagnoses.select(['patid', 'eventdate', 'ICD']).withColumnRenamed('ICD', 'code').withColumn('code', F.concat(F.lit('DIA'), F.col('code'))).dropna()

In [None]:
# set up formater
behrt_formater = predictor_extractor.BEHRT()

In [None]:
data = medications.union(procedure).union(lab_test).union(diagnoses)
data = behrt_formater.format_behrt(data, cohort, col_entry='baseline', col_yob='dob', age_col_name='age', col_code='code', label=None, event='event', time='time').dropna()
data.write.parquet('/home/shared/yikuan/Mo/BEHRT.parquet')

In [None]:
df = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/BEHRT.parquet')
df.show()

# check the records one year after HF

In [11]:
medications = read_parquet(spark.sqlContext, '/home/shared/yikuan/HF_Valid/data/medications.parquet')
medications = medications.select(['patid', 'eventdate', 'prodcode']).withColumn('code', F.concat(F.lit('MED'), F.col('prodcode'))).dropna().drop('prodcode')

In [13]:
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/cohort.parquet/')
cohort = cohort.withColumn('checkdate', cohort.eventdate + F.expr('INTERVAL 1 YEARS')).select(['patid', 'eventdate', 'checkdate']).withColumnRenamed('eventdate', 'hf_date')
data = medications.union(procedure).union(lab_test)

# keep records between event date and the check date
data = data.join(cohort, 'patid', 'inner').dropna() \
    .filter((F.col('eventdate') <= F.col('checkdate')) & (F.col('eventdate') > F.col('hf_date')))

# collect the codes for each patient
data = data.groupby(['patid', 'eventdate']).agg(F.collect_list('code').alias('code'))

# sort and merge code
w = Window.partitionBy('patid').orderBy('eventdate')
data = data.withColumn('code', F.collect_list('code').over(w)) \
    .groupBy('patid').agg(F.max('code').alias('code'))
data = data.withColumn('code', F.flatten(F.col('code')))


data.write.parquet('/home/shared/yikuan/Mo/hf_records_one_year_v1.parquet')


# implement other event of interest for outcomes

In [None]:
# import random
# from CPRD.functions import tables, merge
# import pyspark.sql.functions as F
# from CPRD.config.utils import *
# from CPRD.functions.modalities import *
# from pyspark.sql import Window
# from CPRD.config.spark import read_parquet
# from typing import Any
# import datetime
# from pyspark.sql.types import IntegerType

# def prepareDeathCause(death, condition, column):
#     cause_cols = ['cause', 'cause1', 'cause2', 'cause3', 'cause4', 'cause5', 'cause6', 'cause7',
#                   'cause8', 'cause9', 'cause10', 'cause11', 'cause12', 'cause13', 'cause14', 'cause15']
#     cause_cols = [F.col(each) for each in cause_cols]
#     death = death.withColumn("cause", F.array(cause_cols)).select(['patid', 'cause', 'dod'])
#     rm_dot = F.udf(lambda x: x.replace(".", ""))
#     death = death.withColumn('cause', F.explode('cause')) \
#         .withColumn('cause', rm_dot('cause')) \
#         .filter(F.col('cause').isin(*condition))
#     death = death.groupBy('patid').agg(F.first('dod').alias('eventdate'), F.first('cause').alias(column))
#     return death

# def process_death_diagnoses(diagnoses, death, condition, column):
#     death = prepareDeathCause(death, condition, column)
#     diagnoses = diagnoses.filter(F.col(column).isin(*condition))
#     source = diagnoses.union(death)
#     return source

# def event2time(cohort, source, baseline, column, follow_up_duration, event_col, time_col):
#     tmp = cohort.select(['patid', baseline])
#     source = source.join(tmp, 'patid', 'left').filter(F.col('eventdate')>F.col(baseline)).dropna()
    
#     w = Window.partitionBy('patid').orderBy('eventdate')
#     source = source.withColumn(column, F.first(column).over(w)).groupBy('patid').agg(
#         F.first('eventdate').alias('eventdate'),
#         F.first(column).alias(column)
#     )
    
#     cohort = cohort.join(source, 'patid', 'left').drop(column)
#     cohort = cohort.withColumn('endfollowupdate', cohort.baseline + F.expr('INTERVAL {} MONTHS'.format(follow_up_duration))).cache()

#     cohort_no_event = cohort.filter(F.col('eventdate').isNull()).withColumn(event_col, F.lit(0)).withColumn(time_col, F.least(F.col('enddate'), F.col('endfollowupdate')))
#     cohort_with_event = cohort.filter(F.col('eventdate').isNotNull()).cache()
    
#     cohort_with_event_a = cohort_with_event.filter(F.col('eventdate') > F.col('endfollowupdate')).withColumn(event_col, F.lit(0)).withColumn(time_col, F.col('endfollowupdate'))
#     cohort_with_event_b = cohort_with_event.filter((F.col('eventdate') <= F.col('endfollowupdate'))).withColumn(event_col, F.lit(1)).withColumn(time_col, F.col('eventdate'))
    
#     cohort = cohort_no_event.union(cohort_with_event_a).union(cohort_with_event_b)


#     time2eventdiff = F.unix_timestamp(time_col, "yyyy-MM-dd") - F.unix_timestamp(baseline, "yyyy-MM-dd")
#     cohort = cohort.withColumn(time_col, time2eventdiff).withColumn(time_col, (F.col(time_col) / 3600 / 24 / 30).cast('integer')).drop('eventdate').drop('endfollowupdate')
   
#     return cohort

## acute conditions
1. ischemic stroke
2. acute kidney injury
3. pulmonary embolism
4. Abdominal aortic aneurysm

In [None]:
# cohort_all = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/cohort.parquet').select(['patid', 'dob', 'baseline', 'start', 'startdate', 'end', 'enddate', 'time', 'event'])
# cohort_all = cohort_all.withColumnRenamed('event', 'mortality').withColumnRenamed('time', 'mortality_time')
# cohort_all = cohort_all.filter(F.col('baseline')<F.col('enddate'))

# # surv = SurvRiskPredictionBase(60)
# death = tables.retrieve_death(dir=file['death'], spark=spark)
# condition_query = MedicalDictionary.MedicalDictionaryRiskPrediction(file, spark)

# # diagnosis
# diagnoses = read_parquet(spark.sqlContext, '/home/shared/yikuan/HF_Valid/data/diagnosis.parquet')
# diag_cprd = diagnoses.filter(F.col('source')=='CPRD').select(['patid', 'eventdate', 'medcode']).withColumnRenamed('medcode', 'code')
# diag_hes = diagnoses.filter(F.col('source')=='HES').select(['patid', 'eventdate', 'ICD']).withColumnRenamed('ICD', 'code')
# diagnoses = diag_cprd.union(diag_hes)

In [None]:
# # ischemic stroke
# query = condition_query.queryDisease(['Ischaemic stroke'], merge=True)
# code = []
# for k,v in query['ischaemic stroke'].items():
#     code.extend(v)
    
# source = process_death_diagnoses(diagnoses, death, code, 'code')
# cohort_all = event2time(cohort_all, source, 'baseline', 'code', 60, 'ischaemic_stroke', 'ischaemic_stroke_time')

# # acute kidney injury
# query = condition_query.queryDisease(['Acute Kidney Injury'], merge=True)
# code = []
# for k,v in query['acute kidney injury'].items():
#     code.extend(v)
    
# source = process_death_diagnoses(diagnoses, death, code, 'code')
# cohort_all = event2time(cohort_all, source, 'baseline', 'code', 60, 'acute_kidney_injury', 'acute_kidney_injury_time')

# # pulmonary embolism
# query = condition_query.queryDisease(['Pulmonary embolism'], merge=True)
# code = []
# for k,v in query['pulmonary embolism'].items():
#     code.extend(v)

# source = process_death_diagnoses(diagnoses, death, code, 'code')
# cohort_all = event2time(cohort_all, source, 'baseline', 'code', 60, 'pulmonary_embolism', 'pulmonary_embolism_time')


# # Abdominal aortic aneurysm
# query = condition_query.queryDisease(['Abdominal aortic aneurysm'], merge=True)
# code = []
# for k,v in query['abdominal aortic aneurysm'].items():
#     code.extend(v)
    
# source = process_death_diagnoses(diagnoses, death, code, 'code')
# cohort_all = event2time(cohort_all, source, 'baseline', 'code', 60, 'abdominal_aortic_aneurysm', 'abdominal_aortic_aneurysm_time')

# cohort_all.write.parquet('/home/shared/yikuan/Mo/cohort_v1.parquet')

# merge outcome with berht data and imp predictors

In [None]:
# drop_col = ['dob', 'baseline', 'start', 'startdate', 'end', 'enddate']
# cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/cohort_v1.parquet')

# for col in drop_col:
#     cohort = cohort.drop(col)

# cohort_expert = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/mortality_imp.parquet').drop('event').drop('time')
# cohort_expert = cohort_expert.join(cohort, 'patid', 'left')

# cohort_expert.write.parquet('/home/shared/yikuan/Mo/expert_predictor.parquet')

In [None]:
# cohort_behrt =  read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/BEHRT.parquet').drop('event').drop('time')
# cohort_behrt = cohort_behrt.join(cohort, 'patid', 'left')
# cohort_behrt.write.parquet('/home/shared/yikuan/Mo/BEHRT_v1.parquet')

# analysis

In [5]:
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/mortality_imp.parquet')
baseline = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/cohort.parquet').select(['patid', 'baseline'])
cohort = cohort.join(baseline, 'patid', 'left')

condition_query = MedicalDictionary.MedicalDictionaryRiskPrediction(file, spark)

def code_merge(code_dict, name):
    code = []
    for k,v in code_dict[name].items():
        code.extend(v)
    return code

predictor = predictor_extractor.PredictorExtractorBase()

# obesity
name = 'obesity'
code = condition_query.queryDisease([name], merge=True)
code = code_merge(code, name)
code = predictor.predictor_check_exist(code, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', '_'.join(name.split()))
cohort = cohort.join(code, 'patid', 'left')

# Alcohol problems
name = 'alcohol problems'
code = condition_query.queryDisease([name], merge=True)
code = code_merge(code, name)
code = predictor.predictor_check_exist(code, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', '_'.join(name.split()))
cohort = cohort.join(code, 'patid', 'left')

# Hypo or hyperthyroidism 
name = 'hypo or hyperthyroidism'
code = condition_query.queryDisease([name], merge=True)
code = code_merge(code, name)
code = predictor.predictor_check_exist(code, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', '_'.join(name.split()))
cohort = cohort.join(code, 'patid', 'left')

cohort.write.parquet('/home/shared/yikuan/Mo/evaluation/Predictor_evaluation_1_1.parquet')


In [6]:
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/evaluation/Predictor_evaluation_1_1.parquet')

# Valvular disease 
name = 'valvular disease'
code = condition_query.queryDisease(['Rheumatic valve dz', 'Nonrheumatic aortic valve disorders', 'Nonrheumatic mitral valve disorders', 'Multiple valve dz'], merge=True)
code = code_merge(code, 'merged')
code = predictor.predictor_check_exist(code, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', '_'.join(name.split()))
cohort = cohort.join(code, 'patid', 'left')

# Cardiomyopathy
name = 'cardiomyopathy'
code = condition_query.queryDisease(['Dilated cardiomyopathy', 'Hypertrophic Cardiomyopathy', 'Other Cardiomyopathy'], merge=True)
code = code_merge(code, 'merged')
code = predictor.predictor_check_exist(code, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', '_'.join(name.split()))
cohort = cohort.join(code, 'patid', 'left')

# Ischaemic heart disease 
name= 'ischaemic heart disease'
code=['14A3.00','14A4.00','14A5.00','14AA.00','14AH.00','14AJ.00','14AL.00','14AT.00','14AW.00','182A.00','187..00','1J61.00','3213111','322..00','3222','322Z.00','323..00','3232','3233','3234','3235','3236','323Z.00','32B..00','32B2.00','32B3.00','32BZ.00','32E4.00','44H3.00','44H3000','44HJ.00','44MH.00','44p2.00','5533','5543','5C11.00','661M000','662..00','662..11','662K.00','662K000','662K100','662K200','662K300','662Kz00','662N.00','662Z.00','66f..00','66f1.00','6A2..00','6A4..00','792..11','7920','7920.11','7920000','7920100','7920200','7920300','7920y00','7920z00','7921','7921.11','7921000','7921100','7921200','7921300','7921y00','7921z00','7922','7922.11','7922000','7922100','7922200','7922300','7922y00','7922z00','7923','7923.11','7923000','7923100','7923200','7923300','7923z00','7924','7924000','7924100','7924200','7924y00','7924z00','7925','7925.11','7925000','7925100','7925300','7925311','7925312','7925400','7925y00','7925z00','7926','7926000','7926200','7926300','7926z00','7927500','7928','7928.11','7928000','7928100','7928200','7928300','7928y00','7928z00','7929000','7929100','7929111','7929300','7929400','7929500','7929600','792B000','792C.00','792C000','792Cy00','792Cz00','792D.00','792Dy00','792Dz00','793G.00','793G000','793G100','793G200','793G300','793Gy00','793Gz00','7A4B800','7A54000','7A54500','7A54700','7A54800','7A56000','7A56400','7A6G100','7A6H300','7A6H400','7A6S300','889A.00','8B27.00','8B3k.00','8BGC.00','8CMP.00','8F9..00','8F90.00','8F91.00','8F92.00','8F93.00','8H2V.00','8H7v.00','8I37.00','8I3a.00','8IEY.00','8L40.00','8L41.00','8LF..00','8T04.00','9Ob..00','9Ob0.00','9Ob1.00','9Ob2.00','9Ob3.00','9Ob4.00','9Ob5.00','9Ob6.00','9Ob8.00','9Ob9.00','G....12','G....13','G3...00','G3...11','G3...12','G3...13','G30..00','G30..11','G30..12','G30..13','G30..14','G30..15','G30..16','G30..17','G300.00','G301.00','G301000','G301100','G301z00','G302.00','G303.00','G304.00','G305.00','G306.00','G307.00','G307000','G307100','G308.00','G309.00','G30A.00','G30B.00','G30X.00','G30X000','G30y.00','G30y000','G30y100','G30y200','G30yz00','G30z.00','G31..00','G310.00','G310.11','G311.00','G311.11','G311.12','G311.13','G311.14','G311000','G311011','G311100','G311200','G311300','G311400','G311500','G311z00','G312.00','G31y.00','G31y000','G31y100','G31y200','G31y300','G31yz00','G32..00','G32..11','G32..12','G33..00','G330.00','G330000','G330z00','G331.11','G332.00','G33z.00','G33z000','G33z100','G33z200','G33z300','G33z400','G33z500','G33z600','G33z700','G33zz00','G34..00','G340.00','G340.11','G340.12','G340000','G340100','G341.00','G341.11','G341000','G341100','G341z00','G342.00','G343.00','G344.00','G34y.00','G34y000','G34y100','G34yz00','G34z.00','G34z000','G35..00','G350.00','G351.00','G353.00','G35X.00','G36..00','G360.00','G361.00','G362.00','G363.00','G364.00','G365.00','G366.00','G38..00','G380.00','G381.00','G383.00','G384.00','G38z.00','G39..00','G3y..00','G3z..00','G5...00','G501.00','G574000','G5y..00','G5yyz00','G5yz.00','G5z..00','Gyu3.00','Gyu3000','Gyu3200','Gyu3300','Gyu3400','Gyu3600','Gyu5.00','Gyu7000','SP00300','SP07600','Z677.00','ZL22200','ZV45700','ZV45800','ZV45K00','ZV45K11','ZV45L00','ZV57900','I20','I200','I201','I208','I209','I21','I210','I211','I212','I213','I214','I219','I22','I220','I221','I228','I229','I23','I230','I231','I232','I233','I234','I235','I236','I238','I24','I240','I241','I248','I249','I25','I250','I251','I252','I255','I256','I258','I259','T822','Z955']
code = predictor.predictor_check_exist(code, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', '_'.join(name.split()))
cohort = cohort.join(code, 'patid', 'left')

cohort.write.parquet('/home/shared/yikuan/Mo/evaluation/Predictor_evaluation_1_2.parquet')

In [7]:
cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/evaluation/Predictor_evaluation_1_2.parquet')

# Hypertension 
name = 'hypertension'
code = condition_query.queryDisease([name], merge=True)
code = code_merge(code, name)
code = predictor.predictor_check_exist(code, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', '_'.join(name.split()))
cohort = cohort.join(code, 'patid', 'left')

# Other interstitial pulmonary diseases with fibrosis 
name = 'other interstitial pulmonary diseases with fibrosis'
code = condition_query.queryDisease([name], merge=True)
code = code_merge(code, name)
code = predictor.predictor_check_exist(code, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', '_'.join(name.split()))
cohort = cohort.join(code, 'patid', 'left')

# Chronic obstructive pulmonary disease 
name = 'chronic obstructive pulmonary disease'
code = ['14B3.12','H3...00','H3...11','H31..00','H310.00','H310000','H310z00','H311.00','H311000','H311100','H311z00','H312.00','H312000','H312011','H312100','H312200','H312300','H312z00','H313.00','H31y.00','H31y100','H31yz00','H31z.00','H32..00','H320.00','H320000','H320100','H320200','H320300','H320z00','H321.00','H322.00','H32y.00','H32y000','H32y100','H32y111','H32y200','H32yz00','H32z.00','H36..00','H37..00','H38..00','H39..00','H3A..00','H3y..00','H3y..11','H3y0.00','H3y1.00','H3z..00','H3z..11','H464000','H464100','H583200','Hyu3000','Hyu3100','1001','103494','104608','106650','10802','10863','10980','11150','12166','1446','14798','15157','15626','16410','21061','23492','24248','25603','26125','26306','27819','3243','33450','37247','37959','40159','40788','44525','45089','46578','56860','5710','5798','5909','59263','60188','61118','61513','63216','63479','64721','65733','66043','66058','67040','68066','68662','70787','7884','794','92955','93568','9876','99536','998','J40','J41','J42','J43','J44']
code = predictor.predictor_check_exist(code, diagnoses, cohort, col='code', col_baseline='baseline').withColumnRenamed('code', '_'.join(name.split()))
cohort = cohort.join(code, 'patid', 'left')

cohort.write.parquet('/home/shared/yikuan/Mo/evaluation/Predictor_evaluation_1.parquet')

In [8]:
feature = ['obesity', 'alcohol problems', 'diabetes', 'hypo or hyperthyroidism',
          'valvular disease', 'cardiomyopathy', 'ischaemic heart disease',
          'hypertension', 'other interstitial pulmonary diseases with fibrosis', 'chronic obstructive pulmonary disease']

feature = ['patid'] + ['_'.join(each.split()) for each in feature]

cohort = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/BEHRT.parquet')
data = read_parquet(spark.sqlContext, '/home/shared/yikuan/Mo/evaluation/Predictor_evaluation_1.parquet').select(feature)
cohort = cohort.join(data, 'patid', 'left')

cohort.write.parquet('/home/shared/yikuan/Mo/evaluation/BEHRT_evaluation_1.parquet')