In [0]:
from datetime import datetime, timedelta
from pyspark.sql import functions as F
from pyspark.sql.functions import broadcast
import time

# ==============================================================================
# CONFIGURATION
# ==============================================================================
CATALOG = "4_prod"
RAW_SCHEMA = "raw"
TMP_SCHEMA = "tmp"

# Log locations
LOG_CATALOG = "6_mgmt"
LOG_SCHEMA = "logs"

spark.sql(f"USE CATALOG {CATALOG}")

RUN_TS = datetime.now()
IS_WEEKLY_RUN = (RUN_TS.weekday() == 6)  # Sunday = Deep Clean

# Retention & Settings
LOOKUP_RETENTION_DAYS = 14
METADATA_FRESHNESS_DAYS = 30
INITIAL_LOOKBACK_VERSIONS = 25
FILTER_BHRUT = True
BACKWINDOW_DAYS = 7
FALLBACK_LOOKBACK_HOURS = 48

# Tables
TRUST_MAP_TBL = f"{CATALOG}.{TMP_SCHEMA}.org_to_trust_map"
ORG_HIST_TBL = f"{CATALOG}.{TMP_SCHEMA}.org_mapping_history"
HUB_TBL = f"{CATALOG}.{TMP_SCHEMA}.sw_mapping_hub"
ENC_ORG_TBL = f"{CATALOG}.{TMP_SCHEMA}.sw_enc_org_map"
CHANGED_ENC_TBL = f"{CATALOG}.{TMP_SCHEMA}.sw_changed_encounters"
CONTROL_TBL = f"{CATALOG}.{TMP_SCHEMA}.incr_updt_trust_control"

# Logging tables
FLAG_TBL = f"{LOG_CATALOG}.{LOG_SCHEMA}.organization_flags"
AUDIT_TBL = f"{LOG_CATALOG}.{LOG_SCHEMA}.bhrt_updates"
# NEW: Aggregated audit summary table
AUDIT_SUMMARY_TBL = f"{LOG_CATALOG}.{LOG_SCHEMA}.trust_audit_summary"
# NEW: Phase 2c version tracking for skip optimization
PHASE2C_TRACKER_TBL = f"{CATALOG}.{TMP_SCHEMA}.phase2c_version_tracker"

# Org Config
BARTS_ORGS = [
    873843, 8367658, 669849, 9073614, 2681833, 4401825, 3203824, 2681830,
    8061679, 669848, 8467812, 2681824, 2619824, 2681827, 3203825, 691988,
    3125827, 8061682, 8061694, 2641824, 2641827, 669847, 8056759, 8061685,
    2641830, 3201824, 691989, 669845, 669843, 8061691, 669846, 3199824,
    669850, 6333825, 669844, 8397458, 8152502, 671843, 613843
]
BHRUT_ORGS = [9161976, 9163579, 9161983, 723896, 9161987, 9161988, 9163583, 9161989]

# Key Lookups Configuration
KEY_LOOKUPS = [
    ("SURG_CASE_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_surgical_case", "ENCNTR_ID"),
    ("SURG_CASE_PROC_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_surg_case_procedure", "NULL"),
    ("DCP_FORMS_ACTIVITY_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_dcp_forms_activity", "ENCNTR_ID"),
    ("EPISODE_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_episode_encntr_reltn", "ENCNTR_ID"),
    ("ORDER_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_orders", "ENCNTR_ID"),
    ("PROBLEM_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_problem", "COALESCE(ORIGINATING_ENCNTR_ID, UPDATE_ENCNTR_ID)"),
    ("SCH_EVENT_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_sch_event_patient", "ENCNTR_ID"),
    ("SCHEDULE_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_sch_schedule", "ENCNTR_ID"),
    ("IM_STUDY_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_im_study", "ENCNTR_ID"),
    ("IM_ACQUIRED_STUDY_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_im_acquired_study", "NULL"),
    ("CV_PROC_ID", f"{CATALOG}.{RAW_SCHEMA}.mill_cv_proc", "ENCNTR_ID")
]

# Dependency Map for Chained Updates
CHAINED_DEPENDENCIES = {
    "IM_STUDY_ID": ("mill_im_acquired_study", "MATCHED_STUDY_ID", "IM_ACQUIRED_STUDY_ID"),
    "SURG_CASE_ID": ("mill_surg_case_procedure", "SURG_CASE_ID", "SURG_CASE_PROC_ID")
}

# Tables protected from deletion
TABLES_SKIP_NULL_TRUST_LOGGING = {
    "MILL_PERSON_ORG_RELTN", "MILL_ORG_ORG_RELTN", "MILL_ORGANIZATION_ALIAS",
    "MILL_ORGANIZATION", "MILL_PRSNL_ORG_RELTN", "MILL_ORG_TYPE_RELTN",
    "MILL_SCH_LOCATION", "MILL_SCH_EVENT", "MILL_SCH_APPT",
    "MILL_SCH_EVENT_ALIAS", "MILL_SCH_SCHEDULE"
}

# Source of Truth Tables - don't overwrite ENCNTR_ID on these
SOURCE_OF_TRUTH_TABLES = {
    "MILL_CLINICAL_EVENT", "MILL_ENCOUNTER", "MILL_ORDERS",
    "MILL_SURGICAL_CASE", "MILL_DCP_FORMS_ACTIVITY",
    "MILL_EPISODE_ENCNTR_RELTN", "MILL_SCH_EVENT_PATIENT",
    "MILL_SCH_SCHEDULE", "MILL_IM_STUDY", "MILL_CV_PROC",
    "MILL_PROBLEM"
}

# ==============================================================================
# 1. INFRASTRUCTURE & HELPERS
# ==============================================================================
def time_op(name, fn):
    """Time an operation and print duration"""
    s = time.time()
    try:
        res = fn()
        duration = time.time() - s
        print(f"  {name}: {duration:.2f}s")
        return res
    except Exception as e:
        duration = time.time() - s
        print(f"  {name} FAILED ({duration:.2f}s): {str(e)[:500]}")
        raise e


def ensure_setup():
    """Create all required schemas and tables"""
    spark.sql(f"CREATE SCHEMA IF NOT EXISTS {CATALOG}.{TMP_SCHEMA}")

    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {FLAG_TBL} (
            organization_id LONG, alert INT, event_time TIMESTAMP,
            table_name STRING, first_seen_timestamp TIMESTAMP
        ) USING DELTA
    """)
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {CONTROL_TBL} (
            table_name STRING, last_version LONG, last_timestamp TIMESTAMP, updated_at TIMESTAMP
        ) USING DELTA
    """)
    # Original detailed audit table (keep for backward compatibility, but we won't insert to it by default)
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {AUDIT_TBL} (
            table_name STRING, event_id BIGINT, encntr_id BIGINT,
            organization_id BIGINT, trust STRING, processed_timestamp TIMESTAMP,
            run_timestamp TIMESTAMP, adc_timestamp TIMESTAMP
        ) USING DELTA
    """)
    # NEW: Aggregated audit summary - much lower overhead
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {AUDIT_SUMMARY_TBL} (
            run_timestamp TIMESTAMP,
            table_name STRING,
            trust STRING,
            operation STRING,
            record_count BIGINT,
            min_pk BIGINT,
            max_pk BIGINT
        ) USING DELTA
    """)
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {ORG_HIST_TBL} (
            organization_id LONG, trust STRING, first_seen_timestamp TIMESTAMP
        ) USING DELTA
    """)
    spark.sql(f"CREATE TABLE IF NOT EXISTS {TRUST_MAP_TBL} (organization_id LONG, trust STRING) USING DELTA")

    values = ",".join([f"({o}, 'Barts')" for o in BARTS_ORGS] + [f"({o}, 'BHRUT')" for o in BHRUT_ORGS])
    spark.sql(f"INSERT OVERWRITE {TRUST_MAP_TBL} SELECT * FROM VALUES {values}")

    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {ENC_ORG_TBL} (
            ENCNTR_ID LONG, ORGANIZATION_ID LONG, last_updated TIMESTAMP
        ) USING DELTA
    """)
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {HUB_TBL} (
            key_type STRING, key_id LONG, ENCNTR_ID LONG, last_updated TIMESTAMP
        ) USING DELTA
    """)
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {CHANGED_ENC_TBL} (
            ENCNTR_ID LONG, ORGANIZATION_ID LONG, trust STRING,
            change_timestamp TIMESTAMP, run_date DATE
        ) USING DELTA PARTITIONED BY (run_date)
    """)
    
    # NEW: Phase 2c version tracker - tracks which version we last updated each table to
    spark.sql(f"""
        CREATE TABLE IF NOT EXISTS {PHASE2C_TRACKER_TBL} (
            table_name STRING,
            last_updated_version LONG,
            hub_version LONG,
            enc_org_version LONG,
            updated_at TIMESTAMP
        ) USING DELTA
    """)


