In [0]:
from datetime import date, timedelta
from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import (
    col, coalesce, lit, to_date, row_number, when,
    max as spark_max, datediff, ltrim, rtrim,
    current_date, year, month, concat, upper, trim
)
from delta.tables import DeltaTable

# ── Initialize Spark ────────────────────────────────────────────────────────────
spark = SparkSession.builder.getOrCreate()

# Conservative cluster/task settings for large volumes
spark.conf.set("spark.sql.shuffle.partitions", "200")
spark.conf.set("spark.databricks.optimizer.dynamicPartitionPruning", "true")
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.sparkContext.setCheckpointDir("/tmp/checkpoints")

# Runtime mode: 'full' (initial backfill) or 'update' (monthly incremental)
dbutils.widgets.text('RUN_MODE','full') if 'dbutils' in globals() else None
RUN_MODE = dbutils.widgets.get('RUN_MODE') if 'dbutils' in globals() else 'full'
if RUN_MODE not in ("full", "update"):
    RUN_MODE = "full"

# Financial year starts in April (UK NHS). Compute start date for UPDATE mode.
today = current_date()
fy_year = when(month(today) >= 4, year(today)).otherwise(year(today) - 1)
fy_start_expr = to_date(concat(fy_year.cast('string'), lit('-04-01')))

# For FULL run we need to start from 2019-01-01
full_start_expr = to_date(lit("2019-01-01"))

# Helper: pick start date based on mode
fin_yearStart = when(lit(RUN_MODE) == lit("update"), fy_start_expr).otherwise(full_start_expr)

# ── 0. ADLS Gen2 Base Paths ──────────────────────────────────────────────────────
mesh_base = "abfss://reporting@udalstdatacuratedprod.dfs.core.windows.net/"
ref_base  = "abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/"

# ── 1. Define MESH and PATLondon Table Paths ────────────────────────────────────
mesh_map = {
    "HospSpell":       f"{mesh_base}restricted/patientlevel/MESH/MHSDS/MHS501HospProvSpell_Published/",
    "CareContact":     f"{mesh_base}restricted/patientlevel/MESH/MHSDS/MHS201CareContact_Published/",
    "SubmissionFlags": f"{mesh_base}restricted/patientlevel/MESH/MHSDS/MHSDS_SubmissionFlags_Published/",
    "Referral":        f"{mesh_base}restricted/patientlevel/MESH/MHSDS/MHS101Referral_Published/",
    "ServiceType":     f"{mesh_base}restricted/patientlevel/MESH/MHSDS/MHS102ServiceTypeReferredTo_Published/",
    "TeamDetails":     f"{mesh_base}restricted/patientlevel/MESH/MHSDS/MHS902ServiceTeamDetails_Published/",
    "MPI":             f"{mesh_base}restricted/patientlevel/MESH/MHSDS/MHS001MPI_Published/",
    "GP":              f"{mesh_base}restricted/patientlevel/MESH/MHSDS/MHS002GP_Published/",
}

lookup_map = {
    "ConsultationMechanism":  f"{ref_base}PATLondon/MHUEC_Reference_Files/Care_Contact_Consultation_Mechanism/",
    "DateDim":                f"{ref_base}PATLondon/MHUEC_Reference_Files/Date_Dimension/",  # Delta
    "ProfGroup":              f"{ref_base}PATLondon/MHUEC_Reference_Files/Referring_Care_Professional_Staff_Group/",
    "SourceOfReferral":       f"{ref_base}PATLondon/MHUEC_Reference_Files/Source_Of_Referral_for_Mental_Health_Services/",
    "ReasonOAR":              f"{ref_base}PATLondon/MHUEC_Reference_Files/Reason_for_Out_Of_Area_Referral/",
    "ServiceTeamTypeLookup":  f"{ref_base}PATLondon/MHUEC_Reference_Files/Care_Contact_Service_or_Team_Type_Referred_to/",
    "PrimaryReasonReferral":  f"{ref_base}PATLondon/MHUEC_Reference_Files/Primary_Reason_For_Referral/",
}

core_data_base = f"{ref_base}PATLondon/MHSDS/Core_Tables/Core_Tables/"
core_map = {
    "ReferralsWithContacts": f"{core_data_base}MH_Referrals_with_Care_Contacts_London/",
}

