In [31]:
import json
import geopandas as gpd
import pandas as pd
import folium as fl
from folium.features import GeoJson, GeoJsonTooltip
import branca.colormap as cm
from sqlalchemy import create_engine, Engine, text
import plotly.express as px
import ipywidgets as widgets
from IPython.display import display

DB_USER = "postgres"
DB_PASS = "mobility"
DB_HOST = "localhost"
DB_PORT = "5434"
DB_NAME = "urbanmobility"

In [32]:
# Create a SQLAlchemy engine using the database configuration
def create_sqlalchemy_engine(db_user, db_pass, db_host, db_port, db_name):
    db_url = (
        f"postgresql://{db_user}:{db_pass}@"
        f"{db_host}:{db_port}/{db_name}"
    )
    return create_engine(db_url)


In [33]:
# Fetch data from PostgreSQL and return a GeoDataFrame
def fetch_segment_speed_data(engine: Engine):
    """
    Fetches the average speed per segment from PostgreSQL and returns a GeoDataFrame.

    Parameters:
    - engine: SQLAlchemy engine for database connection.

    Returns:
    - GeoDataFrame containing segment speed and geometries.
    """
    sql_query = """
                SELECT AVG(s.distance_m / EXTRACT(EPOCH FROM (ts.end_time_actual - ts.start_time_actual)) *
                           3.6) AS speedKMH,
                       ST_Transform(s.geometry, 4326) AS geometry,
                       ts.start_stop_id,
                       ts.end_stop_id
                FROM trip_segments ts
                         JOIN segments s
                              ON ts.start_stop_id = s.start_stop_id AND ts.end_stop_id = s.end_stop_id
                WHERE ts.start_time_actual IS NOT NULL
                  AND EXTRACT(EPOCH FROM (ts.end_time_actual - ts.start_time_actual)) > 0
                GROUP BY s.geometry, ts.start_stop_id, ts.end_stop_id; \
                """
    count_query = """
                  SELECT COUNT(*)
                  FROM trip_segments ts
                  WHERE ts.start_time_actual IS NOT NULL
                    AND EXTRACT(EPOCH FROM (ts.end_time_actual - ts.start_time_actual)) > 0; \
                  """
    count_result = pd.read_sql(count_query, engine).iloc[0, 0]
    print(f"Total segments with valid speed data: {count_result}")
    return gpd.read_postgis(sql_query, engine, geom_col='geometry')


In [34]:
# Create a colormap based on speed
def create_colormap(min_speed, max_speed):
    """
    Creates a colormap for speed values.

    Parameters:
    - min_speed: Minimum speed value.
    - max_speed: Maximum speed value.

    Returns:
    - A colormap function.
    """
    return cm.LinearColormap(['red', 'white'], vmin=min_speed, vmax=max_speed)


In [35]:
def visualize_speed_map(gdf, cutoff=35):
    """
    Visualizes the average speed per segment using folium.

    Parameters:
    - gdf: GeoDataFrame containing segment geometries and speedKMH.
    - cutoff: Maximum speed value for the color gradient.

    Returns:
    - A folium map with segments color-coded by speedKMH.
    """
    transport_map = fl.Map(
        location=[60.21441, 24.92397],  # Centered around Riga
        zoom_start=11,
        tiles="https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png",
        attr="OpenStreetMap",
        # Use SRID 3879
        # crs='EPSG3879'
    )

    # Create a colormap for the speeds
    min_speed = gdf['speedkmh'].min()
    max_speed = min(gdf['speedkmh'].max(), cutoff)
    colormap = create_colormap(min_speed, max_speed)

    # Add each segment to the map
    for _, segment in gdf.iterrows():
        speed_kmh = min(segment['speedkmh'], cutoff)
        color = colormap(speed_kmh)

        geo_json = GeoJson(
            data={
                "type": "Feature",
                "geometry": segment['geometry'].__geo_interface__,
                "properties": {
                    "speedkmh": round(segment['speedkmh'], 2),
                    "from_stop_id": segment['start_stop_id'],
                    "to_stop_id": segment['end_stop_id']
                }
            },
            style_function=lambda x, color=color: {
                'color': color,
                'weight': 3,
                'opacity': 0.7
            },
            tooltip=GeoJsonTooltip(
                fields=['speedkmh', 'from_stop_id', 'to_stop_id'],
                aliases=['Speed (km/h)', 'From Stop', 'To Stop'],
                localize=True
            )
        )
        geo_json.add_to(transport_map)

    # Add the colormap to the map
    colormap.add_to(transport_map)
    return transport_map


