Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Pandas DF SQL reader and writer #353

Merged
merged 13 commits into from
Sep 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/pandas/materialization/my_script.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import sqlite3
import sys

import pandas as pd
Expand Down Expand Up @@ -32,6 +33,9 @@
"spend_zero_mean_unit_variance",
]

# set up db connection for sql materializer below
conn = sqlite3.connect("df.db")

materializers = [
# materialize the dataframe to a pickle file
to.pickle(
Expand All @@ -46,6 +50,13 @@
filepath_or_buffer="./df.json",
combine=df_builder,
),
to.sql(
dependencies=output_columns,
id="df_to_sql",
table_name="test",
db_connection=conn,
combine=df_builder,
),
]
# Visualize what is happening
dr.visualize_materialization(
Expand All @@ -61,9 +72,13 @@
additional_vars=[
"df_to_pickle_build_result",
"df_to_json_build_result",
"df_to_sql_build_result",
], # because combine is used, we can get that result here.
inputs=initial_columns,
)
print(materialization_results)
print(additional_outputs["df_to_pickle_build_result"])
print(additional_outputs["df_to_json_build_result"])
print(additional_outputs["df_to_sql_build_result"])

conn.close()
163 changes: 121 additions & 42 deletions examples/pandas/materialization/notebook.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions examples/pandas/materialization/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pandas
sf-hamilton
28 changes: 27 additions & 1 deletion hamilton/io/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from datetime import datetime
from typing import Any, Dict
from typing import Any, Dict, Union

import pandas as pd


def get_file_metadata(path: str) -> Dict[str, Any]:
Expand All @@ -17,3 +19,27 @@ def get_file_metadata(path: str) -> Dict[str, Any]:
"last_modified": os.path.getmtime(path),
"timestamp": datetime.now().utcnow().timestamp(),
}


def get_sql_metadata(query_or_table: str, results: Union[int, pd.DataFrame]) -> Dict[str, Any]:
skrawcz marked this conversation as resolved.
Show resolved Hide resolved
"""Gives metadata from reading a SQL table or writing to SQL db.
This includes:
- the number of rows read, added, or to add.
- the sql query (e.g., "SELECT foo FROM bar")
- the table name (e.g., "bar")
- the current time
"""
query = query_or_table if "SELECT" in query_or_table else None
table_name = query_or_table if "SELECT" not in query_or_table else None
if isinstance(results, int):
rows = results
elif isinstance(results, pd.DataFrame):
rows = len(results)
else:
rows = None
return {
"rows": rows,
"query": query,
"table_name": table_name,
"timestamp": datetime.now().utcnow().timestamp(),
}
123 changes: 121 additions & 2 deletions hamilton/plugins/pandas_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
from collections.abc import Hashable
from io import BufferedReader, BytesIO
from pathlib import Path
from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Collection, Dict, Iterator, List, Optional, Tuple, Type, Union

try:
import pandas as pd
except ImportError:
raise NotImplementedError("Pandas is not installed.")

from pandas._typing import Dtype
from sqlite3 import Connection
skrawcz marked this conversation as resolved.
Show resolved Hide resolved

from pandas._typing import NpDtype
from pandas.core.dtypes.dtypes import ExtensionDtype

from hamilton import registry
from hamilton.io import utils
Expand All @@ -21,6 +24,8 @@
COLUMN_TYPE = pd.Series

JSONSerializable = Optional[Union[str, float, bool, List, Dict]]
IndexLabel = Optional[Union[Hashable, Iterator[Hashable]]]
Dtype = Union[ExtensionDtype, NpDtype]


@registry.get_column.register(pd.DataFrame)
Expand Down Expand Up @@ -383,6 +388,118 @@ def name(cls) -> str:
return "json"


@dataclasses.dataclass
class PandasSqlReader(DataLoader):
"""Class specifically to handle loading SQL data using Pandas.

Disclaimer: We're exposing all the *current* params from the Pandas read_sql method.
Some of these params may get deprecated or new params may be introduced. In the event that
the params/kwargs below become outdated, please raise an issue or submit a pull request.

Should map to https://pandas.pydata.org/docs/reference/api/pandas.read_sql.html
bryangalindo marked this conversation as resolved.
Show resolved Hide resolved
Requires optional Pandas dependencies. See https://pandas.pydata.org/docs/getting_started/install.html#sql-databases.
"""

query_or_table: str
db_connection: Union[str, Connection] # can pass in SQLAlchemy engine/connection
# kwarg
chunksize: Optional[int] = None
coerce_float: bool = True
columns: Optional[List[str]] = None
dtype: Optional[Union[Dtype, Dict[Hashable, Dtype]]] = None
dtype_backend: Optional[str] = None
index_col: Optional[Union[str, List[str]]] = None
params: Optional[Union[List, Tuple, Dict]] = None
parse_dates: Optional[Union[List, Dict]] = None

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [DATAFRAME_TYPE]

