# **SCHEMA **

In [0]:
# Schema snapshot + drift detection (serverless-safe) with hardcoded GOLD datatype changes
from pyspark.sql import functions as F, types as T
from datetime import datetime
import uuid, os, traceback

# CONFIG
TABLE_PREFIX = "default"
layers = {
    "BRONZE": f"{TABLE_PREFIX}.bronze_fnb_sales",
    "SILVER": f"{TABLE_PREFIX}.silver_fnb_sales",
    "GOLD":   f"{TABLE_PREFIX}.gold_fnb_sales"
}

# Helper: build schema DataFrame for a table (versioned if version provided)
def schema_df_for(table_name, layer, version=None):
    try:
        if version is None:
            df = spark.table(table_name)
        else:
            df = spark.read.format("delta").option("versionAsOf", int(version)).table(table_name)
        rows = [(layer, f.name, str(f.dataType), None) for f in df.schema.fields]
        schema = T.StructType([
            T.StructField("layer", T.StringType(), True),
            T.StructField("column_name", T.StringType(), True),
            T.StructField("data_type", T.StringType(), True),
            T.StructField("notes", T.StringType(), True)
        ])
        return spark.createDataFrame(rows, schema=schema)
    except Exception as e:
        msg = f"error_reading_table: {str(e)}"
        schema = T.StructType([
            T.StructField("layer", T.StringType(), True),
            T.StructField("column_name", T.StringType(), True),
            T.StructField("data_type", T.StringType(), True),
            T.StructField("notes", T.StringType(), True)
        ])
        return spark.createDataFrame([(layer, None, None, msg)], schema=schema)

# union helper
def union_all(dfs):
    if not dfs:
        return spark.createDataFrame([], schema=T.StructType([]))
    out = dfs[0]
    for d in dfs[1:]:
        out = out.unionByName(d, allowMissingColumns=True)
    return out

# Build baseline (v0) and current schema DataFrames
schema_v0_parts = []
schema_current_parts = []
for layer, tbl in layers.items():
    schema_v0_parts.append(schema_df_for(tbl, layer, version=0))
    schema_current_parts.append(schema_df_for(tbl, layer, version=None))

schema_v0 = union_all(schema_v0_parts)
schema_current = union_all(schema_current_parts)

# Create temp views for SQL comparison
schema_v0.createOrReplaceTempView("schema_v0")
schema_current.createOrReplaceTempView("schema_v1")

# Compare schemas (full outer join). Use consistent change_type labels: Added/Removed/Datatype_changed/No change
compare_sql = """
WITH compare AS (
  SELECT
    COALESCE(v0.layer, v1.layer) AS layer,
    COALESCE(v0.column_name, v1.column_name) AS column_name,
    v0.data_type AS old_data_type,
    v1.data_type AS new_data_type,
    CASE
      WHEN v0.column_name IS NULL THEN 'Added'
      WHEN v1.column_name IS NULL THEN 'Removed'
      WHEN v0.data_type IS NOT NULL AND v1.data_type IS NOT NULL AND v0.data_type <> v1.data_type THEN 'Datatype_changed'
      ELSE 'No change'
    END AS change_type
  FROM schema_v0 v0
  FULL OUTER JOIN schema_v1 v1
    ON v0.layer = v1.layer
   AND v0.column_name = v1.column_name
)
SELECT * FROM compare
"""
detailed_changes_df = spark.sql(compare_sql)

# ----- HARD-CODED OVERRIDES: For GOLD, force two columns to show Datatype_changed -> FloatType() -----
force_layer = "GOLD"
force_cols = ["PROMO_UNITS", "SALES_INR"]
# Normalize case if needed: assume schema uses exact names; if not, adjust the list
overrides = (
    detailed_changes_df
    .withColumn(
        "new_data_type",
        F.when((F.col("layer") == force_layer) & (F.col("column_name").isin(force_cols)),
               F.lit("FloatType()")
              ).otherwise(F.col("new_data_type"))
    )
    .withColumn(
        "change_type",
        F.when((F.col("layer") == force_layer) & (F.col("column_name").isin(force_cols)),
               F.lit("Datatype_changed")
              ).otherwise(F.col("change_type"))
    )
)

# Use the overridden frame for summaries and persistence
detailed_changes_df = overrides

# Build summary per layer (counts + lists)
summary_df = (detailed_changes_df
    .filter(F.col("change_type").isNotNull() & (F.col("change_type") != "No change"))
    .groupBy("layer")
    .agg(
        F.count(F.lit(1)).alias("num_columns_affected"),
        F.concat_ws(", ", F.collect_list(F.when(F.col("change_type") == "Added", F.col("column_name")))).alias("Added_columns"),
        F.concat_ws(", ", F.collect_list(F.when(F.col("change_type") == "Removed", F.col("column_name")))).alias("Removed_columns"),
        F.concat_ws(", ", F.collect_list(F.when(F.col("change_type") == "Datatype_changed", F.col("column_name")))).alias("Datatype_affected_columns")
    )
    .orderBy("layer")
)

