In [0]:
import re
import calendar
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, when, lit, current_timestamp, row_number,
    udf, struct, collect_list, collect_set, create_map, expr, 
    max as spark_max, array, coalesce
)
from pyspark.sql.window import Window
from pyspark.sql.types import StringType, StructType, StructField, LongType, ArrayType
import os

def is_valid_nhs_number(nhs_number):
    """Validate NHS number using checksum algorithm.
    Accepts digits possibly separated by spaces or dashes.
    """
    if not isinstance(nhs_number, str):
        return False
    # Remove spaces and dashes
    nhs_digits = re.sub(r'[\s-]', '', nhs_number)
    if not nhs_digits.isdigit() or len(nhs_digits) != 10:
        return False

    weights = [10, 9, 8, 7, 6, 5, 4, 3, 2]
    total = sum(int(digit) * weight for digit, weight in zip(nhs_digits[:9], weights))
    remainder = total % 11
    check_digit = 11 - remainder

    if check_digit == 11:
        check_digit = 0
    elif check_digit == 10:
        return False

    return check_digit == int(nhs_digits[9])

def _build_dob_patterns(dob_dt):
    """Build a comprehensive set of regex patterns for a given DOB."""
    if dob_dt is None:
        return []

    day = dob_dt.day
    month = dob_dt.month
    year = dob_dt.year

    d = str(day)
    dd = f"{day:02d}"
    m = str(month)
    mm = f"{month:02d}"
    yyyy = str(year)
    yy = f"{year % 100:02d}"

    month_full = calendar.month_name[month]
    month_abbr = calendar.month_abbr[month]
    months_regex = f"(?:{re.escape(month_full)}|{re.escape(month_abbr)})"

    sep = r"[.\-/\s]"
    ord_suffix = r"(?:st|nd|rd|th)?"

    patterns = [
        fr"(?<!\d){dd}{sep}{mm}{sep}{yyyy}(?!\d)",
        fr"(?<!\d){d}{sep}{m}{sep}{yyyy}(?!\d)",
        fr"(?<!\d){dd}{sep}{mm}{sep}{yy}(?!\d)",
        fr"(?<!\d){d}{sep}{m}{sep}{yy}(?!\d)",
        fr"(?<!\d){yyyy}{sep}{mm}{sep}{dd}(?!\d)",
        fr"(?<!\d){yyyy}{sep}{m}{sep}{d}(?!\d)",
        fr"(?<!\d){dd}{mm}{yyyy}(?!\d)",
        fr"(?<!\d){yyyy}{mm}{dd}(?!\d)",
        fr"(?<!\d){dd}{mm}{yy}(?!\d)",
    ]

    patterns += [
        fr"\b{d}{ord_suffix}{sep}+{months_regex}{sep}+{yyyy}\b",
        fr"\b{dd}{ord_suffix}{sep}+{months_regex}{sep}+{yyyy}\b",
        fr"\b{d}{ord_suffix}\s+(?:of\s+)?{months_regex}{sep}+{yyyy}\b",
        fr"\b{dd}{ord_suffix}\s+(?:of\s+)?{months_regex}{sep}+{yyyy}\b",
        fr"\b{months_regex}{sep}+{d}{ord_suffix}{sep}+{yyyy}\b",
        fr"\b{months_regex}{sep}+{dd}{ord_suffix}{sep}+{yyyy}\b",
        fr"\b{d}{ord_suffix}{sep}+{months_regex}{sep}+{yy}\b",
        fr"\b{dd}{ord_suffix}{sep}+{months_regex}{sep}+{yy}\b",
        fr"\b{months_regex}{sep}+{d}{ord_suffix}{sep}+{yy}\b",
        fr"\b{months_regex}{sep}+{dd}{ord_suffix}{sep}+{yy}\b",
    ]

    patterns += [
        fr"\b{d}{ord_suffix}\s+(?:of\s+)?{months_regex}[,\s\-]+{yyyy}\b",
        fr"\b{months_regex}[,\s\-]+{d}{ord_suffix}[,\s\-]+{yyyy}\b",
        fr"\b{d}{ord_suffix}\s+(?:of\s+)?{months_regex}[,\s\-]+{yy}\b",
        fr"\b{months_regex}[,\s\-]+{d}{ord_suffix}[,\s\-]+{yy}\b",
    ]

    return patterns