# ── 2. Read all static tables into dfs ───────────────────────────────────────────
dfs = {}

# 2a) MESH tables (Parquet)
for name, path in mesh_map.items():
    dfs[name] = spark.read.parquet(path)

# 2b) Lookup tables: DateDim (Delta), the rest as Parquet
for name, path in lookup_map.items():
    if name == "DateDim":
        dfs[name] = spark.read.format("delta").load(path)
    else:
        dfs[name] = spark.read.parquet(path)

# 2c) Core data tables (Delta if present)
for name, path in core_map.items():
    try:
        dfs[name] = spark.read.format("delta").load(path)
    except Exception:
        dfs[name] = None

# ── 3. Additional reference sources (GPData, ODS OrgRefs, Ethnicity, etc.) ──────
dfs["GPData"] = (
 spark.read.format("delta")
         .option("header", "true")
         .option("recursiveFileLookup", "true")
         .load(f"{ref_base}PATLondon/MHUEC_Reference_Files/GP_Data/")
)

dfs["CodeChanges"] = (
   spark.read.option("header", "true")
          .option("recursiveFileLookup", "true")
          .parquet(
              "abfss://unrestricted@udalstdatacuratedprod.dfs.core.windows.net/"
              + "reference/Internal/Reference/ComCodeChanges/Published/"
          )
)

dfs["CommissionerHierarchies"] = (
    spark.read.option("header", "true")
         .option("recursiveFileLookup", "true")
         .parquet(
             "abfss://reporting@udalstdatacuratedprod.dfs.core.windows.net/"
             + "unrestricted/reference/UKHD/ODS/Commissioner_Hierarchies_ICB/"
         )
)

dfs["OrgRef"] = (
    spark.read.option("header", "true")
         .option("recursiveFileLookup", "true")
         .parquet(
             "abfss://unrestricted@udalstdatacuratedprod.dfs.core.windows.net/"
             + "reference/UKHD/ODS_API/"
             + "vwOrganisation_SCD_IsLatestEqualsOneWithRole/Published/1/"
         )
).filter(col("Is_Latest") == 1)

# 3e) AllProviders (Parquet) → filter Is_Latest = 1 immediately
dfs["AllProviders"] = (
    spark.read.parquet(
        "abfss://unrestricted@udalstdatacuratedprod.dfs.core.windows.net/"
        + "reference/UKHD/ODS/All_Providers_SCD/Published/1/"
    )
).filter(col("Is_Latest") == 1)

# 3f) AllCodes (Parquet)
dfs["AllCodes"] = spark.read.parquet(
    "abfss://unrestricted@udalstdatacuratedprod.dfs.core.windows.net/"
    + "reference/UKHD/ODS/All_Codes/Published/1/"
)




# ── 4. Build inpatient flag (#HOc) ──────────────────────────────────────────────
HospSpell = dfs["HospSpell"]
tempHOc = (
    HospSpell
      .select("UniqServReqID", "Der_Person_ID", "MHS501UniqID")
      .distinct()
)

# ── 5. Build ExistRef (#ExistRef) from existing “ReferralsWithContacts” ─────────
ExistRef = (
    (dfs["ReferralsWithContacts"]
        .select("UniqServReqID")
        .distinct())
    if dfs.get("ReferralsWithContacts") is not None
    else spark.createDataFrame([], "UniqServReqID string")
)

# ── 7. Build CCPre (#CCPre) ─────────────────────────────────────────────────────
CareContact   = dfs["CareContact"]
SubmissionF   = dfs["SubmissionFlags"]
Referral      = dfs["Referral"]
ConsMechanism = dfs["ConsultationMechanism"]

prov_list = ["RAT","RKL","RPG","RQY","RRP","RV3","RV5","RWK","TAF","RKE","G6V2S"]