In [36]:
# Create SQLAlchemy engine
engine = create_sqlalchemy_engine(DB_USER, DB_PASS, DB_HOST, DB_PORT, DB_NAME)


In [37]:
# Fetch segment speed data
segment_speed_gdf = fetch_segment_speed_data(engine)


Total segments with valid speed data: 4943


In [38]:
# Visualize the average speed per segment
speed_map = visualize_speed_map(segment_speed_gdf, cutoff=max(segment_speed_gdf['speedkmh']))


In [39]:
# Display the map
speed_map.save("segment_speed_map.html")
print("Map saved as 'segment_speed_map.html'.")

Map saved as 'segment_speed_map.html'.


In [40]:
# display the map
display(speed_map)

# Graph trip delay

In [41]:
def get_all_trip_ids(db_engine):
    """Fetches all unique schedule_trip_ids from the database."""
    query = "SELECT DISTINCT schedule_trip_id FROM trip_stops ORDER BY schedule_trip_id;"
    with db_engine.connect() as connection:
        df = pd.read_sql(query, connection)
    return df['schedule_trip_id'].tolist()


In [42]:
def get_trip_data(trip_ids, db_engine):
    """Fetches delay data for a given list of trip IDs."""
    if not trip_ids:
        return pd.DataFrame(columns=['schedule_trip_id', 'schedule_time', 'delay'])

    query = text("""
        SELECT schedule_trip_id, schedule_time, delay
        FROM trip_stops
        WHERE schedule_trip_id IN :trip_ids
        ORDER BY schedule_trip_id, schedule_time;
    """)

    with db_engine.connect() as connection:
        df = pd.read_sql(query, connection, params={'trip_ids': tuple(trip_ids)})
    return df


In [43]:
try:
    all_ids = get_all_trip_ids(engine)
except Exception as e:
    print(f"Failed to connect to the database or fetch trip IDs: {e}")
    all_ids = []


In [44]:
trip_selector = widgets.SelectMultiple(
    options=all_ids,
    value=[all_ids[0]] if all_ids else [], # Default to the first ID if available
    description='Trip IDs',
    disabled=False,
    layout={'width': '50%'}
)

In [45]:
plot_output = widgets.Output()

In [50]:
def update_plot(change):
    """Callback function to update the plot when the selection changes."""
    with plot_output:
        plot_output.clear_output(wait=True)
        selected_ids = change.new
        if not selected_ids:
            print("Please select at least one trip ID.")
            return

        df = get_trip_data(selected_ids, engine)

        if not df.empty:
            fig = px.line(
                df,
                x='schedule_time',
                y='delay',
                color='schedule_trip_id',
                title='Trip Delay vs. Scheduled Time',
                markers=True,
                labels={'delay': 'Delay (minutes)', 'schedule_time': 'Scheduled Time'}
            )
            fig.show()
        else:
            print("No data found for the selected trip IDs.")


In [51]:
trip_selector.observe(update_plot, names='value')


In [52]:
# Display the widgets and initial plot
print("Select trip IDs from the list below to plot their delay.")
display(trip_selector, plot_output)


Select trip IDs from the list below to plot their delay.


SelectMultiple(description='Trip IDs', index=(8,), layout=Layout(width='50%'), options=('1013_20250709_Pe_1_21…

Output()

In [49]:
# Trigger the initial plot draw
update_plot({'new': trip_selector.value})