In [0]:
spark.sql('CLEAR CACHE')
spark.conf.set('spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation', 'true')

In [0]:
import pyspark.sql.functions as f
import pyspark.sql.types as t
from pyspark.sql import Window

from functools import reduce

import databricks.koalas as ks
import pandas as pd
import numpy as np

import re
import io
import datetime

import matplotlib
import matplotlib.pyplot as plt
from matplotlib import dates as mdates
import seaborn as sns

print("Matplotlib version: ", matplotlib.__version__)
print("Seaborn version: ", sns.__version__)
_datetimenow = datetime.datetime.now() # .strftime("%Y%m%d")
print(f"_datetimenow:  {_datetimenow}")

In [0]:
%run "../SHDS/common/functions"

# 0. Parameters

In [0]:
%run "./CCU056-01-parameters"

# 1. Data

In [0]:
codelist = spark.table(f'{dsa}.{proj}_out_codelist_covariates')

cohort = spark.table(f'{dsa}.{proj}_tmp_main_cohort_final2')

deaths_long     = spark.table(path_cur_deaths_long)

hes_apc_long = spark.table(path_cur_hes_apc_long)

In [0]:
display(deaths_long)

In [0]:
display(hes_apc_long)

In [0]:
display(codelist)

###CVD Death Codes

In [0]:
codelist_outcomes_cvd_death = (codelist.filter(f.col("codelist").contains("Post Intervention")).drop("codelist"))

In [0]:
display(codelist_outcomes_cvd_death)

In [0]:
display(codelist_outcomes_cvd_death.select("name").distinct())

###Hospitalisation with HF Codes

In [0]:
codelist_outcomes_hf = (codelist.filter(f.col("name")=="HF").drop("codelist").filter(f.col("terminology")=="ICD10"))

# remove trailing X's, decimal points, dashes, and spaces
codelist_outcomes_hf = (
  codelist_outcomes_hf
  .withColumn('_code_old', f.col('code'))
  .withColumn('code', f.when(f.col('terminology') == 'ICD10', f.regexp_replace('code', r'X$', '')).otherwise(f.col('code')))\
  .withColumn('code', f.when(f.col('terminology') == 'ICD10', f.regexp_replace('code', r'[\.\-\s]', '')).otherwise(f.col('code')))
  .withColumn('_code_diff', f.when(f.col('code') != f.col('_code_old'), 1).otherwise(0))
)

# tidy
codelist_outcomes_hf = codelist_outcomes_hf.drop('_code_old', '_code_diff')

In [0]:
display(codelist_outcomes_hf)

###Hospitalisation with CVD Codes

In [0]:
codelist_outcomes_cvd_hosp = (codelist_outcomes_cvd_death.filter(f.col("name")!="sudden_death").union(codelist_outcomes_hf))
display(codelist_outcomes_cvd_hosp.orderBy("name","code"))

# 2. Death

##2.1 Checks

In [0]:
print('--------------------------------------------------------------------------------------')
print('individual_censor_dates')
print('--------------------------------------------------------------------------------------')

# CENSOR START - DOB
# CENSOR END - Date of Operation

# Note these will be changed to codelist match st: CENSOR START - Date of Operation and CENSOR END - Study End Date

individual_censor_dates = (
  cohort
  .withColumn("DOB", f.when(f.col("DOB") == "Unknown", "1800-01-01").otherwise(f.col("DOB")))
  .withColumnRenamed('DOB', 'CENSOR_DATE_START')
  .withColumnRenamed('OPERATION_DATE', 'CENSOR_DATE_END')
  .withColumn('CENSOR_DATE_START', f.to_date(f.col('CENSOR_DATE_START')))
  .withColumn('CENSOR_DATE_END', f.to_date(f.col('CENSOR_DATE_END')))
)

# check
count_var(individual_censor_dates, 'PERSON_ID'); print()
print(individual_censor_dates.limit(10).toPandas().to_string()); print()