CCPre = (
    CareContact.alias("cc")
      .join(
          SubmissionF.alias("f"),
          col("f.NHSEUniqSubmissionID") == col("cc.NHSEUniqSubmissionID"),
          how="left"
      )
      # NOTE: respecting your Der_IsLatest handling – leave as-is if you’ve standardized
      .filter(col("Der_IsLatest") == 1)
      .join(
          Referral.alias("g"),
          col("g.UniqServReqID") == col("cc.UniqServReqID"),
          how="inner"
      )
      .filter(
          col("cc.AttendStatus").isin("5","6") &
          (
            col("cc.ConsMechanismMH").isin("01","02","04") |
            ((col("cc.UniqMonthID") < 1459) & (col("cc.ConsMechanismMH") == "03")) |
            ((col("cc.UniqMonthID") >= 1459) & (col("cc.ConsMechanismMH") == "11"))
          ) &
          (col("cc.CareContDate") >= fin_yearStart) &   # ← use FY start in update mode
          (col("g.OrgIDProv").isNotNull() & col("g.OrgIDProv").isin(prov_list))
      )
      .select(col("cc.UniqServReqID").alias("UniqServReqID"))
      .distinct()
)

# ── 8. Build tempUR_all & tempUR ────────────────────────────────────────────────
Referral        = dfs["Referral"]
ServiceType     = dfs["ServiceType"]
TeamDetails     = dfs["TeamDetails"]
DateDim         = dfs["DateDim"]
MPI             = dfs["MPI"]
GP              = dfs["GP"]
GPData          = dfs["GPData"]
CodeChanges     = dfs["CodeChanges"]
Commissioners   = dfs["CommissionerHierarchies"]
OrgRef          = dfs["OrgRef"]
ProfGroup       = dfs["ProfGroup"]
SrcOfReferral   = dfs["SourceOfReferral"]
ReasonOAR       = dfs["ReasonOAR"]