def track_newly_mapped_orgs():
    """Identifies Orgs added to the config in the last BACKWINDOW_DAYS"""
    all_org_ids = [(o, 'Barts') for o in BARTS_ORGS] + [(o, 'BHRUT') for o in BHRUT_ORGS]
    df_config = spark.createDataFrame(all_org_ids, ["organization_id", "trust"])
    df_hist = spark.table(ORG_HIST_TBL)

    new_orgs = df_config.join(df_hist, "organization_id", "left_anti") \
        .withColumn("first_seen_timestamp", F.current_timestamp())

    if new_orgs.count() > 0:
        new_orgs.write.format("delta").mode("append").saveAsTable(ORG_HIST_TBL)
        print(f"    Added {new_orgs.count()} new organizations to history")

    recent = spark.sql(f"""
        SELECT organization_id FROM {ORG_HIST_TBL}
        WHERE first_seen_timestamp >= current_timestamp() - INTERVAL {BACKWINDOW_DAYS} DAYS
    """).collect()
    return [row['organization_id'] for row in recent]


def can_delete_bhrut(table_upper: str) -> bool:
    return FILTER_BHRUT and (table_upper.upper() not in TABLES_SKIP_NULL_TRUST_LOGGING)


def get_table_columns(fqn: str) -> set:
    try:
        return {c.upper() for c in spark.table(fqn).columns}
    except:
        return set()


def pick_time_col(fqn: str):
    """Pick a suitable timestamp column for CDC fallback filters."""
    try:
        cols = {c.upper() for c in spark.table(fqn).columns}
    except:
        return None
    if "ADC_UPDT" in cols:
        return "ADC_UPDT"
    if "UPDT_DT_TM" in cols:
        return "UPDT_DT_TM"
    if "CLINSIG_UPDT_DT_TM" in cols:
        return "CLINSIG_UPDT_DT_TM"
    return None


def log_status(table_name, status, reason=""):
    """Standardized logging format with ANSI colors."""
    c_blue = "\033[94m"
    c_green = "\033[92m"
    c_red = "\033[91m"
    c_yellow = "\033[93m"
    c_reset = "\033[0m"

    color = c_reset
    if status == "QUEUE":
        color = c_blue
    if status == "PROC":
        color = c_green
    if status == "ERR":
        color = c_red
    if status == "SKIP":
        color = c_yellow

    print(f"    {color}[{status}] {table_name:<35} : {reason}{c_reset}")


# ==============================================================================
# 1b. VERSION TRACKING HELPERS (NEW)
# ==============================================================================
def get_table_current_version(fqn: str) -> int:
    """Get the current version of a Delta table."""
    try:
        return spark.sql(f"DESCRIBE HISTORY {fqn} LIMIT 1").collect()[0]['version']
    except:
        return -1


def get_hub_versions() -> tuple:
    """Get current versions of hub and enc_org tables."""
    hub_ver = get_table_current_version(HUB_TBL)
    enc_org_ver = get_table_current_version(ENC_ORG_TBL)
    return hub_ver, enc_org_ver


def should_skip_phase2c(table_name: str, fqn: str, current_hub_ver: int, current_enc_org_ver: int) -> tuple:
    """
    Check if we can skip Phase 2c processing for this table.
    
    Returns (should_skip: bool, reason: str)
    
    Skip conditions:
    1. Our pipeline was the last to update this table (table version unchanged)
    2. AND the hub/enc_org mappings haven't changed since our last update
    
    This means: if no new data came into the table AND no new encounter mappings
    were created, there's nothing new to propagate.
    """
    try:
        row = spark.sql(f"""
            SELECT last_updated_version, hub_version, enc_org_version
            FROM {PHASE2C_TRACKER_TBL}
            WHERE table_name = '{table_name}'
        """).collect()
        
        if not row:
            return False, "No previous tracking record"
        
        tracked = row[0]
        last_table_ver = tracked['last_updated_version']
        last_hub_ver = tracked['hub_version']
        last_enc_org_ver = tracked['enc_org_version']
        
        current_table_ver = get_table_current_version(fqn)
        
        # Condition 1: Table version unchanged (we were the last updater)
        if current_table_ver != last_table_ver:
            return False, f"Table changed (v{last_table_ver} -> v{current_table_ver})"
        
        # Condition 2: Hub mappings unchanged (no new key->encounter mappings)
        if current_hub_ver != last_hub_ver:
            return False, f"Hub changed (v{last_hub_ver} -> v{current_hub_ver})"
        
        # Condition 3: Encounter->Org mappings unchanged
        if current_enc_org_ver != last_enc_org_ver:
            return False, f"Enc-Org map changed (v{last_enc_org_ver} -> v{current_enc_org_ver})"
        
        # All conditions met - safe to skip
        return True, f"No changes since last update (table v{current_table_ver}, hub v{current_hub_ver})"
        
    except Exception as e:
        return False, f"Tracking check failed: {str(e)[:50]}"


def update_phase2c_tracker(table_name: str, fqn: str, hub_ver: int, enc_org_ver: int):
    """Record that we updated this table in Phase 2c with current hub versions."""
    table_ver = get_table_current_version(fqn)
    
    spark.sql(f"""
        MERGE INTO {PHASE2C_TRACKER_TBL} t
        USING (SELECT 
            '{table_name}' as table_name, 
            {table_ver} as last_updated_version,
            {hub_ver} as hub_version,
            {enc_org_ver} as enc_org_version,
            current_timestamp() as updated_at
        ) s
        ON t.table_name = s.table_name
        WHEN MATCHED THEN UPDATE SET
            last_updated_version = s.last_updated_version,
            hub_version = s.hub_version,
            enc_org_version = s.enc_org_version,
            updated_at = s.updated_at
        WHEN NOT MATCHED THEN INSERT *
    """)


# ==============================================================================
# 2. MATERIALIZED HUB VIEWS (NEW - Performance Optimization)
# ==============================================================================
def materialize_hub_views():
    """
    Create temp views for the deduplicated hub and enc_org mappings.
    These are reused throughout the pipeline instead of recalculating.
    """
    print("  Materializing hub views...")

    # Deduplicated Encounter -> Org mapping
    spark.sql(f"""
        CREATE OR REPLACE TEMP VIEW enc_org_deduped AS
        SELECT ENCNTR_ID, ORGANIZATION_ID
        FROM (
            SELECT *, ROW_NUMBER() OVER (PARTITION BY ENCNTR_ID ORDER BY last_updated DESC) as rn
            FROM {ENC_ORG_TBL}
        )
        WHERE rn = 1
    """)

    # Deduplicated Hub (all key types)
    spark.sql(f"""
        CREATE OR REPLACE TEMP VIEW hub_deduped AS
        SELECT key_type, key_id, ENCNTR_ID
        FROM (
            SELECT *, ROW_NUMBER() OVER (PARTITION BY key_type, key_id ORDER BY last_updated DESC) as rn
            FROM {HUB_TBL}
        )
        WHERE rn = 1
    """)

    # Trust map as temp view for easy joining
    spark.sql(f"""
        CREATE OR REPLACE TEMP VIEW trust_map AS
        SELECT organization_id, trust FROM {TRUST_MAP_TBL}
    """)

    # Pre-joined enc -> trust lookup (frequently used)
    spark.sql(f"""
        CREATE OR REPLACE TEMP VIEW enc_trust_lookup AS
        SELECT eo.ENCNTR_ID, eo.ORGANIZATION_ID, tm.trust
        FROM enc_org_deduped eo
        JOIN trust_map tm ON eo.ORGANIZATION_ID = tm.organization_id
    """)