In [0]:
print('--------------------------------------------------------------------------------------')
print('deaths')
print('--------------------------------------------------------------------------------------')


# reduce
_deaths = (
    deaths_long
    .where(f.col('DIAG_POSITION') == 'UNDERLYING')
    .select(['PERSON_ID', 'DATE', 'CODE', 'DIAG_POSITION'])
)

# check
count_var(_deaths, 'PERSON_ID'); print()
tmpt = tab(_deaths, 'DIAG_POSITION'); print()

# add individual censor dates
_deaths = (
    _deaths
    .drop('DIAG_POSITION')
    .join(individual_censor_dates, on='PERSON_ID', how='inner')
    .withColumnRenamed("DATE","DOD")
)

# check
count_var(_deaths, 'PERSON_ID'); print()

In [0]:
_deaths_before_start = (
    _deaths
    .where(
        (f.col('DOD') < f.col('CENSOR_DATE_START'))
  )
)

 # check
count_var(_deaths_before_start, 'PERSON_ID'); print()

In [0]:
_deaths_before_end = (
    _deaths
    .where(
        (f.col('DOD') < f.col('CENSOR_DATE_END'))
  )
)

 # check
count_var(_deaths_before_end, 'PERSON_ID'); print()

In [0]:
display(_deaths_before_end)

In [0]:
_deaths_on_end = (
    _deaths
    .where(
        (f.col('DOD') == f.col('CENSOR_DATE_END'))
  )
)

 # check
count_var(_deaths_on_end, 'PERSON_ID'); print()

In [0]:
display(_deaths_on_end)

In [0]:
# Note that there can be cases where a person has > 1 underlying cause of death - this will be because the code has been truncated in one version
display(_deaths.filter(f.col("PERSON_ID")=="XMB5D4NMEMN9QFJ"))

In [0]:
_deaths = (
    _deaths
#     .where(
#         (f.col('DOD') >= f.col('CENSOR_DATE_START'))
#   )
)

 # check
count_var(_deaths, 'PERSON_ID'); print()

In [0]:
display(_deaths)

##2.2 Codelist Match - CVD and All Cause Death

In [0]:
display(study_end_date)

In [0]:
print('--------------------------------------------------------------------------------------')
print('individual_censor_dates')
print('--------------------------------------------------------------------------------------')

# CENSOR START - Date of Operation
# CENSOR END - Study End Date

individual_censor_dates_post = (
  cohort
  .select("PERSON_ID","OPERATION_DATE")
  .withColumnRenamed('OPERATION_DATE', 'CENSOR_DATE_START')
  .withColumn('CENSOR_DATE_START', f.to_date(f.col('CENSOR_DATE_START')))
  .withColumn('CENSOR_DATE_END', f.lit(study_end_date))

)

# check
count_var(individual_censor_dates_post, 'PERSON_ID'); print()
print(individual_censor_dates_post.limit(10).toPandas().to_string()); print()

In [0]:
_deaths_clean = (
    _deaths.withColumnRenamed("DOD","DATE").drop("CENSOR_DATE_START","CENSOR_DATE_END","LSOA_Outside_England")
    .join(individual_censor_dates_post, on='PERSON_ID', how='inner')
    
)

count_var(_deaths_clean, 'PERSON_ID'); print()

In [0]:
# some people had cvd death and non cvd death found
# since some of their codes were in 4 digits in deaths and also 3
# so the 4 digit was being matched but not the 3 digit (hence non cvd)

window_spec = Window.partitionBy('PERSON_ID').orderBy(f.col('outcomes_death_description'))

