In [104]:
#DataFusionWrapper Class- open the box to see the underlying code
from pathlib import Path
from datafusion import SessionContext

class DataFusionWrapper:
    def __init__(self):
        self.con = SessionContext()
        self.registered_tables = []

    def register_data(self, paths, table_names):
        """
        Registers data files (Parquet, CSV, JSON) to DataFusion context with specified table names.
        Automatically detects the file type based on the file extension.
        
        Args:
            paths (list): List of paths to data files.
            table_names (list): List of table names corresponding to the paths.
        """
        if len(paths) != len(table_names):
            raise ValueError("The number of paths must match the number of table names.")
        
        for path, table_name in zip(paths, table_names):
            file_extension = Path(path).suffix.lower()

            if file_extension == ".parquet":
                self.con.register_parquet(table_name, path)
            elif file_extension == ".csv":
                self.con.register_csv(table_name, path)
            elif file_extension == ".json":
                self.con.register_json(table_name, path)
            else:
                raise ValueError(f"Unsupported file type '{file_extension}' for file: {path}")

            self.registered_tables.append(table_name)

    def run_query(self, sql_query):
        """
        Runs a SQL query on the registered tables in the DataFusion context.
        Args:
            sql_query (str): The SQL query string to execute.
        Returns:
            DataFrame: Query result as a DataFusion DataFrame.
        """
        return self.con.sql(sql_query)

    def _construct_path(self, path, base_path, file_name, extension):
        """
        Constructs the full file path based on input parameters.
        """
        if path:
            return Path(path)
        elif base_path and file_name:
            return Path(base_path) / f"{file_name}.{extension}"
        else:
            # Default file path: "output.<extension>" in the current directory
            return Path(f"output.{extension}")

    def export(self, dataframe, file_type, path=None, base_path=None, file_name=None, with_header=True):
        """
        Exports a DataFrame to the specified file type.
        Args:
            dataframe: DataFusion DataFrame.
            file_type (str): Type of file to export ('parquet', 'csv', 'json').
            path: Full path to the file (optional).
            base_path: Directory path (optional).
            file_name: Name of the file (without extension) (optional).
            with_header: Include header row for CSV files (default: True).
        """
        file_type = file_type.lower()
        if file_type not in ["parquet", "csv", "json"]:
            raise ValueError("file_type must be one of 'parquet', 'csv', or 'json'.")

        # Construct file path
        full_path = self._construct_path(path, base_path, file_name, file_type)

        # Export based on file type
        if file_type == "csv":
            dataframe.write_csv(full_path, with_header=with_header)
        elif file_type == "json":
            dataframe.write_json(full_path)
        elif file_type == "parquet":
            dataframe.write_parquet(full_path)
        
        print(f"File written to: {full_path}")

    def show_tables(self):
        """
        Displays the table names and types currently registered in the catalog.
        """
        query = """
        SELECT table_name, table_type 
        FROM information_schema.tables 
        WHERE table_schema <> 'information_schema'
        """
        result_df = self.run_query(query)
        result_df.show()

    def show_schema(self, table_name):
        """
        Displays the schema of the specified table.
        
        Args:
            table_name (str): Name of the table whose schema is to be displayed.
        """
        query = f"""
        SELECT 
            table_name, 
            column_name, 
            data_type
        FROM 
            information_schema.columns 
        WHERE 
            table_name = '{table_name}'
        """
        result_df = self.run_query(query)
        result_df.show()


In [105]:
# Initialize the wrapper
con = DataFusionWrapper()


In [106]:
# Load local files into DataFusion
# Dynamically find the root of the repo (assuming the .ipynb is located within the repo)
repo_root = Path.cwd().resolve().parents[0]   # Adjust to locate the repo root