# Defensive DateDim join: use whichever column exists
date_cols = [f.name for f in DateDim.schema.fields]
date_col = "Date" if "Date" in date_cols else ("Calendar_Day" if "Calendar_Day" in date_cols else None)
#display(GPData)
tempUR_all = (
    Referral.alias("a")
      # (1) SubmissionFlags (respecting your Der_IsLatest handling)
      .join(
          SubmissionF.alias("sf"),
          (col("sf.NHSEUniqSubmissionID") == col("a.NHSEUniqSubmissionID")),
          how="left"
      )
      .filter(col("sf.Der_IsLatest") == "Y")
      # (2) ServiceType
      .join(
          ServiceType.alias("st"),
          (col("st.RecordNumber") == col("a.RecordNumber"))
          & (col("a.UniqServReqID") == col("st.UniqServReqID")),
          how="left"
      )
      # (3) TeamDetails
      .join(
          TeamDetails.alias("rtd"),
          (col("rtd.UniqCareProfTeamLocalID") == col("a.UniqCareProfTeamLocalID"))
          & (col("rtd.UniqMonthID") == col("a.UniqMonthID")),
          how="left"
      )
      # (4) DateDim join on ReferralRequestReceivedDate
      .join(
          DateDim.alias("dt"),
          to_date(col("a.ReferralRequestReceivedDate")) == (
              col(f"dt.{date_col}") if date_col else to_date(col("a.ReferralRequestReceivedDate"))
          ),
          how="left"
      )
      # (5) MPI
      .join(
          MPI.alias("b"),
          (
            (col("b.Person_ID") == col("a.Person_ID"))
            & (col("b.UniqSubmissionID") == col("a.UniqSubmissionID"))
            & (col("b.UniqMonthID") == col("a.UniqMonthID"))
            & (col("b.RecordNumber") == col("a.RecordNumber"))
          ),
          how="left"
      )
      # (6) GP
      .join(
          GP.alias("gp"),
          (col("gp.RecordNumber") == col("b.RecordNumber"))
          & (col("gp.UniqSubmissionID") == col("a.UniqSubmissionID")),
          how="left"
      )
      # (7) GPData
      .join(
          GPData.alias("gpd"),
          col("gpd.Practice_Code") == col("gp.GMPReg"),
          how="left"
      )
      # (8) CodeChanges
      .join(
          CodeChanges.alias("cc"),
          col("cc.Org_Code")
          == coalesce(col("b.OrgIDSubICBLocResidence"), col("b.OrgIDCCGRes")),
          how="left"
      )
      # (9) Commissioners
      .join(
          Commissioners.alias("c"),
          coalesce(
            col("cc.New_Code"),
            col("b.OrgIDCCGRes"),
            col("b.OrgIDSubICBLocResidence")
          ) == col("c.Organisation_Code"),
          how="left"
      )
      # (10) Inpatient flag
      .join(
          tempHOc.alias("ho"),
          (
            (col("ho.Der_Person_ID") == col("a.Der_Person_ID"))
            & (col("ho.UniqServReqID") == col("a.UniqServReqID"))
          ),
          how="left"
      )
      # (11) OrgRef
      .join(
          OrgRef.alias("ORef"),
          col("ORef.ODS_Code") == col("a.OrgIDReferringOrg"),
          how="left"
      )
      # (12) ProfGroup
      .join(
          ProfGroup.alias("pg"),
          col("pg.Code") == col("a.ReferringCareProfessionalStaffGroup"),
          how="left"
      )
      # (13) SourceOfReferral
      .join(
          SrcOfReferral.alias("sor"),
          col("sor.Code") == col("a.SourceOfReferralMH"),
          how="left"
      )
      # (14) ReasonOAR
      .join(
          ReasonOAR.alias("oop"),
          col("oop.Code").cast("string") == col("a.ReasonOAT").cast("string"),
          how="left"
      )
      # (15) ExistRef
      .join(
          ExistRef.alias("er"),
          col("er.UniqServReqID") == col("a.UniqServReqID"),
          how="left"
      )
      # (16) Join CCPre instead of driver-side isin(list)
      .join(
          CCPre.alias("ccpre"),
          col("ccpre.UniqServReqID") == col("a.UniqServReqID"),
          "left"
      )
      .filter(
          (
            col("er.UniqServReqID").isNull()
            & (col("a.OrgIDProv").isNotNull() & col("a.OrgIDProv").isin(prov_list))
            & (to_date(col("a.ReferralRequestReceivedDate")) >= fin_yearStart)
          )
          |
          col("ccpre.UniqServReqID").isNotNull()
      )
      .select(
          col("dt.Financial_Year").alias("Referral_Fin_Year"),
          col("dt.Month_Start_Date").alias("Referral_Month"),
          row_number().over(
            Window()
              .partitionBy(col("a.UniqServReqID"), col("a.OrgIDProv"))
              .orderBy(
                col("a.UniqMonthID").desc(),
                col("a.UniqSubmissionID").desc(),
                col("st.UniqSubmissionID").desc(),
                col("st.Effective_From").desc()
              )
          ).alias("RowOrder"),
          col("a.UniqServReqID"),
          col("a.Der_Person_ID"),
          col("a.RecordNumber"),
          col("a.Person_ID"),
          col("a.UniqSubmissionID"),
          col("a.UniqMonthID"),
          col("a.FirstContactEverDate"),
          col("a.ReferralRequestReceivedDate"),
          col("a.ReferralRequestReceivedTime"),
          when(col("a.ReferRejectionDate").isNotNull(), lit(1)).alias("Referral_Rejected_Flag"),
          col("a.ReferRejectionDate"),
          col("a.ServDischDate"),
          col("b.Der_Pseudo_NHS_Number"),
          col("b.LSOA2011").alias("Patient_LSOA"),
          col("b.DefaultPostcode").alias("Patient_PostCode"),
          col("b.Gender"),
          col("b.EmploymentNationalLatest"),
          col("b.AccommodationNationalLatest"),
          col("b.EthnicCategory"),
          col("b.NHSDEthnicity"),
          col("a.AgeServReferRecDate"),
          when(col("ho.MHS501UniqID").isNotNull(), lit(1)).alias("Inpatient_Services_Flag"),
          when(col("c.Region_Code") == "Y56", lit("London_Patient"))
           .otherwise(lit("Out_of_London_or_Not_Recorded"))
           .alias("Patient_Region"),
          col("a.PrimReasonReferralMH"),
          when(
            ltrim(rtrim(col("st.ServTeamTypeRefToMH"))).isin(
              "A05","A06","A08","A09","A12","A13","A16","C03","C10"
            ),
            lit(1)
          ).alias("Core_Community_Service_Team_Flag_OLD"),
          when(ltrim(rtrim(col("st.ServTeamTypeRefToMH"))) == "A06", lit(1))
            .alias("Core_Community_Service_Team_Flag"),
          col("gpd.Practice_Code").alias("ODS_GPPrac_OrgCode"),
          col("gpd.PCDS_NoGaps").alias("ODS_GPPrac_PostCode"),
          col("gpd.GP_Code").alias("MPI_GP_Code"),
          col("gpd.GP_Name").alias("Registered_GP_Practice_Name"),
          col("gpd.Local_Authority_Name").alias("GP_Local_Authority"),
          col("gpd.GP_Region_Name"),
          col("gpd.Lower_Super_Output_Area_Code").alias("GP_LSOA"),
          col("gpd.Longitude").alias("GP_Longitude"),
          col("gpd.Latitude").alias("GP_Latitude"),
          col("a.OrgIDReferringOrg").alias("OrgIDReferring"),
          col("ORef.Name").alias("Referring_Organisation"),
          col("ORef.role").alias("Referring_Org_Type"),
          col("rtd.serviceTypeName").alias("Type_of_Service_Referred_to"),
          col("a.SourceOfReferralMH"),
          col("sor.Description").alias("Source_of_Referral"),
          col("a.OrgIDProv"),
          when(col("a.SourceOfReferralMH") == "H1", lit("Emergency_Department"))
            .when(col("a.SourceOfReferralMH") == "H2", lit("Acute_Secondary_Care"))
            .when(col("a.SourceOfReferralMH").isin("A1","A2","A3","A4"), lit("Primary_Care"))
            .when(col("a.SourceOfReferralMH").isin("B1","B2"), lit("Self"))
            .when(col("a.SourceOfReferralMH").isin("E1","E2","E3","E4","E5","E6"), lit("Justice"))
            .when(
               col("a.SourceOfReferralMH").isin(
                 "F1","F2","F3","G1","G2","G3","G4","I1","I2",
                 "M1","M2","M3","M4","M5","M6","M7",
                 "C1","C2","C3","D1","D2","N3"
               ), lit("Other")
            )
            .when(col("a.SourceOfReferralMH") == "P1", lit("Internal"))
            .otherwise(lit("Missing/Invalid"))
            .alias("Source_of_Referral_Derived"),
          when(col("a.SourceOfReferralMH") == "H1", lit("Emergency_Department"))
            .when(col("a.SourceOfReferralMH") == "H2", lit("Acute_Secondary_Care"))
            .otherwise(lit("Other"))
            .alias("Source_of_Referral_Simplified"),
          when(col("a.ClinRespPriorityType") == "1", lit("Emergency"))
            .when(col("a.ClinRespPriorityType").isin("2","U"), lit("Urgent"))
            .when(col("a.ClinRespPriorityType") == "3", lit("Routine"))
            .when(col("a.ClinRespPriorityType") == "4", lit("Very_Urgent"))
            .otherwise(lit("Unknown"))
            .alias("Clinical_Response_Priority_Type"),
          col("pg.Description").alias("Referring_Care_Professional_Staff_Group"),
          col("oop.Description").alias("Reason_for_Out_of_Area_Referral"),
          coalesce(col("b.OrgIDSubICBLocResidence"), col("b.OrgIDCCGRes")).alias("OrgIDCCGRes"),
          col("cc.New_Code"),
          col("sf.Der_IsLatest").alias("Der_IsLatest")
      )
)

