In [1]:
from pyspark.sql import SparkSession
from IPython.display import display, HTML
from pyspark.sql.functions import col, lpad

import os
os.environ["PYSPARK_SUBMIT_ARGS"] = "--driver-memory 8g pyspark-shell"

spark = SparkSession.builder \
    .appName("VSCodeTest") \
    .master("local[*]") \
    .getOrCreate()

print("✅ Spark session running in VS Code!")
print("Spark version:", spark.version)

✅ Spark session running in VS Code!
Spark version: 4.0.0


In [2]:
from pyspark.sql.functions import col, lpad

# Load the datasets
food_facts_raw = (spark.read \
    .option("header", "true") \
    .parquet("C:/Users/jverc/Downloads/food_fact_dataset.parquet")
)


In [3]:
food_facts_raw.createOrReplaceTempView("food_facts_raw")

In [4]:
spark.table("food_facts_raw").describe().show()

+-------+--------------------+-----------------+-------+--------------------+--------------------+---------+---------+--------------------+--------------------+------------------+------------------+------------------+--------------------+--------------------+-------------------+------------------+---------------+------------------+--------------+--------------------+--------------------+--------------------+-------------------+--------------------+
|summary|        product_name|product_name_lang| brands|             barcode|          categories|countries|image_url|    ingredients_text|    energy_kcal_100g|          fat_100g|saturated_fat_100g|carbohydrates_100g|         sugars_100g|          fiber_100g|      proteins_100g|         salt_100g|nutrition_grade|        nova_group|ecoscore_grade|              labels|           allergens|           packaging|          created_t|     last_modified_t|
+-------+--------------------+-----------------+-------+--------------------+-----------------

In [5]:
# List of nutrition columns to check
cols_to_check = [
    "energy_kcal_100g", "fat_100g", "saturated_fat_100g",
    "carbohydrates_100g", "sugars_100g", "fiber_100g",
    "proteins_100g", "salt_100g"
]

# Build the condition: non-null and non-zero for all columns
non_zero_condition = " AND ".join(
    [f"({col} IS NOT NULL AND {col} != 0)" for col in cols_to_check]
)

# Apply the filter
food_facts = food_facts_raw.filter(non_zero_condition)

# Preview the result
food_facts.select(cols_to_check).show()

+----------------+------------------+------------------+------------------+------------------+------------------+------------------+--------------------+
|energy_kcal_100g|          fat_100g|saturated_fat_100g|carbohydrates_100g|       sugars_100g|        fiber_100g|     proteins_100g|           salt_100g|
+----------------+------------------+------------------+------------------+------------------+------------------+------------------+--------------------+
|           485.0|             26.25|              9.75|             58.75|               6.5| 4.199999809265137|              8.75|0.029999999329447746|
|           960.0|             112.0|  73.5999984741211|             240.0|             198.0|               8.0|              32.0|   1.600000023841858|
|           915.0|             125.0| 67.19999694824219|             229.0|             224.0| 9.600000381469727|28.799999237060547|  0.5690000057220459|
|           461.0|21.399999618530273|1.7999999523162842| 55.20000076293945|5

In [6]:
print("Rows before filtering:", food_facts_raw.count())
print("Rows after filtering:", food_facts.count())

Rows before filtering: 1262295
Rows after filtering: 342911


In [7]:
food_facts.createOrReplaceTempView("food_facts")

In [8]:
spark.table("food_facts").describe().show()

+-------+--------------------+-----------------+-----------------+--------------------+--------------------+---------+---------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+--------------------+-------------------+--------------------+---------------+------------------+--------------+--------------------+----------------+--------------------+-------------------+--------------------+
|summary|        product_name|product_name_lang|           brands|             barcode|          categories|countries|image_url|    ingredients_text|    energy_kcal_100g|            fat_100g|  saturated_fat_100g|  carbohydrates_100g|         sugars_100g|          fiber_100g|      proteins_100g|           salt_100g|nutrition_grade|        nova_group|ecoscore_grade|              labels|       allergens|           packaging|          created_t|     last_modified_t|
+-------+--------------------+-----------------+-----------------+