# ==============================================================================
# 3. DYNAMIC DISCOVERY (With Scopes)
# ==============================================================================
def get_table_scopes():
    """Scans schema and logs why tables are selected or rejected."""
    print(f"  Scanning schema for candidate tables...")

    try:
        all_tables_rows = spark.sql(f"SHOW TABLES IN {CATALOG}.{RAW_SCHEMA}").collect()
        all_tables = sorted([row.tableName for row in all_tables_rows])
    except Exception as e:
        print(f"    Error listing tables: {e}")
        return {}, {}

    cdc_candidates = {}
    propagation_candidates = {}

    print(f"  Found {len(all_tables)} total tables. Analyzing...")

    for t in all_tables:
        fqn = f"{CATALOG}.{RAW_SCHEMA}.{t}"

        if not t.startswith("mill_") or t.lower() == "mill_long_text":
            continue

        try:
            cols = get_table_columns(fqn)

            if "TRUST" not in cols:
                log_status(t, "SKIP", "No 'TRUST' column found")
                continue

            propagation_candidates[t] = cols

            details = spark.sql(f"DESCRIBE DETAIL {fqn}").collect()[0]
            delta = datetime.now() - details['lastModified']
            hours_old = delta.total_seconds() / 3600
            days_old = delta.days

            is_forced = globals().get('FORCE_RUN', False)

            if is_forced or (days_old <= METADATA_FRESHNESS_DAYS):
                cdc_candidates[t] = cols
                status_msg = f"Queued for CDC (Last mod: {hours_old:.1f}h ago)"
                if is_forced:
                    status_msg += " [FORCED]"
                log_status(t, "QUEUE", status_msg)
            else:
                log_status(t, "SKIP", f"Stale data (Last mod: {hours_old:.1f}h ago > {METADATA_FRESHNESS_DAYS}d limit)")

        except Exception as e:
            log_status(t, "ERR", f"Metadata read failed: {str(e)[:50]}")
            continue

    print(f"  Summary: {len(cdc_candidates)} CDC Candidates, {len(propagation_candidates)} Propagation Candidates.")
    return cdc_candidates, propagation_candidates


# ==============================================================================
# 4. CDC UTILITIES WITH FALLBACK
# ==============================================================================
def get_cdc_window(fqn, checkpoint_key=None):
    """
    Returns: (start_version, end_version) or ("FALLBACK", last_timestamp) or (None, None)
    """
    lookup_key = checkpoint_key if checkpoint_key else fqn

    try:
        hist_df = spark.sql(f"DESCRIBE HISTORY {fqn}")
        hist_summary = hist_df.select(
            F.min("version").alias("min_ver"),
            F.max("version").alias("max_ver")
        ).collect()[0]
        min_avail_ver, curr_ver = hist_summary['min_ver'], hist_summary['max_ver']

        row = spark.sql(f"""
            SELECT last_version, last_timestamp FROM {CONTROL_TBL}
            WHERE table_name = '{lookup_key}'
        """).collect()
        last_processed = row[0]['last_version'] if row else None
        last_ts = row[0]['last_timestamp'] if row else None

        if last_processed is None:
            start_ver = max(curr_ver - INITIAL_LOOKBACK_VERSIONS, min_avail_ver, 0)
        elif last_processed < min_avail_ver:
            print(f"      [FALLBACK] History truncated for {fqn.split('.')[-1]} (Key: {lookup_key})")
            return ("FALLBACK", last_ts)
        else:
            start_ver = last_processed + 1

        if start_ver > curr_ver:
            return (None, None)
        return (start_ver, curr_ver)
    except Exception as e:
        print(f"      [FALLBACK] CDF Check failed for {fqn.split('.')[-1]}: {str(e)[:80]}")
        try:
            row = spark.sql(f"SELECT last_timestamp FROM {CONTROL_TBL} WHERE table_name = '{lookup_key}'").collect()
            last_ts = row[0]['last_timestamp'] if row else None
            return ("FALLBACK", last_ts)
        except:
            return ("FALLBACK", None)


def update_checkpoint(fqn, version=None, timestamp=None, checkpoint_key=None):
    """Updates the control table."""
    if version is None and timestamp is None:
        return

    key_val = checkpoint_key if checkpoint_key else fqn
    version_val = version if version is not None else "NULL"
    ts_val = f"timestamp('{timestamp}')" if timestamp else "current_timestamp()"

    spark.sql(f"""
        MERGE INTO {CONTROL_TBL} t
        USING (SELECT '{key_val}' as n, {version_val} as v, {ts_val} as ts) s
        ON t.table_name = s.n
        WHEN MATCHED THEN UPDATE SET
            last_version = COALESCE(s.v, t.last_version),
            last_timestamp = s.ts,
            updated_at = current_timestamp()
        WHEN NOT MATCHED THEN INSERT (table_name, last_version, last_timestamp, updated_at)
            VALUES (s.n, s.v, s.ts, current_timestamp())
    """)


def get_fallback_filter(fqn: str, last_ts):
    """Get timestamp filter for fallback mode."""
    col = pick_time_col(fqn)
    if col is None:
        is_key_source = any(k[1] == fqn for k in KEY_LOOKUPS) or \
                        "mill_clinical_event" in fqn.lower() or \
                        "mill_encounter" in fqn.lower()
        if is_key_source:
            print(f"      [WARNING] No time column for {fqn} but it is a Key Source. Defaulting to FULL SCAN.")
            return "1=1"

        print(f"      [WARNING] No time column for {fqn}. Skipping fallback scan for safety.")
        return "1=0"

    if last_ts:
        return f"{col} > timestamp('{last_ts}')"
    else:
        return f"{col} >= current_timestamp() - INTERVAL {FALLBACK_LOOKBACK_HOURS} HOURS"