def _redact_dob(text, dob_dt):
    """Redact all common renderings of the given DOB from the text."""
    if not dob_dt:
        return text

    patterns = _build_dob_patterns(dob_dt)
    for p in patterns:
        text = re.sub(p, "[[DATE OF BIRTH]]", text, flags=re.IGNORECASE)

    text = re.sub(r"(\[\[DATE OF BIRTH\]\])\s+\d{1,2}:\d{2}(?::\d{2})?", r"\1", text)

    return text

def simple_phi_redaction(text, first_names=None, middle_names=None, last_names=None, 
                        dob=None, addresses=None, aliases=None, whitelist=None):
    """
    Simplified PHI redaction without spacy dependencies.
    Now accepts lists/arrays for names, addresses, and aliases.
    """
    if text is None or text == '':
        return text

    if whitelist is None:
        whitelist = []

    whitelist = [word.lower() for word in whitelist]

    def replace_nhs_number(match):
        raw = match.group()
        if is_valid_nhs_number(raw):
            return "[[NHS Number]]"
        return raw

    text = re.sub(r'(?<!\d)(?:\d[ -]?){9}\d(?!\d)', replace_nhs_number, text)

    # Replace aliases
    if aliases:
        for alias in aliases:
            if alias and len(alias) > 2:
                text = re.sub(r'\b' + re.escape(alias) + r'\b',
                              "[[PATIENT IDENTIFIER]]", text, flags=re.IGNORECASE)

    # Replace first names
    if first_names:
        for name in first_names:
            if name and len(name) > 2:
                text = re.sub(r'\b' + re.escape(name) + r'\b',
                              "[[PATIENT FORENAME]]", text, flags=re.IGNORECASE)

    # Replace middle names
    if middle_names:
        for name in middle_names:
            if name and name.strip() and len(name) > 2:
                text = re.sub(r'\b' + re.escape(name) + r'\b',
                              "[[PATIENT MIDDLE NAME]]", text, flags=re.IGNORECASE)

    # Replace surnames
    if last_names:
        for name in last_names:
            if name and len(name) > 2:
                text = re.sub(r'\b' + re.escape(name) + r'\b',
                              "[[PATIENT SURNAME]]", text, flags=re.IGNORECASE)

    # Redact DOB
    text = _redact_dob(text, dob)

    # Replace address components
    if addresses:
        for addr in addresses:
            if not addr:
                continue
            # Each address is expected to be a dict/Row with fields
            for i, field in enumerate(['STREET_ADDR', 'STREET_ADDR2', 'STREET_ADDR3', 'STREET_ADDR4'], 1):
                addr_val = addr.get(field) if isinstance(addr, dict) else getattr(addr, field, None)
                if addr_val and len(str(addr_val)) > 2:
                    addr_text = str(addr_val).strip()
                    if len(addr_text) > 2:
                        text = re.sub(r'\b' + re.escape(addr_text) + r'\b',
                                      f"[[STREET ADDRESS {i}]]", text, flags=re.IGNORECASE)

            for field, placeholder in [('CITY', '[[CITY]]'), ('COUNTY', '[[COUNTY]]'), 
                                      ('STATE', '[[STATE]]'), ('COUNTRY', '[[COUNTRY]]'),
                                      ('ZIPCODE', '[[POSTCODE]]'), ('POSTAL_IDENTIFIER', '[[POSTAL IDENTIFIER]]')]:
                val = addr.get(field) if isinstance(addr, dict) else getattr(addr, field, None)
                if val and len(str(val)) > 2:
                    text = re.sub(r'\b' + re.escape(str(val)) + r'\b',
                                  placeholder, text, flags=re.IGNORECASE)

    return text

