In [None]:
#Run this cell 
#Run this cell 
import sys
from great_tables import GT
from IPython.display import display, HTML
try:
    # IPython is present only if a Jupyter environment is installed
    from IPython.display import display, HTML
    IN_NOTEBOOK = True
except ImportError:
    IN_NOTEBOOK = False

try:
    # Great Tables provides styling support for Polars
    from great_tables import loc, style
    HAS_GREAT_TABLES = True
except ImportError:
    HAS_GREAT_TABLES = False

import duckdb
from pathlib import Path
import polars as pl
import polars.selectors as cs
from rich.console import Console
from rich.table import Table
import sys
try:
    # IPython is present only if a Jupyter environment is installed
    from IPython.display import display, HTML
    IN_NOTEBOOK = True
except ImportError:
    IN_NOTEBOOK = False


class DuckDBWrapper:
    def __init__(self, duckdb_path=None):
        """
        Initialize a DuckDB connection.
        If duckdb_path is provided, a persistent DuckDB database will be used.
        Otherwise, it creates an in-memory database.
        """
        if duckdb_path:
            self.con = duckdb.connect(str(duckdb_path), read_only=False)
        else:
            self.con = duckdb.connect(database=':memory:', read_only=False)
        self.registered_tables = []

        # Enable httpfs for remote paths if needed
        self.con.execute("INSTALL httpfs;")
        self.con.execute("LOAD httpfs;")

    def register_data(self, paths, table_names):
        """
        Registers local data files (Parquet, CSV, JSON) in DuckDB by creating views.
        Automatically detects the file type based on the file extension.

        Args:
            paths (list): List of paths (strings or Path objects) to data files.
            table_names (list): List of table names corresponding to the paths.
        """
        if len(paths) != len(table_names):
            raise ValueError("The number of paths must match the number of table names.")

        for path, table_name in zip(paths, table_names):
            path_str = str(path)
            file_extension = Path(path_str).suffix.lower()

            if file_extension == ".parquet":
                query = f"CREATE VIEW {table_name} AS SELECT * FROM read_parquet('{path_str}')"
            elif file_extension == ".csv":
                query = f"CREATE VIEW {table_name} AS SELECT * FROM read_csv_auto('{path_str}')"
            elif file_extension == ".json":
                query = f"CREATE VIEW {table_name} AS SELECT * FROM read_json_auto('{path_str}')"
            else:
                raise ValueError(f"Unsupported file type '{file_extension}' for file: {path_str}")

            self.con.execute(query)
            self.registered_tables.append(table_name)

    def bulk_register_data(self, repo_root, base_path, table_names, wildcard="*.parquet"):
        """
        Constructs paths for each table based on a shared base path plus the table name,
        then appends a wildcard for file matching (e.g., '*.parquet'), and registers the data.

        Args:
            repo_root (Path): The root path of your repository.
            base_path (str): The relative path from repo_root to your data directory.
            table_names (list): The table names (and implicitly the folder names) to register.
            wildcard (str, optional): A wildcard pattern for the files (default '*.parquet').
        """
        paths = []
        for table_name in table_names:
            path = Path(repo_root) / base_path / table_name / wildcard
            paths.append(path)

        self.register_data(paths, table_names)

    def run_query(self, sql_query, show_results=False):
        """
        Runs a SQL query on the registered tables in DuckDB and returns a Polars DataFrame.
        Optionally displays the result. If running in a .ipynb environment and show_results=True,
        we attempt to style the Polars DataFrame using Great Tables for better visuals.

        Args:
            sql_query (str): The SQL query string to execute.
            show_results (bool): If True, display the result in a table (notebook or terminal).

        Returns:
            pl.DataFrame: Query result as a Polars DataFrame.
        """
        # Convert the DuckDB query results to Polars via Arrow
        arrow_table = self.con.execute(sql_query).arrow()
        df = pl.DataFrame(arrow_table)

        if show_results:
            if IN_NOTEBOOK and HAS_GREAT_TABLES:
                # 1) Create a styled table with Great Tables
                styled = (
                    df.style
                    # Add a title & subtitle
                    .tab_header(
                        title="DuckDB Query Results",
                        subtitle=f"{sql_query[:50]}..."
                    )
                    # Example format: limit decimal places for numeric columns
                    .fmt_number(cs.numeric(), decimals=3)
                )

                # 2) Convert styled table to HTML
                styled_html = styled._repr_html_()

                # 3) Wrap the HTML in a scrollable <div> for horizontal scroll
                scrollable_html = f"""
                <div style="max-width:100%; overflow-x:auto; white-space:nowrap;">
                    {styled_html}
                </div>
                """

                # 4) Display in notebook
                display(HTML(scrollable_html))
            else:
                # If not in a notebook environment or great_tables isn't installed,
                # fallback to existing Rich-based table printing
                self.print_query_results(df, title=f"Query: {sql_query[:50]}...")

        return df

    def print_query_results(self, df, title="Query Results"):
        """
        Prints a Polars DataFrame as a Rich table. Uses a pager for vertical/horizontal scrolling in the terminal.

        Args:
            df (pl.DataFrame): The Polars DataFrame to display.
            title (str): Title for the Rich table.
        """
        console = Console()

        # Use a pager for scrolling (works in many terminals, not always in notebooks)
        with console.pager(styles=True):
            table = Table(title=title, title_style="bold green", show_lines=True)
            # Add columns with some styling
            for column in df.columns:
                table.add_column(str(column), style="bold cyan", overflow="fold")

            # Add rows
            for row in df.iter_rows(named=True):
                values = [str(row[col]) for col in df.columns]
                table.add_row(*values, style="white on black")

            console.print(table)

    def _construct_path(self, path, base_path, file_name, extension):
        """
        Constructs the full file path based on input parameters.
        """
        if path:
            return Path(path)
        elif base_path and file_name:
            return Path(base_path) / f"{file_name}.{extension}"
        else:
            # Default file path: "output.<extension>" in the current directory
            return Path(f"output.{extension}")

    def export(self, result, file_type, path=None, base_path=None, file_name=None, with_header=True):
        """
        Exports a Polars DataFrame (or anything convertible to Polars via .to_arrow()) 
        to the specified file type.

        Args:
            result (pl.DataFrame or DuckDB query result): The data to export.
            file_type (str): Type of file to export ('parquet', 'csv', 'json').
            path (str): Full path to the file (optional).
            base_path (str): Directory path (optional).
            file_name (str): Name of the file (without extension) (optional).
            with_header (bool): Include header row for CSV (default: True).
        """
        file_type = file_type.lower()
        if file_type not in ["parquet", "csv", "json"]:
            raise ValueError("file_type must be one of 'parquet', 'csv', or 'json'.")

        full_path = self._construct_path(path, base_path, file_name, file_type)
        full_path.parent.mkdir(parents=True, exist_ok=True)

        # Convert result to a Polars DataFrame if needed
        if isinstance(result, pl.DataFrame):
            df = result
        elif hasattr(result, "to_arrow"):
            # e.g., a DuckDB result object
            df = pl.DataFrame(result.to_arrow())
        else:
            raise ValueError("Unsupported result type. Must be a Polars DataFrame or have a 'to_arrow()' method.")

        # Export based on file type
        if file_type == "parquet":
            df.write_parquet(str(full_path))
        elif file_type == "csv":
            # Use Polars parameter include_header instead of header or has_header
            df.write_csv(str(full_path), separator=",", include_header=with_header)
        elif file_type == "json":
            df.write_ndjson(str(full_path))  # NDJSON format

        print(f"File written to: {full_path}")

    def show_tables(self):
        """
        Displays the table names and types currently registered in the catalog 
        in a Rich-styled table (similar to show_schema).
        """
        query = """
        SELECT table_name, table_type
        FROM information_schema.tables
        WHERE table_schema='main'
        """
        df = self.run_query(query)  # Polars DataFrame
        console = Console()
        table = Table(title="Registered Tables", title_style="bold green", show_lines=True)
        table.add_column("Table Name", justify="left", style="bold yellow")
        table.add_column("Table Type", justify="left", style="bold cyan")

        for row in df.to_dicts():
            table.add_row(row["table_name"], row["table_type"], style="white on black")

        console.print(table)

    def show_schema(self, table_name):
        """
        Displays the schema of the specified DuckDB table or view, 
        using Polars for the query result and printing in a Rich table.
        
        Args:
            table_name (str): Name of the table/view whose schema is to be displayed.
        """
        query = f"""
        SELECT 
            column_name, 
            data_type
        FROM 
            information_schema.columns 
        WHERE 
            table_name = '{table_name}'
        """
        df = self.run_query(query)
        console = Console()
        schema_table = Table(title=f"Schema for '{table_name}'", title_style="bold green")
        schema_table.add_column("Column Name", justify="left", style="bold yellow", no_wrap=True)
        schema_table.add_column("Data Type", justify="left", style="bold cyan")

        for row in df.to_dicts():
            col_name = row["column_name"]
            col_type = row["data_type"]
            schema_table.add_row(col_name, str(col_type), style="white on black")

        console.print(schema_table)

    def show_parquet_schema(self, file_path):
        """
        Reads a Parquet file directly using Polars, and prints its schema 
        in a visually appealing table using Rich. Includes row count info.
        """
        df = pl.read_parquet(file_path)

        console = Console()
        schema_table = Table(title="Parquet Schema", title_style="bold green")
        schema_table.add_column("Column Name", justify="left", style="bold yellow", no_wrap=True)
        schema_table.add_column("Data Type", justify="left", style="bold cyan")

        for col_name, col_dtype in df.schema.items():
            schema_table.add_row(col_name, str(col_dtype), style="white on black")

        console.print(schema_table)
        console.print(f"[bold magenta]\nNumber of rows:[/] [bold white]{df.height}[/]")

    def register_partitioned_data(self, base_path, table_name, wildcard="*/*/*.parquet"):
        """
        Registers partitioned Parquet data using Hive partitioning by creating a view.

        Args:
            base_path (str or Path): The base directory where partitioned files are located.
            table_name (str): Name of the view to be created.
            wildcard (str, optional): Glob pattern to locate the parquet files (default '*/*/*.parquet').
        """
        path_str = str(Path(base_path) / wildcard)
        query = f"""
        CREATE OR REPLACE VIEW {table_name} AS 
        SELECT * FROM read_parquet('{path_str}', hive_partitioning=true)
        """
        self.con.execute(query)
        self.registered_tables.append(table_name)
        print(f"Partitioned view '{table_name}' created for files at '{path_str}'.")