_deaths_clean = (
    _deaths_clean #underlying causes of death
    # define those which are cvd or not
    .join((codelist_outcomes_cvd_death
           .select(f.col("code").alias("CODE"))
           .withColumn("outcomes_death_description",f.lit("cvd_death"))),
          on="CODE",how="left"
          )
    .withColumn("outcomes_death_description", f.when(f.col("outcomes_death_description").isNull(),f.lit("non_cvd_death")).otherwise(f.col("outcomes_death_description")))

    # .where((f.col('DATE') >= f.col('CENSOR_DATE_START')))
    # .where((f.col('DATE') <= f.col('CENSOR_DATE_END')))

    #.select("PERSON_ID",f.col("DATE").alias("DOD"),"outcomes_death_description")
    .distinct()
    .withColumn('rownum', f.row_number().over(window_spec))
    .where(f.col('rownum') == 1).drop('rownum') #for this with cvd and non cvd keep cvd only

    )

In [0]:
display(
    _deaths_clean
      #.where((f.col('DATE') < f.col('CENSOR_DATE_START'))) # 3 people
     #.where((f.col('DATE') > f.col('CENSOR_DATE_END'))).count() #10000 people have died since the study end date
    )

In [0]:
save_table(df=_deaths_clean, out_name=f'{proj}_outcomes_deaths',save_previous=False)

#3. Hospitalisation with HF

##3.0 Adjust Censor Dates for DISDATE

We need to shift the CENSOR_DATE_START from the Date of Surgery to the Date they were checked out of hospital after surgery

ADMIDATE is on all records so for the Surgery Date that uses EPISTART - get the ADMIDATE
Then for this unique ADMIDATE find the EPIKEY that has DISDATE filled in - not all Episodes have DISDATE - only the final episode of a spell.

In [0]:
%%script echo skipping
hes_apc = extract_batch_from_archive(parameters_df_datasets, 'hes_apc')

tmp_hes_apc = (
  hes_apc  
  .select(['PERSON_ID_DEID', 'EPIKEY', 'EPISTART', 'ADMIDATE', 'DISDATE'] 
          + [col for col in list(hes_apc.columns) if re.match(r'^DIAG_(3|4)_\d\d$', col)])
  .withColumnRenamed('PERSON_ID_DEID', 'PERSON_ID')
  .orderBy('PERSON_ID', 'EPIKEY')
)

tmp_hes_apc_long = (
  reshape_wide_to_long_multi(tmp_hes_apc, i=['PERSON_ID', 'EPIKEY', 'EPISTART', 'ADMIDATE', 'DISDATE'], j='POSITION', stubnames=['DIAG_4_', 'DIAG_3_'])
  .withColumn('_tmp', f.substring(f.col('DIAG_4_'), 1, 3))
  .withColumn('_chk', udf_null_safe_equality('DIAG_3_', '_tmp').cast(t.IntegerType()))
  .withColumn('_DIAG_4_len', f.length(f.col('DIAG_4_')))
  .withColumn('_chk2', f.when((f.col('_DIAG_4_len').isNull()) | (f.col('_DIAG_4_len') <= 4), 1).otherwise(0))
)

tmp_hes_apc_long = reshape_wide_to_long_multi(tmp_hes_apc_long, i=['PERSON_ID', 'EPIKEY', 'EPISTART', 'ADMIDATE', 'DISDATE', 'POSITION'], j='DIAG_DIGITS', stubnames=['DIAG_'])\
  .withColumnRenamed('POSITION', 'DIAG_POSITION')\
  .withColumn('DIAG_POSITION', f.regexp_replace('DIAG_POSITION', r'^[0]', ''))\
  .withColumn('DIAG_DIGITS', f.regexp_replace('DIAG_DIGITS', r'[_]', ''))\
  .withColumn('DIAG_', f.regexp_replace('DIAG_', r'X$', ''))\
  .withColumn('DIAG_', f.regexp_replace('DIAG_', r'[.,\-\s]', ''))\
  .withColumnRenamed('DIAG_', 'CODE')\
  .where((f.col('CODE').isNotNull()) & (f.col('CODE') != ''))\
  .orderBy(['PERSON_ID', 'EPIKEY', 'DIAG_DIGITS', 'DIAG_POSITION'])

