In [None]:
# Databricks notebook source
# =====================================================
# SIMPLE CARE GAPS ETL - BEGINNER FRIENDLY
# Read from ADLS, Create Delta Tables
# =====================================================

print("Starting Care Gaps ETL...")

# COMMAND ----------

# MAGIC %md
# MAGIC ## 1. Configuration (Edit these paths for your environment)

# COMMAND ----------

# Get parameters passed from ADF pipeline
dbutils.widgets.text("run_date", "", "Run Date (yyyy-MM-dd)")
dbutils.widgets.text("environment", "dev", "Environment")
dbutils.widgets.text("care_gaps_count", "0", "Care Gaps Row Count")
dbutils.widgets.text("appointments_count", "0", "Appointments Row Count")
dbutils.widgets.text("patient_summary_count", "0", "Patient Summary Row Count")
dbutils.widgets.text("provider_metrics_count", "0", "Provider Metrics Row Count")
dbutils.widgets.text("campaign_opportunities_count", "0", "Campaign Opportunities Row Count")

RUN_DATE = dbutils.widgets.get("run_date")
ENVIRONMENT = dbutils.widgets.get("environment")

# If run_date not provided (manual run), use today's date
if not RUN_DATE:
    from datetime import datetime
    RUN_DATE = datetime.now().strftime("%Y-%m-%d")

print(f"Run Date: {RUN_DATE}")
print(f"Environment: {ENVIRONMENT}")

# Storage account configuration - CHANGE THESE TO YOUR VALUES
STORAGE_ACCOUNT = "duse1achstdbx1"  # Your storage account name abfss://dev@duse1achstdbx1.dfs.core.windows.net/
CONTAINER = "dev"       # Your container name
STORAGE_KEY = "ouqQcLrewVPACdGe5y9i6z+Qz3+Jz2TT6ivC8HCO5VNiJ/i5x3nJvE/uQplUBlUfXSsqNTg3wNZm+AStDFVQAA=="  # Get from Azure Portal -> Storage Account -> Access Keys

# Configure Spark to access ADLS
spark.conf.set(
    f"fs.azure.account.key.{STORAGE_ACCOUNT}.dfs.core.windows.net",
    STORAGE_KEY
)

print(f"✓ Configured access to storage account: {STORAGE_ACCOUNT}")

# COMMAND ----------

# MAGIC %md
# MAGIC ## 2. Define Paths (All in one place)

# COMMAND ----------

# Base path to your storage
BASE_PATH = f"abfss://{CONTAINER}@{STORAGE_ACCOUNT}.dfs.core.windows.net"

# Landing zone (where ADF puts Parquet files) - includes RunDate partition
LANDING_PATH = f"{BASE_PATH}/landing/chmca_custom/caregaps/{RUN_DATE}"

CATALOG = "dev_kiddo"
BRONZE_SCHEMA = "bronze"
SILVER_SCHEMA = "silver"
GOLD_SCHEMA = "gold"

# Delta Lake paths /Volumes/dev_kiddo/bronze/landing/chmca_custom/ah_eligibility_roster_mrn/raw/ah_eligibility_roster_mrn_*.parquet
BRONZE_PATH = f"{CATALOG}.{BRONZE_SCHEMA}"
SILVER_PATH = f"{CATALOG}.{SILVER_SCHEMA}"
GOLD_PATH = f"{CATALOG}.{GOLD_SCHEMA}"

print("Paths configured:")
print(f"  Landing: {LANDING_PATH}")
print(f"  Bronze:  {BRONZE_PATH}")
print(f"  Silver:  {SILVER_PATH}")
print(f"  Gold:    {GOLD_PATH}")

# COMMAND ----------

# MAGIC %md
# MAGIC ## 3. Read Parquet Files from Landing Zone

# COMMAND ----------

print("\n" + "="*60)
print("STEP 1: Reading Parquet files from landing zone...")
print("="*60)

# File names match ADF sink activity output names
df_care_gaps = spark.read.parquet(f"{LANDING_PATH}/CareGaps_daily.parquet")
care_gaps_count = df_care_gaps.count()
print(f"✓ Care Gaps: {care_gaps_count:,} rows")

df_appointments = spark.read.parquet(f"{LANDING_PATH}/Appointments_daily.parquet")
appointments_count = df_appointments.count()
print(f"✓ Appointments: {appointments_count:,} rows")