In [9]:
spark.table("food_facts").show(5)

+--------------------+-----------------+--------------------+-------------+--------------------+----------------+---------+--------------------+----------------+------------------+------------------+------------------+------------------+------------------+------------------+--------------------+---------------+----------+--------------+------+--------------------+---------+-------------+---------------+
|        product_name|product_name_lang|              brands|      barcode|          categories|       countries|image_url|    ingredients_text|energy_kcal_100g|          fat_100g|saturated_fat_100g|carbohydrates_100g|       sugars_100g|        fiber_100g|     proteins_100g|           salt_100g|nutrition_grade|nova_group|ecoscore_grade|labels|           allergens|packaging|    created_t|last_modified_t|
+--------------------+-----------------+--------------------+-------------+--------------------+----------------+---------+--------------------+----------------+------------------+------

In [10]:
df_display = spark.sql("""
                       
WITH food_facts AS (
SELECT 
  CAST(product_name AS STRING) AS product_name,
  CAST(product_name_lang AS STRING) AS language,
  CAST(brands AS STRING) AS brand,
  CAST(countries AS STRING) AS country,
  ROUND(CAST(energy_kcal_100g AS DOUBLE), 4) AS energy_kcal_100g,
  ROUND(CAST(fat_100g AS DOUBLE), 4) AS fat_100g,
  ROUND(CAST(saturated_fat_100g AS DOUBLE), 4) AS saturated_fat_100g,
  ROUND(CAST(carbohydrates_100g AS DOUBLE), 4) AS carbohydrates_100g,
  ROUND(CAST(sugars_100g AS DOUBLE), 4) AS sugars_100g,
  ROUND(CAST(fiber_100g AS DOUBLE), 4) AS fiber_100g,
  ROUND(CAST(proteins_100g AS DOUBLE), 4) AS proteins_100g,
  ROUND(CAST(salt_100g AS DOUBLE), 4) AS salt_100g,
  CAST(nutrition_grade AS STRING) AS nutrition_grade
FROM food_facts
)
SELECT 
  *
FROM food_facts
                       
WHERE
  energy_kcal_100g > 0 AND
  fat_100g > 0 AND
  saturated_fat_100g > 0 AND
  carbohydrates_100g > 0 AND
  sugars_100g > 0 AND
  fiber_100g > 0 AND
  proteins_100g > 0 AND
  salt_100g > 0 AND
  product_name IS NOT NULL
                      
""")

df_display.show()

df_display.createOrReplaceTempView("cleaned_food_facts")
spark.catalog.cacheTable("cleaned_food_facts")

+--------------------+--------+--------------------+----------------+----------------+--------+------------------+------------------+-----------+----------+-------------+---------+---------------+
|        product_name|language|               brand|         country|energy_kcal_100g|fat_100g|saturated_fat_100g|carbohydrates_100g|sugars_100g|fiber_100g|proteins_100g|salt_100g|nutrition_grade|
+--------------------+--------+--------------------+----------------+----------------+--------+------------------+------------------+-----------+----------+-------------+---------+---------------+
|Piasten, Chocolat...|      en|Goode's Bakery & ...|en:united-states|           485.0|   26.25|              9.75|             58.75|        6.5|       4.2|         8.75|     0.03|              d|
|Knusperflakes Mit...|      en|Ritter Sport,  Al...|en:united-states|           960.0|   112.0|              73.6|             240.0|      198.0|       8.0|         32.0|      1.6|              e|
|Milk Chocolate

In [11]:
spark.table("cleaned_food_facts").describe().show()

