Skip to content

Commit

Permalink
Change data_schemas.to_schema to directly compile spark schema metada…
Browse files Browse the repository at this point in the history
…ta into a schema

Signed-off-by: Avi Shinnar <shinnar@us.ibm.com>
  • Loading branch information
shinnar committed Feb 3, 2023
1 parent 660bbf9 commit 104e19d
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 4 deletions.
91 changes: 90 additions & 1 deletion lale/datasets/data_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def index_names(self) -> List[str]:
def toPandas(self, *args, **kwargs) -> DataFrame:
raise ValueError("pyspark is not installed") # type: ignore

@property
def schema(self) -> Any:
raise ValueError("pyspark is not installed") # type: ignore


def add_schema(obj, schema=None, raise_on_failure=False, recalc=False) -> Any:
from lale.settings import disable_data_schema_validation
Expand Down Expand Up @@ -611,6 +615,91 @@ def liac_arff_to_schema(larff) -> JSON_TYPE:
return result


def make_optional_schema(schema: JSON_TYPE) -> JSON_TYPE:
return {"anyOf": [schema, {"enum": [None]}]}


def _spark_df_to_schema(df) -> JSON_TYPE:
assert spark_installed, """Your Python environment does not have spark installed. You can install it with
pip install pyspark
"""
assert isinstance(df, SparkDataFrameWithIndex)

import pyspark.sql.types as stypes
from pyspark.sql.types import StructField, StructType

def maybe_make_optional(schema: JSON_TYPE, is_option: bool) -> JSON_TYPE:
if is_option:
return make_optional_schema(schema)
return schema

def spark_datatype_to_json_schema(dtype: stypes.DataType) -> JSON_TYPE:
if isinstance(dtype, stypes.ArrayType):
return {
"type": "array",
"items": maybe_make_optional(
spark_datatype_to_json_schema(dtype.elementType), dtype.containsNull
),
}
if isinstance(dtype, stypes.BooleanType):
return {"type": "boolean"}
if isinstance(dtype, stypes.DoubleType):
return {"type": "number"}
if isinstance(dtype, stypes.FloatType):
return {"type": "number"}
if isinstance(dtype, stypes.IntegerType):
return {"type": "integer"}
if isinstance(dtype, stypes.LongType):
return {"type": "integer"}
if isinstance(dtype, stypes.ShortType):
return {"type": "integer"}
if isinstance(dtype, stypes.NullType):
return {"enum": [None]}
if isinstance(dtype, stypes.StringType):
return {"type": "string"}

return {}

def spark_struct_field_to_json_schema(f: StructField) -> JSON_TYPE:
type_schema = spark_datatype_to_json_schema(f.dataType)
result = maybe_make_optional(type_schema, f.nullable)

if f.name is not None:
result["description"] = f.name
return result

def spark_struct_to_json_schema(
s: StructType, index_names, table_name: Optional[str] = None
) -> JSON_TYPE:
items = [
spark_struct_field_to_json_schema(f) for f in s if f.name not in index_names
]
num_items = len(items)
result = {
"type": "array",
"items": {
"type": "array",
"description": "rows",
"minItems": num_items,
"maxItems": num_items,
"items": items,
},
}

if table_name is not None:
result["description"] = table_name

return result

return spark_struct_to_json_schema(df.schema, df.index_names, get_table_name(df))


def spark_df_to_schema(df) -> JSON_TYPE:
result = _spark_df_to_schema(df)
lale.type_checking.validate_is_schema(result)
return result


def _to_schema(obj) -> JSON_TYPE:
result = None
if obj is None:
Expand All @@ -630,7 +719,7 @@ def _to_schema(obj) -> JSON_TYPE:
elif isinstance(obj, list):
result = _list_tensor_to_schema(obj)
elif _is_spark_df(obj):
result = _dataframe_to_schema(obj.toPandas())
result = _spark_df_to_schema(obj)
elif lale.type_checking.is_schema(obj):
result = obj
# Does not need to validate again the schema
Expand Down
15 changes: 12 additions & 3 deletions test/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@
from test import EnableSchemaValidation # pylint:disable=wrong-import-order

from lale.datasets import pandas2spark
from lale.datasets.data_schemas import add_table_name, get_index_name, get_table_name
from lale.datasets.data_schemas import (
add_table_name,
get_index_name,
get_table_name,
make_optional_schema,
)
from lale.datasets.multitable import multitable_train_test_split
from lale.datasets.multitable.fetch_datasets import fetch_go_sales_dataset
from lale.expressions import ( # pylint:disable=redefined-builtin
Expand Down Expand Up @@ -2036,7 +2041,9 @@ def expr(X):

if s is None:
ret["unknown_" + c] = it[c]
elif type_checking.is_subschema(s, {"type": "number"}):
elif type_checking.is_subschema(
s, make_optional_schema({"type": "number"})
):
ret["num_" + c] = it[c]
ret["shifted_" + c] = it[c] + 5
else:
Expand Down Expand Up @@ -2107,7 +2114,9 @@ def test_dynamic_trainable(self):
def test_project(self):
from lale.lib.lale import Project

pipeline = Scan(table=it.go_products) >> Project(columns={"type": "number"})
pipeline = Scan(table=it.go_products) >> Project(
columns=make_optional_schema({"type": "number"})
)
for _tgt, datasets in self.tgt2datasets.items():
datasets = datasets["go_sales"]
result = pipeline.fit(datasets).transform(datasets)
Expand Down

0 comments on commit 104e19d

Please sign in to comment.