In [None]:
# Import necessary libraries and initialize SparkSession
from pyspark.sql.functions import explode, col, broadcast, avg, lit, count
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType
from pyspark.sql.functions import rand
from IPython.display import display, HTML
import pandas as pd
import matplotlib.pyplot as plt
import time
import json
import folium

%run "/usr/local/spark/notebooks/00-spark-connection.ipynb"

# Set shuffle partitions to match the number of cores for better performance
spark.conf.set("spark.sql.shuffle.partitions", "8")  # Adjust as needed

In [None]:
def load_parquet_data(file_path):
    start_time = time.time()
    df = spark.read.parquet(file_path)
    print(f"Loaded Parquet data in {time.time() - start_time:.2f} seconds")
    return df

def load_json_data(file_path):
    start_time = time.time()
    with open(file_path, "r") as file:
        json_data = json.load(file)
    print(f"Loaded JSON data in {time.time() - start_time:.2f} seconds")
    return json_data

def create_dataframe_from_json(json_data, schema, row_mapper):
    rows = row_mapper(json_data)
    df = spark.createDataFrame(rows, schema)
    # df.show(5)
    return df

def join_dataframes(df1, df2, join_condition, select_columns):
    start_time = time.time()
    joined_df = df1.join(df2, join_condition, "left").select(*select_columns)
    # joined_df.show(5)
    print(f"Joined data in {time.time() - start_time:.2f} seconds")
    return joined_df

# Load enriched departures from Parquet file
df = load_parquet_data("data/enriched_01.parquet")

# Load line data
line_json_data = load_json_data("data/relevantLines_with_stops.json")

# Define the schema for the line DataFrame
line_schema = StructType([
    StructField("line_id", StringType(), True),
    StructField("name", StringType(), True),
    StructField("product", StringType(), True)
])

# Create the line DataFrame
lines_df = create_dataframe_from_json(line_json_data, line_schema, lambda data: [(line_id, line_info['name'], line_info['product']) for line_id, line_info in data.items()])

# Cache the lines_df as it will be reused
lines_df.cache()

# Join the enriched departures with the line information
joined_df = join_dataframes(df, lines_df, df.lineId == lines_df.line_id, [df["*"], lines_df.name.alias("line_name"), lines_df.product.alias("line_product")])

# Load stops data
stops_json_data = load_json_data("data/stops.json")

# Define the schema for the stops DataFrame
stops_schema = StructType([
    StructField("stop_id", StringType(), True),
    StructField("stop_name", StringType(), True),
    StructField("latitude", StringType(), True),
    StructField("longitude", StringType(), True)
])

# Create the stops DataFrame
stops_df = create_dataframe_from_json(stops_json_data['stops'], stops_schema, lambda data: [(stop['id'], stop['name'], stop['latitude'], stop['longitude']) for stop in data])

# Cache the stops_df as it will be reused
stops_df.cache()

# Join the enriched departures with the line and stop information
final_df = join_dataframes(joined_df, stops_df, joined_df.stopId == stops_df.stop_id, [joined_df["*"], stops_df.stop_name.alias("stop_name"), stops_df.latitude.alias("stop_latitude"), stops_df.longitude.alias("stop_longitude")])

In [None]:
# Get a random tripId with at least 3 different added delay values
valid_trips = final_df.groupBy("tripId").agg(F.countDistinct("added_delay").alias("distinct_delays")) \
                      .filter(col("distinct_delays") >= 3) \
                      .select("tripId") \
                      .distinct() \
                      .orderBy(rand()) \
                      .limit(1) \
                      .collect()

# Get a random tripId
random_trip = valid_trips[0]["tripId"]

# Filter the dataframe for the selected trip and order by plannedWhen
trip_stops = final_df.filter(col("tripId") == random_trip) \
                     .select("stop_name", "stop_latitude", "stop_longitude", "plannedWhen", "when", "delay", "added_delay", "line_name", "line_product", "direction") \
                     .orderBy("plannedWhen")

# Collect the stop data
stops_data = trip_stops.collect()

