In [0]:
from pyspark.sql import SparkSession
from sklearn.model_selection import train_test_split
import pandas as pd

# 1️⃣ Initialize Spark session
spark = SparkSession.builder.appName("TelecomChurnSplit").getOrCreate()

# 2️⃣ Read data from the Delta table inside your schema
df_spark = spark.table("kusha_solutions.telecom_churn_ml.telco_customer_churn")

# Convert Spark DF → Pandas for sklearn split
df = df_spark.toPandas()

# 3️⃣ Separate features (X) and target (y)
X = df.drop("Churn", axis=1)
y = df["Churn"]

# 4️⃣ Split into training and testing data (80% / 20%)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

# 5️⃣ Combine features + target
train_df = X_train.copy()
train_df["Churn"] = y_train

test_df = X_test.copy()
test_df["Churn"] = y_test

print("✅ Split complete:")
print("Train size:", train_df.shape)
print("Test size:", test_df.shape)
print("\nTraining churn distribution:")
print(train_df['Churn'].value_counts(normalize=True))

# 6️⃣ Convert back to Spark DataFrames
train_spark = spark.createDataFrame(train_df)
test_spark = spark.createDataFrame(test_df)

# 7️⃣ Save as Delta tables inside same schema
train_spark.write.format("delta").mode("overwrite").saveAsTable("kusha_solutions.telecom_churn_ml.telecom_train")
test_spark.write.format("delta").mode("overwrite").saveAsTable("kusha_solutions.telecom_churn_ml.telecom_test")

print("✅ Train and Test tables saved successfully in schema `kusha_solutions.telecom_churn_ml`")

# 8️⃣ Optional: Verify
spark.sql("SHOW TABLES IN kusha_solutions.telecom_churn_ml").show()
spark.sql("SELECT COUNT(*) FROM kusha_solutions.telecom_churn_ml.telecom_train").show()
spark.sql("SELECT COUNT(*) FROM kusha_solutions.telecom_churn_ml.telecom_test").show()
