In [0]:
# Databricks notebook  
from pyspark.sql import functions as F  
from datetime import datetime
  
# Config  
CATALOG = "4_prod"  
SCHEMA = "raw"  
RAW_CE = f"`{CATALOG}`.`raw`.`mill_clinical_event`"  
RAW_ENC = f"`{CATALOG}`.`raw`.`mill_encounter`"  
  
# Modes:  
FULL_REPAIR = False   # Set True for the initial historical backfill  
LOOKBACK_DAYS = 3     # Only used when FULL_REPAIR == False  
  
BARTS_ORG_IDS = [873843, 8367658, 669849, 9073614, 2681833, 4401825, 3203824, 2681830, 8061679, 669848, 9163579, 8467812, 2681824, 9161976, 2619824, 2681827, 3203825, 691988, 3125827, 8061682, 8061694, 2641824, 2641827, 669847, 8056759, 8061685, 2641830, 3201824, 691989, 669845, 669843, 8061691, 669846, 3199824, 669850, 6333825, 669844, 0, 8397458, 8152502, 671843, 613843,9163579, 9161983
]  

BHRUT_ORG_IDS = [-1]

# Get current timestamp for this run
RUN_TIMESTAMP = datetime.now()

spark.sql(f"USE CATALOG `{CATALOG}`")  

# Create logging table for BHRUT updates and NULL trusts
spark.sql(f"""
CREATE TABLE IF NOT EXISTS `6_mgmt`.`logs`.`bhrt_updates` (
    table_name STRING,
    event_id BIGINT,
    encntr_id BIGINT,
    organization_id BIGINT,
    trust STRING,
    processed_timestamp TIMESTAMP,
    run_timestamp TIMESTAMP
)
USING DELTA
""")
  
# Lean lookup views with valid_until_dt_tm filter for clinical_event
spark.sql(f"""  
CREATE OR REPLACE TEMP VIEW v_event_to_encntr AS  
SELECT EVENT_ID, ENCNTR_ID  
FROM {RAW_CE}  
WHERE EVENT_ID IS NOT NULL 
  AND ENCNTR_ID IS NOT NULL
  AND valid_until_dt_tm > current_timestamp()
""")  

# Create deduplicated view for EVENT_ID to ENCNTR_ID mapping
# Using MAX to ensure we get a single value per EVENT_ID
spark.sql(f"""
CREATE OR REPLACE TEMP VIEW v_event_to_encntr_dedup AS
SELECT EVENT_ID, 
       MAX(ENCNTR_ID) as ENCNTR_ID
FROM v_event_to_encntr
GROUP BY EVENT_ID
""")
  
spark.sql(f"""  
CREATE OR REPLACE TEMP VIEW v_encntr_to_org AS  
SELECT ENCNTR_ID, ORGANIZATION_ID  
FROM {RAW_ENC}  
WHERE ENCNTR_ID IS NOT NULL 
  AND ORGANIZATION_ID IS NOT NULL  
""")  

# Create deduplicated view for ENCNTR_ID to ORGANIZATION_ID mapping
# Using MAX to ensure we get a single value per ENCNTR_ID
spark.sql(f"""
CREATE OR REPLACE TEMP VIEW v_encntr_to_org_dedup AS
SELECT ENCNTR_ID, 
       MAX(ORGANIZATION_ID) as ORGANIZATION_ID
FROM v_encntr_to_org
GROUP BY ENCNTR_ID
""")
  
# Barts orgs view  
if len(BARTS_ORG_IDS) > 0:  
    barts_df = spark.createDataFrame([(int(x),) for x in BARTS_ORG_IDS], ["ORGANIZATION_ID"])  
    barts_df.createOrReplaceTempView("v_barts_orgs")  
else:  
    spark.sql("CREATE OR REPLACE TEMP VIEW v_barts_orgs AS SELECT CAST(NULL AS BIGINT) AS ORGANIZATION_ID WHERE FALSE")

# BHRUT orgs view  
if len(BHRUT_ORG_IDS) > 0:  
    bhrut_df = spark.createDataFrame([(int(x),) for x in BHRUT_ORG_IDS], ["ORGANIZATION_ID"])  
    bhrut_df.createOrReplaceTempView("v_bhrut_orgs")  