# Keep only RowOrder = 1 → tempUR
tempUR = tempUR_all.filter(col("RowOrder") == 1)
#tempUR.schema
#display(tempUR)
# ── 9. Build CC (unique contacts) ───────────────────────────────────────────────
r = (
    tempUR
      .select(
          "UniqServReqID","Der_Person_ID",
          "Type_of_Service_Referred_to","ReferralRequestReceivedDate",
          "Clinical_Response_Priority_Type"
      )
      .distinct()
      .filter(col("ReferralRequestReceivedDate").isNotNull())
)

CC_with_row2 = (
    CareContact.alias("cc")
      .join(SubmissionF.alias("f"),
            col("f.NHSEUniqSubmissionID")==col("cc.NHSEUniqSubmissionID"), "left")
      .filter(col("f.Der_IsLatest") =="Y")
      .join(
          ServiceType.alias("st"),
          (col("st.UniqServReqID") == col("cc.UniqServReqID"))
          & (col("st.UniqMonthID") == col("cc.UniqMonthID"))
          & (col("st.UniqSubmissionID") == col("cc.UniqSubmissionID")),
          "left"
      )
      .join(r.alias("r"),
            (col("r.UniqServReqID")==col("cc.UniqServReqID"))
            & (col("r.Der_Person_ID")==col("cc.Der_Person_ID"))
            & (col("r.ReferralRequestReceivedDate") <= col("cc.CareContDate")),
            "inner")
      .filter(
          col("cc.AttendStatus").isin("5","6")
          &
          (
            col("cc.ConsMechanismMH").isin("01","02","04")
            |
            ((col("cc.UniqMonthID") < 1459) & (col("cc.ConsMechanismMH") == "03"))
            |
            ((col("cc.UniqMonthID") >= 1459) & (col("cc.ConsMechanismMH") == "11"))
          )
      )
      .withColumn(
          "RowOrder",
          row_number().over(
              Window.partitionBy("cc.UniqServReqID","cc.CareContDate","cc.CareContTime","cc.ConsMechanismMH")
                    .orderBy(col("cc.UniqSubmissionID").desc())
          )
      )
      .select(
          col("RowOrder"),
          col("cc.UniqServReqID").alias("CC_UniqServReqID"),
          col("cc.RecordNumber"),
          col("cc.Der_Person_ID"),
          col("st.ServTeamTypeRefToMH").alias("Service_Team_Type_Code"),
          col("r.Type_of_Service_Referred_to").alias("Type_of_Service_Referred_to"),
          col("r.ReferralRequestReceivedDate").alias("ReferralRequestReceivedDate"),
          col("cc.CareContDate"),
          col("cc.CareContTime"),
          col("cc.UniqSubmissionID"),
          col("cc.MHS201UniqID"),
          col("cc.AttendStatus"),
          col("cc.ConsMechanismMH"),
          col("Clinical_Response_Priority_Type"),
          when(
              col("cc.ConsMechanismMH").isin("01","02","04"),
              lit("face to face, telephone or talk type")
          )
          .when(
              ((col("cc.UniqMonthID") < 1459) & (col("cc.ConsMechanismMH") == "03"))
              | ((col("cc.UniqMonthID") >= 1459) & (col("cc.ConsMechanismMH") == "11")),
              lit("video")
          )
          .otherwise(lit(None))
          .alias("ContactTypeDesc"),
          when(
              col("cc.ConsMechanismMH").isin("01","02","04","11")
              | ((col("cc.OrgIDProv") == "DFC") & col("cc.ConsMechanismMH").isin("05","09","10","13")),
              lit(1)
          ).alias("Der_Contact"),
          when(
              col("cc.AttendStatus").isin("5","6")
              & col("cc.ConsMechanismMH").isin("01","02","04","11"),
              lit(1)
          ).alias("Der_DirectContact"),
          when(
              col("cc.AttendStatus").isin("5","6")
              & col("cc.ConsMechanismMH").isin("01","11"),
              lit(1)
          ).alias("Der_FacetoFaceContact")
      )
)