# Define relative paths from the repo root and corresponding table names
paths = [
    repo_root / "data/examples/mta_operations_statement/file_1.parquet",
    repo_root / "data/examples/mta_hourly_subway_socrata/*.parquet",
    repo_root / "data/examples/mta_daily_ridership/*.parquet",
    repo_root / "data/examples/mta_bus_wait_time/*.parquet",
    repo_root / "data/examples/daily_weather_asset/*.parquet",
    repo_root / "data/examples/csv_example/*.csv",
    repo_root / "data/examples/json_example/*.json",
]
table_names = [
    "mta_operations_statement",
    "mta_hourly_subway_socrata",
    "mta_daily_ridership",
    "mta_bus_wait_time",
    "daily_weather_asset",
    "operations_csv",
    "operations_json"
]

# Register data files (automatically detects file type)
con.register_data(paths, table_names)

In [107]:
# Quickly see all of the tables that are loaded and their names
con.show_tables()

DataFrame()
+---------------------------+------------+
| table_name                | table_type |
+---------------------------+------------+
| mta_daily_ridership       | BASE TABLE |
| operations_json           | BASE TABLE |
| mta_operations_statement  | BASE TABLE |
| daily_weather_asset       | BASE TABLE |
| mta_bus_wait_time         | BASE TABLE |
| operations_csv            | BASE TABLE |
| mta_hourly_subway_socrata | BASE TABLE |
+---------------------------+------------+


In [108]:
#Show the schema of a table
con.show_schema("mta_operations_statement")

DataFrame()
+--------------------------+---------------------+-----------+
| table_name               | column_name         | data_type |
+--------------------------+---------------------+-----------+
| mta_operations_statement | fiscal_year         | Int64     |
| mta_operations_statement | timestamp           | Date32    |
| mta_operations_statement | scenario            | LargeUtf8 |
| mta_operations_statement | financial_plan_year | Int64     |
| mta_operations_statement | expense_type        | LargeUtf8 |
| mta_operations_statement | agency              | LargeUtf8 |
| mta_operations_statement | type                | LargeUtf8 |
| mta_operations_statement | subtype             | LargeUtf8 |
| mta_operations_statement | general_ledger      | LargeUtf8 |
| mta_operations_statement | amount              | Float64   |
| mta_operations_statement | agency_full_name    | LargeUtf8 |
+--------------------------+---------------------+-----------+


In [109]:
# Write a simple SQL query against the tables and print result as a dataframe
sql_query = """
select count(*) as total_rows 
from mta_hourly_subway_socrata
"""

# Run query and show results
result_df = con.run_query(sql_query)
result_df.show()


DataFrame()
+------------+
| total_rows |
+------------+
| 67763465   |
+------------+


In [110]:
# Export Dataframes into Parquet, CSV, or JSON
file_type = "csv"
file_name = "row_count"
base_path = repo_root / "data/exports"

# Export results of result_df
con.export(result_df, file_type=file_type, base_path=base_path, file_name=file_name)

File written to: /home/christianocean/datafusion-quickstart/data/exports/row_count.csv


In [111]:
# Write a more complciated SQL query
sql_query = """
WITH payment_stats AS (
    SELECT
        general_ledger,
        MIN(timestamp) AS first_payment_date,
        MAX(timestamp) AS last_payment_date,
        AVG(amount) AS average_payment,
        COUNT(DISTINCT timestamp) AS total_payments
    FROM
        mta_operations_statement
    WHERE
        scenario = 'Actual'
        AND type = 'Debt Service Expenses'
    GROUP BY
        general_ledger
),
first_payment AS (
    SELECT
        general_ledger,
        MIN(timestamp) AS first_payment_date,
        MIN(amount) AS first_payment_amount
    FROM
        mta_operations_statement
    WHERE
        scenario = 'Actual'
        AND type = 'Debt Service Expenses'
    GROUP BY
        general_ledger
),
last_payment AS (
    SELECT
        general_ledger,
        MAX(timestamp) AS last_payment_date,
        MAX(amount) AS last_payment_amount
    FROM
        mta_operations_statement
    WHERE
        scenario = 'Actual'
        AND type = 'Debt Service Expenses'
    GROUP BY
        general_ledger
)
SELECT DISTINCT
    ps.general_ledger,
    fp.first_payment_date,
    fp.first_payment_amount,
    ps.last_payment_date,
    lp.last_payment_amount,
    ps.average_payment,
    ps.total_payments
FROM
    payment_stats ps
LEFT JOIN
    first_payment fp ON ps.general_ledger = fp.general_ledger
LEFT JOIN
    last_payment lp ON ps.general_ledger = lp.general_ledger
ORDER BY
    ps.general_ledger;

"""