else:  
    spark.sql("CREATE OR REPLACE TEMP VIEW v_bhrut_orgs AS SELECT CAST(NULL AS BIGINT) AS ORGANIZATION_ID WHERE FALSE")
  
# Helpers  
def get_columns_and_types(table_name: str):  
    rows = spark.sql(f"""  
        SELECT column_name, data_type  
        FROM system.information_schema.columns  
        WHERE table_catalog = '{CATALOG}'  
          AND table_schema = '{SCHEMA}'  
          AND table_name = '{table_name}'  
    """).collect()  
    return {r.column_name.upper(): r.data_type.upper() for r in rows}  
  
def add_column_if_missing(fqn: str, col_name: str, data_type: str, cols_upper: set):  
    if col_name.upper() not in cols_upper:  
        spark.sql(f"ALTER TABLE {fqn} ADD COLUMNS (`{col_name}` {data_type})")  
        cols_upper.add(col_name.upper())  
  
def adc_cond(alias: str, col_types: dict) -> str:
    """Returns a condition to filter records based on ADC_UPDT age"""
    if FULL_REPAIR or "ADC_UPDT" not in col_types:
        return ""
    
    # ADC_UPDT exists and should always be a temporal field
    return f" AND {alias}.ADC_UPDT >= current_timestamp() - INTERVAL {LOOKBACK_DAYS} DAYS"

def trust_cond() -> str:
    """Returns the trust condition based on FULL_REPAIR mode
    In FULL_REPAIR mode, we overwrite existing trust assignments
    Otherwise, we only update NULL trust values"""
    if FULL_REPAIR:
        return ""  # No condition - overwrite all
    else:
        return " AND (Trust IS NULL)"

def find_unique_key(cols_upper: set) -> str:
    """Find a unique identifier column if it exists"""
    possible_keys = ["ROW_ID", "RECORD_ID", "ID", "UNIQUE_ID", "PK_ID", "CLINICAL_EVENT_ID", "ENCNTR_ALIAS_ID"]
    for key in possible_keys:
        if key in cols_upper:
            return key
    return None

def update_encntr_id_safe(fqn: str, col_types: dict):
    """Safely update ENCNTR_ID using aggregated subquery"""
    spark.sql(f"""
        UPDATE {fqn}
        SET ENCNTR_ID = (
            SELECT MAX(m.ENCNTR_ID)
            FROM v_event_to_encntr_dedup AS m
            WHERE m.EVENT_ID = {fqn}.EVENT_ID
        )
        WHERE EVENT_ID IS NOT NULL
          AND ENCNTR_ID IS NULL
          {adc_cond(fqn, col_types)}
    """)

def update_organization_id_safe(fqn: str, col_types: dict):
    """Safely update ORGANIZATION_ID using aggregated subquery"""
    spark.sql(f"""
        UPDATE {fqn}
        SET ORGANIZATION_ID = (
            SELECT MAX(m.ORGANIZATION_ID)
            FROM v_encntr_to_org_dedup AS m
            WHERE m.ENCNTR_ID = {fqn}.ENCNTR_ID
        )
        WHERE ENCNTR_ID IS NOT NULL
          AND ORGANIZATION_ID IS NULL
          {adc_cond(fqn, col_types)}
    """)