+-------+--------------------+--------+------------------+---------+------------------+------------------+------------------+------------------+--------------------+------------------+------------------+------------------+---------------+
|summary|        product_name|language|             brand|  country|  energy_kcal_100g|          fat_100g|saturated_fat_100g|carbohydrates_100g|         sugars_100g|        fiber_100g|     proteins_100g|         salt_100g|nutrition_grade|
+-------+--------------------+--------+------------------+---------+------------------+------------------+------------------+------------------+--------------------+------------------+------------------+------------------+---------------+
|  count|              336696|  336692|            289386|   336539|            336696|            336696|            336696|            336696|              336696|            336696|            336696|            336696|         336696|
|   mean| 3.26146399532588E12|    NULL|2238.

In [12]:
cleaned_food_facts = spark.sql("SELECT * FROM cleaned_food_facts")

In [13]:
from pyspark.sql.functions import min, max

cleaned_food_facts.groupBy("language").agg(
    min("energy_kcal_100g").alias("min_energy"),
    max("energy_kcal_100g").alias("max_energy"),
    min("fat_100g").alias("min_fat"),
    max("fat_100g").alias("max_fat"),
    min("saturated_fat_100g").alias("min_sat_fat"),
    max("saturated_fat_100g").alias("max_sat_fat"),
    min("carbohydrates_100g").alias("min_carbs"),
    max("carbohydrates_100g").alias("max_carbs"),
    min("sugars_100g").alias("min_sugars"),
    max("sugars_100g").alias("max_sugars"),
    min("fiber_100g").alias("min_fiber"),
    max("fiber_100g").alias("max_fiber"),
    min("proteins_100g").alias("min_protein"),
    max("proteins_100g").alias("max_protein"),
    min("salt_100g").alias("min_salt"),
    max("salt_100g").alias("max_salt")
).orderBy("language").show(100, truncate=False)


+--------+----------+----------+-------+-------+-----------+-----------+---------+---------+----------+--------------------+---------+---------+-----------+-----------+--------+--------+
|language|min_energy|max_energy|min_fat|max_fat|min_sat_fat|max_sat_fat|min_carbs|max_carbs|min_sugars|max_sugars          |min_fiber|max_fiber|min_protein|max_protein|min_salt|max_salt|
+--------+----------+----------+-------+-------+-----------+-----------+---------+---------+----------+--------------------+---------+---------+-----------+-----------+--------+--------+
|NULL    |27.0      |554.0     |1.0    |34.0   |0.2        |23.7       |3.5      |60.9     |0.5       |60.2                |1.0      |17.0     |0.6        |16.0       |0.2     |0.9     |
|ES      |278.0     |278.0     |5.9    |5.9    |0.8        |0.8        |42.0     |42.0     |3.2       |3.2                 |6.0      |6.0      |11.0       |11.0       |0.98    |0.98    |
|aa      |61.0      |503.0     |1.5    |34.0   |0.2        |8.5  

## Removing Outliers

In [14]:
cleaned_food_facts = cleaned_food_facts.filter(
    (col("energy_kcal_100g") > 20) &             # Remove unrealistically low energy values.
    (col("energy_kcal_100g") <= 900) &           # Realistic upper bound for energy (pure fat = 900 kcal/100g).
    
    (col("fat_100g") <= 60) &                    # Cap fat to exclude extremely dense products.
    (col("saturated_fat_100g") <= 30) &          # Saturated fat rarely exceeds this even in cheese/chocolate.
    
    (col("carbohydrates_100g") >= 5) &           # Avoid division-by-near-zero in sugar_ratio.
    (col("carbohydrates_100g") <= 100) &         # Cap based on observed high-carb, high-sugar entries.
    
    (col("sugars_100g") <= 70) &                 # Trim a few extremely sugary items above 80g.
    (col("fiber_100g") <= 30) &                  # Fiber rarely exceeds this even in legumes/fiber bars.
    
    (col("proteins_100g") <= 70) &               # Cap protein; 70g covers supplements and high-protein foods.
    (col("salt_100g") <= 4.5)  &                 # Slightly tighter than 5g to trim peak processed items.

    (col("brand").rlike("^[\\x00-\\x7F]+$")) &   # Ensure brand names are ASCII only.
    (col("product_name").rlike("^[\\x00-\\x7F]+$")) # Ensure product names are ASCII only.
)

