In [0]:
# Get data

df = spark.read.format("csv").option("header", "true").option("inferSchema", "true").load("/Volumes/ml_prod/default/churn_data/Churn Modeling.csv")

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS ml_test.churn

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS ml_prod.churn

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS ml_dev.churn

In [0]:
import pyspark.sql.functions as F

schema = "churn"
table = "churn_dataset"


def transform(df_in, sample_pct=None, salt=None):
    df_t = df_in

    if sample_pct is not None:
        df_t = df_t.where(F.pmod(F.xxhash64(F.col("customerId")), F.lit(100)) > F.lit(sample_pct))

    if salt is not None:
        df_t = df_t.withColumn("customerId_hashed",
                               F.sha2(F.col("customerId").cast("string"), 256))
    else:
        df_t = df_t.withColumn(
            "cutsomerId_hasehd",
            F.sha2(F.concat(F.lit(salt), F.col("customerId").cast("string")), 256)
        )
    df_t = df_t.drop("customerId")
    return df_t

salt = None

targets = [
    ("ml_dev", None),
    ("ml_test", 30),
    ("ml_prod", None)
]

for catalog, pct in targets:
    target_table = f"{catalog}.{schema}.{table}"
    df_out = transform(df, sample_pct=pct, salt=salt)

    (df_out.write
     .mode("overwrite")
     .format("delta")
     .saveAsTable(target_table))
    
    print(f"Wrote {target_table}")