def log_trust_records(fqn: str, table_name: str, col_types: dict, cols_upper: set):
    """Log rows with BHRUT trust or NULL trust that are at least 3 days old (if eligible for lookup)"""
    # Build the select columns based on what exists in the table
    select_cols = []
    if "EVENT_ID" in cols_upper:
        select_cols.append("EVENT_ID")
    else:
        select_cols.append("NULL AS EVENT_ID")
    
    if "ENCNTR_ID" in cols_upper:
        select_cols.append("ENCNTR_ID")
    else:
        select_cols.append("NULL AS ENCNTR_ID")
    
    if "ORGANIZATION_ID" in cols_upper:
        select_cols.append("ORGANIZATION_ID")
    else:
        select_cols.append("NULL AS ORGANIZATION_ID")
    
    # Build date condition for NULL trust rows (3+ days old)
    date_cond = ""
    if "ADC_UPDT" in col_types:
        # ADC_UPDT should always be a timestamp - check records at least 3 days old
        date_cond = "AND ADC_UPDT <= current_timestamp() - INTERVAL 3 DAYS"    
        
    # Build condition for NULL trust - only log if eligible for lookup
    # (has non-null ENCNTR_ID or EVENT_ID)
    null_trust_cond = "Trust IS NULL"
    if "ENCNTR_ID" in cols_upper or "EVENT_ID" in cols_upper:
        # Table is eligible for lookup, check if the row has the necessary IDs
        eligibility_parts = []
        if "ENCNTR_ID" in cols_upper:
            eligibility_parts.append("ENCNTR_ID IS NOT NULL")
        if "EVENT_ID" in cols_upper:
            eligibility_parts.append("EVENT_ID IS NOT NULL")
        eligibility_cond = " OR ".join(eligibility_parts)
        null_trust_cond = f"(Trust IS NULL AND ({eligibility_cond}) {date_cond})"
    else:
        # Table not eligible for lookup, don't log NULL trust rows
        null_trust_cond = "FALSE"
    
    # Insert rows that are either BHRUT or NULL trust (if eligible)
    spark.sql(f"""
        INSERT INTO `6_mgmt`.`logs`.`bhrt_updates`
        SELECT 
            '{table_name}' AS table_name,
            {', '.join(select_cols)},
            Trust AS trust,
            current_timestamp() AS processed_timestamp,
            timestamp('{RUN_TIMESTAMP.isoformat()}') AS run_timestamp
        FROM {fqn}
        WHERE (
            Trust = 'BHRUT' 
            OR {null_trust_cond}
        )
    """)

# Target tables  
tables = [r.table_name for r in spark.sql(f"""  
SELECT table_name  
FROM system.information_schema.tables  
WHERE table_catalog = '{CATALOG}'  
  AND table_schema = '{SCHEMA}'
""").collect()]  
  
