***
# Robust Journey Planning

**Link to project presentation:** https://youtube.com

***

## Table of Contents
* [0. Imports](#imports)
    * [0.1 HDFS/Hive](#hive)
    * [0.2 Spark](#spark)
    * [0.3 Geospatial User Defined Functions](#udf)
* [1. Data](#data)
    * [1.1 Timetable](#timetablegeostops)
    * [1.2 Actual Data](#actualdata)
    * [1.3 Geo Shapes](#geoshapes)
    * [1.4 Weather Data](#weather)
* [2. Data Preprocessing](#datapreprocessing)
    * [2.1 Preprocessing Timetable & Geostops](#preprocessingtimetablegeostops)
    * [2.2 Preprocessing Istdaten Data](#preprocessingactualdata)
* [3. Building the Transportation Graph](#transportationgraph)
    * [3.1 test](#test)
* [4. Modelling Delays](#modellingdelays)

## 0. Imports <a class="anchor" id="imports"></a>
In this section we import necessary packages, connect to HDFS / Hive and initialize the Spark environment we will use in the assignment. Finally, we will add support for Geospatial User Defined Funtions.

In [1]:
# INSERT A REGION OBJECTID
OBJECTID = 1

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,User,Current session?
4336,application_1713270977862_4678,pyspark,idle,Link,Link,,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [2]:
# Cluster-based imports
from pyspark.ml.feature import StringIndexer, IndexToString
from pyspark.sql import SparkSession, Row, HiveContext, Window,  functions as F
from pyspark.sql.types import IntegerType, StringType, ArrayType, StructField, StructType

import getpass
import os
import pandas as pd

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
%%local
# Local imports for visualization and data manipulation
from pyarrow.fs import HadoopFileSystem, FileSelector
from pyhive import hive

import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import plotly.express as px
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)

In [4]:
# model part 
from pyspark.ml.feature import VectorAssembler, StandardScaler, StringIndexer
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml import Pipeline
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import functions as F

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### 0.1 HDFS/Hive <a class="anchor" id="hive"></a>


In [5]:
%%local
# Environment variables setup
default_db = 'com490'
hive_server = os.environ.get('HIVE_SERVER', 'iccluster080.iccluster.epfl.ch:10000')
hadoop_fs = os.environ.get('HADOOP_DEFAULT_FS', 'hdfs://iccluster067.iccluster.epfl.ch:8020')
hdfs = HadoopFileSystem.from_uri(hadoop_fs)
username = os.environ.get('USER', 'anonym')
hive_host, hive_port = hive_server.split(':')

# Connect to Hive
conn = hive.connect(host=hive_host, port=int(hive_port), username=username)
cur = conn.cursor()

# Print connection details
print(f"Hadoop HDFS URL: {hadoop_fs}")
print(f"Username: {username}")
print(f"Connected to Hive at: {hive_host}:{hive_port}")

log4j:WARN No appenders could be found for logger (org.apache.hadoop.fs.FileSystem).
log4j:WARN Please initialize the log4j system properly.
log4j:WARN See http://logging.apache.org/log4j/1.2/faq.html#noconfig for more info.


Hadoop HDFS URL: hdfs://iccluster067.iccluster.epfl.ch:8020
Username: schiffer
Connected to Hive at: iccluster080.iccluster.epfl.ch:10000


In [6]:
%%local 

# Create directory for finalproject if it doesn't already exist
base_path = f"/user/{username}/"

# List all files and directories in the base path
selector = FileSelector(base_path, recursive=False)
file_info_list = hdfs.get_file_info(selector)

# Directory to check
target_directory = "finalproject"

# Check if the target directory exists among the listed files/directories
directory_exists = any(info.path.rstrip('/').split('/')[-1] == target_directory for info in file_info_list)

if not directory_exists:
    # Create the directory if it does not exist
    directory_path = f"{base_path}{target_directory}"
    hdfs.create_dir(directory_path, recursive=True)
    print(f"Directory created: {directory_path}")
else:
    print(f"Directory already exists: {base_path}{target_directory}")

Directory already exists: /user/schiffer/finalproject


In [7]:
%%local
# Create a new database
query = f"CREATE DATABASE IF NOT EXISTS {username} LOCATION '/user/{username}/finalproject'"
cur.execute(query)
print(f"Database {username} created or already exists.")

# Switch to the new database
query = f"USE {username}"
cur.execute(query)
print(f"Switched to database: {username}")

Database schiffer created or already exists.
Switched to database: schiffer


In [8]:
%%local
cur.execute(f"SHOW TABLES IN {username}")
cur.fetchall()

[('geo_shapes',),
 ('sbb_lausanne_trip_times',),
 ('sbb_orc_istdaten',),
 ('sbb_orc_stops',),
 ('sbb_stop_times_lausanne_region',),
 ('sbb_stop_to_stop_lausanne_region',),
 ('sbb_stops_lausanne',),
 ('sbb_stops_lausanne_region',)]

In [9]:
%%local

# Make sure to give rw access to Hive and Livy
!hdfs dfs -setfacl -R -m user:hive:rwx /user/${USER}/finalproject
!hdfs dfs -setfacl -R -m default:user:hive:rwx /user/${USER}/finalproject
!hdfs dfs -setfacl -R -m user:livy:rwx /user/${USER}/finalproject
!hdfs dfs -setfacl -R -m default:user:livy:rwx /user/${USER}/finalproject

In [10]:
# Remember, when not using %%local our username is 'livy'
local_username = os.environ.get('USER', getpass.getuser())
local_username

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

'livy'

In [11]:
%%local
username

'schiffer'

### 0.2 Spark <a class="anchor" id="Spark"></a>

In [12]:
# Initializing the spark session and sending %%local {username} to Spark
sparkSession = SparkSession.builder.appName('final-project-{0}'.format(getpass.getuser())).getOrCreate()
sparkSession.getActiveSession()
print(sparkSession.getActiveSession())
print(type(sparkSession))
sc = sparkSession.sparkContext

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

<pyspark.sql.session.SparkSession object at 0x7fcdc09ba220>
<class 'pyspark.sql.session.SparkSession'>

In [13]:
%%send_to_spark -i username -t str -n username

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Successfully passed 'username' as 'username' to Spark kernel

In [14]:
# Check that Spark has access to personal HDFS
spark.sql(f"SHOW TABLES IN {username}").show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+---------+--------------------+-----------+
|namespace|           tableName|isTemporary|
+---------+--------------------+-----------+
| schiffer|       sbb_orc_stops|      false|
| schiffer|  sbb_stops_lausanne|      false|
| schiffer|    sbb_orc_istdaten|      false|
| schiffer|sbb_stops_lausann...|      false|
| schiffer|sbb_stop_to_stop_...|      false|
| schiffer|sbb_stop_times_la...|      false|
| schiffer|sbb_lausanne_trip...|      false|
| schiffer|          geo_shapes|      false|
+---------+--------------------+-----------+

In [15]:
%%send_to_spark -i hadoop_fs -t str -n hadoop_fs

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Successfully passed 'hadoop_fs' as 'hadoop_fs' to Spark kernel

### 0.3 Geospatial User Defined Functions <a class="anchor" id="udf"></a>


In [None]:
spark.sql(f"""
ADD JARS
    {hadoop_fs}/data/jars/esri-geometry-api-2.2.4.jar
    {hadoop_fs}/data/jars/spatial-sdk-hive-2.2.0.jar
    {hadoop_fs}/data/jars/spatial-sdk-json-2.2.0.jar
""")

# Create or replace temporary functions
gis_functions = [
    "ST_Point", "ST_Distance", "ST_SetSRID", "ST_GeodesicLengthWGS84",
    "ST_LineString", "ST_AsBinary", "ST_PointFromWKB", "ST_GeomFromWKB", "ST_Contains"
]

for func in gis_functions:
    spark.sql(f"CREATE OR REPLACE TEMPORARY FUNCTION {func} AS 'com.esri.hadoop.hive.{func}'")

# Get the list of functions
functions_df = spark.sql("SHOW FUNCTIONS")

# Filter and show only functions starting with 'st_'
functions_df.filter(F.col("function").startswith("st_")).show(truncate=False)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## 1. Data <a class="anchor" id="data"></a>

In this section, we will load and preprocess first the timetable and geostops data and finally the Istdaten data.

In [None]:
# Just to check that everything looks normal
print(f"remote USER={os.getenv('USER',None)}")
print(f"local USER={username}")

In [None]:
%%local
print(f"local USER={os.getenv('USER',None)}")

### 1.1 Timetables<a class="anchor" id="timetables"></a>


In [None]:
# Load timetables data from HDFS
stops = spark.read.orc('/data/sbb/orc/timetables/stops/year=2024/month=5/day=16')
stop_times = spark.read.orc('/data/sbb/orc/timetables/stop_times/year=2024/month=5/day=16')
trips = spark.read.orc('/data/sbb/orc/timetables/trips/year=2024/month=5/day=16')
calendar = spark.read.orc('/data/sbb/orc/timetables/calendar/year=2024/month=5/day=16')
routes = spark.read.orc('/data/sbb/orc/timetables/routes/year=2024/month=5/day=16')
transfers = spark.read.orc('/data/sbb/orc/timetables/transfers/year=2024/month=5/day=16')

stops.show(2)
stop_times.show(2)
trips.show(2)
calendar.show(2)
routes.show(2)
transfers.show(2)

In [None]:
transfers.show(20)

### 1.2 Istdaten <a class="anchor" id="istdaten"></a>

In [None]:
istdaten = spark.read.orc('/data/sbb/orc/istdaten')

#istdaten.printSchema()
istdaten.show(2)

### 1.3 Geo Shapes <a class="anchor" id="geoshapes"></a>


In [None]:
spark.sql(f"DROP TABLE IF EXISTS {username}.geo_shapes")
spark.sql(f"""
CREATE EXTERNAL TABLE {username}.geo_shapes(
    objectid INT,
    name     STRING,
    geometry BINARY
)
PARTITIONED BY(country STRING, region STRING)
ROW FORMAT SERDE 'com.esri.hadoop.hive.serde.GeoJsonSerDe' 
STORED AS INPUTFORMAT 'com.esri.json.hadoop.UnenclosedEsriJsonInputFormat'
OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
LOCATION '/data/geo/json/'
""")

In [None]:
spark.sql(f"MSCK REPAIR TABLE {username}.geo_shapes")
spark.sql(f"SELECT * FROM {username}.geo_shapes ORDER BY objectid ASC").show(3)

### 1.4 Optional: Weather Data <a class="anchor" id="weatherdata"></a>

In [None]:
# Add code from hw3
stations = spark.read.csv('/data/wunderground/csv/stations', header=True)
weather_data = spark.read.json('/data/wunderground/json/history')

## 2. Data Preprocessing <a class="anchor" id="datapreprocessing"></a>

### 2.1 Preprocessing Timetable & Geostops <a class="anchor" id="preprocessingtimetablegeostops"></a>


In [None]:
# Filter out stops that are not in the specified object_id
object_id = 2 

stops_region = spark.sql(f"""
SELECT
    a.stop_id,
    a.stop_lat,
    a.stop_lon
FROM {username}.sbb_orc_stops a JOIN {username}.geo_shapes b
ON ST_Contains(b.geometry, ST_Point(a.stop_lon, a.stop_lat))
WHERE b.objectid = {object_id}
""")

stops_region.show(5)

In [None]:
# Register the stops_region dataframe as a temporary view to use in SQL
stops_region.createOrReplaceTempView("stops_region")

# Execute the query
stop_pairs_within_500m = spark.sql("""
SELECT 
    a.stop_id AS stop_id1,
    b.stop_id AS stop_id2,
    a.stop_lat AS stop_lat1,
    a.stop_lon AS stop_lon1,
    b.stop_lat AS stop_lat2,
    b.stop_lon AS stop_lon2,
    ST_GeodesicLengthWGS84(ST_SetSRID(ST_LineString(a.stop_lon, a.stop_lat, b.stop_lon, b.stop_lat), 4326)) AS distance_meters
FROM 
    stops_region a 
CROSS JOIN 
    stops_region b
WHERE 
    a.stop_id != b.stop_id
    AND ST_GeodesicLengthWGS84(ST_SetSRID(ST_LineString(array(ST_Point(a.stop_lon, a.stop_lat), ST_Point(b.stop_lon, b.stop_lat))), 4326)) < 500
""")

# Calculate walking transfer times in minutes
# We assume that 2min mininum are required for transfers within a same location, 
# to which we add 1min per 50m walking time to connect two stops at most 500m appart
stop_pairs_within_500m = stop_pairs_within_500m.withColumn(
    "transfer_time_minutes",
    F.expr("2 + (distance_meters / 50)")
)
# Show results
stop_pairs_within_500m.show(2)

In [None]:
# We choose to focus only on the weekday schedule
calendar.show(2)
print(f"Trips full week: {calendar.count()}")

# Filter for services that are active from Monday to Friday
weekday_calendar = calendar.filter(
    (F.col("monday") == True) &
    (F.col("tuesday") == True) &
    (F.col("wednesday") == True) &
    (F.col("thursday") == True) &
    (F.col("friday") == True)
).select("service_id").distinct()

print(f"Trips weekday: {weekday_calendar.count()}")

In [None]:
# Join trips with weekday calendar
weekday_calendar_trips = trips.join(weekday_calendar, "service_id")

# Add transportation type info (bus, train etc.) from the routes table
weekday_calendar_trips = weekday_calendar_trips.join(routes.select("route_id", "route_desc"), "route_id")

# Use trip_id as a key to join stop_times with weekday_trips.
weekday_calendar_stop_times = stop_times.join(weekday_calendar_trips, "trip_id")

# Join with stops_region to focus only on stops within the specific region
weekday_all_info = weekday_calendar_stop_times.join(stops_region, "stop_id")
columns_to_drop = ['pickup_type', 'drop_off_type']
weekday_all_info = weekday_all_info.drop(*columns_to_drop)
weekday_all_info.show(2)

1. Connections Table:
- This table should capture all possible connections between stops. It should include fields for departure stop, arrival stop, departure time, arrival time, and transport mode (bus, train, etc.), taking into account both direct transit connections and possible walking transfers.
- For walking transfers, you should consider all stops within 500 meters of each other as potentially connected. Calculate walking times based on the distance and a fixed walking speed (50m/1min).

2. Journey Table:
- Create a "journeys" table that aggregates data from stop_times, trips, routes, and calendar. It should include trip information (route, departure times, arrival times, etc.) and filter for typical business days using the calendar table.
- Incorporate service variations and trip cancellations to ensure accuracy.

3. 

1. **Connections Table**:
- This table should capture all possible connections between stops. It should include fields for departure stop, arrival stop, departure time, arrival time, and transport mode (bus, train, etc.), taking into account both direct transit connections and possible walking transfers.
- For walking transfers, you should consider all stops within 500 meters of each other as potentially connected. Calculate walking times based on the distance and a fixed walking speed (50m/1min).

2. **Journey Table**:
- Create a "journeys" table that aggregates data from stop_times, trips, routes, and calendar. It should include trip information (route, departure times, arrival times, etc.) and filter for typical business days using the calendar table.
- Incorporate service variations and trip cancellations to ensure accuracy.

3. **Transfers Table**:
- This table should specifically account for possible transfers between routes, considering both transit and walking transfers.
- You may include an estimated transfer time (including buffer times for realistic transfers between different modes of transport).

4. **Routes and Confidence Table**:
- This is a critical table that will compute potential routes from stop A to stop B considering your criteria (e.g., arrival before a specific time T with a confidence level Q).
- Use statistical methods or historical data to assess the reliability of each route segment to establish confidence levels.

In [None]:
# Connections table
# Register the DataFrames as temporary views
stops.createOrReplaceTempView("stops")
stop_times.createOrReplaceTempView("stop_times")
trips.createOrReplaceTempView("trips")
routes.createOrReplaceTempView("routes")
calendar.createOrReplaceTempView("calendar")

In [None]:
# Run the query and create a DataFrame
transit_connections = spark.sql("""
SELECT 
    st1.stop_id AS departure_stop_id,
    st2.stop_id AS arrival_stop_id,
    st1.departure_time,
    st2.arrival_time,
    rt.route_desc AS transport_mode
FROM 
    stop_times st1
JOIN 
    stop_times st2 ON st1.trip_id = st2.trip_id AND st1.stop_sequence = st2.stop_sequence - 1
JOIN 
    trips tr ON st1.trip_id = tr.trip_id
JOIN 
    routes rt ON tr.route_id = rt.route_id
""")
transit_connections.show(5)

### 2.2 Preprocessing Istdaten Data <a class="anchor" id="preprocessingactualdata"></a>

## 3. Building the Transportation Graph <a class="anchor" id="transportationgraph"></a>
Simplifying assumptions:
1. ...

## 4. Modelling Delays <a class="anchor" id="modellingdelays"></a>
Simplifying assumptions:
1. **Delay Label Creation** : We assume that any delay greater than 30 minutes should not be taken into account. For delays less than 30 minutes, the delay is rounded up to the nearest whole minute using F.ceil, and the resulting value is cast to an integer. The primary assumption here is that delays of 30 minutes are quite rare and there would always be another connection that would enable to take a more efficient route 

#### On Mock Data to test it 

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from pyspark.ml.feature import VectorAssembler, StandardScaler, StringIndexer
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml import Pipeline
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

# Define schema for mock data
schema = StructType([
    StructField("bpuic", IntegerType(), True),
    StructField("ankunftszeit", StringType(), True),
    StructField("linien_id", IntegerType(), True),
    StructField("avg_delay", DoubleType(), True),
    StructField("stddev_delay", DoubleType(), True),
    StructField("temperature", DoubleType(), True),
    StructField("precip_hrly", DoubleType(), True),
    StructField("an_prognose", StringType(), True)
])

# Create mock data
data = [
    (1, "01.01.2024 08:00", 101, 5.0, 2.0, 15.0, 0.0, "01.01.2024 08:05:00"),
    (2, "01.01.2024 09:00", 102, 3.0, 1.5, 18.0, 0.1, "01.01.2024 09:10:00"),
    (3, "01.01.2024 10:00", 103, 2.0, 1.0, 20.0, 0.0, "01.01.2024 10:02:00"),
    (4, "01.01.2024 11:00", 104, 4.0, 2.0, 22.0, 0.2, "01.01.2024 11:15:00"),
    (5, "01.01.2024 12:00", 105, 1.0, 0.5, 25.0, 0.0, "01.01.2024 12:00:30")
]

# Create DataFrame
mock_df = spark.createDataFrame(data, schema=schema)

# Convert time str to timestamp type
mock_df = mock_df.withColumn("scheduled_arrival", F.to_timestamp("ankunftszeit", "dd.MM.yyyy HH:mm"))
mock_df = mock_df.withColumn("actual_arrival", F.to_timestamp("an_prognose", "dd.MM.yyyy HH:mm:ss"))

# Calculate delay in minutes
mock_df = mock_df.withColumn("arrival_delay", (F.unix_timestamp("actual_arrival") - F.unix_timestamp("scheduled_arrival")) / 60)

# Create labels for each minute of delay from 0 to 30, rounding up to the nearest integer
mock_df = mock_df.withColumn("delay_label", F.when(F.col("arrival_delay") >= 30, 30).otherwise(F.ceil(F.col("arrival_delay")).cast("int")))

# Extract hour and minute from scheduled_arrival for feature engineering
mock_df = mock_df.withColumn("arrival_hour", F.hour("scheduled_arrival"))
mock_df = mock_df.withColumn("arrival_minute", F.minute("scheduled_arrival"))

# Define feature columns
feature_cols = ["bpuic", "linien_id", "avg_delay", "stddev_delay", "temperature", "precip_hrly", "arrival_hour", "arrival_minute"]

# Create stages for pipeline
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features", handleInvalid="skip")
scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withStd=True, withMean=False)
classifier = RandomForestClassifier(labelCol="delay_label", featuresCol="scaled_features", numTrees=20)

# Create pipeline
pipeline = Pipeline(stages=[assembler, scaler, classifier])

# Define the parameter grid with an expanded range for numTrees
paramGrid = ParamGridBuilder() \
    .addGrid(classifier.numTrees, [20, 50, 100, 200]) \
    .addGrid(classifier.maxDepth, [5, 10, 20]) \
    .build()

# Configure cross-validation
crossval = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=MulticlassClassificationEvaluator(labelCol="delay_label", predictionCol="prediction", metricName="accuracy"),
                          numFolds=3)

# Split data into train and test sets
(train_data, test_data) = mock_df.randomSplit([0.8, 0.2])

# Run cross-validation, and choose the best set of parameters
cvModel = crossval.fit(train_data)

# Make predictions on test data
predictions = cvModel.transform(test_data)

# Select relevant columns and show predictions
predictions.select("bpuic", "ankunftszeit", "linien_id", "avg_delay", "stddev_delay", "temperature", "precip_hrly", "arrival_delay", "delay_label", "prediction").show()


### Preprocess istdaten

In [None]:
# Preprocess `istdaten` data
istdaten = istdaten.withColumn("scheduled_arrival", F.to_timestamp("ankunftszeit", "dd.MM.yyyy HH:mm"))
istdaten = istdaten.withColumn("actual_arrival", F.to_timestamp("an_prognose", "dd.MM.yyyy HH:mm:ss"))
istdaten = istdaten.withColumn("arrival_delay", (F.unix_timestamp("actual_arrival") - F.unix_timestamp("scheduled_arrival")) / 60)
istdaten = istdaten.filter(F.col("arrival_delay").isNotNull() & (F.col("arrival_delay") <= 30))
istdaten = istdaten.withColumn("delay_label", F.when(F.col("arrival_delay") >= 30, 30).otherwise(F.ceil(F.col("arrival_delay")).cast("int")))

# Join `istdaten` with `weekday_all_info` based on relevant keys
merged_data = istdaten.join(weekday_all_info, (istdaten.bpuic == weekday_all_info.stop_id) & (istdaten.linen_id == weekday_all_info.route_id), "inner")

### Prediction pipeline

In [None]:
# Feature engineering
feature_cols = ["bpuic", "linien_id", "avg_delay", "stddev_delay", "temperature", "precip_hrly", "arrival_hour", "arrival_minute"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features", handleInvalid="skip")
feature_vector = assembler.transform(merged_data)  # merged_data contains all necessary info

# Standardize features
scaler = StandardScaler(inputCol="features", outputCol="scaled_features", withStd=True, withMean=False)

# Use a multi-class classifier
classifier = RandomForestClassifier(labelCol="delay_label", featuresCol="scaled_features", numTrees=20)

# Create pipeline
pipeline = Pipeline(stages=[assembler, scaler, classifier])

# Define the parameter grid
paramGrid = ParamGridBuilder() \
    .addGrid(classifier.numTrees, [20, 50, 100, 200]) \
    .addGrid(classifier.maxDepth, [5, 10, 20]) \
    .build()

# Configure cross-validation
crossval = CrossValidator(estimator=pipeline,
                          estimatorParamMaps=paramGrid,
                          evaluator=MulticlassClassificationEvaluator(labelCol="delay_label", predictionCol="prediction", metricName="accuracy"),
                          numFolds=3)

# Split data into train and test sets
(train_data, test_data) = merged_data.randomSplit([0.8, 0.2])

# Run cross-validation, and choose the best set of parameters
cvModel = crossval.fit(train_data)

model_path = f"/user/{username}/finalproject/model"
# save it 
cvModel.save(model_path)

# Make predictions on test data
predictions = cvModel.transform(test_data)

###  Model Evaluation

In [None]:
# Set up evaluators
evaluatorMulti = MulticlassClassificationEvaluator(labelCol="delay_label", predictionCol="prediction", metricName="areaUnderROC")
auc = evaluatorMulti.evaluate(predictions)
print(f"Area Under ROC: {auc:.2f}")

accuracy = evaluatorMulti.evaluate(predictions, {evaluatorMulti.metricName: "accuracy"})
precision = evaluatorMulti.evaluate(predictions, {evaluatorMulti.metricName: "weightedPrecision"})
recall = evaluatorMulti.evaluate(predictions, {evaluatorMulti.metricName: "weightedRecall"})
f1_score = evaluatorMulti.evaluate(predictions, {evaluatorMulti.metricName: "f1"})

# Print metrics
print(f"Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1_score:.2f}")

## Graveyard

In [None]:
# Create database and switch to it
spark.sql(f"CREATE DATABASE IF NOT EXISTS {username} LOCATION 'user/${USER}/finalproject'")
spark.sql(f"USE {username}")

databases = spark.sql("SHOW DATABASES")
#print(databases.show())

# Print current database to verify
current_db = spark.sql("SELECT current_database()")
print(current_db.show())

In [48]:
# If we want to download a dataframe in the cluster
# into local context:
# %%spark -o VAR_NAME
#print(type(current_db))

#%%spark -o current_db

#%%local
#print(type(current_db))
#current_db
    
# Check the dataframe object locally,
# %%local
# type(twitter_lang)
# twitter_lang


An error was encountered:
Invalid status code '404' from http://iccluster080.iccluster.epfl.ch:28998/sessions/3968 with error payload: {"msg":"Session '3968' not found."}
