In [0]:
import re
import calendar
from datetime import datetime
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, when, lit, current_timestamp, row_number,
    udf, struct, collect_list, create_map, expr
)
from pyspark.sql.window import Window
from pyspark.sql.types import StringType, StructType, StructField, LongType
import os

def is_valid_nhs_number(nhs_number):
    """Validate NHS number using checksum algorithm.
    Accepts digits possibly separated by spaces or dashes.
    """
    if not isinstance(nhs_number, str):
        return False
    # Remove spaces and dashes
    nhs_digits = re.sub(r'[\s-]', '', nhs_number)
    if not nhs_digits.isdigit() or len(nhs_digits) != 10:
        return False

    weights = [10, 9, 8, 7, 6, 5, 4, 3, 2]
    total = sum(int(digit) * weight for digit, weight in zip(nhs_digits[:9], weights))
    remainder = total % 11
    check_digit = 11 - remainder

    if check_digit == 11:
        check_digit = 0
    elif check_digit == 10:
        return False

    return check_digit == int(nhs_digits[9])

def _build_dob_patterns(dob_dt):
    """Build a comprehensive set of regex patterns for a given DOB."""
    if dob_dt is None:
        return []

    day = dob_dt.day
    month = dob_dt.month
    year = dob_dt.year

    d = str(day)
    dd = f"{day:02d}"
    m = str(month)
    mm = f"{month:02d}"
    yyyy = str(year)
    yy = f"{year % 100:02d}"

    month_full = calendar.month_name[month]   # e.g., "January"
    month_abbr = calendar.month_abbr[month]   # e.g., "Jan"
    months_regex = f"(?:{re.escape(month_full)}|{re.escape(month_abbr)})"

    # Common separators (space, slash, dash, dot)
    sep = r"[.\-/\s]"
    # Ordinal suffix for day (st, nd, rd, th)
    ord_suffix = r"(?:st|nd|rd|th)?"

    # Numeric date formats
    patterns = [
        # dd/mm/yyyy, d/m/yyyy (with various separators)
        fr"(?<!\d){dd}{sep}{mm}{sep}{yyyy}(?!\d)",
        fr"(?<!\d){d}{sep}{m}{sep}{yyyy}(?!\d)",
        # dd/mm/yy, d/m/yy
        fr"(?<!\d){dd}{sep}{mm}{sep}{yy}(?!\d)",
        fr"(?<!\d){d}{sep}{m}{sep}{yy}(?!\d)",
        # yyyy/mm/dd
        fr"(?<!\d){yyyy}{sep}{mm}{sep}{dd}(?!\d)",
        fr"(?<!\d){yyyy}{sep}{m}{sep}{d}(?!\d)",
        # Compact numeric forms (ddmmyyyy, yyyymmdd, ddmmyy)
        fr"(?<!\d){dd}{mm}{yyyy}(?!\d)",
        fr"(?<!\d){yyyy}{mm}{dd}(?!\d)",
        fr"(?<!\d){dd}{mm}{yy}(?!\d)",
    ]

    # Month name formats:
    # d[st|nd|rd|th]? <sep|' of '> Month <sep> yyyy
    patterns += [
        # Day Month Year (allow hyphen/space/slash/dot between day-month and month-year)
        fr"\b{d}{ord_suffix}{sep}+{months_regex}{sep}+{yyyy}\b",
        fr"\b{dd}{ord_suffix}{sep}+{months_regex}{sep}+{yyyy}\b",
        # Day of Month Year
        fr"\b{d}{ord_suffix}\s+(?:of\s+)?{months_regex}{sep}+{yyyy}\b",
        fr"\b{dd}{ord_suffix}\s+(?:of\s+)?{months_regex}{sep}+{yyyy}\b",

        # Month Day Year
        fr"\b{months_regex}{sep}+{d}{ord_suffix}{sep}+{yyyy}\b",
        fr"\b{months_regex}{sep}+{dd}{ord_suffix}{sep}+{yyyy}\b",

        # 2-digit year variants
        fr"\b{d}{ord_suffix}{sep}+{months_regex}{sep}+{yy}\b",
        fr"\b{dd}{ord_suffix}{sep}+{months_regex}{sep}+{yy}\b",
        fr"\b{months_regex}{sep}+{d}{ord_suffix}{sep}+{yy}\b",
        fr"\b{months_regex}{sep}+{dd}{ord_suffix}{sep}+{yy}\b",
    ]

    # Also allow comma between month and year/day
    patterns += [
        fr"\b{d}{ord_suffix}\s+(?:of\s+)?{months_regex}[,\s\-]+{yyyy}\b",
        fr"\b{months_regex}[,\s\-]+{d}{ord_suffix}[,\s\-]+{yyyy}\b",
        fr"\b{d}{ord_suffix}\s+(?:of\s+)?{months_regex}[,\s\-]+{yy}\b",
        fr"\b{months_regex}[,\s\-]+{d}{ord_suffix}[,\s\-]+{yy}\b",
    ]

    return patterns