# ==============================================================================
# 5. PHASE 1: MAINTAIN THE HUB
# ==============================================================================
def refresh_hub():
    """Refresh mapping hub with Runtime Fallback logic and Checkpoint Isolation."""
    print("Phase 1: Refreshing Hub...")
    processed_keys = set()

    # Clean up tracking table for current run
    spark.sql(f"DELETE FROM {CHANGED_ENC_TBL} WHERE run_date < current_date() - INTERVAL 1 DAY")

    # ==========================================================================
    # 1. SYNC ENCOUNTERS (Enc -> Org Map)
    # ==========================================================================
    enc_fqn = f"{CATALOG}.{RAW_SCHEMA}.mill_encounter"
    hub_key = f"{enc_fqn}_HUB"

    sv, ev = get_cdc_window(enc_fqn, checkpoint_key=hub_key)

    if sv is not None:
        print(f"    [PROC] MILL_ENCOUNTER                : Syncing Hub (v{sv} to v{ev})")

        def run_enc_sync(is_fallback):
            if is_fallback:
                last_ts_row = spark.sql(f"SELECT last_timestamp FROM {CONTROL_TBL} WHERE table_name = '{hub_key}'").collect()
                ts_val = last_ts_row[0]['last_timestamp'] if last_ts_row else None
                time_filter = get_fallback_filter(enc_fqn, ts_val)
                sql = f"""
                    SELECT ENCNTR_ID, ORGANIZATION_ID, current_timestamp() as last_updated
                    FROM {enc_fqn}
                    WHERE {time_filter} AND ENCNTR_ID IS NOT NULL AND ORGANIZATION_ID IS NOT NULL
                """
            else:
                sql = f"""
                    SELECT ENCNTR_ID, ORGANIZATION_ID, current_timestamp() as last_updated
                    FROM table_changes('{enc_fqn}', {sv}, {ev})
                    WHERE _change_type IN ('insert', 'update_postimage')
                      AND ENCNTR_ID IS NOT NULL AND ORGANIZATION_ID IS NOT NULL
                """

            spark.sql(f"""
                MERGE INTO {ENC_ORG_TBL} tgt
                USING (
                    SELECT ENCNTR_ID, MAX(ORGANIZATION_ID) as ORGANIZATION_ID, MAX(last_updated) as last_updated
                    FROM ({sql}) GROUP BY ENCNTR_ID
                ) src
                ON tgt.ENCNTR_ID = src.ENCNTR_ID
                WHEN MATCHED THEN UPDATE SET
                    ORGANIZATION_ID = src.ORGANIZATION_ID,
                    last_updated = src.last_updated
                WHEN NOT MATCHED THEN INSERT (ENCNTR_ID, ORGANIZATION_ID, last_updated)
                    VALUES (src.ENCNTR_ID, src.ORGANIZATION_ID, src.last_updated)
            """)

            spark.sql(f"""
                INSERT INTO {CHANGED_ENC_TBL}
                SELECT DISTINCT c.ENCNTR_ID, c.ORGANIZATION_ID, tm.trust, current_timestamp(), current_date()
                FROM ({sql}) c
                LEFT JOIN {TRUST_MAP_TBL} tm ON c.ORGANIZATION_ID = tm.organization_id
            """)

        try:
            if sv == "FALLBACK":
                raise Exception("Force Fallback Mode")
            run_enc_sync(is_fallback=False)
            update_checkpoint(enc_fqn, version=ev, checkpoint_key=hub_key)
        except Exception as e:
            print(f"      [Runtime Fallback] Encounters Sync switched to Time-Window. Reason: {str(e)[:100]}")
            run_enc_sync(is_fallback=True)
            update_checkpoint(enc_fqn, timestamp=datetime.now().isoformat(), checkpoint_key=hub_key)
    else:
        print(f"    [SKIP] MILL_ENCOUNTER                : Hub up to date (Checked v{ev})")

    # ==========================================================================
    # 2. SYNC KEYS (Specific ID -> Enc Map)
    # ==========================================================================
    for key_col, source_fqn, enc_expr in KEY_LOOKUPS:
        table_name = source_fqn.split('.')[-1].upper()
        hub_key = f"{source_fqn}_HUB"

        sv, ev = get_cdc_window(source_fqn, checkpoint_key=hub_key)

        if sv is None:
            print(f"    [SKIP] {table_name:<29} : Hub up to date (Checked v{ev})")
            continue

        processed_keys.add(key_col)
        print(f"    [PROC] {table_name:<29} : Syncing Key {key_col} (v{sv} to v{ev})")

        def run_key_sync(is_fallback):
            if is_fallback:
                last_ts_row = spark.sql(f"SELECT last_timestamp FROM {CONTROL_TBL} WHERE table_name = '{hub_key}'").collect()
                ts_val = last_ts_row[0]['last_timestamp'] if last_ts_row else None
                time_filter = get_fallback_filter(source_fqn, ts_val)
                changes_view = f"(SELECT * FROM {source_fqn} WHERE {time_filter})"
            else:
                changes_view = f"""(SELECT * FROM table_changes('{source_fqn}', {sv}, {ev})
                                   WHERE _change_type IN ('insert', 'update_postimage'))"""

            if key_col == "SURG_CASE_PROC_ID":
                key_enc_sql = f"""
                    SELECT c.SURG_CASE_PROC_ID as key_id, sc.ENCNTR_ID
                    FROM {changes_view} c
                    JOIN {CATALOG}.{RAW_SCHEMA}.mill_surgical_case sc ON c.SURG_CASE_ID = sc.SURG_CASE_ID
                    WHERE sc.ENCNTR_ID IS NOT NULL
                """
            elif key_col == "IM_STUDY_ID":
                im_cols = get_table_columns(source_fqn)
                has_direct_enc = "ENCNTR_ID" in im_cols
                enc_select = "COALESCE(c.ENCNTR_ID, cv.ENCNTR_ID)" if has_direct_enc else "cv.ENCNTR_ID"

                key_enc_sql = f"""
                    SELECT c.IM_STUDY_ID as key_id, {enc_select} as ENCNTR_ID
                    FROM {changes_view} c
                    LEFT JOIN {CATALOG}.{RAW_SCHEMA}.mill_cv_proc cv ON c.ORIG_ENTITY_ID = cv.CV_PROC_ID
                    WHERE {enc_select} IS NOT NULL
                """
            elif key_col == "IM_ACQUIRED_STUDY_ID":
                im_fqn = f"{CATALOG}.{RAW_SCHEMA}.mill_im_study"
                im_cols = get_table_columns(im_fqn)
                has_direct_enc = "ENCNTR_ID" in im_cols
                enc_select = "COALESCE(st.ENCNTR_ID, cv.ENCNTR_ID)" if has_direct_enc else "cv.ENCNTR_ID"

                key_enc_sql = f"""
                    SELECT acq.IM_ACQUIRED_STUDY_ID as key_id, {enc_select} as ENCNTR_ID
                    FROM {changes_view} acq
                    JOIN {im_fqn} st ON acq.MATCHED_STUDY_ID = st.IM_STUDY_ID
                    LEFT JOIN {CATALOG}.{RAW_SCHEMA}.mill_cv_proc cv ON st.ORIG_ENTITY_ID = cv.CV_PROC_ID
                    WHERE {enc_select} IS NOT NULL
                """
            else:
                key_enc_sql = f"""
                    SELECT {key_col} as key_id, {enc_expr} as ENCNTR_ID
                    FROM {changes_view}
                    WHERE {enc_expr} IS NOT NULL
                """

            spark.sql(f"""
                MERGE INTO {HUB_TBL} tgt
                USING (
                    SELECT '{key_col}' as key_type, key_id, MAX(ENCNTR_ID) as ENCNTR_ID,
                           current_timestamp() as last_updated
                    FROM ({key_enc_sql}) GROUP BY key_id
                ) src
                ON tgt.key_type = src.key_type AND tgt.key_id = src.key_id
                WHEN MATCHED THEN UPDATE SET
                    ENCNTR_ID = src.ENCNTR_ID,
                    last_updated = src.last_updated
                WHEN NOT MATCHED THEN INSERT (key_type, key_id, ENCNTR_ID, last_updated)
                    VALUES (src.key_type, src.key_id, src.ENCNTR_ID, src.last_updated)
            """)

        try:
            if sv == "FALLBACK":
                raise Exception("Force Fallback Mode")
            run_key_sync(is_fallback=False)
            update_checkpoint(source_fqn, version=ev, checkpoint_key=hub_key)
        except Exception as e:
            print(f"      [Runtime Fallback] Key Sync ({key_col}) switched to Time-Window. Reason: {str(e)[:100]}")
            run_key_sync(is_fallback=True)
            update_checkpoint(source_fqn, timestamp=datetime.now().isoformat(), checkpoint_key=hub_key)

    # ==========================================================================
    # 3. CLINICAL EVENTS (Event -> Enc Map)
    # ==========================================================================
    ce_fqn = f"{CATALOG}.{RAW_SCHEMA}.mill_clinical_event"
    hub_key = f"{ce_fqn}_HUB"

    sv, ev = get_cdc_window(ce_fqn, checkpoint_key=hub_key)

    if sv is not None:
        print(f"    [PROC] MILL_CLINICAL_EVENT           : Syncing Events (v{sv} to v{ev})")

        def run_ce_sync(is_fallback):
            if is_fallback:
                last_ts_row = spark.sql(f"SELECT last_timestamp FROM {CONTROL_TBL} WHERE table_name = '{hub_key}'").collect()
                ts_val = last_ts_row[0]['last_timestamp'] if last_ts_row else None
                time_filter = get_fallback_filter(ce_fqn, ts_val)
                sql = f"""
                    SELECT EVENT_ID as key_id, ENCNTR_ID
                    FROM {ce_fqn}
                    WHERE {time_filter}
                      AND EVENT_ID IS NOT NULL AND ENCNTR_ID IS NOT NULL
                      AND (VALID_UNTIL_DT_TM IS NULL OR VALID_UNTIL_DT_TM > current_timestamp())
                """
            else:
                sql = f"""
                    SELECT EVENT_ID as key_id, ENCNTR_ID
                    FROM table_changes('{ce_fqn}', {sv}, {ev})
                    WHERE _change_type IN ('insert', 'update_postimage')
                      AND EVENT_ID IS NOT NULL AND ENCNTR_ID IS NOT NULL
                      AND (VALID_UNTIL_DT_TM IS NULL OR VALID_UNTIL_DT_TM > current_timestamp())
                """

            spark.sql(f"""
                MERGE INTO {HUB_TBL} tgt
                USING (
                    SELECT 'EVENT_ID' as key_type, key_id, MAX(ENCNTR_ID) as ENCNTR_ID,
                           current_timestamp() as last_updated
                    FROM ({sql}) GROUP BY key_id
                ) src
                ON tgt.key_type = src.key_type AND tgt.key_id = src.key_id
                WHEN MATCHED THEN UPDATE SET
                    ENCNTR_ID = src.ENCNTR_ID,
                    last_updated = src.last_updated
                WHEN NOT MATCHED THEN INSERT (key_type, key_id, ENCNTR_ID, last_updated)
                    VALUES (src.key_type, src.key_id, src.ENCNTR_ID, src.last_updated)
            """)

        try:
            if sv == "FALLBACK":
                raise Exception("Force Fallback Mode")
            run_ce_sync(is_fallback=False)
            update_checkpoint(ce_fqn, version=ev, checkpoint_key=hub_key)
            processed_keys.add('EVENT_ID')
        except Exception as e:
            print(f"      [Runtime Fallback] Event Sync switched to Time-Window. Reason: {str(e)[:100]}")
            run_ce_sync(is_fallback=True)
            update_checkpoint(ce_fqn, timestamp=datetime.now().isoformat(), checkpoint_key=hub_key)
            processed_keys.add('EVENT_ID')
    else:
        print(f"    [SKIP] MILL_CLINICAL_EVENT           : Hub up to date (Checked v{ev})")

    # Cleanup Old Cache
    spark.sql(f"DELETE FROM {HUB_TBL} WHERE last_updated < current_timestamp() - INTERVAL {LOOKUP_RETENTION_DAYS} DAYS")
    spark.sql(f"DELETE FROM {ENC_ORG_TBL} WHERE last_updated < current_timestamp() - INTERVAL {LOOKUP_RETENTION_DAYS} DAYS")

    return processed_keys