df_patient_summary = spark.read.parquet(f"{LANDING_PATH}/PatientGapsSummary_daily.parquet")
patient_summary_count = df_patient_summary.count()
print(f"✓ Patient Summary: {patient_summary_count:,} rows")

df_provider_metrics = spark.read.parquet(f"{LANDING_PATH}/ProviderMetrics_daily.parquet")
provider_metrics_count = df_provider_metrics.count()
print(f"✓ Provider Metrics: {provider_metrics_count:,} rows")

df_campaign_opportunities = spark.read.parquet(f"{LANDING_PATH}/CampaignOpportunities_daily.parquet")
campaign_opportunities_count = df_campaign_opportunities.count()
print(f"✓ Campaign Opportunities: {campaign_opportunities_count:,} rows")

print("\n✓ All files read successfully!")

# COMMAND ----------

# MAGIC %md
# MAGIC ## 4. Create Bronze Delta Tables (Raw Data)

# COMMAND ----------

print("\n" + "="*60)
print("STEP 2: Creating Bronze Delta tables...")
print("="*60)

# Bronze: Care Gaps

df_care_gaps.write.format("delta").mode("overwrite").saveAsTable(
    f"{BRONZE_PATH}.care_gaps_daily"
)
print("✓ Bronze: care_gaps_daily created")

df_appointments.write.format("delta").mode("overwrite").saveAsTable(
    f"{BRONZE_PATH}.appointments_daily"
)
print("✓ Bronze: appointments_daily created")

df_patient_summary.write.format("delta").mode("overwrite").saveAsTable(
    f"{BRONZE_PATH}.patient_summary_daily"
)
print("✓ Bronze: patient_summary_daily created")

df_provider_metrics.write.format("delta").mode("overwrite").saveAsTable(
    f"{BRONZE_PATH}.provider_metrics_daily"
)
print("✓ Bronze: provider_metrics_daily created")

df_campaign_opportunities.write.format("delta").mode("overwrite").saveAsTable(
    f"{BRONZE_PATH}.campaign_opportunities_daily"
)
print("✓ Bronze: campaign_opportunities_daily created")

# COMMAND ----------

# MAGIC %md
# MAGIC ## 5. Create Silver Delta Tables (Cleaned Data)

# COMMAND ----------

from pyspark.sql.functions import *
from pyspark.sql.window import Window
from delta.tables import DeltaTable

print("\n" + "="*60)
print("STEP 3: Creating Silver Delta tables...")
print("="*60)

# Silver: Care Gaps (cleaned)
df_care_gaps_clean = df_care_gaps \
    .filter(col("PAT_ID").isNotNull()) \
    .filter(col("GAP_TYPE").isNotNull()) \
    .withColumn("PRIORITY_NAME", 
                when(col("PRIORITY_LEVEL") == 1, "Critical")
                .when(col("PRIORITY_LEVEL") == 2, "Important")
                .otherwise("Routine"))

df_care_gaps_clean.write.format("delta").mode("overwrite").saveAsTable(f"{SILVER_PATH}.care_gaps_cleaned")

silver_care_gaps_count = df_care_gaps_clean.count()
print(f"✓ Silver: care_gaps_cleaned ({silver_care_gaps_count:,} rows)")

# Silver: Patient 360 (joined data)
df_patient_360 = df_patient_summary.join(
    df_appointments.groupBy("PAT_ID").agg(
        min("APPT_DATE").alias("FIRST_APPT_DATE"),
        min("DAYS_UNTIL_APPT").alias("DAYS_UNTIL_FIRST_APPT")
    ),
    "PAT_ID",
    "left"
)

df_patient_360.write.format("delta").mode("overwrite").saveAsTable(f"{SILVER_PATH}.patient_360")

silver_patient_360_count = df_patient_360.count()
print(f"✓ Silver: patient_360 ({silver_patient_360_count:,} rows)")

# ------------------------------------------------------------------
# Silver: Campaign Opportunities — MERGE to preserve llm_message & status
# ------------------------------------------------------------------
campaign_table_name = f"{SILVER_PATH}.campaign_opportunities"

df_campaign_clean = df_campaign_opportunities \
    .filter(col("patient_mrn").isNotNull()) \
    .filter(col("campaign_type").isNotNull())

# Deduplicate source: keep one row per (patient_mrn, subject_mrn, campaign_type)
# to avoid "multiple source rows matched same target row" MERGE error
w = Window.partitionBy("patient_mrn", "subject_mrn", "campaign_type") \
          .orderBy(col("appointment_date").desc())
