In [None]:
import os
import time
from functools import cache, wraps
from io import BytesIO
from pathlib import Path
from zipfile import ZipFile

import duckdb
import folium
import httpx
import matplotlib as plt
import pandas as pd
import plotly.graph_objs as go
from plotly.subplots import make_subplots

In [None]:
def download_if_outdated(threshold_seconds):
    def decorator(download_func):
        @wraps(download_func)
        def wrapper(*args, **kwargs):
            extract_dir = Path(kwargs.get("extract_dir", "../data/road_traffic_counts_hourly_permanent"))
            # Ensure the directory exists
            extract_dir.mkdir(parents=True, exist_ok=True)

            # Check if the directory is not empty
            if extract_dir.exists() and any(extract_dir.iterdir()):
                now = time.time()

                # Check the age of the first file in the directory
                first_file = next(extract_dir.iterdir())
                file_mod_time = os.path.getmtime(first_file)

                # If the file is newer than the threshold, skip the download
                if now - file_mod_time < threshold_seconds:
                    print("Files are up-to-date. Skipping download.")
                    return [path.as_posix() for path in extract_dir.glob("*/*.csv")]

            # If the files are older than the threshold, or the directory is empty, call the download function
            return download_func(*args, **kwargs)

        return wrapper

    return decorator

In [None]:
table_name_counts = "road_traffic_counts"
table_name_stations = "station_reference"


csv_files_path = "../data/road_traffic_counts_hourly_permanent/"

# station_reference_csv = "../data/road_traffic_counts_station_reference.csv"

hourly_road_count_zip = "https://opendata.transport.nsw.gov.au/dataset/ef2b0bd2-db1e-48f3-9ea1-2bb9e6bc6504/resource/bca06c7e-30be-4a90-bc8b-c67428c0823a/download/road_traffic_counts_hourly_permanent.zip"
station_reference_csv = "https://opendata.transport.nsw.gov.au/dataset/ef2b0bd2-db1e-48f3-9ea1-2bb9e6bc6504/resource/c65ad7b4-0257-4cc6-953e-5299ac8d27ba/download/road_traffic_counts_station_reference.csv"

In [None]:
@download_if_outdated(threshold_seconds=2 * 24 * 60 * 60)  # 2 days in seconds
def download_extract_hourly_road_count_data():
    extract_dir = Path(csv_files_path)
    extract_dir.mkdir(parents=True, exist_ok=True)

    with httpx.Client() as client:
        response = client.get(hourly_road_count_zip)

    with ZipFile(BytesIO(response.content)) as zip_ref:
        zip_ref.extractall(extract_dir)

    return [path.as_posix() for path in extract_dir.glob("*/*.csv")]

In [None]:
csv_files = download_extract_hourly_road_count_data()

In [None]:
@cache
def load_data_duckdb():
    con = duckdb.connect()

    con.execute(
        f"CREATE TABLE {table_name_stations} AS SELECT * FROM read_csv_auto('{station_reference_csv}')"
    )

    # Assume the first file defines the table structure
    con.execute(f"CREATE TABLE {table_name_counts} AS SELECT * FROM read_csv_auto('{csv_files[0]}')")

    # For each subsequent file, insert the data into the existing table
    for csv_file in csv_files[1:]:
        con.execute(f"INSERT INTO {table_name_counts} SELECT * FROM read_csv_auto('{csv_file}')")

    df = con.sql(f"SELECT * FROM {table_name_counts}").to_df()
    stats = pd.DataFrame(df.describe())
    stats_all = pd.DataFrame(df.describe(include="all"))

    # stats_objects = df.describe(include=[object])
    return df, stats, stats_all, con

In [None]:
df, stats, stats_all, con = load_data_duckdb()

In [None]:
def show_schema(table_name):
    schema = con.execute(f"DESCRIBE {table_name}").fetch_df()
    return schema[["column_name", "column_type", "null"]]

In [None]:
show_schema(table_name_counts)

In [None]:
show_schema(table_name_stations)

In [None]:
def df_station_id():
    exclude_station_id = ["18031", "11139", "19035"]
    exclude_station_id_sql = ", ".join(f"'{id}'" for id in exclude_station_id)

    station_sql_query = f"""
        SELECT * FROM {table_name_stations}
        WHERE full_name ILIKE '%Victoria Road%'
        AND station_id NOT IN ({exclude_station_id_sql});
    """

    station_df = con.sql(station_sql_query).to_df()
    return station_df