# ==============================================================================
# 6. PHASE 2: UPDATE TARGETS WITH JOIN-BASED RESOLUTION (OPTIMIZED)
# ==============================================================================
def build_trust_resolution_sql_v2(table_name: str, cols: set, pk_col: str, changes_sql: str):
    """
    Build trust resolution SQL using explicit JOINs instead of correlated subqueries.
    This is significantly more efficient for tables with multiple key columns.
    """
    # Identify which keys are present in this table
    key_cols_present = []
    if "ENCNTR_ID" in cols:
        key_cols_present.append(("ENCNTR_ID", "direct"))
    if "EVENT_ID" in cols:
        key_cols_present.append(("EVENT_ID", "hub"))
    for key_col, _, _ in KEY_LOOKUPS:
        if key_col in cols:
            key_cols_present.append((key_col, "hub"))

    has_org = "ORGANIZATION_ID" in cols

    # If no way to link to encounter and no Org ID, skip
    if not key_cols_present and not has_org:
        return None

    # Build JOIN clauses and COALESCE parts
    join_clauses = []
    enc_coalesce_parts = []
    join_idx = 0

    for key_col, key_type in key_cols_present:
        if key_type == "direct":
            enc_coalesce_parts.append("src.ENCNTR_ID")
        else:
            alias = f"h{join_idx}"
            join_clauses.append(f"""
                LEFT JOIN hub_deduped {alias}
                    ON {alias}.key_type = '{key_col}' AND {alias}.key_id = src.{key_col}
            """)
            enc_coalesce_parts.append(f"{alias}.ENCNTR_ID")
            join_idx += 1

    enc_coalesce = f"COALESCE({', '.join(enc_coalesce_parts)})" if enc_coalesce_parts else "CAST(NULL AS BIGINT)"
    joins_sql = "\n".join(join_clauses)

    # Build org resolution for the final CTE
    org_parts = []
    if has_org:
        # CORRECTION HERE: 
        # Refer to the column aliased in the 'resolved' CTE ('r'), not the 'changes' CTE ('src')
        org_parts.append("r.direct_org") 
    if enc_coalesce_parts:
        org_parts.append("etl.ORGANIZATION_ID")

    org_coalesce = f"COALESCE({', '.join(org_parts)})" if org_parts else "CAST(NULL AS BIGINT)"

    return f"""
        WITH changes AS ({changes_sql}),
        resolved AS (
            SELECT
                src.{pk_col},
                {enc_coalesce} as resolved_enc,
                {"src.ORGANIZATION_ID" if has_org else "CAST(NULL AS BIGINT)"} as direct_org
            FROM changes src
            {joins_sql}
        ),
        with_trust AS (
            SELECT
                r.{pk_col},
                COALESCE(r.resolved_enc, etl.ENCNTR_ID) as new_enc_id,
                {org_coalesce} as new_org_id,
                COALESCE(tm_direct.trust, etl.trust) as new_trust
            FROM resolved r
            LEFT JOIN enc_trust_lookup etl ON r.resolved_enc = etl.ENCNTR_ID
            LEFT JOIN trust_map tm_direct ON r.direct_org = tm_direct.organization_id
            WHERE COALESCE(tm_direct.trust, etl.trust) IS NOT NULL
        )
        SELECT {pk_col}, new_enc_id, new_org_id, new_trust FROM with_trust
    """


def audit_changes_aggregated(table_name: str, pk_col: str, trust_counts: dict):
    """
    Log aggregated audit summary instead of individual rows.
    This dramatically reduces audit table overhead.
    """
    return
    try:
        for trust, stats in trust_counts.items():
            # FIX: Convert Python None to SQL NULL string to prevent "UNRESOLVED_COLUMN" error
            min_pk_val = stats['min_pk'] if stats['min_pk'] is not None else "NULL"
            max_pk_val = stats['max_pk'] if stats['max_pk'] is not None else "NULL"

            spark.sql(f"""
                INSERT INTO {AUDIT_SUMMARY_TBL}
                VALUES (
                    timestamp('{RUN_TS.isoformat()}'),
                    '{table_name}',
                    '{trust}',
                    'UPDATE',
                    {stats['count']},
                    {min_pk_val},
                    {max_pk_val}
                )
            """)
    except Exception as e:
        print(f"      [Audit Warning] Could not log summary for {table_name}: {e}")


def execute_merge_deduped_v2(fqn: str, source_sql: str, pk_col: str, can_delete: bool, cols: set, processed_pks: set = None):
    """
    Execute MERGE with pre-deduplication and aggregated audit logging.
    Returns set of processed PKs for deduplication in Phase 2c.
    """
    table_upper = fqn.split('.')[-1].upper()
    has_enc = "ENCNTR_ID" in cols
    has_org = "ORGANIZATION_ID" in cols
    has_adc = "ADC_UPDT" in cols

    # LINEAGE GUARD
    safe_to_update_enc = has_enc and table_upper not in SOURCE_OF_TRUTH_TABLES
    safe_to_update_org = has_org and table_upper != "MILL_ENCOUNTER"

    set_parts = ["tgt.Trust = src.new_trust"]
    if safe_to_update_enc:
        set_parts.append("tgt.ENCNTR_ID = COALESCE(tgt.ENCNTR_ID, src.new_enc_id)")
    if safe_to_update_org:
        set_parts.append("tgt.ORGANIZATION_ID = COALESCE(tgt.ORGANIZATION_ID, src.new_org_id)")
    if has_adc:
        set_parts.append("tgt.ADC_UPDT = current_timestamp()")

    delete_clause = ""
    if FILTER_BHRUT and can_delete:
        delete_clause = "WHEN MATCHED AND src.new_trust = 'BHRUT' THEN DELETE"

    deduped_source = f"""
        SELECT {pk_col}, MAX(new_enc_id) as new_enc_id, MAX(new_org_id) as new_org_id,
               MAX(new_trust) as new_trust
        FROM ({source_sql}) GROUP BY {pk_col}
    """

    # Execute the merge
    spark.sql(f"""
        MERGE INTO {fqn} tgt
        USING ({deduped_source}) src
        ON tgt.{pk_col} = src.{pk_col}
        {delete_clause}
        WHEN MATCHED AND (tgt.Trust IS NULL OR tgt.Trust = '' OR tgt.Trust != src.new_trust)
        THEN UPDATE SET {", ".join(set_parts)}
    """)

    return set()


