In [0]:
%run ./project_config

In [0]:
%run ./parameters

In [0]:
from pyspark.sql import functions as f, DataFrame
from pyspark.sql.window import Window
from functions import load_table, save_table, read_csv_file, map_column_values
from functions.functions import union_dataframe_list

# 1 Load table

In [0]:
cohort_filtered = load_table('cohort_filtered')
display(cohort_filtered.limit(100))

hes_apc_procedure = load_table('hes_apc_procedure')
display(hes_apc_procedure.limit(100))


# 2 Codelists

In [0]:
dict_codelists_opcs = {
    "post_mi_pci": "./codelists/percutaneous_coronary_intervention_opcs4.csv",
    "post_mi_cabg": "./codelists/coronary_artery_bypass_grafts_opcs4.csv"
}

list_codelists_opcs = [
    read_csv_file(codelist_path)
    .withColumn('phenotype', f.lit(phenotype))
    for phenotype, codelist_path in dict_codelists_opcs.items()
]

codelist_opcs = spark.createDataFrame(union_dataframe_list(list_codelists_opcs))

display(codelist_opcs)


# 3 Prepare datasets

In [0]:
hes_apc_procedure_prepared = (
    hes_apc_procedure
    .select(
        'person_id', 'code',
        f.col('procedure_date').alias('date'),
        f.lit('hes_apc_procedure').alias('data_source'),
        f.lit(1).alias('source_priority')
    )
)

# 4 Cohort dates

In [0]:
cohort_prepared = (
    cohort_filtered
    .select(
        'person_id',
        f.date_add('index_mi_date', 1).alias('min_date'),
        f.least(f.col('date_of_death'), f.col('follow_up_end_date')).alias('max_date')
    )
)

# 5 Perform matching

In [0]:
hes_apc_procedure_matched = (
    hes_apc_procedure_prepared
    .join(
        f.broadcast(codelist_opcs),
        on='code', how='inner'
    )
    .join(
        cohort_prepared,
        on='person_id', how='inner'
    )
    .filter("(date >= min_date) AND (date <= max_date)")
)

In [0]:
outcomes_all_events = (
    hes_apc_procedure_matched
)

save_table(outcomes_all_events, 'outcomes_all_events')

# 6 Aggregate

In [0]:
outcomes_all_events = load_table('outcomes_all_events')

_win = Window.partitionBy('person_id', 'phenotype').orderBy(f.col('date').asc(), 'source_priority')

outcomes_first_event = (
    outcomes_all_events
    .withColumn('rank', f.row_number().over(_win))
    .filter('rank = 1')
    .withColumn('flag', f.lit(1))
)

outcomes_first_event = (
    outcomes_first_event
    .groupBy('person_id')
    .pivot('phenotype')
    .agg(
        f.first('flag').alias('flag'),
        f.first('date').alias('date'),
        f.first('code').alias('code'),
        f.first('data_source').alias('source')
    )
)

save_table(outcomes_first_event, 'outcomes_first_event')

# 7 Save table

In [0]:
cohort_filtered = load_table('cohort_filtered')
outcomes_first_event = load_table('outcomes_first_event')

cohort_outcomes = (
    cohort_filtered
    .select('person_id')
    .join(
        outcomes_first_event,
        on='person_id', how='left'
    )
)
    
save_table(cohort_outcomes, 'cohort_outcomes')

# 8 Display

In [0]:
cohort_outcomes = load_table('cohort_outcomes')
display(cohort_outcomes.limit(100))