print("Row count after filtering:", cleaned_food_facts.count())

Row count after filtering: 176690


In [15]:
from pyspark.sql.functions import min, max

cleaned_food_facts.groupBy("language").agg(
    min("energy_kcal_100g").alias("min_energy"),
    max("energy_kcal_100g").alias("max_energy"),
    min("fat_100g").alias("min_fat"),
    max("fat_100g").alias("max_fat"),
    min("saturated_fat_100g").alias("min_sat_fat"),
    max("saturated_fat_100g").alias("max_sat_fat"),
    min("carbohydrates_100g").alias("min_carbs"),
    max("carbohydrates_100g").alias("max_carbs"),
    min("sugars_100g").alias("min_sugars"),
    max("sugars_100g").alias("max_sugars"),
    min("fiber_100g").alias("min_fiber"),
    max("fiber_100g").alias("max_fiber"),
    min("proteins_100g").alias("min_protein"),
    max("proteins_100g").alias("max_protein"),
    min("salt_100g").alias("min_salt"),
    max("salt_100g").alias("max_salt")
).orderBy("language").show(100, truncate=False)


+--------+----------+----------+-------+-------+-----------+-----------+---------+---------+----------+----------+---------+---------+-----------+-----------+--------+--------+
|language|min_energy|max_energy|min_fat|max_fat|min_sat_fat|max_sat_fat|min_carbs|max_carbs|min_sugars|max_sugars|min_fiber|max_fiber|min_protein|max_protein|min_salt|max_salt|
+--------+----------+----------+-------+-------+-----------+-----------+---------+---------+----------+----------+---------+---------+-----------+-----------+--------+--------+
|NULL    |233.0     |554.0     |5.2    |34.0   |0.3        |23.7       |22.0     |60.9     |0.5       |60.2      |1.6      |17.0     |3.2        |16.0       |0.2     |0.9     |
|aa      |61.0      |503.0     |1.7    |26.0   |0.2        |6.5        |11.0     |69.0     |5.0       |35.0      |0.5      |11.35    |0.5        |9.61       |0.1     |1.2     |
|af      |59.0      |516.0     |1.6    |29.0   |0.3        |16.1       |5.6      |71.0     |0.49      |51.6      |0

## Feature Engineering

In [16]:
from pyspark.sql.functions import col, when

engineered_food_facts = (
    cleaned_food_facts

    # Binary label: Healthy = A or B nutrition grade
    .withColumn("is_healthy", when(col("nutrition_grade").isin("a", "b"), 1).otherwise(0))

    # Sugar ratio (sugars / carbs)
    .withColumn("sugar_ratio", col("sugars_100g") / col("carbohydrates_100g"))

    # Fat to energy ratio (fat * 9 kcal/g divided by energy)
    .withColumn("fat_to_energy_ratio", (col("fat_100g") * 9) / col("energy_kcal_100g"))

    # Sugar to fiber ratio (with smoothing)
    .withColumn("sugar_to_fiber_ratio", col("sugars_100g") / (col("fiber_100g") + 1))

    # High sodium flag: salt > 1.5g/100g
    .withColumn("high_sodium_flag", when(col("salt_100g") > 1.5, 1).otherwise(0))

    # Protein density (protein / energy)
    .withColumn("protein_density", col("proteins_100g") / col("energy_kcal_100g"))
)

# Register as a temp table for SQL
engineered_food_facts.createOrReplaceTempView("engineered_food_facts")