def determine_pk_column(table_name: str, cols: set) -> str:
    table_upper = table_name.upper()
    pk_patterns = {
        "MILL_ENCOUNTER": "ENCNTR_ID",
        "MILL_CLINICAL_EVENT": "EVENT_ID",
        "MILL_ORDERS": "ORDER_ID",
        "MILL_PROBLEM": "PROBLEM_ID",
        "MILL_SURGICAL_CASE": "SURG_CASE_ID",
        "MILL_SURG_CASE_PROCEDURE": "SURG_CASE_PROC_ID",
        "MILL_DCP_FORMS_ACTIVITY": "DCP_FORMS_ACTIVITY_ID",
        "MILL_EPISODE_ENCNTR_RELTN": "EPISODE_ID",
        "MILL_SCH_EVENT_PATIENT": "SCH_EVENT_ID",
        "MILL_SCH_SCHEDULE": "SCHEDULE_ID",
        "MILL_IM_STUDY": "IM_STUDY_ID",
        "MILL_IM_ACQUIRED_STUDY": "IM_ACQUIRED_STUDY_ID",
        "MILL_CV_PROC": "CV_PROC_ID",
    }
    if table_upper in pk_patterns and pk_patterns[table_upper] in cols:
        return pk_patterns[table_upper]
    for pk_candidate in ["EVENT_ID", "ENCNTR_ID", "ORDER_ID", "PROBLEM_ID"]:
        if pk_candidate in cols:
            return pk_candidate
    for key_col, _, _ in KEY_LOOKUPS:
        if key_col in cols:
            return key_col
    return None


def force_refresh_new_orgs(new_orgs):
    if not new_orgs:
        return
    print(f"    Forcing refresh for {len(new_orgs)} newly mapped organizations...")
    org_list = ",".join(str(x) for x in new_orgs)
    enc_fqn = f"{CATALOG}.{RAW_SCHEMA}.mill_encounter"
    cols = get_table_columns(enc_fqn)

    sql = f"""
        SELECT e.ENCNTR_ID, e.ENCNTR_ID as new_enc_id, e.ORGANIZATION_ID as new_org_id, tm.trust as new_trust
        FROM {enc_fqn} e
        JOIN {TRUST_MAP_TBL} tm ON e.ORGANIZATION_ID = tm.organization_id
        WHERE e.ORGANIZATION_ID IN ({org_list})
          AND (e.Trust IS NULL OR e.Trust = '' OR e.Trust != tm.trust)
    """
    can_delete = can_delete_bhrut("MILL_ENCOUNTER")
    execute_merge_deduped_v2(enc_fqn, sql, "ENCNTR_ID", can_delete, cols)

    spark.sql(f"""
        INSERT INTO {CHANGED_ENC_TBL}
        SELECT DISTINCT e.ENCNTR_ID, e.ORGANIZATION_ID, tm.trust, current_timestamp(), current_date()
        FROM {enc_fqn} e
        JOIN {TRUST_MAP_TBL} tm ON e.ORGANIZATION_ID = tm.organization_id
        WHERE e.ORGANIZATION_ID IN ({org_list})
    """)


def update_targets(cdc_tables, new_orgs):
    """
    Update target tables based on their own CDC changes.
    Returns dict of {table_name: set(processed_pks)} for Phase 2c deduplication.
    """
    print("Phase 2: Updating Target Tables (Direct CDC)...")
    force_refresh_new_orgs(new_orgs)

    processed_tables = {}

    for table_name, cols in cdc_tables.items():
        fqn = f"{CATALOG}.{RAW_SCHEMA}.{table_name}"

        sv, ev = get_cdc_window(fqn)

        if sv is None:
            log_status(table_name, "SKIP", f"Up to date (Checked v{ev})")
            continue

        pk_col = determine_pk_column(table_name, cols)
        if not pk_col:
            log_status(table_name, "SKIP", "Could not determine Primary Key")
            continue

        log_status(table_name, "PROC", f"Processing updates (v{sv} to v{ev})")

        def run_update(is_fallback):
            if is_fallback:
                last_ts = spark.sql(f"SELECT last_timestamp FROM {CONTROL_TBL} WHERE table_name = '{fqn}'").collect()
                ts_val = last_ts[0]['last_timestamp'] if last_ts else None
                time_filter = get_fallback_filter(fqn, ts_val)
                changes_sql = f"SELECT * FROM {fqn} WHERE {time_filter}"
            else:
                changes_sql = f"""
                    SELECT * FROM table_changes('{fqn}', {sv}, {ev})
                    WHERE _change_type IN ('insert', 'update_postimage')
                """

            resolution_sql = build_trust_resolution_sql_v2(table_name, cols, pk_col, changes_sql)
            if resolution_sql:
                can_delete = can_delete_bhrut(table_name.upper())
                pks = execute_merge_deduped_v2(fqn, resolution_sql, pk_col, can_delete, cols, processed_pks=set())
                return pks
            else:
                print(f"      [Info] No resolution logic available (Missing link columns)")
                return set()

        try:
            if sv == "FALLBACK":
                raise Exception("Force Fallback Mode")
            pks = run_update(is_fallback=False)
            processed_tables[table_name] = pks
            update_checkpoint(fqn, version=ev)
        except Exception as e:
            print(f"      [Runtime Fallback] {table_name} Update switched to Time-Window. Reason: {str(e)[:200]}")
            pks = run_update(is_fallback=True)
            processed_tables[table_name] = pks
            update_checkpoint(fqn, timestamp=datetime.now().isoformat())

    return processed_tables


# ==============================================================================
# 7. PHASE 2b: CHAINED DEPENDENCIES
# ==============================================================================
def propagate_chained_updates(processed_keys, all_targets):
    """
    Propagate updates through chained dependencies.
    Note: Updated to use all_targets since child tables may not have TRUST column
    but still need updates propagated.
    """
    print("Phase 2b: Propagating Chained Updates...")

    for parent_key, (child_table, join_col, child_pk) in CHAINED_DEPENDENCIES.items():
        if parent_key not in processed_keys:
            continue

        # Check if child table exists in raw schema (it may not have TRUST column)
        child_fqn = f"{CATALOG}.{RAW_SCHEMA}.{child_table}"
        cols = get_table_columns(child_fqn)

        if "TRUST" not in cols:
            print(f"    Skipping {child_table} - no TRUST column")
            continue

        print(f"    Triggering updates for {child_table} via {parent_key}")

        source_sql = f"""
            SELECT child.{child_pk}, h.ENCNTR_ID as new_enc_id,
                   eo.ORGANIZATION_ID as new_org_id, tm.trust as new_trust
            FROM {child_fqn} child
            JOIN hub_deduped h ON h.key_type = '{parent_key}' AND child.{join_col} = h.key_id
            JOIN enc_org_deduped eo ON h.ENCNTR_ID = eo.ENCNTR_ID
            JOIN trust_map tm ON eo.ORGANIZATION_ID = tm.organization_id
            WHERE child.Trust IS NULL OR child.Trust = '' OR child.Trust != tm.trust
        """
        try:
            can_delete = can_delete_bhrut(child_table.upper())
            execute_merge_deduped_v2(child_fqn, source_sql, child_pk, can_delete, cols)
        except Exception as e:
            print(f"      Error: {str(e)[:150]}")