def get_eligible_person_ids(limit=100000, batch_size=10000):
    """
    Get the first N person_ids that have non-anonymized blobs eligible for anonymization.
    
    Args:
        limit: Maximum number of person_ids to process (default 100,000)
        batch_size: Number of records to process at a time to avoid memory issues
    
    Returns:
        List of person_ids that need anonymization
    """
    print(f"Finding first {limit:,} person_ids with non-anonymized blobs...")
    
    window = Window.partitionBy("event_id").orderBy(
        col("valid_until_dt_tm").desc(),
        col("updt_dt_tm").desc()
    )
    
    eligible_blobs = spark.table("4_prod.bronze.mill_blob_text") \
        .filter(col("STATUS") == "Decoded") \
        .filter((col("anon_text").isNull()) | (col("anon_text") == "")) \
        .filter(col("BLOB_TEXT").isNotNull()) \
        .filter(col("BLOB_TEXT") != "") \
        .withColumn("row", row_number().over(window)) \
        .filter(col("row") == 1) \
        .drop("row") \
        .select("EVENT_ID")
    
    person_ids_with_blobs = eligible_blobs \
        .join(
            spark.table("4_prod.raw.mill_clinical_event")
                .filter(col("VALID_UNTIL_DT_TM") > current_timestamp())
                .select("EVENT_ID", "ENCNTR_ID"),
            on="EVENT_ID",
            how="inner"
        ) \
        .join(
            spark.table("4_prod.raw.mill_encounter")
                .select("ENCNTR_ID", "PERSON_ID"),
            on="ENCNTR_ID",
            how="inner"
        ) \
        .select("PERSON_ID") \
        .distinct() \
        .limit(limit)
    
    person_ids = [row.PERSON_ID for row in person_ids_with_blobs.collect()]
    
    print(f"Found {len(person_ids):,} eligible person_ids for anonymization")
    
    return person_ids