for t in tables:  
    fqn = f"`{CATALOG}`.`{SCHEMA}`.`{t}`"  
    try:  
        col_types = get_columns_and_types(t)  
        cols_upper = set(col_types.keys())  
        unique_key = find_unique_key(cols_upper)
  
        has_event = "EVENT_ID" in cols_upper  
        has_enc = "ENCNTR_ID" in cols_upper  
  
        # 1) ENCNTR_ID: ensure column and backfill via UPDATE with aggregated subquery
        if has_event:  
            add_column_if_missing(fqn, "ENCNTR_ID", "BIGINT", cols_upper)  
            
            # Try MERGE first with a better construction, fall back to UPDATE if it fails
            try:
                if unique_key:
                    # Use MERGE with unique key and row_number to avoid duplicates
                    spark.sql(f"""  
                        MERGE INTO {fqn} AS tgt  
                        USING (  
                            SELECT {unique_key}, EVENT_ID, ENCNTR_ID
                            FROM (
                                SELECT p.{unique_key}, p.EVENT_ID, m.ENCNTR_ID,
                                       ROW_NUMBER() OVER (PARTITION BY p.{unique_key} ORDER BY m.ENCNTR_ID) as rn
                                FROM {fqn} AS p  
                                JOIN v_event_to_encntr_dedup AS m ON p.EVENT_ID = m.EVENT_ID  
                                WHERE p.EVENT_ID IS NOT NULL  
                                  AND p.ENCNTR_ID IS NULL  
                                  {adc_cond('p', col_types)}
                            )
                            WHERE rn = 1
                        ) src  
                        ON tgt.{unique_key} = src.{unique_key}
                        WHEN MATCHED THEN UPDATE SET tgt.ENCNTR_ID = src.ENCNTR_ID  
                    """)
                else:
                    # No unique key, use UPDATE with aggregated subquery
                    update_encntr_id_safe(fqn, col_types)
            except:
                # If MERGE fails, fall back to UPDATE with aggregated subquery
                update_encntr_id_safe(fqn, col_types)
            
            has_enc = True  # now present  
  
        # 2) ORGANIZATION_ID: ensure column and backfill  
        if has_enc:  
            add_column_if_missing(fqn, "ORGANIZATION_ID", "BIGINT", cols_upper)  
            
            # Try MERGE first with a better construction, fall back to UPDATE if it fails
            try:
                if unique_key:
                    # Use MERGE with unique key and row_number to avoid duplicates
                    spark.sql(f"""  
                        MERGE INTO {fqn} AS tgt  
                        USING (  
                            SELECT {unique_key}, ENCNTR_ID, ORGANIZATION_ID
                            FROM (
                                SELECT q.{unique_key}, q.ENCNTR_ID, m.ORGANIZATION_ID,
                                       ROW_NUMBER() OVER (PARTITION BY q.{unique_key} ORDER BY m.ORGANIZATION_ID) as rn
                                FROM {fqn} AS q  
                                JOIN v_encntr_to_org_dedup AS m ON q.ENCNTR_ID = m.ENCNTR_ID  
                                WHERE q.ENCNTR_ID IS NOT NULL  
                                  AND q.ORGANIZATION_ID IS NULL  
                                  {adc_cond('q', col_types)}
                            )
                            WHERE rn = 1
                        ) src  
                        ON tgt.{unique_key} = src.{unique_key}
                        WHEN MATCHED THEN UPDATE SET tgt.ORGANIZATION_ID = src.ORGANIZATION_ID  
                    """)
                else:
                    # No unique key, use UPDATE with aggregated subquery
                    update_organization_id_safe(fqn, col_types)
            except:
                # If MERGE fails, fall back to UPDATE with aggregated subquery
                update_organization_id_safe(fqn, col_types)
  
        # 3) Trust: always ensure the column exists  
        add_column_if_missing(fqn, "Trust", "STRING", cols_upper)  
  
        # Only update Trust if ORGANIZATION_ID exists on the table  
        has_org = "ORGANIZATION_ID" in cols_upper  
        if has_org:  
            # Update Trust='BHRUT' for BHRUT organizations
            # In FULL_REPAIR mode, overwrite existing trust assignments
            spark.sql(f"""
                UPDATE {fqn}
                SET Trust = 'BHRUT'
                WHERE ORGANIZATION_ID IN (SELECT ORGANIZATION_ID FROM v_bhrut_orgs)
                  {trust_cond()}
                  {adc_cond(fqn, col_types)}
            """)
            
            # Update Trust='Barts' for Barts organizations
            # In FULL_REPAIR mode, overwrite existing trust assignments (except BHRUT which has priority)
            # In normal mode, only update if Trust is NULL
            if FULL_REPAIR:
                # In full repair, update Barts but don't overwrite BHRUT
                spark.sql(f"""
                    UPDATE {fqn}
                    SET Trust = 'Barts'
                    WHERE ORGANIZATION_ID IN (SELECT ORGANIZATION_ID FROM v_barts_orgs)
                      {adc_cond(fqn, col_types)}
                """)
            else:
                # In normal mode, only update NULL trust values
                spark.sql(f"""
                    UPDATE {fqn}
                    SET Trust = 'Barts'
                    WHERE ORGANIZATION_ID IN (SELECT ORGANIZATION_ID FROM v_barts_orgs)
                      AND (Trust IS NULL)
                      {adc_cond(fqn, col_types)}
                """)
            
            # Log trust records AFTER the assignments
            try:
                log_trust_records(fqn, t, col_types, cols_upper)
            except Exception as log_e:
                print(f"Warning: Failed to log trust records for {fqn}: {log_e}")
        else:  
            # No ORGANIZATION_ID column on this table; still log NULL trust records if Trust column exists
            if "TRUST" in cols_upper:
                try:
                    log_trust_records(fqn, t, col_types, cols_upper)
                except Exception as log_e:
                    print(f"Warning: Failed to log trust records for {fqn}: {log_e}")
  
        print(f"Processed {fqn}")  
  
    except Exception as e:  
        print(f"Error processing {fqn}: {e}")

# Print summary of logged records
spark.sql(f"""
    SELECT 
        trust,
        COUNT(*) as record_count,
        COUNT(DISTINCT table_name) as tables_affected
    FROM `6_mgmt`.`logs`.`bhrt_updates`
    WHERE run_timestamp = timestamp('{RUN_TIMESTAMP.isoformat()}')
    GROUP BY trust
    ORDER BY trust
""").show()

# Also show total summary
spark.sql(f"""
    SELECT 
        COUNT(*) as total_logged,
        COUNT(DISTINCT table_name) as total_tables_affected,
        SUM(CASE WHEN trust = 'BHRUT' THEN 1 ELSE 0 END) as bhrut_records,
        SUM(CASE WHEN trust IS NULL THEN 1 ELSE 0 END) as null_trust_records
    FROM `6_mgmt`.`logs`.`bhrt_updates`
    WHERE run_timestamp = timestamp('{RUN_TIMESTAMP.isoformat()}')
""").show()