def _redact_dob(text, dob_dt):
    """Redact all common renderings of the given DOB from the text."""
    if not dob_dt:
        return text

    patterns = _build_dob_patterns(dob_dt)
    for p in patterns:
        text = re.sub(p, "[[DATE OF BIRTH]]", text, flags=re.IGNORECASE)

    # Optionally remove trailing time-of-day immediately after a DOB that just got redacted.
    text = re.sub(r"(\[\[DATE OF BIRTH\]\])\s+\d{1,2}:\d{2}(?::\d{2})?", r"\1", text)

    return text

def simple_phi_redaction(text, patient_info=None, address_info=None, alias_info=None, whitelist=None):
    """
    Simplified PHI redaction without spacy dependencies.
    Now includes: NHS (with spaces/dashes), patient aliases, address anonymization, and DOB redaction.
    Handles multiple name variants for the same patient.
    """
    if text is None or text == '':
        return text

    if whitelist is None:
        whitelist = []

    # Convert whitelist to lowercase for case-insensitive comparison
    whitelist = [word.lower() for word in whitelist]

    # First, find and replace all valid NHS numbers (allowing spaces or dashes)
    def replace_nhs_number(match):
        raw = match.group()
        if is_valid_nhs_number(raw):
            return "[[NHS Number]]"
        return raw

    # Match 10 digits possibly separated by spaces or dashes (exactly 10 digits total)
    text = re.sub(r'(?<!\d)(?:\d[ -]?){9}\d(?!\d)', replace_nhs_number, text)

    # Replace patient aliases if info is provided
    if alias_info:
        for alias in alias_info:
            if alias and len(alias) > 2:
                text = re.sub(r'\b' + re.escape(alias) + r'\b',
                              "[[PATIENT IDENTIFIER]]", text, flags=re.IGNORECASE)

    # Replace patient names if info is provided - NOW HANDLES MULTIPLE VARIANTS
    if patient_info:
        # Replace forename(s) - handle both single strings and sets
        forenames = patient_info.get('NAME_FIRST', set())
        if isinstance(forenames, str):
            forenames = {forenames}  # Convert single string to set
        for name in forenames:
            if name and len(name) > 2:
                text = re.sub(r'\b' + re.escape(name) + r'\b',
                              "[[PATIENT FORENAME]]", text, flags=re.IGNORECASE)

        # Replace middle name(s)
        middle_names = patient_info.get('NAME_MIDDLE', set())
        if isinstance(middle_names, str):
            middle_names = {middle_names}
        for name in middle_names:
            if name and name.strip() and len(name) > 2:
                text = re.sub(r'\b' + re.escape(name) + r'\b',
                              "[[PATIENT MIDDLE NAME]]", text, flags=re.IGNORECASE)

        # Replace surname(s)
        surnames = patient_info.get('NAME_LAST', set())
        if isinstance(surnames, str):
            surnames = {surnames}
        for name in surnames:
            if name and len(name) > 2:
                text = re.sub(r'\b' + re.escape(name) + r'\b',
                              "[[PATIENT SURNAME]]", text, flags=re.IGNORECASE)

        # Redact DOB
        dob = patient_info.get('BIRTH_DT_TM')
        text = _redact_dob(text, dob)

    # Replace address components if info is provided
    if address_info:
        for addr in address_info:
            # Replace street addresses
            for i, field in enumerate(['STREET_ADDR', 'STREET_ADDR2', 'STREET_ADDR3', 'STREET_ADDR4'], 1):
                if addr.get(field) and len(addr[field]) > 2:
                    addr_text = addr[field].strip()
                    if len(addr_text) > 2:
                        text = re.sub(r'\b' + re.escape(addr_text) + r'\b',
                                      f"[[STREET ADDRESS {i}]]", text, flags=re.IGNORECASE)

            # Replace city
            if addr.get('CITY') and len(addr['CITY']) > 2:
                text = re.sub(r'\b' + re.escape(addr['CITY']) + r'\b',
                              "[[CITY]]", text, flags=re.IGNORECASE)

            # Replace county
            if addr.get('COUNTY') and len(addr['COUNTY']) > 2:
                text = re.sub(r'\b' + re.escape(addr['COUNTY']) + r'\b',
                              "[[COUNTY]]", text, flags=re.IGNORECASE)

            # Replace state
            if addr.get('STATE') and len(addr['STATE']) > 2:
                text = re.sub(r'\b' + re.escape(addr['STATE']) + r'\b',
                              "[[STATE]]", text, flags=re.IGNORECASE)

            # Replace country
            if addr.get('COUNTRY') and len(addr['COUNTRY']) > 2:
                text = re.sub(r'\b' + re.escape(addr['COUNTRY']) + r'\b',
                              "[[COUNTRY]]", text, flags=re.IGNORECASE)

            # Replace postal code
            if addr.get('ZIPCODE') and len(addr['ZIPCODE']) > 2:
                text = re.sub(r'\b' + re.escape(addr['ZIPCODE']) + r'\b',
                              "[[POSTCODE]]", text, flags=re.IGNORECASE)

            # Replace postal identifier
            if addr.get('POSTAL_IDENTIFIER') and len(addr['POSTAL_IDENTIFIER']) > 2:
                text = re.sub(r'\b' + re.escape(addr['POSTAL_IDENTIFIER']) + r'\b',
                              "[[POSTAL IDENTIFIER]]", text, flags=re.IGNORECASE)

    return text