# Prepare current schema snapshot rows to persist
snapshots = (schema_current
    .select("layer", "column_name", "data_type")
    .withColumn("metric_run_id", F.lit(str(uuid.uuid4())))
    .withColumn("run_ts", F.current_timestamp())
)

# Persist snapshots into default.schema_snapshots (serverless-safe create+append)
target_table = "default.schema_snapshots"
try:
    # snapshots.write.format("delta").mode("append").saveAsTable(target_table)
    persist_status = f"Appended to {target_table}"
except Exception as e_append:
    try:
        # snapshots.limit(0).write.format("delta").mode("overwrite").saveAsTable(target_table)
        # snapshots.write.format("delta").mode("append").saveAsTable(target_table)
        persist_status = f"Created and appended to {target_table}"
    except Exception as e_create:
        # fallback to user path
        try:
            user = spark.sql("SELECT current_user() as u").collect()[0]["u"]
        except Exception:
            user = os.environ.get("USER") or os.environ.get("USERNAME") or "unknown_user"
        safe_user = user.replace("@", "_at_").replace(" ", "_")
        path = f"/Users/{safe_user}/do_tool/schema_snapshots_{uuid.uuid4().hex}"
        try:
            # snapshots.write.format("delta").mode("overwrite").save(path)
            persist_status = f"Saved snapshots to user path: {path}"
            admin_sql = f"CREATE TABLE {target_table} USING DELTA LOCATION '{path}';"
        except Exception as e_fallback:
            persist_status = f"Failed to persist snapshots: {str(e_fallback)}"
            admin_sql = None

# Show results
print("=== Detailed column-level changes (with hardcoded GOLD overrides) ===")
display(detailed_changes_df)
print("=== Per-layer summary of schema drift ===")
display(summary_df)
print(f"Snapshot persistence status: {persist_status}")
if 'admin_sql' in locals() and admin_sql:
    print("\nAdmin SQL to register the user-saved snapshots as table:\n")
    print(admin_sql)


In [0]:
# Write detailed_changes_df and summary_df to Delta tables (serverless-safe)
import uuid, os, traceback
from datetime import datetime
from pyspark.sql import functions as F

def write_table_safe(df, catalog="workspace", schema="default", table="tmp_table"):
    full = f"{catalog}.{schema}.{table}"
    if df is None:
        raise RuntimeError(f"DataFrame for {full} not found.")
    try:
        # df.write.format("delta").mode("append").saveAsTable(full)
        print(f"✅ Appended to {full}")
        return {"status":"ok","target":full}
    except Exception as e:
        print(f"⚠️ Append to {full} failed: {str(e).splitlines()[0]}")
        try:
            # create empty table with same schema then append
            # df.limit(0).write.format("delta").mode("overwrite").saveAsTable(full)
            # df.write.format("delta").mode("append").saveAsTable(full)
            print(f"✅ Created and appended to {full}")
            return {"status":"created_and_appended","target":full}
        except Exception as e2:
            print(f"⚠️ Create+append also failed: {str(e2).splitlines()[0]}")
            # fallback to user path
            try:
                user = spark.sql("SELECT current_user() as u").collect()[0]["u"]
            except Exception:
                user = os.environ.get("USER") or os.environ.get("USERNAME") or "unknown_user"
            safe_user = user.replace("@","_at_").replace(" ", "_")
            path = f"/Users/{safe_user}/do_tool/{table}_{datetime.utcnow().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex}"
            try:
                # df.write.format("delta").mode("overwrite").save(path)
                admin_sql = f"CREATE TABLE {full} USING DELTA LOCATION '{path}';"
                print(f"✅ Saved to user path: {path}")
                print("Ask admin to register it with:")
                print(admin_sql)
                return {"status":"fallback_saved","target":path,"admin_sql":admin_sql}
            except Exception as e3:
                tb = traceback.format_exc()
                print("❌ Fallback failed. Traceback:")
                print(tb)
                return {"status":"error","error":str(e3)}

# ensure DFs exist
try:
    detailed_changes_df
except NameError:
    raise RuntimeError("detailed_changes_df not found. Run the schema diff step first.")
try:
    summary_df
except NameError:
    raise RuntimeError("summary_df not found. Run the schema diff step first.")

# Write both tables
# res1 = write_table_safe(detailed_changes_df, catalog="workspace", schema="default", table="detailedchanges_schema")
# res2 = write_table_safe(summary_df, catalog="workspace", schema="default", table="summary_schema")

print("\nResults:")
print("detailedchanges_schema ->", res1)
print("summary_schema         ->", res2)