In [17]:
from pyspark.sql.functions import min, max, sum
engineered_food_facts.select(

    min("is_healthy").alias("min_is_healthy"),
    max("is_healthy").alias("max_is_healthy"),
    sum("is_healthy").alias("count_healthy"),

    min("high_sodium_flag").alias("min_high_sodium_flag"),
    max("high_sodium_flag").alias("max_high_sodium_flag"),
    sum("high_sodium_flag").alias("count_high_sodium"),

    min("sugar_ratio").alias("min_sugar_ratio"),
    max("sugar_ratio").alias("max_sugar_ratio"),

    min("fat_to_energy_ratio").alias("min_fat_to_energy_ratio"),
    max("fat_to_energy_ratio").alias("max_fat_to_energy_ratio"),

    min("sugar_to_fiber_ratio").alias("min_sugar_to_fiber_ratio"),
    max("sugar_to_fiber_ratio").alias("max_sugar_to_fiber_ratio"),

    min("protein_density").alias("min_protein_density"),
    max("protein_density").alias("max_protein_density")
).show(truncate=False)

+--------------+--------------+-------------+--------------------+--------------------+-----------------+--------------------+-----------------+-----------------------+-----------------------+------------------------+------------------------+--------------------+-------------------+
|min_is_healthy|max_is_healthy|count_healthy|min_high_sodium_flag|max_high_sodium_flag|count_high_sodium|min_sugar_ratio     |max_sugar_ratio  |min_fat_to_energy_ratio|max_fat_to_energy_ratio|min_sugar_to_fiber_ratio|max_sugar_to_fiber_ratio|min_protein_density |max_protein_density|
+--------------+--------------+-------------+--------------------+--------------------+-----------------+--------------------+-----------------+-----------------------+-----------------------+------------------------+------------------------+--------------------+-------------------+
|0             |1             |36556        |0                   |1                   |24598            |2.079002079002079E-6|4.836065573770492|3.33

In [18]:
healthy = engineered_food_facts.filter("is_healthy = 1").count()
unhealthy = engineered_food_facts.filter("is_healthy = 0").count()

print(f"Healthy: {healthy}")
print(f"Unhealthy: {unhealthy}")

Healthy: 36556
Unhealthy: 140134


In [19]:
engineered_food_facts.filter("is_healthy = 1").show()

+--------------------+--------+--------------------+--------------------+----------------+--------+------------------+------------------+-----------+----------+-------------+---------+---------------+----------+--------------------+--------------------+--------------------+----------------+--------------------+
|        product_name|language|               brand|             country|energy_kcal_100g|fat_100g|saturated_fat_100g|carbohydrates_100g|sugars_100g|fiber_100g|proteins_100g|salt_100g|nutrition_grade|is_healthy|         sugar_ratio| fat_to_energy_ratio|sugar_to_fiber_ratio|high_sodium_flag|     protein_density|
+--------------------+--------+--------------------+--------------------+----------------+--------+------------------+------------------+-----------+----------+-------------+---------+---------------+----------+--------------------+--------------------+--------------------+----------------+--------------------+
|Organic Tortellin...|      en|         Nuovo Pasta|    en:un

In [20]:
# engineered_food_facts = engineered_food_facts.drop("x", "y", "z")

In [21]:
engineered_food_facts.describe().show()

+-------+--------------------+--------+--------------------+---------+------------------+------------------+------------------+------------------+------------------+------------------+-----------------+------------------+---------------+-------------------+--------------------+--------------------+--------------------+------------------+--------------------+
|summary|        product_name|language|               brand|  country|  energy_kcal_100g|          fat_100g|saturated_fat_100g|carbohydrates_100g|       sugars_100g|        fiber_100g|    proteins_100g|         salt_100g|nutrition_grade|         is_healthy|         sugar_ratio| fat_to_energy_ratio|sugar_to_fiber_ratio|  high_sodium_flag|     protein_density|
+-------+--------------------+--------+--------------------+---------+------------------+------------------+------------------+------------------+------------------+------------------+-----------------+------------------+---------------+-------------------+--------------------+

In [24]:
import os