In [3]:
# Initialize the DuckDBWrapper (in-memory DuckDB instance) You can connect directly to a DuckDB file by adding the path like con = DuckDBWrapper()
con = DuckDBWrapper()

In [4]:
#Register data
from pathlib import Path


repo_root = Path.cwd().resolve().parents[0]  # Adjust to locate the repo root

BASE_PATH = "data/opendata"

# Define table names that match the folder names under BASE_PATH
bulk_table_names = [
    "mta_operations_statement",
    "mta_hourly_subway_socrata",
    "mta_daily_ridership",
    "mta_bus_wait_time",
    "mta_bus_speeds",    
    "daily_weather_asset",
    "hourly_weather_asset",

]

# Use bulk_register_data to register them all with a single call
con.bulk_register_data(repo_root=repo_root, base_path=BASE_PATH, table_names=bulk_table_names, wildcard="*.parquet")

#Example of registing random CSV, parquet, and JSON files based on a path
#paths = [
#    "your/computer/data/exports/row_count.csv",
#    "your/computer/data/exports/row_count.json",
#    "your/computer/data/exports/row_count.parquet"
#]
#table_names = [
#   "row_count_csv_table",
#   "row_count_json_table",
#   "row_count_parquet_table"
#]
#con.register_data(paths, table_names)