# 9c) Keep only RowOrder = 1 → CC
CC = CC_with_row2.filter(col("RowOrder") == 1)

# ── 10. Build RefCC base (#d) joining demographics, orgs, primary reason, etc. ──
genderDF = (
    spark.read.parquet(
        "abfss://unrestricted@udalstdatacuratedprod.dfs.core.windows.net/"
        + "reference/UKHD/Data_Dictionary/Gender_Identity_Code_SCD/Published/1/"
    )
)


# 3g) EthnicityLondon (Parquet)
df_EthnicityLondon = (
    spark.read
         .option("header", "true")
         .option("recursiveFileLookup", "true")
         .parquet(
             "abfss://unrestricted@udalstdatacuratedprod.dfs.core.windows.net/"
             + "reference/UKHD/Data_Dictionary/Ethnic_Category_Code_SCD/Published/1/"
         )
)

df_PrimaryReasonReferral = (
    spark.read
         .option("header", "true")
         .option("recursiveFileLookup", "true")
         .parquet(
             "abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/"
             + "PATLondon/MHUEC_Reference_Files/Primary_Reason_For_Referral/"
         )
)
df_postcode_lookup = (
    spark.read
         .format("delta")  # Change the format to "delta" for reading Delta tables
         .load(
             "abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/"
             + "PATLondon/MHUEC_Reference_Files/PostCode_to_LA/"
         )
)


