In [None]:
# Import necessary libraries and initialize SparkSession
from pyspark.sql import SparkSession
from pyspark.sql.functions import explode, col, broadcast, avg, lit, count
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.functions import rand
import pandas as pd
import matplotlib.pyplot as plt
import time
import json
import folium

%run "/usr/local/spark/notebooks/00-spark-connection.ipynb"

# Set shuffle partitions to match the number of cores for better performance
spark.conf.set("spark.sql.shuffle.partitions", "8")  # Adjust as needed

In [None]:
def load_parquet_data(file_path):
    start_time = time.time()
    df = spark.read.parquet(file_path)
    print(f"Loaded Parquet data in {time.time() - start_time:.2f} seconds")
    return df

def load_json_data(file_path):
    start_time = time.time()
    with open(file_path, "r") as file:
        json_data = json.load(file)
    print(f"Loaded JSON data in {time.time() - start_time:.2f} seconds")
    return json_data

def create_dataframe_from_json(json_data, schema, row_mapper):
    rows = row_mapper(json_data)
    df = spark.createDataFrame(rows, schema)
    df.show(5)
    return df

def join_dataframes(df1, df2, join_condition, select_columns):
    start_time = time.time()
    joined_df = df1.join(df2, join_condition, "left").select(*select_columns)
    joined_df.show(5)
    print(f"Joined data in {time.time() - start_time:.2f} seconds")
    return joined_df

# Load enriched departures from Parquet file
df = load_parquet_data("data/enriched_01.parquet")

# Load line data
line_json_data = load_json_data("data/relevantLines_with_stops.json")

# Define the schema for the line DataFrame
line_schema = StructType([
    StructField("line_id", StringType(), True),
    StructField("name", StringType(), True),
    StructField("product", StringType(), True)
])

# Create the line DataFrame
lines_df = create_dataframe_from_json(line_json_data, line_schema, lambda data: [(line_id, line_info['name'], line_info['product']) for line_id, line_info in data.items()])

# Join the enriched departures with the line information
joined_df = join_dataframes(df, lines_df, df.lineId == lines_df.line_id, [df["*"], lines_df.name.alias("line_name"), lines_df.product.alias("line_product")])

# Load stops data
stops_json_data = load_json_data("data/stops.json")

# Define the schema for the stops DataFrame
stops_schema = StructType([
    StructField("stop_id", StringType(), True),
    StructField("stop_name", StringType(), True),
    StructField("latitude", StringType(), True),
    StructField("longitude", StringType(), True)
])

# Create the stops DataFrame
stops_df = create_dataframe_from_json(stops_json_data['stops'], stops_schema, lambda data: [(stop['id'], stop['name'], stop['latitude'], stop['longitude']) for stop in data])

# Join the enriched departures with the line and stop information
final_df = join_dataframes(joined_df, stops_df, joined_df.stopId == stops_df.stop_id, [joined_df["*"], stops_df.stop_name.alias("stop_name"), stops_df.latitude.alias("stop_latitude"), stops_df.longitude.alias("stop_longitude")])

In [None]:
# Get a random tripId
random_trip = df.select("tripId").distinct().orderBy(rand()).limit(1).collect()[0]["tripId"]

# Filter the dataframe for the selected trip and order by plannedWhen
trip_stops = final_df.filter(col("tripId") == random_trip) \
                     .select("stop_name", "stop_latitude", "stop_longitude", "plannedWhen") \
                     .orderBy("plannedWhen")

print(f"Stops for trip {random_trip}:")
trip_stops.show(truncate=False, n=trip_stops.count())

# Collect the stop data
stops_data = trip_stops.collect()

# Create a Folium map centered around the average coordinates of the stops
import folium
from IPython.display import display

if stops_data:
    avg_latitude = sum([float(stop["stop_latitude"]) for stop in stops_data]) / len(stops_data)
    avg_longitude = sum([float(stop["stop_longitude"]) for stop in stops_data]) / len(stops_data)
    m = folium.Map(location=[avg_latitude, avg_longitude], zoom_start=13)

    # Create a feature group for the stops
    stops_group = folium.FeatureGroup(name='Stops')

    # Add a marker for each stop
    for stop in stops_data:
        folium.Marker(
            location=[float(stop["stop_latitude"]), float(stop["stop_longitude"])],
            popup=f"{stop['stop_name']} ({stop['plannedWhen']})"
        ).add_to(stops_group)

    # Add the stops group to the map
    stops_group.add_to(m)

    # Create a list of coordinates for the stops
    coords = [[float(stop["stop_latitude"]), float(stop["stop_longitude"])] for stop in stops_data]

    # Add a line to the map connecting the stops
    folium.PolyLine(coords, color='blue').add_to(m)

    # Display the map
    display(m)
else:
    print("No stops found for the selected trip.")