if stops_data:
    start_stop = stops_data[0]
    end_stop = stops_data[-1]
    avg_delay = sum([float(stop["delay"]) for stop in stops_data]) / len(stops_data)
    
    # Calculate duration between start and stop
    start_time = pd.to_datetime(start_stop["plannedWhen"])
    end_time = pd.to_datetime(end_stop["plannedWhen"])
    duration = (end_time - start_time).total_seconds() / 60  # duration in minutes
    
    # Create a table with the specified columns
    stops_table_data = {
        "Stop Name": [stop["stop_name"] for stop in stops_data],
        "Planned When": [stop["plannedWhen"] for stop in stops_data],
        "When": [stop["when"] for stop in stops_data],
        "Delay": [stop["delay"] for stop in stops_data],
        "Added Delay": [stop["added_delay"] for stop in stops_data]
    }
    stops_table_df = pd.DataFrame(stops_table_data)
    
    # Create a table for the additional information
    additional_info_data = {
        "Info": ["Trip ID", "Line Name", "Direction", "Line Product", "Start Stop", "Start Planned When", "Start Actual When", "End Stop", "End Planned When", "End Actual When", "Average Delay", "Start Delay", "End Delay", "Duration"],
        "Value": [random_trip, start_stop['line_name'], start_stop['direction'], start_stop['line_product'], start_stop['stop_name'], start_stop['plannedWhen'], start_stop['when'], end_stop['stop_name'], end_stop['plannedWhen'], end_stop['when'], f"{avg_delay:.2f} minutes", f"{start_stop['delay']} seconds", f"{end_stop['delay']} seconds", f"{duration:.2f} minutes"]
    }
    additional_info_df = pd.DataFrame(additional_info_data)
    
    # Display the tables side by side
    display_html = f"""
    <div style="display: flex; justify-content: space-around;">
        <div>{additional_info_df.to_html(index=False)}</div>
        <div>{stops_table_df.to_html(index=False)}</div>
    </div>
    """
    display(HTML(display_html))
    
    avg_latitude = sum([float(stop["stop_latitude"]) for stop in stops_data]) / len(stops_data)
    avg_longitude = sum([float(stop["stop_longitude"]) for stop in stops_data]) / len(stops_data)
    m = folium.Map(location=[avg_latitude, avg_longitude], zoom_start=13)

    # Create a feature group for the stops
    stops_group = folium.FeatureGroup(name='Stops')

    # Function to map delay to a color for stops
    def stop_delay_to_color(delay):
        if delay < 0:
            return 'blue'
        elif delay <= 0:
            return 'green'
        elif delay <= 60:
            return 'lightred'
        elif delay <= 120:
            return 'orange'
        else:
            return 'red'

    # Add a marker for each stop with different icons for start, end, and intermediate stops
    for i, stop in enumerate(stops_data):
        stop_color = stop_delay_to_color(float(stop["delay"]))
        if i == 0:
            # Start stop
            folium.Marker(
                location=[float(stop["stop_latitude"]), float(stop["stop_longitude"])],
                popup=f"{stop['stop_name']} ({stop['plannedWhen']})",
                icon=folium.Icon(color=stop_color, icon='play', prefix='fa')
            ).add_to(stops_group)
        elif i == len(stops_data) - 1:
            # End stop
            folium.Marker(
                location=[float(stop["stop_latitude"]), float(stop["stop_longitude"])],
                popup=f"{stop['stop_name']} ({stop['plannedWhen']})",
                icon=folium.Icon(color=stop_color, icon='stop', prefix='fa')
            ).add_to(stops_group)
        else:
            # Intermediate stops
            folium.Marker(
                location=[float(stop["stop_latitude"]), float(stop["stop_longitude"])],
                popup=f"{stop['stop_name']} ({stop['plannedWhen']})",
                icon=folium.Icon(color=stop_color, icon='circle', prefix='fa')
            ).add_to(stops_group)

    # Add the stops group to the map
    stops_group.add_to(m)

    # Create a list of coordinates for the stops
    coords = [[float(stop["stop_latitude"]), float(stop["stop_longitude"])] for stop in stops_data]
    colors = [float(stop["added_delay"]) for stop in stops_data[1:]]  # Colors based on added delay for each segment

    # Normalize colors to be between 0 and 1 for colormap, with 300 seconds as the maximum delay
    max_delay = 300
    colors = [min(delay / max_delay, 1) for delay in colors]

    # Add a ColorLine to the map
    folium.ColorLine(
        positions=coords,
        colors=colors,
        colormap=["green", "yellow", "orange", "red"],
        weight=5
    ).add_to(m)

    # Display the map
    display(m)
else:
    print("No stops found for the selected trip.")

In [None]:
# Display distinct line names
line_names = final_df.select("line_name").distinct().orderBy("line_name").collect()
line_name_options = [line["line_name"] for line in line_names]
print("Available Line Names:")
for line_name in line_name_options:
    print(line_name)

In [None]:
line_name = "STR 1"

