In [None]:
from pyspark.sql import SparkSession
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

spark = SparkSession.builder.appName("CRMClassifier-Models").getOrCreate()
df = spark.read.parquet("../data/crm_features.parquet")

train, test = df.randomSplit([0.8, 0.2], seed=42)
models = {
    "LogisticRegression": LogisticRegression(featuresCol="features", labelCol="label", maxIter=20),
    "DecisionTree": DecisionTreeClassifier(featuresCol="features", labelCol="label"),
    "RandomForest": RandomForestClassifier(featuresCol="features", labelCol="label", numTrees=10)
}

evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")

for name, clf in models.items():
    model = clf.fit(train)
    predictions = model.transform(test)
    acc = evaluator.evaluate(predictions)
    print(f"{name} Accuracy: {acc:.4f}")