In [None]:
query = f"""

SELECT * from mta_hourly_subway_socrata  limit 20000

"""

result = con.run_query(query)

print(result)


In [None]:
#If you want a better looking table, set show_results=True. I'd recomend capping the limit at about 50 rows
#T

query = f"""

SELECT * from mta_hourly_subway_socrata limit 50
"""

result = con.run_query(query,show_results=True)


In [None]:
#More complicated query

query = f"""

WITH weekly_ridership AS (
    SELECT 
        station_complex, 
        DATE_TRUNC('week', transit_timestamp) AS week_start,
        SUM(ridership) AS total_weekly_ridership,
        MIN(latitude) AS latitude,  -- Assuming latitude is the same for each station complex, use MIN() or MAX()
        MIN(longitude) AS longitude  -- Assuming longitude is the same for each station complex, use MIN() or MAX()
    FROM 
        mta_hourly_subway_socrata
    GROUP BY 
        station_complex, 
        DATE_TRUNC('week', transit_timestamp)
),
weekly_weather AS (
    SELECT 
        DATE_TRUNC('week', date) AS week_start,
        AVG(temperature_mean) AS avg_weekly_temperature,
        SUM(precipitation_sum) AS total_weekly_precipitation
    FROM 
        daily_weather_asset
    GROUP BY 
        DATE_TRUNC('week', date)
)
SELECT 
    wr.station_complex, 
    wr.week_start, 
    wr.total_weekly_ridership,
    wr.latitude,
    wr.longitude,
    ww.avg_weekly_temperature,
    ww.total_weekly_precipitation
FROM 
    weekly_ridership wr
LEFT JOIN 
    weekly_weather ww
ON 
    wr.week_start = ww.week_start
WHERE 
    wr.week_start < '2024-09-17'
ORDER BY 
    wr.station_complex, 
    wr.week_start
 

LIMIT 50
"""

result = con.run_query(query,show_results=True)


In [None]:
# Show the tables registered
con.show_tables()


In [None]:
# Show the schema of a specific table
con.show_schema("mta_hourly_subway_socrata")

In [None]:
repo_root = Path.cwd().resolve().parents[0]  # Adjust to locate the repo root
base_path = repo_root / "data/exports"
file_name = "mta_hourly_subway_socrata_data_sample"
file_type= "csv"
# Export the query result to CSV
con.export(result, file_type=file_type, base_path=base_path, file_name=file_name)