def update_output(line_name):
    start_time = time.time()
    
    # Filter the dataframe for the selected line name and cache it
    line_stops = final_df.filter(col("line_name") == line_name) \
                         .select("stopId", "stop_name", "stop_latitude", "stop_longitude", "plannedWhen", "when", "delay", "added_delay", "line_name", "line_product", "direction", "tripId") \
                         .orderBy("plannedWhen") \
                         .cache()
    print(f"Filtered and cached line stops in {time.time() - start_time:.2f} seconds")
    
    # Calculate average delays for distinct stopId
    avg_delays = line_stops.groupBy("stopId", "stop_name").agg(
        avg("delay").alias("avg_delay"),
        avg("added_delay").alias("avg_added_delay")
    ).orderBy("stopId").collect()
    
    # Print the average delays with stop names
    print("Average Delays for Stops:")
    for row in avg_delays:
        print(f"Stop ID: {row['stopId']}, Stop Name: {row['stop_name']}, Avg Delay: {row['avg_delay']:.2f} seconds, Avg Added Delay: {row['avg_added_delay']:.2f} seconds")
    
    # Get the first tripId
    first_trip_id = line_stops.select("tripId").distinct().orderBy("tripId").limit(1).collect()[0]["tripId"]
    
    # Filter the stops for the first tripId
    trip_stops = line_stops.filter(col("tripId") == first_trip_id).collect()
    print(f"Collected {len(trip_stops)} stop data for the first trip in {time.time() - start_time:.2f} seconds")

    if trip_stops:
        avg_latitude = sum([float(stop["stop_latitude"]) for stop in trip_stops]) / len(trip_stops)
        avg_longitude = sum([float(stop["stop_longitude"]) for stop in trip_stops]) / len(trip_stops)
        m = folium.Map(location=[avg_latitude, avg_longitude], zoom_start=13)

        # Create a feature group for the stops
        stops_group = folium.FeatureGroup(name='Stops')

        # Function to map delay to a color for stops
        def stop_delay_to_color(delay):
            if delay < 60:
                return 'blue'
            elif delay <= 120:
                return 'green'
            elif delay <= 180:
                return 'lightred'
            elif delay <= 240:
                return 'orange'
            else:
                return 'red'

        # Create a dictionary to map stopId to avg_added_delay
        avg_delays_dict = {row['stopId']: row['avg_delay'] for row in avg_delays}
        avg_added_delay_dict = {row['stopId']: row['avg_added_delay'] for row in avg_delays}

        # Add a marker for each stop with different icons for start, end, and intermediate stops
        for i, stop in enumerate(trip_stops):
            stop_color = stop_delay_to_color(float(stop["delay"]))
            avg_delay = avg_delays_dict.get(stop["stopId"], 0)
            avg_added_delay = avg_added_delay_dict.get(stop["stopId"], 0)
            popup_text = f"{stop['stop_name']} (Avg Delay: {avg_delay:.2f} seconds, Avg Added Delay: {avg_added_delay:.2f} seconds)"
            if i == 0:
                # Start stop
                folium.Marker(
                    location=[float(stop["stop_latitude"]), float(stop["stop_longitude"])],
                    popup=popup_text,
                    icon=folium.Icon(color=stop_color, icon='play', prefix='fa')
                ).add_to(stops_group)
            elif i == len(trip_stops) - 1:
                # End stop
                folium.Marker(
                    location=[float(stop["stop_latitude"]), float(stop["stop_longitude"])],
                    popup=popup_text,
                    icon=folium.Icon(color=stop_color, icon='stop', prefix='fa')
                ).add_to(stops_group)
            else:
                # Intermediate stops
                folium.Marker(
                    location=[float(stop["stop_latitude"]), float(stop["stop_longitude"])],
                    popup=popup_text,
                    icon=folium.Icon(color=stop_color, icon='circle', prefix='fa')
                ).add_to(stops_group)

        # Add the stops group to the map
        stops_group.add_to(m)

        # Create a list of coordinates for the stops
        coords = [[float(stop["stop_latitude"]), float(stop["stop_longitude"])] for stop in trip_stops]
        colors = [avg_added_delay_dict.get(stop["stopId"], 0) for stop in trip_stops[1:]]  # Colors based on average added delay for each segment

        # Calculate min and max added delay for normalization
        min_delay = min(colors)
        max_delay = max(colors)

        # Normalize colors to be between 0 and 1 for colormap
        colors = [(delay - min_delay) / (max_delay - min_delay) if max_delay > min_delay else 0 for delay in colors]

        # Add a ColorLine to the map
        folium.ColorLine(
            positions=coords,
            colors=colors,
            colormap=["green", "yellow", "orange", "red"],
            weight=5
        ).add_to(m)

        # Display the map
        display(m)
        print(f"Displayed map in {time.time() - start_time:.2f} seconds")
    else:
        print("No stops found for the selected line name.")

# Call the function with the hardcoded line name
update_output(line_name)