# ==============================================================================
# 8. PHASE 2c: REVERSE PROPAGATION WITH VERSION-BASED SKIP OPTIMIZATION
# ==============================================================================
def propagate_encounter_changes(all_targets, processed_tables):
    """
    Propagate encounter changes to downstream tables.
    
    NEW: Uses version tracking to skip tables where:
    1. Our pipeline was the last to update the table (version unchanged)
    2. The hub/enc_org mappings haven't changed since our last update
    """
    print("Phase 2c: Propagating Encounter Changes to Downstream Tables...")

    try:
        changed_count = spark.sql(f"""
            SELECT COUNT(DISTINCT ENCNTR_ID) as cnt
            FROM {CHANGED_ENC_TBL}
            WHERE run_date = current_date()
        """).collect()[0]['cnt']
    except:
        changed_count = 0

    if changed_count == 0:
        print("    No encounter changes to propagate")
        return

    print(f"    Found {changed_count} changed encounters. Propagating to {len(all_targets)} potential tables.")

    # Get current hub versions for skip optimization
    current_hub_ver, current_enc_org_ver = get_hub_versions()
    print(f"    Current versions - Hub: v{current_hub_ver}, Enc-Org: v{current_enc_org_ver}")

    # Track statistics for skip optimization
    skipped_count = 0
    processed_count = 0

    # Create a BROADCASTED view of changed encounters for efficient joins
    spark.sql(f"""
        CREATE OR REPLACE TEMP VIEW changed_enc_broadcast AS
        SELECT /*+ BROADCAST(ce) */
            ce.ENCNTR_ID, ce.ORGANIZATION_ID, ce.trust as new_trust
        FROM {CHANGED_ENC_TBL} ce
        WHERE ce.run_date = current_date() AND ce.trust IS NOT NULL
    """)

    # Alias for consistency in loop
    spark.sql("CREATE OR REPLACE TEMP VIEW changed_enc_cached AS SELECT DISTINCT ENCNTR_ID, ORGANIZATION_ID, new_trust FROM changed_enc_broadcast")

    for table_name, cols in all_targets.items():
        if table_name == "mill_encounter":
            continue

        fqn = f"{CATALOG}.{RAW_SCHEMA}.{table_name}"
        pk_col = determine_pk_column(table_name, cols)
        if not pk_col:
            continue

        # =====================================================================
        # NEW: Version-based skip optimization
        # =====================================================================
        should_skip, skip_reason = should_skip_phase2c(
            table_name, fqn, current_hub_ver, current_enc_org_ver
        )
        
        if should_skip:
            log_status(table_name, "SKIP", f"Version check: {skip_reason}")
            skipped_count += 1
            continue
        # =====================================================================

        # Get PKs already processed in Phase 2 for this table
        already_processed = processed_tables.get(table_name, set())

        # Track if we actually updated anything (for version tracking)
        table_was_updated = False

        # Direct Link via ENCNTR_ID
        if "ENCNTR_ID" in cols:
            print(f"      Propagating to {table_name} (via ENCNTR_ID) - {skip_reason if not should_skip else ''}")
            try:
                # Build exclusion clause if we have processed PKs
                exclusion_clause = ""
                if already_processed and len(already_processed) < 10000:
                    # Only use IN clause for reasonable sizes
                    pk_list = ",".join(str(pk) for pk in already_processed)
                    exclusion_clause = f"AND t.{pk_col} NOT IN ({pk_list})"
                elif already_processed:
                    # For large sets, create a temp table
                    spark.createDataFrame(
                        [(pk,) for pk in already_processed],
                        [pk_col]
                    ).createOrReplaceTempView(f"processed_{table_name}")
                    exclusion_clause = f"AND t.{pk_col} NOT IN (SELECT {pk_col} FROM processed_{table_name})"

                source_sql = f"""
                    SELECT /*+ BROADCAST(ce) */
                        t.{pk_col}, ce.ENCNTR_ID as new_enc_id,
                        ce.ORGANIZATION_ID as new_org_id, ce.new_trust
                    FROM {fqn} t
                    JOIN changed_enc_cached ce ON t.ENCNTR_ID = ce.ENCNTR_ID
                    WHERE (t.Trust IS NULL OR t.Trust = '' OR t.Trust != ce.new_trust)
                    {exclusion_clause}
                """
                can_delete = can_delete_bhrut(table_name.upper())
                execute_merge_deduped_v2(fqn, source_sql, pk_col, can_delete, cols)
                table_was_updated = True
            except Exception as e:
                print(f"        Error: {str(e)[:150]}")

        # Key Link (if not direct ENCNTR_ID)
        else:
            key_cols_present = [k for k, _, _ in KEY_LOOKUPS if k in cols]
            if "EVENT_ID" in cols:
                key_cols_present.append("EVENT_ID")

            if not key_cols_present:
                continue

            print(f"      Propagating to {table_name} (via {', '.join(key_cols_present)})")
            try:
                # Build JOIN-based encounter resolution
                join_clauses = []
                enc_parts = []
                for idx, k in enumerate(key_cols_present):
                    alias = f"hk{idx}"
                    join_clauses.append(f"""
                        LEFT JOIN hub_deduped {alias}
                            ON {alias}.key_type = '{k}' AND {alias}.key_id = t.{k}
                    """)
                    enc_parts.append(f"{alias}.ENCNTR_ID")

                enc_coalesce = f"COALESCE({', '.join(enc_parts)})"
                joins_sql = "\n".join(join_clauses)

                # Build exclusion for already processed
                exclusion_clause = ""
                if already_processed and len(already_processed) < 10000:
                    pk_list = ",".join(str(pk) for pk in already_processed)
                    exclusion_clause = f"AND t.{pk_col} NOT IN ({pk_list})"

                source_sql = f"""
                    WITH resolved AS (
                        SELECT t.{pk_col}, {enc_coalesce} as resolved_enc
                        FROM {fqn} t
                        {joins_sql}
                        WHERE (t.Trust IS NULL OR t.Trust = '' OR t.Trust NOT IN ('Barts', 'BHRUT'))
                        {exclusion_clause}
                    )
                    SELECT /*+ BROADCAST(ce) */
                        r.{pk_col}, ce.ENCNTR_ID as new_enc_id,
                        ce.ORGANIZATION_ID as new_org_id, ce.new_trust
                    FROM resolved r
                    JOIN changed_enc_cached ce ON r.resolved_enc = ce.ENCNTR_ID
                """
                can_delete = can_delete_bhrut(table_name.upper())
                execute_merge_deduped_v2(fqn, source_sql, pk_col, can_delete, cols)
                table_was_updated = True
            except Exception as e:
                print(f"        Error: {str(e)[:150]}")

        # =====================================================================
        # NEW: Update version tracker after processing
        # =====================================================================
        if table_was_updated:
            try:
                update_phase2c_tracker(table_name, fqn, current_hub_ver, current_enc_org_ver)
            except Exception as e:
                print(f"        [Tracker Warning] Could not update version tracker: {str(e)[:50]}")
        
        processed_count += 1

    print(f"    Phase 2c Summary: {processed_count} tables processed, {skipped_count} tables skipped (version unchanged)")


# ==============================================================================
# 9. LOGGING & ALERTS
# ==============================================================================
def flag_unknown_organizations():
    print("Phase 3: Flagging Unknown Organizations...")
    try:
        enc_fqn = f"{CATALOG}.{RAW_SCHEMA}.mill_encounter"
        time_filter = get_fallback_filter(enc_fqn, None)
        if time_filter == "1=0":
            time_filter = "1=1"

        spark.sql(f"""
            INSERT INTO {FLAG_TBL}
            SELECT DISTINCT
                e.ORGANIZATION_ID,
                1 AS alert,
                current_timestamp() AS event_time,
                'MILL_ENCOUNTER' AS table_name,
                current_timestamp() AS first_seen_timestamp
            FROM {enc_fqn} e
            WHERE {time_filter}
              AND e.ORGANIZATION_ID IS NOT NULL
              AND NOT EXISTS (SELECT 1 FROM {TRUST_MAP_TBL} m WHERE m.organization_id = e.ORGANIZATION_ID)
              AND NOT EXISTS (
                  SELECT 1 FROM {FLAG_TBL} f
                  WHERE f.organization_id = e.ORGANIZATION_ID
                    AND f.event_time >= current_timestamp() - INTERVAL 7 DAYS
              )
        """)
        flagged = spark.sql(f"""
            SELECT COUNT(DISTINCT organization_id) as cnt
            FROM {FLAG_TBL}
            WHERE event_time >= current_timestamp() - INTERVAL 1 DAY
        """).collect()[0]['cnt']
        if flagged > 0:
            print(f"    Flagged {flagged} unknown organization IDs")
    except Exception as e:
        print(f"    Warning: Org Flagging failed: {str(e)[:150]}")