ec_lu  = df_EthnicityLondon.alias("ec").filter(col("ec.Is_Latest") == lit(1))         # if Is_Latest exists
pm_lu  = df_PrimaryReasonReferral.alias("pm")
gdf_lu = genderDF.alias("gdf").filter(col("gdf.Is_Latest") == lit(1))        # or use the right case: is_latest vs Is_Latest

d = (
    tempUR.alias("d")
      .join(ec_lu,  col("ec.Main_Code_Text") == col("d.EthnicCategory"), "left")
      .join(pm_lu,  col("pm.Code")           == col("d.PrimReasonReferralMH"), "left")
      .join(gdf_lu, col("gdf.Main_Code_Text")== col("d.Gender"), "left")
      .join(
          df_postcode_lookup.alias("pc"),
          coalesce(col("pc.Postcode_3"),col("pc.Postcode_1"),col("pc.PCDS_NoGaps")).substr(1, 4) == coalesce(col("d.Patient_PostCode"), col("d.ODS_GPPrac_PostCode")).substr(1, 4),
          "left")   
      .filter(col("d.ReferralRequestReceivedDate").isNotNull())
    .select(
          col("d.*"),  # keep every column from tempUR
          # Ethnicity fields (renamed)
          col("ec.Main_Description").alias("Ethnic_Category"),
          col("ec.Category").alias("Broad_Ethnic_Category"),
          # Primary reason field (rename as needed to your schema's description column)
          col("pc.Local_Authority_Name").alias("Patient_Postcode_Borough"),
          col("pc.Lower_Super_Output_Area_Code").alias("Patient_Postcode_LSOA"),
          
          col("pm.Description").alias("Primary_Reason_For_Referral"),
          # Gender description (rename as desired)
          col("gdf.Main_Description").alias("Gender_Description"),
           when(col("ec.Category") == "Asian or Asian British", lit("Asian"))
          .when(col("ec.Category") == "Black or Black British", lit("Black"))
          .when(
            col("ec.Main_Description").isin(
              "mixed", "Any other ethnic group", "White & Black Caribbean",
              "Any other mixed background", "Chinese"
            ),
            lit("Mixed/Other")
          )
          .otherwise(col("ec.Category"))
          .alias("Derived_Broad_Ethnic_Category"),
      )
)
#display(d)
d.dtypes

