In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import input_file_name

from pyspark.sql.functions import isnull
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [None]:
spark = SparkSession.\
        builder.\
        appName("pyspark-notebook").\
        getOrCreate()        
        

In [None]:
# Prepare data
logs = spark.read.parquet("hdfs://namenode:9000//user/data/spark_ml_101/ec_web_logs_analysis/data/")

In [None]:
# Preprocessing and feature engineering
feature_prep = logs.select("product_category_id", "device_type", "connect_type", "age_group") \
                   .where(~isnull("age_group"))

In [None]:
feature_prep.show()

In [None]:
feature_prep = StringIndexer(inputCol="age_group", outputCol="age_group_indexed") \
              .fit(feature_prep) \
              .transform(feature_prep)

In [None]:
final_data = VectorAssembler(inputCols=["product_category_id", "device_type", "connect_type"],
                             outputCol="features").transform(feature_prep)

In [None]:
final_data.show()

In [None]:
# Split data into train and test sets
train_data, test_data = final_data.randomSplit([0.7, 0.3])

In [None]:
# Model training
classifier = DecisionTreeClassifier(featuresCol="features", labelCol="age_group_indexed", maxDepth=10)
model = classifier.fit(train_data)

In [None]:
# Save the model
model.save("hdfs://namenode:9000//user/data/spark_ml_101/ec_web_logs_analysis/models/model_age_group_prediction/")

In [None]:
# Transform the test data using the model to get predictions
predicted_test_data = model.transform(test_data)

In [None]:
# Evaluate the model performance
evaluator_accuracy = MulticlassClassificationEvaluator(labelCol='age_group_indexed',
                                                       predictionCol='prediction',
                                                       metricName='accuracy')
print("Accuracy: {}", evaluator_accuracy.evaluate(predicted_test_data))

In [None]:
spark.stop()