# Convert Spark DataFrame to Pandas
pandas_df = engineered_food_facts.toPandas()

# Write to CSV on your Desktop
username = os.getlogin()
output_path = f"C:/Users/{username}/Desktop/engineered_food_facts.csv"
pandas_df.to_csv(output_path, index=False)

## OLD MARKDOWN CODE FROM LAST ASSIGNMENT BELOW

In [None]:
# VectorAssembler Assembles all of these columns into one single vector. To do this, set the input columns and output column. Then that assembler will be used to transform the prepped data to the final dataset.
from pyspark.ml.feature import StringIndexer

# Index categorical features
cut_indexer = StringIndexer(inputCol="cut", outputCol="cut_index")
color_indexer = StringIndexer(inputCol="color", outputCol="color_index")
clarity_indexer = StringIndexer(inputCol="clarity", outputCol="clarity_index")

# Transform the DataFrame with the indexers
indexed = cut_indexer.fit(engineered_diamonds).transform(engineered_diamonds)
indexed = color_indexer.fit(indexed).transform(indexed)
indexed = clarity_indexer.fit(indexed).transform(indexed)

# Save final DataFrame with indexed columns
mapping = indexed

indexed = indexed.drop("cut", "color", "clarity")
# indexed = indexed.withColumnRenamed("log_price", "price")
indexed.show()

NameError: name 'engineered_diamonds' is not defined

In [None]:
from pyspark.ml.feature import VectorAssembler

nonFeatureCols = [
    "price"
]
featureCols = [col for col in indexed.columns if col not in nonFeatureCols]

assembler = (VectorAssembler()
  .setInputCols(featureCols)
  .setOutputCol("features"))

finalPrep = assembler.transform(indexed)

In [None]:
training, test = finalPrep.randomSplit([0.7, 0.3])

#  Going to cache the data to make sure things stay snappy!
training.cache()
test.cache()

print(training.count()) # Why execute count here??
print(test.count())

37803
16117


In [None]:
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import RegressionEvaluator

from pyspark.ml import Pipeline

rfModel = (RandomForestRegressor()
  .setLabelCol("price")
  .setFeaturesCol("features"))

paramGrid = (ParamGridBuilder()
  .addGrid(rfModel.maxDepth, [5, 10])
  .addGrid(rfModel.numTrees, [20, 60])
  .build())
# Note, that this parameter grid will take a long time
# to run in the community edition due to limited number
# of workers available! Be patient for it to run!
# If you want it to run faster, remove some of
# the above parameters and it'll speed right up!

stages = [rfModel]

pipeline = Pipeline().setStages(stages)

cv = (CrossValidator() # you can feel free to change the number of folds used in cross validation as well
  .setEstimator(pipeline) # the estimator can also just be an individual model rather than a pipeline
  .setEstimatorParamMaps(paramGrid)
  .setEvaluator(RegressionEvaluator().setLabelCol("price")))

pipelineFitted = cv.fit(training)

In [None]:
print("The Best Parameters:\n--------------------")
print(pipelineFitted.bestModel.stages[0])
pipelineFitted.bestModel.stages[0].extractParamMap()

The Best Parameters:
--------------------
RandomForestRegressionModel: uid=RandomForestRegressor_c262131eff68, numTrees=20, numFeatures=8