def _get_loading_kwargs(self) -> Dict[str, Any]:
kwargs = {}
if self.chunksize is not None:
kwargs["chunksize"] = self.chunksize
if self.coerce_float is not None:
kwargs["coerce_float"] = self.coerce_float
if self.columns is not None:
kwargs["columns"] = self.columns
if self.dtype is not None:
kwargs["dtype"] = self.dtype
if self.dtype_backend is not None:
kwargs["dtype_backend"] = self.dtype_backend
if self.index_col is not None:
kwargs["index_col"] = self.index_col
if self.params is not None:
kwargs["params"] = self.params
if self.parse_dates is not None:
kwargs["parse_dates"] = self.parse_dates
return kwargs

def load_data(self, type_: Type) -> Tuple[DATAFRAME_TYPE, Dict[str, Any]]:
df = pd.read_sql(self.query_or_table, self.db_connection, **self._get_loading_kwargs())
metadata = utils.get_sql_metadata(self.query_or_table, df)
return df, metadata

@classmethod
def name(cls) -> str:
return "sql"


@dataclasses.dataclass
class PandasSqlWriter(DataSaver):
"""Class specifically to handle saving DataFrames to SQL databases using Pandas.

Disclaimer: We're exposing all the *current* params from the Pandas DataFrame.to_sql method.
Some of these params may get deprecated or new params may be introduced. In the event that
the params/kwargs below become outdated, please raise an issue or submit a pull request.

Should map to https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_sql.html
bryangalindo marked this conversation as resolved.
Show resolved Hide resolved
Requires optional Pandas dependencies. See https://pandas.pydata.org/docs/getting_started/install.html#sql-databases.
"""

table_name: str
db_connection: Union[str, Connection] # can pass in SQLAlchemy engine/connection
# kwargs
chunksize: Optional[int] = None
dtype: Optional[Union[Dtype, Dict[Hashable, Dtype]]] = None
if_exists: str = "fail"
index: bool = True
index_label: Optional[IndexLabel] = None
method: Optional[Union[str, Callable]] = None
schema: Optional[str] = None

@classmethod
def applicable_types(cls) -> Collection[Type]:
return [DATAFRAME_TYPE]

def _get_saving_kwargs(self) -> Dict[str, Any]:
kwargs = {}
if self.chunksize is not None:
kwargs["chunksize"] = self.chunksize
if self.dtype is not None:
kwargs["dtype"] = self.dtype
if self.if_exists is not None:
kwargs["if_exists"] = self.if_exists
if self.index is not None:
kwargs["index"] = self.index
if self.index_label is not None:
kwargs["index_label"] = self.index_label
if self.method is not None:
kwargs["method"] = self.method
if self.schema is not None:
kwargs["schema"] = self.schema
return kwargs

def save_data(self, data: DATAFRAME_TYPE) -> Dict[str, Any]:
results = data.to_sql(self.table_name, self.db_connection, **self._get_saving_kwargs())
return utils.get_sql_metadata(self.table_name, results)

@classmethod
def name(cls) -> str:
return "sql"


def register_data_loaders():
"""Function to register the data loaders for this extension."""
for loader in [
Expand All @@ -393,6 +510,8 @@ def register_data_loaders():
PandasPickleWriter,
PandasJsonReader,
PandasJsonWriter,
PandasSqlReader,
PandasSqlWriter,
]:
registry.register_adapter(loader)

Expand Down
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ networkx
pyarrow
pytest
pytest-cov
sqlalchemy==1.4.49; python_version == '3.7.*'
sqlalchemy; python_version >= '3.8'
18 changes: 18 additions & 0 deletions tests/io/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pandas as pd

from hamilton.io.utils import get_sql_metadata


def test_get_sql_metadata():
results = 5
table = "foo"
query = "SELECT foo FROM bar"
df = pd.DataFrame({"foo": ["bar"]})
metadata1 = get_sql_metadata(table, df)
metadata2 = get_sql_metadata(query, results)
metadata3 = get_sql_metadata(query, "foo")
assert metadata1["table_name"] == table
assert metadata1["rows"] == 1
assert metadata2["query"] == query
assert metadata2["rows"] == 5
assert metadata3["rows"] is None
92 changes: 88 additions & 4 deletions tests/plugins/test_pandas_extensions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import pathlib
import sqlite3
import sys
from typing import Union

import pandas as pd
import pytest
from sqlalchemy import create_engine

