In [None]:
from pathlib import Path
import sys
import polars as pl

# ── Set up imports ──────────────────────────────────────────────
nb_dir   = Path.cwd()                    # …/notebooks
repo_dir = nb_dir.parent
sys.path.insert(0, str(repo_dir))

from pipeline.datasets import (
    SINGLE_FILE_ASSETS_NAMES,
    PARTITIONED_ASSETS_NAMES,
)
from pipeline.utils.polars_sql_wrapper import PolarsSQLWrapper

# ── Create the context ─────────────────────────────────────────
ctx = PolarsSQLWrapper()


In [None]:
BASE_PATH = "data/opendata"
ctx.bulk_register_data(
    repo_root=repo_dir,
    base_path=BASE_PATH,
    table_names=SINGLE_FILE_ASSETS_NAMES,
    wildcard="*.parquet",
)

# ── Register partitioned datasets ──────────────────────────────
ctx.bulk_register_partitioned_data(
    repo_root=repo_dir,
    base_path=BASE_PATH,
    table_names=PARTITIONED_ASSETS_NAMES,
    wildcard="year=*/month=*/*.parquet",
)

ctx.show_tables()
# You can also peek directly:
print(ctx._catalogue.keys())

In [None]:
ctx.show_tables()

In [None]:
df = ctx.run_query("SELECT COUNT(*) AS n FROM mta_subway_hourly_ridership")
print(df)

In [None]:
lf = ctx.lazy("mta_subway_hourly_ridership")

In [None]:
# e.g. total ridership by borough in 2023
df_totals = (
    lf
    .filter(pl.col("transit_timestamp").dt.year() == 2023)
    .group_by("borough")
    .agg(pl.col("ridership").sum().alias("total_2023"))
    .sort("total_2023", descending=True)
    .collect()
)
print(df_totals)

In [None]:
df_avg = (
    lf
    .filter(
        (pl.col("transit_timestamp").dt.year() == 2023) &
        (pl.col("transit_timestamp").dt.month().is_between(1, 3))
    )
    .group_by("payment_method")
    .agg(pl.col("transfers").mean().round(2).alias("avg_transfers_q1"))
    .sort("avg_transfers_q1", descending=True)
    .collect()
)
print(df_avg)

In [None]:
sql = """
SELECT borough, SUM(ridership) AS total_2023
  FROM mta_subway_hourly_ridership
 WHERE EXTRACT(year FROM transit_timestamp) = 2023
 GROUP BY borough
 ORDER BY total_2023 DESC
"""
df_sql = ctx.run_query(sql)
print(df_sql)