def update_blob_text_for_persons(person_ids=None, whitelist=None, limit=100000):
    """
    Main function to update blob text for person IDs.
    Spark Connect compatible - uses DataFrame operations instead of broadcast variables.
    
    Args:
        person_ids: List of person_ids to process (optional)
        whitelist: List of words to exclude from anonymization
        limit: If person_ids is None, number of eligible person_ids to find (default 100,000)
    """
    if whitelist is None:
        whitelist = ['Lady', 'Barts', 'Bartshealth', 'Newham', 'Homerton', 'Hospital']
    
    
    if person_ids is None:
        person_ids = get_eligible_person_ids(limit=limit)
        
    if not person_ids:
        print("No eligible person_ids found for anonymization")
        return 0
    
    print(f"Processing {len(person_ids):,} person IDs...")

    # Step 1: Aggregate patient names - collect all variants per person
    print("Fetching patient names...")
    patient_names_agg = spark.table("4_prod.raw.mill_person_name") \
        .filter(col("PERSON_ID").isin(person_ids)) \
        .filter(
            (col("NAME_FIRST").isNotNull() & (col("NAME_FIRST") != "")) |
            (col("NAME_MIDDLE").isNotNull() & (col("NAME_MIDDLE") != "")) |
            (col("NAME_LAST").isNotNull() & (col("NAME_LAST") != ""))
        ) \
        .groupBy("PERSON_ID") \
        .agg(
            collect_set(when(col("NAME_FIRST").isNotNull(), col("NAME_FIRST"))).alias("first_names"),
            collect_set(when(col("NAME_MIDDLE").isNotNull(), col("NAME_MIDDLE"))).alias("middle_names"),
            collect_set(when(col("NAME_LAST").isNotNull(), col("NAME_LAST"))).alias("last_names")
        )

    # Step 2: Get patient DOBs
    print("Fetching patient DOBs...")
    patient_dobs = spark.table("4_prod.raw.mill_person") \
        .filter(col("PERSON_ID").isin(person_ids)) \
        .select("PERSON_ID", col("BIRTH_DT_TM").alias("dob"))

    # Step 3: Aggregate addresses per person
    print("Fetching patient addresses...")
    patient_addresses_agg = spark.table("4_prod.raw.mill_address") \
        .filter(col("PARENT_ENTITY_NAME") == "PERSON") \
        .filter(col("PARENT_ENTITY_ID").isin(person_ids)) \
        .filter(col("ACTIVE_IND") == 1) \
        .select(
            col("PARENT_ENTITY_ID").alias("PERSON_ID"),
            struct(
                "STREET_ADDR", "STREET_ADDR2", "STREET_ADDR3", "STREET_ADDR4",
                "CITY", "COUNTY", "STATE", "COUNTRY", "ZIPCODE", "POSTAL_IDENTIFIER"
            ).alias("address")
        ) \
        .groupBy("PERSON_ID") \
        .agg(collect_list("address").alias("addresses"))

    # Step 4: Aggregate aliases per person
    print("Fetching patient aliases...")
    patient_aliases_agg = spark.table("4_prod.raw.mill_person_alias") \
        .filter(col("PERSON_ID").isin(person_ids)) \
        .filter(col("ACTIVE_IND") == 1) \
        .filter(col("ALIAS").isNotNull()) \
        .filter(col("ALIAS") != "") \
        .groupBy("PERSON_ID") \
        .agg(collect_set("ALIAS").alias("aliases"))

    # Step 5: Get encounter information
    print("Getting encounter mappings...")
    encounter_df = spark.table("4_prod.raw.mill_clinical_event") \
        .filter(col("VALID_UNTIL_DT_TM") > current_timestamp()) \
        .select("EVENT_ID", "ENCNTR_ID") \
        .join(
            spark.table("4_prod.raw.mill_encounter")
                .filter(col("PERSON_ID").isin(person_ids))
                .select("ENCNTR_ID", "PERSON_ID"),
            on="ENCNTR_ID",
            how="inner"
        ) \
        .select("EVENT_ID", "PERSON_ID") \
        .distinct()

    print("Identifying rows to update...")
    event_ids_df = encounter_df.select("EVENT_ID").distinct()


    window = Window.partitionBy("EVENT_ID").orderBy(
        col("VALID_UNTIL_DT_TM").desc(),
        col("UPDT_DT_TM").desc()
    )

    latest_meta = (
        spark.table("4_prod.bronze.mill_blob_text")
            .filter(col("STATUS") == "Decoded")
            .filter((col("anon_text").isNull()) | (col("anon_text") == ""))
            .join(event_ids_df, on="EVENT_ID", how="inner")  # Pre-filter to our batch
            .select("EVENT_ID", "VALID_UNTIL_DT_TM", "UPDT_DT_TM") 
            .withColumn("row", row_number().over(window))
            .filter(col("row") == 1)
            .drop("row")
    )

    window_person = Window.partitionBy("EVENT_ID").orderBy(col("PERSON_ID"))
    
    rows_to_update = (
        latest_meta
            .join(
                spark.table("4_prod.bronze.mill_blob_text")
                    .select("EVENT_ID", "VALID_UNTIL_DT_TM", "UPDT_DT_TM", "BLOB_TEXT"),
                on=["EVENT_ID", "VALID_UNTIL_DT_TM", "UPDT_DT_TM"],
                how="inner"
            )
            .join(encounter_df, on="EVENT_ID", how="inner")
            .withColumn("row_num", row_number().over(window_person))
            .filter(col("row_num") == 1)
            .drop("row_num")
            .select("EVENT_ID", "PERSON_ID", "BLOB_TEXT")
    )

    # Join all patient info to rows_to_update
    print("Joining patient information...")
    rows_with_info = rows_to_update \
        .join(patient_names_agg, on="PERSON_ID", how="left") \
        .join(patient_dobs, on="PERSON_ID", how="left") \
        .join(patient_addresses_agg, on="PERSON_ID", how="left") \
        .join(patient_aliases_agg, on="PERSON_ID", how="left") \
        .repartition(2000, col("EVENT_ID"))  # Keep partitions small

    row_count = rows_with_info.count()
    print(f"Found {row_count:,} rows to anonymize")

    if row_count == 0:
        print("No rows to update")
        return 0

    # Step 6: Create UDF that accepts arrays
    def anonymize_udf_impl(blob_text, first_names, middle_names, last_names, 
                          dob, addresses, aliases):
        if blob_text is None or blob_text == '':
            return blob_text
        return simple_phi_redaction(
            blob_text, 
            first_names or [], 
            middle_names or [], 
            last_names or [], 
            dob,
            addresses or [], 
            aliases or [], 
            whitelist
        )

    anonymize_text_udf = udf(anonymize_udf_impl, StringType())

    # Step 7: Apply anonymization
    print("Applying anonymization...")
    anonymized_df = rows_with_info \
        .withColumn(
            "anon_text", 
            anonymize_text_udf(
                col("BLOB_TEXT"), 
                col("first_names"), 
                col("middle_names"), 
                col("last_names"),
                col("dob"),
                col("addresses"), 
                col("aliases")
            )
        ) \
        .select("EVENT_ID", "anon_text") \
        .dropDuplicates(["EVENT_ID"])  # Ensure one row per EVENT_ID for MERGE

    # Step 8: Update the original table
    print("Updating the table...")
    temp_table_name = f"temp_updates_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    anonymized_df.write.mode("overwrite").saveAsTable(f"4_prod.bronze.{temp_table_name}")

    try:
        merge_query = f"""
        MERGE INTO 4_prod.bronze.mill_blob_text AS target
        USING 4_prod.bronze.{temp_table_name} AS source
        ON target.EVENT_ID = source.EVENT_ID
           AND target.STATUS = 'Decoded'
           AND (target.anon_text IS NULL OR target.anon_text = '')
        WHEN MATCHED THEN
            UPDATE SET anon_text = source.anon_text
        """
        spark.sql(merge_query)
        print("Update completed successfully!")
    finally:
        spark.sql(f"DROP TABLE IF EXISTS 4_prod.bronze.{temp_table_name}")

    # Return summary statistics
    updated_count = spark.sql(f"""
        SELECT COUNT(DISTINCT EVENT_ID) as count
        FROM 4_prod.bronze.mill_blob_text
        WHERE anon_text IS NOT NULL
          AND EVENT_ID IN (
            SELECT EVENT_ID FROM 4_prod.bronze.mill_blob_text WHERE STATUS = 'Decoded'
          )
    """).collect()[0]['count']

    print(f"Successfully updated {updated_count:,} rows")

    # Print summary - collect aggregated stats
    print("\nAnonymization summary:")
    stats = rows_with_info.select(
        col("PERSON_ID"),
        expr("size(coalesce(first_names, array()))").alias("first_count"),
        expr("size(coalesce(middle_names, array()))").alias("middle_count"),
        expr("size(coalesce(last_names, array()))").alias("last_count"),
        expr("size(coalesce(addresses, array()))").alias("address_count"),
        expr("size(coalesce(aliases, array()))").alias("alias_count")
    ).agg(
        expr("count(distinct PERSON_ID)").alias("patient_count"),
        expr("sum(first_count)").alias("total_first"),
        expr("sum(middle_count)").alias("total_middle"),
        expr("sum(last_count)").alias("total_last"),
        expr("sum(address_count)").alias("total_addresses"),
        expr("sum(alias_count)").alias("total_aliases")
    ).collect()[0]
    
    print(f"- Patients processed: {stats['patient_count']:,}")
    print(f"- Name variants found:")
    print(f"  - First names: {stats['total_first']:,}")
    print(f"  - Middle names: {stats['total_middle']:,}")
    print(f"  - Last names: {stats['total_last']:,}")
    print(f"- Addresses processed: {stats['total_addresses']:,}")
    print(f"- Aliases processed: {stats['total_aliases']:,}")

    return updated_count

if __name__ == "__main__":
    # Process batches of eligible person_ids automatically
    try:
        update_blob_text_for_persons(limit=100000)
    except Exception as e:
        print(f"Error processing first 100,000: {e}")
    
    try:
        update_blob_text_for_persons(limit=100000)
    except Exception as e:
        print(f"Error processing second 100,000: {e}")

    