df_ethnicity_population = (
    spark.read
         .option("header", "true")                # Keep column names
         .option("recursiveFileLookup", "true")   # In case it's split across subfolders
         .parquet(
             "abfss://analytics-projects@udalstdataanalysisprod.dfs.core.windows.net/"
             + "PATLondon/MHUEC_Reference_Files/Ethnicity_Population/London/"
         )
)
Providers = dfs["AllProviders"]
display(Providers)
#df_ethnicity_population.show(10, truncate=False)
#── 11. Build final “newRowsDF” ─────────────────────────────────────────────────
MaxDate = to_date(lit("2099-12-31"))
RefCC = (
    d.alias("d")
      .join(CC.alias("c"),
            col("c.CC_UniqServReqID")==col("d.UniqServReqID"),
            "left")
      .join(Providers.alias("pro"),
            (col("pro.Organisation_Code")==col("d.OrgIDProv")) &
            (col("pro.Is_Latest")==lit(1)),
            "left")
      .join(df_ethnicity_population.alias("ep"),
            (col("ep.Borough") == col("d.Patient_Postcode_Borough")) &
            (col("ep.Broad_Ethnic_Category") == col("d.Derived_Broad_Ethnic_Category")),
            "left")
      .select(
          col("d.UniqServReqID"),
          col("d.Der_Person_ID"),
          col("d.Der_Pseudo_NHS_Number"),
          col("d.Patient_Postcode_LSOA"),
          col("d.Patient_PostCode"),

          col("d.Gender_Description").alias("Gender"),
          col("d.AgeServReferRecDate").alias("Age_at_Referral"),
          when((col("d.AgeServReferRecDate") <= 18) & col("d.AgeServReferRecDate").isNotNull(), lit("CYP"))
           .when((col("d.AgeServReferRecDate") > 18) & col("d.AgeServReferRecDate").isNotNull(), lit("Adult"))
           .otherwise(lit(None)).alias("Age_Group"),

          # 👇 use the columns already present on d from the ec join
          col("d.Ethnic_Category"),
          col("d.Broad_Ethnic_Category"),

          col("d.Derived_Broad_Ethnic_Category"),
         # 👇 calculated field from population
          ((lit(1) / when(col("ep.Value").cast("float") != 0, col("ep.Value").cast("float"))) * lit(100000))
              .alias("Ethnic_proportion_per_100000_of_London_Borough_2020"),

          #col("d.Ethnic_proportion_per_100000_of_England_2020"),
          col("d.EmploymentNationalLatest"),
          col("d.AccommodationNationalLatest"),

          coalesce(col("d.ODS_GPPrac_OrgCode"), col("d.MPI_GP_Code")).alias("Registered_GP_Practice_OrgCode"),
          col("d.Registered_GP_Practice_Name"),
          col("d.GP_Local_Authority"),
    
          col("d.OrgIDCCGRes"),
          col("d.GP_Region_Name"),

          col("d.ReferralRequestReceivedDate"),
      
          col("d.Primary_Reason_For_Referral"),
          col("d.SourceOfReferralMH"),
          col("d.Source_of_Referral"),
          col("d.Source_of_Referral_Derived"),
          col("d.Source_of_Referral_Simplified"),

          col("d.OrgIDProv"),
          col("pro.Organisation_Name").alias("Provider"),
          col("d.OrgIDReferring").alias("OrgIDReferring"),
          col("d.Referring_Organisation"),
          col("d.Referring_Org_Type"),
          col("d.Referring_Care_Professional_Staff_Group"),
          col("d.Reason_for_Out_of_Area_Referral"),
          col("d.FirstContactEverDate"),
          col("d.ReferralRequestReceivedTime"),
          col("c.CareContDate"),
          col("c.CareContTime"),
          col("d.ServDischDate"),
          col("d.Referral_Rejected_Flag"),
          col("d.ReferRejectionDate"),
          col("c.ContactTypeDesc"),
          col("c.Der_Contact"),
          col("c.Der_DirectContact"),
          col("c.Der_FacetoFaceContact"),
          lit(None).cast("string").alias("Face_to_Face_Order"),

          datediff(col("c.CareContDate"), col("d.ReferralRequestReceivedDate")).alias("Days_Between_Referral_and_Care_Contact"),
          datediff(coalesce(col("d.ReferRejectionDate"), col("d.ServDischDate"), to_date(lit("2099-12-31"))),
                   col("d.ReferralRequestReceivedDate")).alias("Days_Between_Referral_and_Closure"),

          lit(None).cast("long").alias("Episode_Order"),
          lit(None).cast("long").alias("Overall_Order"),

          col("d.Type_of_Service_Referred_to"),
          col("d.Clinical_Response_Priority_Type"),
      )
)


newRowsDF = RefCC

# === Write / Merge Logic ===
target_path = core_map["ReferralsWithContacts"]

if RUN_MODE == "full":
    # Full rebuild (heavy): replace entire dataset
    (newRowsDF
        .coalesce(32)            # reduce small files
        .write
        .format("delta")
        .mode("overwrite")
        .option("overwriteSchema", "true")
        .save(target_path)
    )
else:
    # Incremental update without driver collect
    idsDF = tempUR.select("UniqServReqID").distinct()
    idsDF.createOrReplaceTempView("upd_ids")

    # Ensure target exists (no-op if already present)
    spark.sql(f"CREATE TABLE IF NOT EXISTS delta.`{target_path}` USING DELTA LOCATION '{target_path}' AS SELECT * FROM (SELECT NULL AS UniqServReqID) WHERE 1=0")

    # Delete affected referrals via subquery
    spark.sql(f"DELETE FROM delta.`{target_path}` WHERE UniqServReqID IN (SELECT UniqServReqID FROM upd_ids)")

    # Append rebuilt rows
    (newRowsDF
        .coalesce(16)
        .write
        .format("delta")
        .mode("append")
        .save(target_path)
    )