df_campaign_dedup = df_campaign_clean \
    .withColumn("_rn", row_number().over(w)) \
    .filter(col("_rn") == 1) \
    .drop("_rn")

# Check if the Silver table already exists
table_exists = spark.catalog.tableExists(campaign_table_name)

if table_exists:
    # MERGE: update staging columns but PRESERVE llm_message and status
    preserve_cols = {"llm_message", "status"}
    source_cols = [c for c in df_campaign_dedup.columns]
    update_set = {c: f"source.{c}" for c in source_cols if c.lower() not in preserve_cols}

    delta_table = DeltaTable.forName(spark, campaign_table_name)
    delta_table.alias("target").merge(
        df_campaign_dedup.alias("source"),
        """target.patient_mrn = source.patient_mrn
           AND target.subject_mrn = source.subject_mrn
           AND target.campaign_type = source.campaign_type"""
    ).whenMatchedUpdate(
        set=update_set
    ).whenNotMatchedInsertAll(
    ).execute()

    silver_campaign_count = spark.table(campaign_table_name).count()
    print(f"✓ Silver: campaign_opportunities MERGED ({silver_campaign_count:,} rows) — llm_message & status preserved")
else:
    # First run — create the table
    df_campaign_dedup.write.format("delta").mode("overwrite").saveAsTable(campaign_table_name)
    silver_campaign_count = df_campaign_dedup.count()
    print(f"✓ Silver: campaign_opportunities CREATED ({silver_campaign_count:,} rows)")

print("\n✓ All Silver tables created!")

# COMMAND ----------

# MAGIC %md
# MAGIC ## 5b. Generate LLM Messages (Run Independently)
# MAGIC 
# MAGIC **This cell is self-contained.** Run it any time to generate LLM messages
# MAGIC for campaign opportunities that don't have one yet.
# MAGIC 
# MAGIC - Reads directly from `dev_kiddo.silver.campaign_opportunities`
# MAGIC - Only processes unique (patient_mrn, subject_mrn, campaign_type) combos needing messages
# MAGIC - Uses Delta MERGE with deduplicated source to avoid conflicts
# MAGIC - Rate-limited to 3 concurrent workers with retry for 429 errors
# MAGIC - Safe to re-run — skips rows that already have messages
# MAGIC
# MAGIC **Prerequisites:** Run `%pip install openai` in a cell above if not already installed.

# COMMAND ----------

import os
import time
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
import pandas as pd
from openai import OpenAI
from delta.tables import DeltaTable

# Suppress noisy MLflow tracing warnings
logging.getLogger("mlflow.tracing").setLevel(logging.ERROR)

print("\n" + "="*60)
print("GENERATING LLM MESSAGES FOR CAMPAIGN OPPORTUNITIES")
print("="*60)

# ---- Configuration ----
CAMPAIGN_TABLE = "dev_kiddo.silver.campaign_opportunities"
LLM_ENDPOINT = "databricks-meta-llama-3-3-70b-instruct"
MAX_WORKERS = 3          # Stay under workspace QPS limit
MAX_RETRIES = 5          # Retry on 429 rate-limit errors
BASE_DELAY = 2.0         # Initial backoff delay in seconds

# ---- Read UNIQUE rows that need messages (with full context) ----
df_needs_messages = spark.sql(f"""
    SELECT patient_mrn, subject_mrn, campaign_type,
           FIRST(patient_name) AS patient_name,
           FIRST(subject_name) AS subject_name,
           FIRST(appointment_date) AS appointment_date,
           FIRST(appointment_location) AS appointment_location,
           FIRST(has_asthma) AS has_asthma,
           FIRST(last_flu_vaccine_date) AS last_flu_vaccine_date,
           FIRST(llm_prompt_context) AS llm_prompt_context
    FROM {CAMPAIGN_TABLE}
    WHERE llm_prompt_context IS NOT NULL
      AND TRIM(llm_prompt_context) != ''
      AND (llm_message IS NULL OR TRIM(llm_message) = '')
    GROUP BY patient_mrn, subject_mrn, campaign_type
""")

rows_to_process = df_needs_messages.count()
print(f"  Unique rows needing LLM messages: {rows_to_process}")

if rows_to_process == 0:
    print("✓ All rows already have LLM messages — nothing to do")