def get_eligible_person_ids(limit=100000, batch_size=10000):
    """
    Get the first N person_ids that have non-anonymized blobs eligible for anonymization.
    
    Args:
        limit: Maximum number of person_ids to process (default 100,000)
        batch_size: Number of records to process at a time to avoid memory issues
    
    Returns:
        List of person_ids that need anonymization
    """
    print(f"Finding first {limit:,} person_ids with non-anonymized blobs...")
    
    # Window function to get the most recent blob per event
    window = Window.partitionBy("event_id").orderBy(
        col("valid_until_dt_tm").desc(),
        col("updt_dt_tm").desc()
    )
    
    # Find eligible blobs that need anonymization
    eligible_blobs = spark.table("4_prod.bronze.mill_blob_text") \
        .filter(col("STATUS") == "Decoded") \
        .filter((col("anon_text").isNull()) | (col("anon_text") == "")) \
        .filter(col("BLOB_TEXT").isNotNull()) \
        .filter(col("BLOB_TEXT") != "") \
        .withColumn("row", row_number().over(window)) \
        .filter(col("row") == 1) \
        .drop("row") \
        .select("EVENT_ID")
    
    # Join with clinical events and encounters to get person_ids
    person_ids_with_blobs = eligible_blobs \
        .join(
            spark.table("4_prod.raw.mill_clinical_event")
                .filter(col("VALID_UNTIL_DT_TM") > current_timestamp())
                .select("EVENT_ID", "ENCNTR_ID"),
            on="EVENT_ID",
            how="inner"
        ) \
        .join(
            spark.table("4_prod.raw.mill_encounter")
                .select("ENCNTR_ID", "PERSON_ID"),
            on="ENCNTR_ID",
            how="inner"
        ) \
        .select("PERSON_ID") \
        .distinct() \
        .limit(limit)
    
    # Collect the person_ids
    person_ids = [row.PERSON_ID for row in person_ids_with_blobs.collect()]
    
    print(f"Found {len(person_ids):,} eligible person_ids for anonymization")
    
    return person_ids