# adding in cohort to make table smaller and quicker to save
tmp_hes_apc_long = (
  cohort.select("PERSON_ID").join(tmp_hes_apc_long,on="PERSON_ID",how="left")
  .select('PERSON_ID', f.col('EPISTART'), 'ADMIDATE','DISDATE','CODE', 'DIAG_POSITION', 'DIAG_DIGITS','EPIKEY')
)

In [0]:
%%script echo skipping
save_table(df=tmp_hes_apc_long, out_name=f'{proj}_outcomes_tmp_hes_apc_long',save_previous=False)

In [0]:
tmp_hes_apc_long = spark.table(f'{dsa}.{proj}_outcomes_tmp_hes_apc_long')

In [0]:
# codelists_inclusions = spark.table(f'{dsa}.{proj}_out_codelists_inclusions')

procedure_codes_op_dates = spark.table(f'{dsa}.{proj}_tmp_cases_procedure_codes_operation_dates')

window_spec = Window.partitionBy('PERSON_ID','OPDATE').orderBy(f.col('DISDATE').desc())

discharge_dates = (
    cohort #182318
    .select("PERSON_ID",f.col("OPERATION_DATE").alias("OPDATE"))
    .join((procedure_codes_op_dates.select("PERSON_ID","OPDATE","EPIKEY")),on=["PERSON_ID","OPDATE"],how="left") #Now have EPIKEY
    .distinct() #183079
    .join(tmp_hes_apc_long,on=["PERSON_ID","EPIKEY"],how="left")
    .distinct()
    #.join(codelists_inclusions,on="CODE",how="right")
    .select("PERSON_ID","EPIKEY","OPDATE","ADMIDATE","DISDATE")
    .distinct()
    .withColumn('rownum', f.row_number().over(window_spec)).where(f.col('rownum') == 1).drop('rownum')

    )

display(discharge_dates)

In [0]:
discharge_dates_post = (
        discharge_dates
        .filter(f.col("DISDATE").isNotNull())
        .filter(f.col("DISDATE")>f.col("OPDATE"))
        .select("PERSON_ID","DISDATE","OPDATE")
        .distinct() #161556 of the 182318 have a DISDATE (20762 11% have no DISDATE)
        )

In [0]:
display(
    discharge_dates_post.withColumn("Diff", f.datediff(f.col("DISDATE"),f.col("OPDATE")))
    .agg(
        f.avg(f.col('Diff')).alias('avg'),
        f.median(f.col('Diff')).alias('median'),
        f.min(f.col('Diff')).alias('min'),
        f.max(f.col('Diff')).alias('max'),
    )
        )

In [0]:
display(
    discharge_dates_post.withColumn("Diff", f.datediff(f.col("DISDATE"),f.col("OPDATE")))
        )


For those with no DISDATE - will add on the median days diff between OP and DIS to their OP DATE

In [0]:
cohort_discharge_dates = (
    discharge_dates
        .withColumn("Impute", f.when(f.col("DISDATE").isNull(), "1").otherwise("0"))
        .withColumn("Impute", f.when(f.col("DISDATE")<f.col("OPDATE"), "1").otherwise(f.col("Impute")))
        .withColumn("DISDATE_NEW", f.when(f.col("Impute")=="1",f.to_date(f.col("OPDATE"))+7).otherwise(f.col("DISDATE")))
        .select("PERSON_ID",f.col("DISDATE_NEW").alias("DISDATE"))
        .distinct() #182318 as expected

    )

In [0]:
%%script echo skipping
save_table(df=cohort_discharge_dates, out_name=f'{proj}_cohort_discharge_dates',save_previous=False)

In [0]:
cohort_discharge_dates = spark.table(f'{dsa}.{proj}_cohort_discharge_dates')

In [0]:
print('--------------------------------------------------------------------------------------')
print('individual_censor_dates')
print('--------------------------------------------------------------------------------------')