else:
    # ---- Set up LLM client ----
    token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()
    workspace_url = spark.conf.get("spark.databricks.workspaceUrl")

    client = OpenAI(
        api_key=token,
        base_url=f"https://{workspace_url}/serving-endpoints"
    )

    # ---- Flu Vaccine Piggybacking System Prompt ----
    SYSTEM_PROMPT = (
        "You are a message writer for Akron Children's Hospital's "
        "Flu Vaccine Piggybacking campaign. The goal: when one child in a "
        "household has an upcoming appointment, we message the parent to "
        "bring a SIBLING who is overdue for their flu vaccine to that same visit. "
        "Generate a single SMS message (160 characters max). "
        "Be cheerful, positive, and professional. "
        "The message should focus on the FLU VACCINE opportunity for the sibling, "
        "NOT on the existing appointment itself. "
        "Output ONLY the message text, nothing else."
    )

    def build_user_prompt(row):
        """Build a structured prompt from opportunity row data."""
        patient = row.get("patient_name") or "your child"
        subject = row.get("subject_name") or ""
        appt_date = str(row.get("appointment_date") or "")
        appt_loc = str(row.get("appointment_location") or "")
        has_asthma = str(row.get("has_asthma") or "N").upper() == "Y"
        last_vax = str(row.get("last_flu_vaccine_date") or "")
        # Sibling piggybacking = patient is different from subject
        is_sibling = (patient.strip().lower() != subject.strip().lower()) if subject else False

        if is_sibling:
            prompt = (
                f"{patient} is overdue for a flu vaccine. "
                f"Their household member {subject} has an upcoming appointment "
                f"at {appt_loc} on {appt_date}. "
                f"Write an SMS to the parent suggesting they bring {patient} "
                f"for a flu shot during {subject}'s visit."
            )
        else:
            prompt = (
                f"{patient} is overdue for a flu vaccine and has an upcoming "
                f"appointment at {appt_loc} on {appt_date}. "
                f"Write an SMS suggesting they get their flu shot during that visit."
            )

        if last_vax and last_vax.lower() not in ("", "none", "never", "nan", "nat"):
            prompt += f" Their last flu vaccine was {last_vax} — remind them each vaccine only protects for one season."

        if has_asthma:
            prompt += f" IMPORTANT: {patient} has asthma, which puts them at higher risk for severe flu complications. Mention this."

        return prompt

    def generate_message(row_dict):
        """Call Llama 3.3 with retry on rate-limit errors."""
        user_prompt = build_user_prompt(row_dict)
        for attempt in range(MAX_RETRIES):
            try:
                response = client.chat.completions.create(
                    model=LLM_ENDPOINT,
                    messages=[
                        {"role": "system", "content": SYSTEM_PROMPT},
                        {"role": "user", "content": user_prompt},
                    ],
                    max_tokens=80,
                    temperature=0.7,
                )
                msg = response.choices[0].message.content.strip()
                if msg.startswith('"') and msg.endswith('"'):
                    msg = msg[1:-1]
                return msg[:160]
            except Exception as e:
                if "429" in str(e) or "RATE_LIMIT" in str(e).upper():
                    delay = BASE_DELAY * (2 ** attempt)
                    time.sleep(delay)
                    continue
                print(f"  LLM error: {e}")
                return None
        print(f"  Gave up after {MAX_RETRIES} retries (rate limited)")
        return None

    # ---- Generate messages with limited concurrency ----
    pdf = df_needs_messages.toPandas()
    messages = [None] * len(pdf)
    row_items = []

    for idx, row in pdf.iterrows():
        prompt_ctx = row.get("llm_prompt_context")
        if prompt_ctx and pd.notna(prompt_ctx) and str(prompt_ctx).strip():
            row_items.append((idx, row.to_dict()))

    print(f"  Processing {len(row_items)} prompts with {MAX_WORKERS} workers...")

    completed = 0
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        future_to_idx = {
            executor.submit(generate_message, row_dict): idx
            for idx, row_dict in row_items
        }
        for future in as_completed(future_to_idx):
            idx = future_to_idx[future]
            try:
                messages[idx] = future.result()
            except Exception as e:
                print(f"  Error for row {idx}: {e}")
            completed += 1
            if completed % 50 == 0:
                print(f"  Progress: {completed}/{len(row_items)}")

    pdf["llm_message"] = messages
    generated_count = sum(1 for m in messages if m is not None)
    print(f"  Generated {generated_count}/{len(row_items)} messages")

    # ---- MERGE deduplicated source into Delta table ----
    df_updates = spark.createDataFrame(
        pdf[pdf["llm_message"].notna()][["patient_mrn", "subject_mrn", "campaign_type", "llm_message"]]
    )

    delta_table = DeltaTable.forName(spark, CAMPAIGN_TABLE)
    delta_table.alias("target").merge(
        df_updates.alias("source"),
        """target.patient_mrn = source.patient_mrn
           AND target.subject_mrn = source.subject_mrn
           AND target.campaign_type = source.campaign_type"""
    ).whenMatchedUpdate(
        set={"llm_message": "source.llm_message"}
    ).execute()

    print(f"✓ Updated rows with LLM messages in {CAMPAIGN_TABLE}")