from hamilton.plugins.pandas_extensions import (
PandasJsonReader,
PandasJsonWriter,
PandasPickleReader,
PandasPickleWriter,
PandasSqlReader,
PandasSqlWriter,
)

DB_RELATIVE_PATH = "tests/resources/data/test.db"
DB_MEMORY_PATH = "sqlite://"
DB_DISK_PATH = f"{DB_MEMORY_PATH}/"
DB_ABSOLUTE_PATH = f"{DB_DISK_PATH}{DB_RELATIVE_PATH}"


def test_pandas_pickle(tmp_path: pathlib.Path) -> None:
data = {
Expand All @@ -32,23 +44,95 @@ def test_pandas_pickle(tmp_path: pathlib.Path) -> None:
assert len(list(tmp_path.iterdir())) == 1, "Unexpected number of files in tmp_path directory."


def test_pandas_json_reader(tmp_path: pathlib.Path) -> None:
def test_pandas_json_reader() -> None:
file_path = "tests/resources/data/test_load_from_data.json"
reader = PandasJsonReader(filepath_or_buffer=file_path, encoding="utf-8")
kwargs = reader._get_loading_kwargs()
df, metadata = reader.load_data(pd.DataFrame)

assert PandasJsonReader.applicable_types() == [pd.DataFrame]
assert kwargs["encoding"] == "utf-8"
assert df.shape == (3, 1)
assert metadata["path"] == file_path


def test_pandas_json_writer(tmp_path: pathlib.Path) -> None:
file_path = tmp_path / "test.json"
writer = PandasJsonWriter(filepath_or_buffer=file_path, indent=4)
kwargs = writer._get_saving_kwargs()
metadata = writer.save_data(pd.DataFrame({"foo": ["bar"]}))
writer.save_data(pd.DataFrame({"foo": ["bar"]}))

assert PandasJsonWriter.applicable_types() == [pd.DataFrame]
assert kwargs["indent"] == 4
assert file_path.exists()
assert metadata["path"] == file_path


@pytest.mark.parametrize(
"conn",
[
DB_ABSOLUTE_PATH,
sqlite3.connect(DB_RELATIVE_PATH),
create_engine(DB_ABSOLUTE_PATH),
],
)
def test_pandas_sql_reader(conn: Union[str, sqlite3.Connection]) -> None:
reader = PandasSqlReader(query_or_table="SELECT foo FROM bar", db_connection=conn)
kwargs = reader._get_loading_kwargs()
df, metadata = reader.load_data(pd.DataFrame)

assert PandasSqlReader.applicable_types() == [pd.DataFrame]
assert kwargs["coerce_float"] is True
assert metadata["rows"] == 1
assert df.shape == (1, 1)

if hasattr(conn, "close"):
conn.close()


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Test requires Python 3.8 or higher")
@pytest.mark.parametrize(
"conn",
[
DB_MEMORY_PATH,
sqlite3.connect(":memory:"),
create_engine(DB_MEMORY_PATH),
],
)
def test_pandas_sql_writer(conn: Union[str, sqlite3.Connection]) -> None:
writer = PandasSqlWriter(table_name="test", db_connection=conn)
kwargs = writer._get_saving_kwargs()
metadata = writer.save_data(pd.DataFrame({"foo": ["bar"]}))

assert PandasSqlWriter.applicable_types() == [pd.DataFrame]
assert kwargs["if_exists"] == "fail"
assert metadata["rows"] == 1

if hasattr(conn, "close"):
conn.close()


@pytest.mark.skipif(sys.version_info >= (3, 8), reason="Test requires Python 3.7.x")
@pytest.mark.parametrize(
"conn",
[
sqlite3.connect(":memory:"),
create_engine(DB_MEMORY_PATH),
],
)
def test_pandas_sql_writer_py37(conn: Union[str, sqlite3.Connection]) -> None:
"""Workaround for py37, pandas v1.3.5 since pandas DataFrame.to_sql
doesn't return the number of rows inserted. Also, we skip testing the str URI
since the pandas read_sql and to_sql will treat these as two separate dbs.
"""
df = pd.DataFrame({"foo": ["bar"]})
writer = PandasSqlWriter(table_name="test", db_connection=conn)
reader = PandasSqlReader(query_or_table="SELECT foo FROM test", db_connection=conn)
kwargs = writer._get_saving_kwargs()
writer.save_data(df)
df2, metadata = reader.load_data(pd.DataFrame)

assert PandasSqlWriter.applicable_types() == [pd.DataFrame]
assert kwargs["if_exists"] == "fail"
assert df.equals(df2)

if hasattr(conn, "close"):
conn.close()
Binary file added tests/resources/data/test.db
Binary file not shown.