# Run query and show results
result_df = con.run_query(sql_query)
result_df.show()


DataFrame()
+----------------------------------------------------------------------+--------------------+----------------------+-------------------+---------------------+--------------------+----------------+
| general_ledger                                                       | first_payment_date | first_payment_amount | last_payment_date | last_payment_amount | average_payment    | total_payments |
+----------------------------------------------------------------------+--------------------+----------------------+-------------------+---------------------+--------------------+----------------+
| 2 Broadway Certificates of Participation (2 Bdwy COPS w/o Int + Int) | 2019-01-01         | 0.0                  | 2023-12-01        | 434020.78           | 115023.14730392158 | 60             |
| Debt Service for Payroll Mobility Tax Bonds                          | 2021-01-01         | 0.0                  | 2024-06-01        | 34215833.25         | 5759055.434345238  | 42             |
| N

In [112]:
#Write a SQL query that includes a join between tables
sql_query = """
WITH weekly_ridership AS (
    SELECT 
        station_complex, 
        DATE_TRUNC('week', transit_timestamp) AS week_start,
        SUM(ridership) AS total_weekly_ridership,
        MIN(latitude) AS latitude,  -- Assuming latitude is the same for each station complex, use MIN() or MAX()
        MIN(longitude) AS longitude  -- Assuming longitude is the same for each station complex, use MIN() or MAX()
    FROM 
        mta_hourly_subway_socrata
    GROUP BY 
        station_complex, 
        DATE_TRUNC('week', transit_timestamp)
),
weekly_weather AS (
    SELECT 
        DATE_TRUNC('week', date) AS week_start,
        AVG(temperature_mean) AS avg_weekly_temperature,
        SUM(precipitation_sum) AS total_weekly_precipitation
    FROM 
        daily_weather_asset
    GROUP BY 
        DATE_TRUNC('week', date)
)
SELECT 
    wr.station_complex, 
    wr.week_start, 
    wr.total_weekly_ridership,
    wr.latitude,
    wr.longitude,
    ww.avg_weekly_temperature,
    ww.total_weekly_precipitation
FROM 
    weekly_ridership wr
LEFT JOIN 
    weekly_weather ww
ON 
    wr.week_start = ww.week_start
WHERE 
    wr.week_start < '2024-09-17'
ORDER BY 
    wr.station_complex, 
    wr.week_start

"""

# Run query and show results
result_df = con.run_query(sql_query)
result_df.show()


DataFrame()
+-----------------+---------------------+------------------------+-----------+-----------+------------------------+----------------------------+
| station_complex | week_start          | total_weekly_ridership | latitude  | longitude | avg_weekly_temperature | total_weekly_precipitation |
+-----------------+---------------------+------------------------+-----------+-----------+------------------------+----------------------------+
| 1 Av (L)        | 2022-01-31T00:00:00 | 44875                  | 40.730953 | -73.98163 | 26.785714285714285     | 36.9                       |
| 1 Av (L)        | 2022-02-07T00:00:00 | 93646                  | 40.730953 | -73.98163 | 37.357142857142854     | 14.7                       |
| 1 Av (L)        | 2022-02-14T00:00:00 | 99510                  | 40.730953 | -73.98163 | 32.51428571428571      | 15.200000000000001         |
| 1 Av (L)        | 2022-02-21T00:00:00 | 61907                  | 40.730953 | -73.98163 | 38.27142857142858      | 25