# COMMAND ----------

# MAGIC %md
# MAGIC ## 6. Create Gold Delta Tables (Analytics)

# COMMAND ----------

print("\n" + "="*60)
print("STEP 4: Creating Gold Delta tables...")
print("="*60)

# Gold: Gap Summary by Type
df_gap_summary = df_care_gaps_clean.groupBy("GAP_TYPE", "PRIORITY_NAME") \
    .agg(
        count("*").alias("TOTAL_GAPS"),
        countDistinct("PAT_ID").alias("PATIENTS_AFFECTED")
    ) \
    .orderBy("TOTAL_GAPS", ascending=False)

df_gap_summary.write.format("delta").mode("overwrite").saveAsTable(f"{GOLD_PATH}.gap_summary")

gold_gap_summary_count = df_gap_summary.count()
print(f"✓ Gold: gap_summary ({gold_gap_summary_count:,} rows)")

# Gold: Provider Dashboard
df_provider_dashboard = df_provider_metrics \
    .withColumn("GAP_RATE", col("TOTAL_GAPS") / col("TOTAL_PATIENTS_WITH_GAPS"))

df_provider_dashboard.write.format("delta").mode("overwrite").saveAsTable(f"{GOLD_PATH}.provider_dashboard")

gold_provider_dashboard_count = df_provider_dashboard.count()
print(f"✓ Gold: provider_dashboard ({gold_provider_dashboard_count:,} rows)")

print("\n✓ All Gold tables created!")

# COMMAND ----------

# MAGIC %md
# MAGIC ## 7. Summary

# COMMAND ----------

print("\n" + "="*60)
print("ETL COMPLETE - SUMMARY")
print("="*60)
print(f"\nRun Date: {RUN_DATE}")
print("\nData Loaded:")
print(f"  Care Gaps:      {care_gaps_count:,} rows")
print(f"  Appointments:   {appointments_count:,} rows")
print(f"  Patient Summary: {patient_summary_count:,} rows")
print(f"  Provider Metrics: {provider_metrics_count:,} rows")
print(f"  Campaign Opportunities: {campaign_opportunities_count:,} rows")

print("\nDelta Tables Created:")
print("\nBronze Layer:")
print(f"  {BRONZE_PATH}.care_gaps_daily")
print(f"  {BRONZE_PATH}.appointments_daily")
print(f"  {BRONZE_PATH}.patient_summary_daily")
print(f"  {BRONZE_PATH}.provider_metrics_daily")
print(f"  {BRONZE_PATH}.campaign_opportunities_daily")

print("\nSilver Layer:")
print(f"  {SILVER_PATH}.care_gaps_cleaned ({silver_care_gaps_count:,} rows)")
print(f"  {SILVER_PATH}.patient_360 ({silver_patient_360_count:,} rows)")
print(f"  {SILVER_PATH}.campaign_opportunities ({silver_campaign_count:,} rows)")

print("\nGold Layer:")
print(f"  {GOLD_PATH}.gap_summary ({gold_gap_summary_count:,} rows)")
print(f"  {GOLD_PATH}.provider_dashboard ({gold_provider_dashboard_count:,} rows)")

print("\n" + "="*60)
print("✓ SUCCESS - All data processed!")
print("="*60)

# COMMAND ----------

# MAGIC %md
# MAGIC ## 8. Preview Data (Optional)

# COMMAND ----------

# Uncomment to see sample data

# print("\nSample Care Gaps:")
# display(df_care_gaps_clean.limit(10))

# print("\nSample Patient 360:")
# display(df_patient_360.limit(10))

# print("\nGap Summary:")
# display(df_gap_summary)