{Param(parent='RandomForestRegressor_c262131eff68', name='bootstrap', doc='Whether bootstrap samples are used when building trees.'): True,
 Param(parent='RandomForestRegressor_c262131eff68', name='cacheNodeIds', doc='If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees. Users can set how often should the cache be checkpointed or disable it by setting checkpointInterval.'): False,
 Param(parent='RandomForestRegressor_c262131eff68', name='checkpointInterval', doc='set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. Note: this setting will be ignored if the checkpoint directory is not set in the SparkContext.'): 10,
 Param(parent='RandomForestRegressor_c262131eff68', name='featureSubsetStrategy', doc="The number of features to consider for splits at each tree node. Supporte

In [None]:
pipelineFitted.bestModel

PipelineModel_f7feee4cc687

In [None]:
holdout2 = (pipelineFitted.bestModel
  .transform(test)
  .selectExpr("prediction as raw_prediction", 
              "double(round(prediction)) as prediction",
              "price ", 
              "cut_index",
              "round(abs((prediction - price) / price) * 100, 2) AS pct_off",
              "round(100 - (abs((prediction - price) / price) * 100), 2) AS pct_accuracy"
  ))
  
holdout2.show()

+------------------+----------+-----+---------+-------+------------+
|    raw_prediction|prediction|price|cut_index|pct_off|pct_accuracy|
+------------------+----------+-----+---------+-------+------------+
|509.89059955622287|     510.0|367.0|      0.0|  38.93|       61.07|
| 552.4290700127754|     552.0|367.0|      0.0|  50.53|       49.47|
| 542.0659199708288|     542.0|404.0|      1.0|  34.17|       65.83|
| 516.0763758267157|     516.0|470.0|      1.0|    9.8|        90.2|
|469.13064742024716|     469.0|327.0|      3.0|  43.47|       56.53|
| 563.3670676379622|     563.0|438.0|      2.0|  28.62|       71.38|
| 644.6659498107296|     645.0|550.0|      1.0|  17.21|       82.79|
| 549.3501843743887|     549.0|530.0|      2.0|   3.65|       96.35|
| 464.4631461342551|     464.0|373.0|      2.0|  24.52|       75.48|
| 483.4722606040126|     483.0|402.0|      2.0|  20.27|       79.73|
| 568.6253397117944|     569.0|530.0|      2.0|   7.29|       92.71|
| 570.3690150620312|     570.0|505

In [None]:
from pyspark.ml.evaluation import RegressionEvaluator

# Evaluate Mean Squared Error
evaluator = RegressionEvaluator(
    labelCol="price", 
    predictionCol="raw_prediction", 
    metricName="r2"
)
mse = evaluator.evaluate(holdout2)

mae = evaluator.setMetricName("mae").evaluate(holdout2)
rmse = evaluator.setMetricName("rmse").evaluate(holdout2)
r2 = evaluator.setMetricName("r2").evaluate(holdout2)

print(f"MSE: {mse}")
print(f"MAE: {mae}")
print(f"RMSE: {rmse}")
print(f"R^2: {r2}")


MSE: 0.9720007234928092
MAE: 352.8628793743668
RMSE: 665.9413324221837
R^2: 0.9720007234928092


In [None]:
from pyspark.sql.functions import col, abs

# Mean Absolute Percentage Error
mape_df = holdout2.withColumn(
    "pct_error", abs(col("raw_prediction") - col("price")) / col("price")
)

mape = mape_df.selectExpr("avg(pct_error)").first()[0]
print(f"MAPE: {mape * 100:.2f}%")

MAPE: 9.95%


In [None]:
holdout2.selectExpr("avg(pct_accuracy) as avg_prediction_accuracy").show()

+-----------------------+
|avg_prediction_accuracy|
+-----------------------+
|      90.04846125209383|
+-----------------------+



In [None]:
from pyspark.sql.functions import avg

cut_mapping = mapping.select("cut", "cut_index").distinct()
holdout2_with_cut = holdout2.join(cut_mapping, on="cut_index", how="left")

# Group by cut index and calculate average accuracy
holdout2_with_cut.groupBy("cut").agg(avg("pct_accuracy").alias("avg_accuracy")).show()

+---------+-----------------+
|      cut|     avg_accuracy|
+---------+-----------------+
|  Premium|90.00256478566219|
|    Ideal|89.89090993214067|
|     Good|89.83024561403501|
|     Fair|84.95818763326223|
|Very Good|91.13138781163435|
+---------+-----------------+



In [None]:
# spark.stop(