# ==============================================================================
# 10. WEEKLY DEEP CLEAN (OPTIMIZED)
# ==============================================================================
def run_weekly_clean(all_targets):
    """Run comprehensive cleanup using JOIN-based trust resolution."""
    print("Phase 4: Running Weekly Deep Clean...")

    def build_join_clauses(cols, table_alias):
        """Build JOIN clauses for a given table alias."""
        join_clauses = []
        trust_parts = []

        if "ORGANIZATION_ID" in cols:
            join_clauses.append(f"LEFT JOIN trust_map tm_direct ON {table_alias}.ORGANIZATION_ID = tm_direct.organization_id")
            trust_parts.append("tm_direct.trust")

        if "ENCNTR_ID" in cols:
            join_clauses.append(f"LEFT JOIN enc_trust_lookup etl_direct ON {table_alias}.ENCNTR_ID = etl_direct.ENCNTR_ID")
            trust_parts.append("etl_direct.trust")

        if "EVENT_ID" in cols:
            join_clauses.append(f"""
                LEFT JOIN hub_deduped h_event ON h_event.key_type = 'EVENT_ID' AND h_event.key_id = {table_alias}.EVENT_ID
                LEFT JOIN enc_trust_lookup etl_event ON h_event.ENCNTR_ID = etl_event.ENCNTR_ID
            """)
            trust_parts.append("etl_event.trust")

        for k, _, _ in KEY_LOOKUPS:
            if k in cols:
                safe_k = k.lower()
                join_clauses.append(f"""
                    LEFT JOIN hub_deduped h_{safe_k} ON h_{safe_k}.key_type = '{k}' AND h_{safe_k}.key_id = {table_alias}.{k}
                    LEFT JOIN enc_trust_lookup etl_{safe_k} ON h_{safe_k}.ENCNTR_ID = etl_{safe_k}.ENCNTR_ID
                """)
                trust_parts.append(f"etl_{safe_k}.trust")

        return join_clauses, trust_parts

    for table_name, cols in all_targets.items():
        fqn = f"{CATALOG}.{RAW_SCHEMA}.{table_name}"
        pk_col = determine_pk_column(table_name, cols)
        
        if not pk_col:
            print(f"    Skipping {table_name} - no PK column found")
            continue
            
        print(f"    Deep cleaning {table_name}...")

        try:
            # Build JOIN clauses using 't' as alias for main queries
            join_clauses, trust_parts = build_join_clauses(cols, "t")

            if not trust_parts:
                continue

            trust_coalesce = f"COALESCE({', '.join(trust_parts)})"
            joins_sql = "\n".join(join_clauses)

            has_adc = "ADC_UPDT" in cols
            time_filter = "t.ADC_UPDT >= current_date() - INTERVAL 90 DAYS" if has_adc else "1=1"
            adc_update = ", tgt.ADC_UPDT = current_timestamp()" if has_adc else ""

            # 1. BHRUT Deletion Logic - use a temp view + IN clause approach
            #    This avoids the problematic string replacement entirely
            if FILTER_BHRUT and can_delete_bhrut(table_name.upper()):
                # Create a temp view of PKs to delete
                temp_view_name = f"bhrut_delete_pks_{table_name.replace('.', '_')}"
                spark.sql(f"""
                    CREATE OR REPLACE TEMP VIEW {temp_view_name} AS
                    SELECT t.{pk_col}
                    FROM {fqn} t
                    {joins_sql}
                    WHERE t.Trust = 'BHRUT'
                      AND {trust_coalesce} = 'BHRUT'
                """)
                
                spark.sql(f"""
                    DELETE FROM {fqn}
                    WHERE {pk_col} IN (SELECT {pk_col} FROM {temp_view_name})
                """)

            # 2. Update Missing Trust using MERGE with proper PK matching
            spark.sql(f"""
                MERGE INTO {fqn} tgt
                USING (
                    SELECT t.{pk_col}, {trust_coalesce} as resolved_trust
                    FROM {fqn} t
                    {joins_sql}
                    WHERE (t.Trust IS NULL OR t.Trust = '')
                      AND {time_filter}
                      AND {trust_coalesce} IS NOT NULL
                ) src
                ON tgt.{pk_col} = src.{pk_col}
                WHEN MATCHED THEN UPDATE SET tgt.Trust = src.resolved_trust {adc_update}
            """)

        except Exception as e:
            print(f"      Error: {str(e)[:150]}")

    print("    Optimizing lookup tables...")
    try:
        spark.sql(f"OPTIMIZE {HUB_TBL} ZORDER BY (key_type, key_id)")
        spark.sql(f"OPTIMIZE {ENC_ORG_TBL} ZORDER BY (ENCNTR_ID)")
        spark.sql(f"VACUUM {HUB_TBL} RETAIN 168 HOURS")
        spark.sql(f"VACUUM {ENC_ORG_TBL} RETAIN 168 HOURS")
    except Exception as e:
        print(f"      Optimization warning: {str(e)[:100]}")
    
    # Clear version tracker on weekly clean to force full re-evaluation next run
    try:
        spark.sql(f"DELETE FROM {PHASE2C_TRACKER_TBL}")
        print("    Cleared Phase 2c version tracker for fresh start")
    except:
        pass

# ==============================================================================
# 11. SUMMARY REPORTING
# ==============================================================================
def print_summary(start_time):
    print("\n" + "=" * 60 + "\nRUN SUMMARY\n" + "=" * 60)

    try:
        enc_stats = spark.sql(f"""
            SELECT COUNT(DISTINCT ENCNTR_ID) as encounters_changed,
                   COUNT(DISTINCT ORGANIZATION_ID) as orgs_involved
            FROM {CHANGED_ENC_TBL}
            WHERE run_date = current_date()
        """).collect()[0]
        print(f"Encounters Changed: {enc_stats['encounters_changed']}")
        print(f"Organizations Involved: {enc_stats['orgs_involved']}")
    except:
        pass

    # Report from aggregated audit summary
    try:
        audit_stats = spark.sql(f"""
            SELECT trust, SUM(record_count) as total_records, COUNT(DISTINCT table_name) as tables_affected
            FROM {AUDIT_SUMMARY_TBL}
            WHERE run_timestamp >= timestamp('{RUN_TS.isoformat()}')
            GROUP BY trust
        """).collect()
        print("\nAudit Summary (Aggregated):")
        for r in audit_stats:
            print(f"  {r['trust']}: {r['total_records']} records across {r['tables_affected']} tables")
    except:
        pass

    # Report Phase 2c skip statistics
    try:
        tracker_count = spark.sql(f"SELECT COUNT(*) as cnt FROM {PHASE2C_TRACKER_TBL}").collect()[0]['cnt']
        print(f"\nPhase 2c Tracker: {tracker_count} tables tracked for version-based skipping")
    except:
        pass

    duration = time.time() - start_time
    print(f"\nMode: {'WEEKLY DEEP CLEAN' if IS_WEEKLY_RUN else 'INCREMENTAL'}")
    print(f"Total Duration: {duration:.2f}s")
    print("=" * 60)


# ==============================================================================
# MAIN EXECUTION
# ==============================================================================
if __name__ == "__main__":
    start_time = time.time()
    print("=" * 60 + f"\nTrust Assignment Pipeline v3 - {RUN_TS.strftime('%Y-%m-%d %H:%M:%S')}\n" + "=" * 60)

    time_op("Setup Infrastructure", ensure_setup)

    # 1. Discover tables
    cdc_tables, all_targets = time_op("Discover Table Scopes", get_table_scopes)
    if not all_targets:
        print("No eligible tables found.")
        raise SystemExit(0)

    # 2. Track new orgs
    newly_mapped_orgs = time_op("Track Newly Mapped Orgs", track_newly_mapped_orgs)

    # 3. Refresh Hub (Phase 1)
    mod_keys = time_op("Refresh Mapping Hub", refresh_hub)

    # 4. Materialize hub views for efficient joins (NEW)
    time_op("Materialize Hub Views", materialize_hub_views)

    # 5. Update Targets - Direct CDC (Phase 2) - now returns processed PKs
    processed_tables = time_op("Update Targets (Direct CDC)", lambda: update_targets(cdc_tables, newly_mapped_orgs))

    # 6. Chained Dependencies (Phase 2b) - uses all_targets now
    time_op("Propagate Chained Updates", lambda: propagate_chained_updates(mod_keys or set(), all_targets))

    # 7. Reverse Propagation (Phase 2c) - with version-based skip optimization
    time_op("Propagate Encounter Changes", lambda: propagate_encounter_changes(all_targets, processed_tables or {}))

    # 8. Flag Unknown Orgs (Phase 3)
    time_op("Flag Unknown Organizations", flag_unknown_organizations)

    # 9. Weekly Deep Clean (Phase 4)
    if IS_WEEKLY_RUN:
        time_op("Weekly Deep Clean", lambda: run_weekly_clean(all_targets))

    print_summary(start_time)