In [None]:
station_df = df_station_id()
vic_rd_stations = station_df["station_key"].unique().tolist()

In [None]:
result = con.execute(f"SELECT COUNT(*) FROM {table_name_counts}").fetchone()

# result now contains the count of rows, which is the first (and only) element in the returned tuple
row_count = result[0]

print(f"Number of rows in {table_name_counts}: {row_count}")
assert len(df) == row_count

In [None]:
# Convert year_start to datetime
year_start_datetime = pd.to_datetime("2018-01-01")

# Get the current date as datetime
current_date_datetime = pd.to_datetime("now")

In [None]:
def plot_counts_for_station_key_by_hour(station_key, hour, year_start=year_start_datetime):
    df = con.sql(
        f"SELECT date, hour_{hour:02}, daily_total FROM {table_name_counts} WHERE station_key = {station_key} AND classification_seq = 2 ORDER BY date ASC"
    ).to_df()
    quantile_max = round(df[f"hour_{hour:02}"].quantile(0.999) / 100) * 100  # round to nearest 100
    df.plot(
        x="date", y=f"hour_{hour:02}", xlim=[year_start, current_date_datetime], ylim=[0, quantile_max]
    )
    return df

In [None]:
# df_all = {}
# quantiles = {}
# for hour in range(0, 24):
#     df_all[hour] = plot_counts_for_station_key_by_hour("99990010", hour)

In [None]:
def plotly_hourly_count(df, year_start=year_start_datetime):
    # Assuming df is your DataFrame and it contains a 'date' column and multiple 'hour_xx' columns

    df_plot = df.copy()
    df_plot = df_plot[df_plot["date"] >= year_start]
    # Create a figure with a slider
    fig = make_subplots(specs=[[{"secondary_y": True}]])

    # Add traces for each hour, assuming hours 0 through 23
    for hour in range(24):
        fig.add_trace(
            go.Scatter(x=df["date"], y=df[f"hour_{hour:02}"], name=f"Hour {hour}"),
            secondary_y=False,
        )

    # Create and add slider
    steps = []
    for i, hour in enumerate(range(24)):
        step = dict(method="update", args=[{"visible": [False] * 24}], label=f"Hour {hour}")
        step["args"][0]["visible"][i] = True  # Toggle i-th trace to "visible"
        steps.append(step)

    sliders = [dict(active=0, currentvalue={"prefix": "Hour: "}, pad={"t": 50}, steps=steps)]

    fig.update_layout(sliders=sliders)
    return fig

In [18]:
fig = plotly_hourly_count()
fig.show()

In [None]:
def plot_counts_for_station_key(station_key):
    # df = con.sql(f"SELECT * FROM {table_name_counts} WHERE station_key = {station_key} ORDER BY date").to_df()
    df = con.sql(
        f"SELECT date, daily_total FROM {table_name_counts} WHERE station_key = {station_key} AND classification_seq = 2 ORDER BY date ASC"
    ).to_df()
    if len(df) > 0:
        df.plot(x="date", y="daily_total", legend=f"Station key: {station_key}")
        return df
    else:
        print(f"Station key: {station_key} has no data")
        return station_key

In [None]:
df_all = {}
station_key_no_data = []

for station_key in vic_rd_stations:
    result = plot_counts_for_station_key(station_key)
    if isinstance(result, pd.DataFrame):
        df_all[station_key] = result
    else:
        station_key_no_data.append(result)

In [None]:
station_key_no_data

In [None]:
len(df_all)

In [None]:
print([table_name[0] for table_name in con.execute("SHOW TABLES;").fetchall()])

In [None]:
def map_selected_stations(station_df):
    df = station_df.copy()
    # Assuming df has 'wgs84_latitude', 'wgs84_longitude', 'station_id', 'station_key' and 'full_name' columns

    m = folium.Map()
    fg = folium.FeatureGroup()  # Create a feature group

    # Add markers to the feature group with popups
    for _, row in df.iterrows():
        popup_text = f"Station ID (Key): {row['station_id']} ({row['station_key']})<br>Full Name: {row['full_name']}"
        marker = folium.Marker(
            [row["wgs84_latitude"], row["wgs84_longitude"]],
            popup=folium.Popup(popup_text, max_width=450),
        )
        fg.add_child(marker)
    m.add_child(fg)
    m.fit_bounds(fg.get_bounds())
    return m

In [None]:
m = map_selected_stations(station_df)

In [None]:
m