# SparkML with Dataproc Serverless

## Overview

This notebook tutorial demonstrates the execution of Apache SparkML jobs using Dataproc Serverless. This example machine learning pipeline ingests the [NYC TLC (Taxi and Limousine Commission) Trips](https://console.cloud.google.com/marketplace/product/city-of-new-york/nyc-tlc-trips) dataset from your lakehouse and performs cleaning, feature engineering, model training, and model evaluation to calculate trip duration.

The tutorial uses the following Google Cloud products:
- `Dataproc`
- `BigQuery`
- `Vertex AI Training`
- `BigLake`

## Tutorial

### Set your project ID and location

In [None]:
# Retrieve the current active project and store it as a list of strings.
PROJECT_ID = !gcloud config get-value project

# Extract the project ID from the list.
PROJECT_ID = PROJECT_ID[0] if PROJECT_ID else None

# Retrieve the current location.
LOCATION = !gcloud compute instances list --project={PROJECT_ID} --format='get(ZONE)'
LOCATION = str(LOCATION).split("/")[-1][:-4]

### Get a Cloud Storage bucket URI

In [None]:
# Define the prefix of the bucket created via Terraform.
BUCKET_PREFIX = "gcp-lakehouse-model"

# Retrieve the Cloud Storage bucket URI for storing the machine learning model.
BUCKET_URI = !gcloud storage buckets list --format='value(name)' --filter='name:{BUCKET_PREFIX}*'

# Extract the bucket URI from the list.
BUCKET_URI = BUCKET_URI[0] if BUCKET_URI else None

### Import required libraries

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from geopandas import gpd
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import GBTRegressor
# A Spark Session is how you interact with Spark SQL to create Dataframes
from pyspark.sql import SparkSession
# PySpark functions
from pyspark.sql.functions import col, floor, unix_timestamp

### Initialize the SparkSession

Use the [spark-bigquery-connector](https://github.com/GoogleCloudDataproc/spark-bigquery-connector) to read and write data between Apache Spark and BigQuery.

In [None]:
VER = "0.34.0"
FILE_NAME = f"spark-bigquery-with-dependencies_2.12-{VER}.jar"
connector = f"gs://spark-lib/bigquery/{FILE_NAME}"

# Initialize the SparkSession.
spark = (
    SparkSession.builder.appName("spark-ml-taxi")
    .config("spark.jars", connector)
    .config("spark.logConf", "false")
    .getOrCreate()
)

### Fetch data

Load the table `gcp_primary_staging.new_york_taxi_trips_tlc_yellow_trips_2022`.

In [None]:
# Load NYC_taxi in Github Activity Public Dataset from BigQuery.
taxi_df = (
    spark.read.format("bigquery")
    .option(
        "table",
        f"{PROJECT_ID}.gcp_primary_staging.new_york_taxi_trips_tlc_yellow_trips_2022",
    )
    .load()
)

# Sample parameter. Increase or decrease to experiment with different data sizes.
FRACTION = 0.05

# Sample data to minimize the runtime.
taxi_df = taxi_df.sample(fraction=FRACTION, seed=42)

### Perform Exploratory Data Analysis (EDA)

Perform EDA to uncover more information about your data.

In [None]:
taxi_df.printSchema()

Select and modify necessary columns.

In [None]:
# Choose necessary columns.
COLUMNS_TO_SELECT = [
    "start_time",
    "end_time",
    "passenger_count",
    "trip_distance",
    "trip_duration",
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "imp_surcharge",
    "airport_fee",
    "total_amount",
    "start_zone_id",
    "end_zone_id",
]

# Convert pickup_location_id and dropoff_location_id to integers for a later processing step:
taxi_df = (
  taxi_df.withColumn("start_zone_id", col("pickup_location_id").cast("int"))  # Convert pickup_location_id to integer
  .withColumn("end_zone_id", col("dropoff_location_id").cast("int"))  # Convert dropoff_location_id to integer
)

# Convert datetime from string to Unix timestamp:
taxi_df = (
  taxi_df.withColumn("start_time", unix_timestamp(col("pickup_datetime")))  # Convert pickup_datetime to Unix timestamp
  .withColumn("end_time", unix_timestamp(col("dropoff_datetime")))  # Convert dropoff_datetime to Unix timestamp
)

# Calculate trip_duration.
taxi_df = taxi_df.withColumn("trip_duration", col("end_time") - col("start_time"))

# Select the specified columns:
taxi_df = taxi_df.select(*COLUMNS_TO_SELECT)  # Selects columns based on the list in COLUMNS_TO_SELECT

# Display summary statistics and preview the modified DataFrame.
taxi_df.describe().show()

Build a boxplot to further assess the data.

In [None]:
# Convert Spark DataFrame into a Pandas DataFrame.
taxi_pd = taxi_df.toPandas()

# Define columns to be converted to numerical type in Pandas and be visualized.
PD_COLUMNS = [
    "trip_distance",
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "imp_surcharge",
    "airport_fee",
    "total_amount",
]

# Convert columns of "object" type to the float type.
taxi_pd[PD_COLUMNS] = taxi_pd[PD_COLUMNS].astype(float)

# Box plots and histograms for the specified columns.
for column in taxi_pd.columns:
    if column in PD_COLUMNS:
        _, ax = plt.subplots(1, 2, figsize=(5, 2))
        taxi_pd[column].plot(kind="box", ax=ax[0])
        taxi_pd[column].plot(kind="hist", ax=ax[1])
        plt.title(column)
        plt.figure()
plt.show()

From these summary and boxplots, there are over 1 million trip histories for Yellow Taxi in 2022, which represents approximately 5% of the total trips. 

However, some trip histories have data anomalies. Trips exceeding 10,000 miles are beyond realistic expectations and will be excluded. Additionally, null and negative values in fare, tax, and tolls create inconsistencies and can distort analysis. Filter these values out of the data.

In [None]:
taxi_df = taxi_df.where(
    (col("trip_distance") < 10000)
    & (col("fare_amount") > 0)
    & (col("extra") >= 0)
    & (col("mta_tax") >= 0)
    & (col("tip_amount") >= 0)
    & (col("tolls_amount") >= 0)
    & (col("imp_surcharge") >= 0)
    & (col("airport_fee") >= 0)
    & (col("total_amount") > 0)
).dropna()

### Perform Feature Engineering

While the Taxi dataset contains trips for all NYC boroughs, precise location information is categorized using `NYC Taxi zones`. Use the `bigquery-public-data.new_york_taxi_trips.taxi_zone_geom` public dataset to calculate longitude and latitude values.

In [None]:
# Load the GeoJSON format of NYC Taxi zones from the BigQuery public dataset.
geo_df = (
    spark.read.format("bigquery")
    .option("table", "bigquery-public-data.new_york_taxi_trips.taxi_zone_geom")
    .load()
)

# Convert Spark DataFrame into Pandas DataFrame to integrate with the GeoPandas library.
geo_pd = geo_df.toPandas()

# Create a GeoDataFrame based on the central point of each taxi zone, separated by latitude and longitude.
geo_pd["long"] = gpd.GeoSeries.from_wkt(geo_pd["zone_geom"]).centroid.x
geo_pd["lat"] = gpd.GeoSeries.from_wkt(geo_pd["zone_geom"]).centroid.y

# Drop unnecessary columns.
geo_pd = geo_pd[["zone_id", "long", "lat"]]

# Convert back to a Spark DataFrame.
geo_spark_df = spark.createDataFrame(geo_pd)

# Join taxi_df with geographic position for each start_zone_id and end_zone_id.
taxi_zone_df = (
    taxi_df.join(geo_spark_df, taxi_df.start_zone_id == geo_spark_df.zone_id)
    .withColumnRenamed("long", "start_long")
    .withColumnRenamed("lat", "start_lat")
    .drop("zone_id")
    .join(geo_spark_df, taxi_df.end_zone_id == geo_spark_df.zone_id)
    .withColumnRenamed("long", "end_long")
    .withColumnRenamed("lat", "end_lat")
    .drop("zone_id")
)

# Convert Spark DataFrame into a Pandas DataFrame.
taxi_pd = taxi_df.toPandas()

# Convert columns of "object" type to the float type.
taxi_pd["trip_duration"] = taxi_pd["trip_duration"].astype(float)

# Box plots and histograms for the specified columns.
_, ax = plt.subplots(1, 2, figsize=(10, 4))
taxi_pd["trip_duration"].plot(kind="box", ax=ax[0])
taxi_pd["trip_duration"].plot(kind="hist", ax=ax[1])
plt.title("trip_duration")
plt.figure()
plt.show()

`trip_duration` also has some extreme values. Remove these.

In [None]:
# Filter trips occurring between same taxi zones and exceeding where trip_duration is more than 28800 seconds (8 hours).
taxi_df = taxi_zone_df.where(
    (col("trip_duration") < 28800) & (col("start_zone_id") != col("end_zone_id"))
)

Create the scatterplot to see the relationship between `trip_distance` and `trip_duration`.

In [None]:
# Convert Spark DataFrame into a Pandas DataFrame.
taxi_pd = taxi_df.toPandas()

# Convert "trip_distance" column of "object" type to the float type.
taxi_pd["trip_distance"] = taxi_pd["trip_distance"].astype(float)

# Filter the DataFrame to include data within reasonable ranges.
taxi_pd_filtered = taxi_pd.query(
    "trip_distance > 0 and trip_distance < 20 \
    and trip_duration > 0 and trip_duration < 10000"
)

# Scatter plot to visualize the relationship between trip_distance and trip_duration.
sns.relplot(
    data=taxi_pd_filtered,
    x="trip_distance",
    y="trip_duration",
    kind="scatter",
)

Takeaways here include:
  * the data is right-skewed
  * there is a positive correlation between `trip_distance` and `trip_duration`
  * most trips are completed in under 3600 seconds (one hour)

### Feature Selection

Use `VectorAssembler()` to consolidate feature columns into a vector column.

In [None]:
# List of selected features for training the model.
feature_cols = [
    "passenger_count",
    "trip_distance",
    "start_time",
    "end_time",
    "start_long",
    "start_lat",
    "end_long",
    "end_lat",
    "total_amount",
    "fare_amount",
    "extra",
    "mta_tax",
    "tip_amount",
    "tolls_amount",
    "imp_surcharge",
    "airport_fee",
]

# Create a VectorAssembler with specified input and output columns.
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

# Transform each column into vector form using the VectorAssembler.
taxi_transformed_data = assembler.transform(taxi_df)

# Randomly split the transformed data into training and test sets.
(taxi_training_data, taxi_test_data) = taxi_transformed_data.randomSplit([0.95, 0.05])

### Training the Model

Use `GBTRegressor` model to train the data.

In [None]:
# Define GBTRegressor model with specified input, output, and prediction columns.
gbt = GBTRegressor(
    featuresCol="features",
    labelCol="trip_duration",
    predictionCol="pred_trip_duration",
)

# Define an evaluator for calculating the R2 score.
evaluator_r2 = RegressionEvaluator(
    labelCol=gbt.getLabelCol(), predictionCol=gbt.getPredictionCol(), metricName="r2"
)

# Define an evaluator for calculating the RMSE error.
evaluator_rmse = RegressionEvaluator(
    labelCol=gbt.getLabelCol(), predictionCol=gbt.getPredictionCol(), metricName="rmse"
)

In [None]:
# Train a Gradient Boosted Trees (GBT) model on the Taxi dataset. This process may take several minutes.
taxi_gbt_model = gbt.fit(taxi_training_data)

# Get predictions for the Taxi dataset using the trained GBT model.
taxi_gbt_predictions = taxi_gbt_model.transform(taxi_test_data)

In [None]:
# Evaluate the R2 score for the Taxi dataset predictions.
taxi_gbt_accuracy_r2 = evaluator_r2.evaluate(taxi_gbt_predictions)
print(f"Taxi Test GBT R2 Accuracy = {taxi_gbt_accuracy_r2}")

# Evaluate the Root Mean Squared Error (RMSE) for the Taxi dataset predictions.
taxi_gbt_accuracy_rmse = evaluator_rmse.evaluate(taxi_gbt_predictions)
print(f"Taxi Test GBT RMSE Accuracy = {taxi_gbt_accuracy_rmse}")

### View the result

Expect an R2 score of approximately 83-87% and a Root Mean Square Error(RMSE) of 200-300. This sample does not include [Cross-validation (statistics)](https://en.wikipedia.org/wiki/Cross-validation_%28statistics%29) which can provide improved model performance.

### Save the model to Cloud Storage for future use

To ensure the preservation and accessibility of the trained model, it can be saved to a Cloud Storage path.

In [None]:
# Save the trained model to a Cloud Storage path
taxi_gbt_model.write().overwrite().save(f"gs://{BUCKET_URI}/")

### Delete the Dataproc session template

To delete the running Dataproc Serverless session, run the following command.

In [None]:
!gcloud beta dataproc session-templates delete sparkml-template --location='{LOCATION}' --quiet