In [None]:
from pyspark.sql import SparkSession
from delta.tables import *
from pyspark.sql.functions import *

ECR_DELTA_TABLE_FILE_PATH = "ecr-datastore"
MII_DELTA_TABLE_FILE_PATH = "MII"
COVID_IDENTIFICATION_CONFIG_FILE_PATH = "covid_identification_config.json"

spark = SparkSession.builder.getOrCreate()

# Read in data
ecr = spark.read.format("delta").load(ECR_DELTA_TABLE_FILE_PATH)
mci = spark.read.format("delta").load(MII_DELTA_TABLE_FILE_PATH).select("incident_id","person_id","specimen_collection_date").withColumnRenamed("incident_id","incident_id_mci").withColumnRenamed("person_id","person_id_mci").withColumnRenamed("specimen_collection_date","specimen_collection_date_mci")

# Covid identification data
df = spark.read.json(COVID_IDENTIFICATION_CONFIG_FILE_PATH)
covid_test_type_codes = df.select('covid_test_type_codes').rdd.flatMap(lambda x: x).collect()[0]
covid_positive_results = df.select('covid_positive_results').rdd.flatMap(lambda x: x).collect()[0]

In [None]:
# Add `comparison_date` column to ecr data ahead of join with mci to find positive covid tests
ecr = ecr.withColumn("comparison_date",
    when((lower(ecr.test_type_code_1).isin(covid_test_type_codes) & lower(ecr.test_result_1).isin(covid_positive_results)), ecr.specimen_collection_date_1)
    .when((lower(ecr.test_type_code_2).isin(covid_test_type_codes) & lower(ecr.test_result_2).isin(covid_positive_results)), ecr.specimen_collection_date_2)
    .when((lower(ecr.test_type_code_3).isin(covid_test_type_codes) & lower(ecr.test_result_3).isin(covid_positive_results)), ecr.specimen_collection_date_3)
    .when((lower(ecr.test_type_code_4).isin(covid_test_type_codes) & lower(ecr.test_result_4).isin(covid_positive_results)), ecr.specimen_collection_date_4)
    .when((lower(ecr.test_type_code_5).isin(covid_test_type_codes) & lower(ecr.test_result_5).isin(covid_positive_results)), ecr.specimen_collection_date_5)
    .when((lower(ecr.test_type_code_6).isin(covid_test_type_codes) & lower(ecr.test_result_6).isin(covid_positive_results)), ecr.specimen_collection_date_6)
    .when((lower(ecr.test_type_code_7).isin(covid_test_type_codes) & lower(ecr.test_result_7).isin(covid_positive_results)), ecr.specimen_collection_date_7)
    .when((lower(ecr.test_type_code_8).isin(covid_test_type_codes) & lower(ecr.test_result_8).isin(covid_positive_results)), ecr.specimen_collection_date_8)
    .when((lower(ecr.test_type_code_9).isin(covid_test_type_codes) & lower(ecr.test_result_9).isin(covid_positive_results)), ecr.specimen_collection_date_9)
    .when((lower(ecr.test_type_code_10).isin(covid_test_type_codes) & lower(ecr.test_result_10).isin(covid_positive_results)), ecr.specimen_collection_date_10)
    .when((lower(ecr.test_type_code_11).isin(covid_test_type_codes) & lower(ecr.test_result_11).isin(covid_positive_results)), ecr.specimen_collection_date_11)
    .when((lower(ecr.test_type_code_12).isin(covid_test_type_codes) & lower(ecr.test_result_12).isin(covid_positive_results)), ecr.specimen_collection_date_12)
    .when((lower(ecr.test_type_code_12).isin(covid_test_type_codes) & lower(ecr.test_result_13).isin(covid_positive_results)), ecr.specimen_collection_date_13)
    .when((lower(ecr.test_type_code_14).isin(covid_test_type_codes) & lower(ecr.test_result_14).isin(covid_positive_results)), ecr.specimen_collection_date_14)
    .when((lower(ecr.test_type_code_15).isin(covid_test_type_codes) & lower(ecr.test_result_15).isin(covid_positive_results)), ecr.specimen_collection_date_15)
    .when((lower(ecr.test_type_code_16).isin(covid_test_type_codes) & lower(ecr.test_result_16).isin(covid_positive_results)), ecr.specimen_collection_date_16)
    .when((lower(ecr.test_type_code_17).isin(covid_test_type_codes) & lower(ecr.test_result_17).isin(covid_positive_results)), ecr.specimen_collection_date_17)
    .when((lower(ecr.test_type_code_18).isin(covid_test_type_codes) & lower(ecr.test_result_18).isin(covid_positive_results)), ecr.specimen_collection_date_18)
    .when((lower(ecr.test_type_code_19).isin(covid_test_type_codes) & lower(ecr.test_result_19).isin(covid_positive_results)), ecr.specimen_collection_date_19)
    .when((lower(ecr.test_type_code_20).isin(covid_test_type_codes) & lower(ecr.test_result_20).isin(covid_positive_results)), ecr.specimen_collection_date_20)
    .otherwise(lit(None))
)

In [None]:
# Join MCI and ECR to get ecr updates (positive covid tests)
ecr_updates = ecr.join(mci,((ecr.iris_id ==  mci.person_id_mci) & (datediff(ecr.comparison_date,mci.specimen_collection_date_mci) <= 90)),"inner").select("iris_id","incident_id_mci")
ecr_updates = ecr_updates.toDF("iris_id","incident_id_mci")


In [None]:
# Load ecr delta table
ecr_main = DeltaTable.forPath(spark,ECR_DELTA_TABLE_FILE_PATH)

# Merge in ecr updates such that the incident_id is updated
ecr_main.alias("ecr") \
  .merge(
    ecr_updates.alias("ecr_updates"),
    "ecr.person_id = ecr_updates.iris_id") \
  .whenMatchedUpdate(set = {"incident_id": "ecr_updates.incident_id_mci" }) \
  .execute()
