Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/reference/model_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,5 +274,8 @@ Options specified within the `kind` property's `csv_settings` property (override
| `skipinitialspace` | Skip spaces after delimiter. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | bool | N |
| `lineterminator` | Character used to denote a line break. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N |
| `encoding` | Encoding to use for UTF when reading/writing (ex. 'utf-8'). More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N |
| `na_values` | An array of values that should be recognized as NA/NaN. In order to specify such an array per column, a mapping in the form of `(col1 = (v1, v2, ...), col2 = ...)` can be passed instead. These values can be integers, strings, booleans or NULL, and they are converted to their corresponding Python values. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | array[value] \| array[array[key = value]] | N |
| `keep_default_na` | Whether or not to include the default NaN values when parsing the data. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | bool | N |


Python model kind `name` enum value: [ModelKindName.SEED](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/model/kind.html#ModelKindName)
3 changes: 2 additions & 1 deletion sqlmesh/core/model/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,10 @@ def to_expression(

@property
def data_hash_values(self) -> t.List[t.Optional[str]]:
csv_setting_values = (self.csv_settings or CsvSettings()).dict().values()
return [
*super().data_hash_values,
*(self.csv_settings or CsvSettings()).dict().values(),
*(v if isinstance(v, (str, type(None))) else str(v) for v in csv_setting_values),
]

@property
Expand Down
41 changes: 40 additions & 1 deletion sqlmesh/core/model/seed.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import typing as t
import zlib
from io import StringIO
Expand All @@ -8,12 +9,18 @@
import pandas as pd
from sqlglot import exp
from sqlglot.dialects.dialect import UNESCAPED_SEQUENCES
from sqlglot.helper import seq_get
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers

from sqlmesh.core.model.common import parse_bool
from sqlmesh.utils.pandas import columns_to_types_from_df
from sqlmesh.utils.pydantic import PydanticModel, field_validator

logger = logging.getLogger(__name__)

NaHashables = t.List[t.Union[int, str, bool, t.Literal[None]]]
NaValues = t.Union[NaHashables, t.Dict[str, NaHashables]]


class CsvSettings(PydanticModel):
"""Settings for CSV seeds."""
Expand All @@ -25,8 +32,10 @@ class CsvSettings(PydanticModel):
skipinitialspace: t.Optional[bool] = None
lineterminator: t.Optional[str] = None
encoding: t.Optional[str] = None
na_values: t.Optional[NaValues] = None
keep_default_na: t.Optional[bool] = None

@field_validator("doublequote", "skipinitialspace", mode="before")
@field_validator("doublequote", "skipinitialspace", "keep_default_na", mode="before")
@classmethod
def _bool_validator(cls, v: t.Any) -> t.Optional[bool]:
if v is None:
Expand All @@ -46,6 +55,36 @@ def _str_validator(cls, v: t.Any) -> t.Optional[str]:
v = v.this
return UNESCAPED_SEQUENCES.get(v, v)

@field_validator("na_values", mode="before")
@classmethod
def _na_values_validator(cls, v: t.Any) -> t.Optional[NaValues]:
if v is None or not isinstance(v, exp.Expression):
return v

try:
if isinstance(v, exp.Paren) or not isinstance(v, (exp.Tuple, exp.Array)):
v = exp.Tuple(expressions=[v.unnest()])

expressions = v.expressions
if isinstance(seq_get(expressions, 0), (exp.PropertyEQ, exp.EQ)):
return {
e.left.name: [
rhs_val.to_py()
for rhs_val in (
[e.right.unnest()]
if isinstance(e.right, exp.Paren)
else e.right.expressions
)
]
for e in expressions
}

return [e.to_py() for e in expressions]
except ValueError as e:
logger.warning(f"Failed to coerce na_values '{v}', proceeding with defaults. {str(e)}")

return None


class CsvSeedReader:
def __init__(self, content: str, dialect: str, settings: CsvSettings):
Expand Down
66 changes: 65 additions & 1 deletion tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,8 @@ def test_seed_csv_settings():
csv_settings (
quotechar = '''',
escapechar = '\\',
keep_default_na = false,
na_values = (id = [1, '2', false, null], alias = ('foo'))
),
),
columns (
Expand All @@ -910,7 +912,39 @@ def test_seed_csv_settings():
model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))

assert isinstance(model.kind, SeedKind)
assert model.kind.csv_settings == CsvSettings(quotechar="'", escapechar="\\")
assert model.kind.csv_settings == CsvSettings(
quotechar="'",
escapechar="\\",
na_values={"id": [1, "2", False, None], "alias": ["foo"]},
keep_default_na=False,
)
assert model.kind.data_hash_values == [
"SEED",
"'",
"\\",
"{'id': [1, '2', False, None], 'alias': ['foo']}",
"False",
]

expressions = d.parse(
"""
MODEL (
name db.seed,
kind SEED (
path '../seeds/waiter_names.csv',
csv_settings (
na_values = ('#N/A', 'other')
),
),
);
"""
)

model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))

assert isinstance(model.kind, SeedKind)
assert model.kind.csv_settings == CsvSettings(na_values=["#N/A", "other"])
assert model.kind.data_hash_values == ["SEED", "['#N/A', 'other']"]


def test_seed_marker_substitution():
Expand Down Expand Up @@ -7755,3 +7789,33 @@ def get_current_date(evaluator):
FROM "discount_promotion_dates" AS "discount_promotion_dates"
""",
)


def test_seed_dont_coerce_na_into_null(tmp_path):
model_csv_path = (tmp_path / "model.csv").absolute()

with open(model_csv_path, "w", encoding="utf-8") as fd:
fd.write("code\nNA")

expressions = d.parse(
f"""
MODEL (
name db.seed,
kind SEED (
path '{str(model_csv_path)}',
csv_settings (
-- override NaN handling, such that no value can be coerced into NaN
keep_default_na = false,
na_values = (),
),
),
);
"""
)

model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql"))

assert isinstance(model.kind, SeedKind)
assert model.seed is not None
assert len(model.seed.content) > 0
assert next(model.render(context=None)).to_dict() == {"code": {0: "NA"}}