def update_blob_text_for_persons(person_ids=None, whitelist=None, limit=100000):
    """
    Main function to update blob text for person IDs.
    If person_ids is None, it will find the first N eligible person_ids automatically.
    
    Args:
        person_ids: List of person_ids to process (optional)
        whitelist: List of words to exclude from anonymization
        limit: If person_ids is None, number of eligible person_ids to find (default 100,000)
    """
    if whitelist is None:
        whitelist = ['Lady', 'Barts', 'Bartshealth', 'Newham', 'Homerton', 'Hospital']
    
    # If no person_ids provided, find eligible ones
    if person_ids is None:
        person_ids = get_eligible_person_ids(limit=limit)
        
    if not person_ids:
        print("No eligible person_ids found for anonymization")
        return 0
    
    print(f"Processing {len(person_ids):,} person IDs...")

    # Step 1: Get ALL patient name variants for all person_ids
    print("Fetching patient names...")
    patient_names_df = spark.table("4_prod.raw.mill_person_name") \
        .filter(col("PERSON_ID").isin(person_ids)) \
        .select("PERSON_ID", "NAME_FIRST", "NAME_MIDDLE", "NAME_LAST") \
        .collect()

    # Create a dictionary that stores ALL name variants for each person
    patient_names_dict = {}
    for row in patient_names_df:
        person_id = row['PERSON_ID']
        if person_id not in patient_names_dict:
            patient_names_dict[person_id] = {
                'PERSON_ID': person_id,
                'NAME_FIRST': set(),
                'NAME_MIDDLE': set(),
                'NAME_LAST': set()
            }
        
        # Add each name variant to a set (avoiding duplicates and handling nulls)
        if row['NAME_FIRST'] and row['NAME_FIRST'].strip():
            patient_names_dict[person_id]['NAME_FIRST'].add(row['NAME_FIRST'].strip())
        if row['NAME_MIDDLE'] and row['NAME_MIDDLE'].strip():
            patient_names_dict[person_id]['NAME_MIDDLE'].add(row['NAME_MIDDLE'].strip())
        if row['NAME_LAST'] and row['NAME_LAST'].strip():
            patient_names_dict[person_id]['NAME_LAST'].add(row['NAME_LAST'].strip())

    # Step 1b: Get patient DOBs from mill_person and add to patient_names_dict
    print("Fetching patient DOBs...")
    dob_rows = spark.table("4_prod.raw.mill_person") \
        .filter(col("PERSON_ID").isin(person_ids)) \
        .select("PERSON_ID", "BIRTH_DT_TM") \
        .collect()
    
    # Add DOB to patient info dict
    for row in dob_rows:
        pid = row['PERSON_ID']
        if pid not in patient_names_dict:
            patient_names_dict[pid] = {
                'PERSON_ID': pid,
                'NAME_FIRST': set(),
                'NAME_MIDDLE': set(),
                'NAME_LAST': set()
            }
        patient_names_dict[pid]['BIRTH_DT_TM'] = row['BIRTH_DT_TM']

    # Step 2: Get addresses for all person_ids
    print("Fetching patient addresses...")
    addresses_df = spark.table("4_prod.raw.mill_address") \
        .filter(col("PARENT_ENTITY_NAME") == "PERSON") \
        .filter(col("PARENT_ENTITY_ID").isin(person_ids)) \
        .filter(col("ACTIVE_IND") == 1) \
        .select("PARENT_ENTITY_ID", "STREET_ADDR", "STREET_ADDR2", "STREET_ADDR3",
                "STREET_ADDR4", "CITY", "COUNTY", "STATE", "COUNTRY",
                "ZIPCODE", "POSTAL_IDENTIFIER") \
        .collect()

    # Create a dictionary of addresses by person_id (can have multiple addresses per person)
    addresses_dict = {}
    for row in addresses_df:
        person_id = row['PARENT_ENTITY_ID']
        if person_id not in addresses_dict:
            addresses_dict[person_id] = []
        addresses_dict[person_id].append(row.asDict())

    # Step 3: Get patient aliases for all person_ids
    print("Fetching patient aliases...")
    aliases_df = spark.table("4_prod.raw.mill_person_alias") \
        .filter(col("PERSON_ID").isin(person_ids)) \
        .filter(col("ACTIVE_IND") == 1) \
        .filter(col("ALIAS").isNotNull()) \
        .select("PERSON_ID", "ALIAS") \
        .collect()

    # Create a dictionary of aliases by person_id
    aliases_dict = {}
    for row in aliases_df:
        person_id = row['PERSON_ID']
        alias = row['ALIAS']
        if person_id not in aliases_dict:
            aliases_dict[person_id] = []
        if alias and alias.strip():  # Only add non-empty aliases
            aliases_dict[person_id].append(alias.strip())

    # Step 4: Get the encounter information to link person_id with event_id
    print("Getting encounter mappings...")
    encounter_df = spark.table("4_prod.raw.mill_clinical_event") \
        .filter(col("VALID_UNTIL_DT_TM") > current_timestamp()) \
        .select("EVENT_ID", "ENCNTR_ID") \
        .join(
            spark.table("4_prod.raw.mill_encounter")
                .filter(col("PERSON_ID").isin(person_ids))
                .select("ENCNTR_ID", "PERSON_ID"),
            on="ENCNTR_ID",
            how="inner"
        ) \
        .select("EVENT_ID", "PERSON_ID") \
        .distinct()

    # Step 5: Identify rows to update based on lookup_blob_content logic
    print("Identifying rows to update...")
    window = Window.partitionBy("event_id").orderBy(
        col("valid_until_dt_tm").desc(),
        col("updt_dt_tm").desc()
    )

    # Get the most recent decoded blob for each event_id that needs updating
    rows_to_update = spark.table("4_prod.bronze.mill_blob_text") \
        .filter(col("STATUS") == "Decoded") \
        .filter((col("anon_text").isNull()) | (col("anon_text") == "")) \
        .withColumn("row", row_number().over(window)) \
        .filter(col("row") == 1) \
        .drop("row") \
        .join(encounter_df, on="EVENT_ID", how="inner") \
        .select("EVENT_ID", "PERSON_ID", "BLOB_TEXT")

    # Count rows to process
    row_count = rows_to_update.count()
    print(f"Found {row_count:,} rows to anonymize")

    if row_count == 0:
        print("No rows to update")
        return 0

    # Step 6: Create UDF for anonymization with address, alias, and DOB support
    def anonymize_udf(blob_text, person_id):
        if blob_text is None or blob_text == '':
            return blob_text
        patient_info = patient_names_dict.get(person_id, {})
        address_info = addresses_dict.get(person_id, [])
        alias_info = aliases_dict.get(person_id, [])
        return simple_phi_redaction(blob_text, patient_info, address_info, alias_info, whitelist)

    # Register UDF
    anonymize_text_udf = udf(anonymize_udf, StringType())

    # Step 7: Apply anonymization
    print("Applying anonymization...")
    anonymized_df = rows_to_update \
        .withColumn("anon_text", anonymize_text_udf(col("BLOB_TEXT"), col("PERSON_ID"))) \
        .select("EVENT_ID", "anon_text")

    # Step 8: Update the original table using a temporary table
    print("Updating the table...")

    temp_table_name = f"temp_updates_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    anonymized_df.write.mode("overwrite").saveAsTable(f"4_prod.bronze.{temp_table_name}")

    try:
        merge_query = f"""
        MERGE INTO 4_prod.bronze.mill_blob_text AS target
        USING 4_prod.bronze.{temp_table_name} AS source
        ON target.EVENT_ID = source.EVENT_ID
           AND target.STATUS = 'Decoded'
           AND (target.anon_text IS NULL OR target.anon_text = '')
        WHEN MATCHED THEN
            UPDATE SET
                anon_text = source.anon_text
        """
        spark.sql(merge_query)
        print("Update completed successfully!")
    finally:
        spark.sql(f"DROP TABLE IF EXISTS 4_prod.bronze.{temp_table_name}")

    # Return summary statistics
    updated_count = spark.sql(f"""
        SELECT COUNT(DISTINCT EVENT_ID) as count
        FROM 4_prod.bronze.mill_blob_text
        WHERE anon_text IS NOT NULL
          AND EVENT_ID IN (
            SELECT EVENT_ID FROM 4_prod.bronze.mill_blob_text WHERE STATUS = 'Decoded'
          )
    """).collect()[0]['count']

    print(f"Successfully updated {updated_count:,} rows")

    # Print summary of anonymization including name variants
    print("\nAnonymization summary:")
    print(f"- Patients processed: {len(patient_names_dict):,}")
    
    # Count total unique name variants
    total_first_names = sum(len(p.get('NAME_FIRST', set())) for p in patient_names_dict.values())
    total_middle_names = sum(len(p.get('NAME_MIDDLE', set())) for p in patient_names_dict.values())
    total_last_names = sum(len(p.get('NAME_LAST', set())) for p in patient_names_dict.values())
    
    print(f"- Name variants found:")
    print(f"  - First names: {total_first_names:,}")
    print(f"  - Middle names: {total_middle_names:,}")
    print(f"  - Last names: {total_last_names:,}")
    print(f"- Addresses processed: {sum(len(addrs) for addrs in addresses_dict.values()):,}")
    print(f"- Aliases processed: {sum(len(aliases) for aliases in aliases_dict.values()):,}")

    return updated_count

# Example usage:
if __name__ == "__main__":
    # Option 1: Process first 100,000 eligible person_ids automatically
    update_blob_text_for_persons(limit=100000)
    update_blob_text_for_persons(limit=100000)
    update_blob_text_for_persons(limit=25000)

    