# CENSOR START - Date of Discharge after surgery
# CENSOR END - Study End Date

individual_censor_dates_discharge = (
  cohort_discharge_dates
  .select("PERSON_ID","DISDATE")
  .withColumnRenamed('DISDATE', 'CENSOR_DATE_START')
  .withColumn('CENSOR_DATE_START', f.to_date(f.col('CENSOR_DATE_START')))
  .withColumn('CENSOR_DATE_END', f.lit(study_end_date))

)

# check
count_var(individual_censor_dates_discharge, 'PERSON_ID'); print()
print(individual_censor_dates_discharge.limit(10).toPandas().to_string()); print()

In [0]:
# reduce and rename columns
hes_apc_long_prepared = (
  hes_apc_long
  .select('PERSON_ID', f.col('EPISTART').alias('DATE'), 'CODE', 'DIAG_POSITION', 'DIAG_DIGITS')
)

# merge in individual censor dates
hes_apc_long_prepared = (
  hes_apc_long_prepared
  .join(individual_censor_dates_discharge, on='PERSON_ID', how='inner')
)


# filter to admissions post oepration and before study end date
hes_apc_long_prepared = (
  hes_apc_long_prepared
      .where((f.col('DATE') >= f.col('CENSOR_DATE_START')))
      .where((f.col('DATE') <= f.col('CENSOR_DATE_END')))
  
)

In [0]:


display(hes_apc_long_prepared)

In [0]:
display(codelist_outcomes_hf)

##3.2 Codelist Match - HF and All Cause Hospitalisations

In [0]:
hes_wide = (hes_apc_long_prepared
        .select("PERSON_ID","DATE","CODE","DIAG_POSITION","DIAG_DIGITS")
        .groupBy("PERSON_ID","DATE","DIAG_POSITION").pivot("DIAG_DIGITS").agg(f.first("CODE"))
        .withColumnRenamed("3","DIAG_3")
        .withColumnRenamed("4","DIAG_4")
        )

In [0]:
# display(hes_wide.filter(f.col("DiAG_4").isNull()))

In [0]:
outcomes_hospitalisations_raw = (
    hes_apc_long_prepared
    .filter(f.col("DIAG_POSITION")==1) #first position only
        .join((codelist_outcomes_hf
           .select(f.col("code").alias("CODE"))
           .withColumn("outcomes_hosp_description",f.lit("cvd_hosp"))),
          on="CODE",how="left"
          )
    .withColumn("outcomes_hosp_description", f.when(f.col("outcomes_hosp_description").isNull(),f.lit("non_cvd_hosp")).otherwise(f.col("outcomes_hosp_description")))
    .join(cohort,on="PERSON_ID",how="left")

        )

In [0]:
save_table(df=outcomes_hospitalisations_raw, out_name=f'{proj}_outcomes_hospitalisations_raw',save_previous=False)

In [0]:
outcomes_hospitalisations_raw = spark.table(f'{dsa}.{proj}_outcomes_hospitalisations_raw')

In [0]:
display(outcomes_hospitalisations_raw)

##3.3 Codelist Match - CVD and All Cause Hospitalisations

In [0]:
outcomes_hospitalisations_raw_cvd = (
    hes_apc_long_prepared
    .filter(f.col("DIAG_POSITION")==1) #first position only
        .join((codelist_outcomes_cvd_hosp
           .select(f.col("code").alias("CODE"))
           .withColumn("outcomes_hosp_description",f.lit("cvd_hosp"))),
          on="CODE",how="left"
          )
    .withColumn("outcomes_hosp_description", f.when(f.col("outcomes_hosp_description").isNull(),f.lit("non_cvd_hosp")).otherwise(f.col("outcomes_hosp_description")))
    .join(cohort,on="PERSON_ID",how="left")

        )

In [0]:
save_table(df=outcomes_hospitalisations_raw_cvd, out_name=f'{proj}_outcomes_hospitalisations_raw_cvd',save_previous=False)