diff --git a/cloudbuild.yaml b/cloudbuild.yaml index acb9367a0..fe18a65ab 100644 --- a/cloudbuild.yaml +++ b/cloudbuild.yaml @@ -41,3 +41,10 @@ steps: - 'NOX_SESSION=integration_bigquery' - 'PROJECT_ID=pso-kokoro-resources' waitFor: ['-'] +- id: integration_spanner + name: 'gcr.io/pso-kokoro-resources/python-multi' + args: ['bash', './ci/build.sh'] + env: + - 'NOX_SESSION=integration_spanner' + - 'PROJECT_ID=pso-kokoro-resources' + waitFor: ['-'] diff --git a/noxfile.py b/noxfile.py index e8dd0ec1c..a3fc3687c 100644 --- a/noxfile.py +++ b/noxfile.py @@ -137,3 +137,19 @@ def integration_bigquery(session): raise Exception("Expected Env Var: %s" % env_var) session.run("pytest", test_path, *session.posargs) + + +@nox.session(python=PYTHON_VERSIONS, venv_backend="venv") +def integration_spanner(session): + """Run Spanner integration tests. + Ensure Spanner validation is running as expected. + """ + _setup_session_requirements(session, extra_packages=[]) + + expected_env_vars = ["PROJECT_ID"] + for env_var in expected_env_vars: + if not os.environ.get(env_var, ""): + raise Exception("Expected Env Var: %s" % env_var) + + # TODO: Add tests for DVT data sources. See integration_bigquery. + session.run("pytest", "third_party/ibis/ibis_cloud_spanner/tests", *session.posargs) diff --git a/requirements.txt b/requirements.txt index b24beb0ea..aa6804968 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,5 +16,6 @@ pyarrow==3.0.0 pydata-google-auth==1.1.0 google-cloud-bigquery==2.7.0 google-cloud-bigquery-storage==2.2.1 +google-cloud-spanner==3.1.0 setuptools>=34.0.0 jellyfish==0.8.2 diff --git a/third_party/ibis/ibis_cloud_spanner/__init__.py b/third_party/ibis/ibis_cloud_spanner/__init__.py new file mode 100644 index 000000000..139597f9c --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/third_party/ibis/ibis_cloud_spanner/api.py b/third_party/ibis/ibis_cloud_spanner/api.py new file mode 100644 index 000000000..d035b5320 --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/api.py @@ -0,0 +1,80 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""CloudSpanner public API.""" + + +from third_party.ibis.ibis_cloud_spanner.client import CloudSpannerClient +from third_party.ibis.ibis_cloud_spanner.compiler import dialect + +import google.cloud.spanner # noqa: F401, fail early if spanner is missing +import ibis.common.exceptions as com + +__all__ = ("compile", "connect", "verify") + + +def compile(expr, params=None): + """Compile an expression for Cloud Spanner. + + Returns + ------- + compiled : str + + See Also + -------- + ibis.expr.types.Expr.compile + + """ + from third_party.ibis.ibis_cloud_spanner.compiler import to_sql + + return to_sql(expr, dialect.make_context(params=params)) + + +def verify(expr, params=None): + """Check if an expression can be compiled using Cloud Spanner.""" + try: + compile(expr, params=params) + return True + except com.TranslationError: + return False + + +def connect( + instance_id, + database_id, + project_id=None, +) -> CloudSpannerClient: + """Create a CloudSpannerClient for use with Ibis. + + Parameters + ---------- + instance_id : str + A Cloud Spanner Instance id. + database_id : str + A database id inside of the Cloud Spanner Instance + project_id : str (Optional) + The ID of the project which owns the instances, tables and data. + + Returns + ------- + CloudSpannerClient + + """ + + return CloudSpannerClient( + instance_id=instance_id, + database_id=database_id, + project_id=project_id, + ) diff --git a/third_party/ibis/ibis_cloud_spanner/client.py b/third_party/ibis/ibis_cloud_spanner/client.py new file mode 100644 index 000000000..1f1692e69 --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/client.py @@ -0,0 +1,389 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Cloud Spanner ibis client implementation.""" + +import datetime +from typing import Optional, Tuple + +import google.cloud.spanner as cs +from google.cloud import spanner +import pandas as pd +import re +from multipledispatch import Dispatcher + +import ibis +import ibis.common.exceptions as com +import ibis.expr.datatypes as dt +import ibis.expr.lineage as lin +import ibis.expr.operations as ops +import ibis.expr.types as ir +from third_party.ibis.ibis_cloud_spanner import compiler as comp +from third_party.ibis.ibis_cloud_spanner.datatypes import ( + ibis_type_to_cloud_spanner_type, +) +from ibis.client import Database, Query, SQLClient + +from third_party.ibis.ibis_cloud_spanner import table + +from google.cloud.spanner_v1 import TypeCode +from third_party.ibis.ibis_cloud_spanner.to_pandas import pandas_df + + +def parse_instance_and_dataset( + instance: str, dataset: Optional[str] = None +) -> Tuple[str, str, Optional[str]]: + + data_instance = instance + dataset = dataset + + return data_instance, dataset + + +class CloudSpannerTable(ops.DatabaseTable): + pass + + +def _find_scalar_parameter(expr): + """Find all :class:`~ibis.expr.types.ScalarParameter` instances. + + Parameters + ---------- + expr : ibis.expr.types.Expr + + Returns + ------- + Tuple[bool, object] + The operation and the parent expresssion's resolved name. + + """ + op = expr.op() + + if isinstance(op, ops.ScalarParameter): + result = op, expr.get_name() + else: + result = None + return lin.proceed, result + + +def convert_to_cs_type(dtype): + if dtype == "FLOAT64": + return spanner.param_types.FLOAT64 + elif dtype == "INT64": + return spanner.param_types.INT64 + elif dtype == "DATE": + return spanner.param_types.DATE + elif dtype == "TIMESTAMP": + return spanner.param_types.TIMESTAMP + elif dtype == "NUMERIC": + return spanner.param_types.NUMERIC + elif dtype == "INT64": + return spanner.param_types.INT64 + else: + return spanner.param_types.STRING + + +cloud_spanner_param = Dispatcher("cloud_spanner_param") + + +@cloud_spanner_param.register(ir.ArrayValue, list) +def cs_param_array(param, value): + param_type = param.type() + assert isinstance(param_type, dt.Array), str(param_type) + + try: + spanner_type = ibis_type_to_cloud_spanner_type(param_type.value_type) + except NotImplementedError: + raise com.UnsupportedBackendType(param_type) + else: + if isinstance(param_type.value_type, dt.Struct): + raise TypeError("ARRAY> is not supported in Cloud Spanner") + elif isinstance(param_type.value_type, dt.Array): + raise TypeError("ARRAY> is not supported in Cloud Spanner") + else: + query_value = value + + params = ({param.get_name(): query_value},) + param_types = {param.get_name(): convert_to_cs_type(spanner_type)} + final_dict = {"params": params, "param_types": param_types} + + return final_dict + + +@cloud_spanner_param.register( + ir.TimestampScalar, (str, datetime.datetime, datetime.date) +) +def cs_param_timestamp(param, value): + assert isinstance(param.type(), dt.Timestamp), str(param.type()) + + timestamp_value = pd.Timestamp(value, tz="UTC").to_pydatetime() + params = ({param.get_name(): timestamp_value},) + param_types = {param.get_name(): spanner.param_types.TIMESTAMP} + final_dict = {"params": params[0], "param_types": param_types} + return final_dict + + +@cloud_spanner_param.register(ir.StringScalar, str) +def cs_param_string(param, value): + params = ({param.get_name(): value},) + param_types = {param.get_name(): spanner.param_types.STRING} + final_dict = {"params": params[0], "param_types": param_types} + return final_dict + + +@cloud_spanner_param.register(ir.IntegerScalar, int) +def cs_param_integer(param, value): + params = ({param.get_name(): value},) + param_types = {param.get_name(): spanner.param_types.INT64} + final_dict = {"params": params[0], "param_types": param_types} + return final_dict + + +@cloud_spanner_param.register(ir.FloatingScalar, float) +def cs_param_double(param, value): + params = ({param.get_name(): value},) + param_types = {param.get_name(): spanner.param_types.FLOAT64} + final_dict = {"params": params[0], "param_types": param_types} + return final_dict + + +@cloud_spanner_param.register(ir.BooleanScalar, bool) +def cs_param_boolean(param, value): + params = ({param.get_name(): value},) + param_types = {param.get_name(): spanner.param_types.BOOL} + final_dict = {"params": params[0], "param_types": param_types} + return final_dict + + +@cloud_spanner_param.register(ir.DateScalar, str) +def cs_param_date_string(param, value): + params = ({param.get_name(): pd.Timestamp(value).to_pydatetime().date()},) + param_types = {param.get_name(): spanner.param_types.DATE} + final_dict = {"params": params[0], "param_types": param_types} + return final_dict + + +@cloud_spanner_param.register(ir.DateScalar, datetime.datetime) +def cs_param_date_datetime(param, value): + params = ({param.get_name(): value.date()},) + param_types = {param.get_name(): spanner.param_types.DATE} + final_dict = {"params": params[0], "param_types": param_types} + return final_dict + + +@cloud_spanner_param.register(ir.DateScalar, datetime.date) +def cs_param_date(param, value): + params = ({param.get_name(): value},) + param_types = {param.get_name(): spanner.param_types.DATE} + final_dict = {"params": params[0], "param_types": param_types} + return final_dict + + +class CloudSpannerQuery(Query): + def __init__(self, client, ddl, query_parameters=None): + super().__init__(client, ddl) + + # self.expr comes from the parent class + query_parameter_names = dict(lin.traverse(_find_scalar_parameter, self.expr)) + + self.query_parameters = [ + cloud_spanner_param( + param.to_expr().name(query_parameter_names[param]), value + ) + for param, value in (query_parameters or {}).items() + ] + + def execute(self): + dataframe_output = self.client._execute( + self.compiled_sql, results=True, query_parameters=self.query_parameters + ) + + return dataframe_output + + +class CloudSpannerDatabase(Database): + """A Cloud spanner dataset.""" + + +class CloudSpannerClient(SQLClient): + """An ibis CloudSpanner client implementation.""" + + query_class = CloudSpannerQuery + database_class = CloudSpannerDatabase + table_class = CloudSpannerTable + + def __init__(self, instance_id, database_id, project_id=None, credentials=None): + """Construct a CloudSpannerClient. + + Parameters + ---------- + instance_id : str + A instance name + database_id : Optional[str] + A ``.`` string or just a dataset name + project_id : str (Optional) + The ID of the project which owns the instances, tables and data. + + + """ + self.spanner_client = spanner.Client(project=project_id) + self.instance = self.spanner_client.instance(instance_id) + self.database_name = self.instance.database(database_id) + ( + self.data_instance, + self.dataset, + ) = parse_instance_and_dataset(instance_id, database_id) + self.client = cs.Client() + + def _parse_instance_and_dataset(self, dataset): + if not dataset and not self.dataset: + raise ValueError("Unable to determine Cloud Spanner dataset.") + instance, dataset = parse_instance_and_dataset( + self.data_instance,(dataset or self.dataset) + ) + + return instance, dataset + + def get_data_using_query(self, query, results=False): + return self._execute(query, results=results) + + @property + def instance_id(self): + return self.data_instance + + @property + def dataset_id(self): + return self.dataset + + def table(self, name, database=None): + t = super().table(name, database=database) + return t + + def _build_ast(self, expr, context): + result = comp.build_ast(expr, context) + return result + + def _get_query(self, dml, **kwargs): + return self.query_class(self, dml, query_parameters=dml.context.params) + + def _fully_qualified_name(self, name, database): + return name + + def _get_table_schema(self, qualified_name): + table = qualified_name + dataset = self.dataset_id + assert dataset is not None, "dataset is None" + return self.get_schema(table, database=dataset) + + @property + def current_database(self): + return self.database(self.dataset) + + def list_databases(self, like=None): + databases = self.instance.list_databases() + list_db = [] + for row in databases: + list_db.append((row.name).rsplit("/", 1)[1]) + return list_db + + def list_tables(self, like=None, database=None): + # TODO: use list_tables from the Database class when available. + + if database is None: + db_value = self.dataset_id + else: + db_value = database + db = self.instance.database(db_value) + tables = [] + with db.snapshot() as snapshot: + query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES where SPANNER_STATE = 'COMMITTED' " + results = snapshot.execute_sql(query) + for row in results: + tables.append(row[0]) + + if like: + tables = [ + table_name + for table_name in tables + if re.match(like, table_name) is not None + ] + return tables + + def exists_table(self, name, database=None): + + if database is None: + database = self.dataset_id + + db_value = self.instance.database(database) + result = table.Table(name, db_value).exists() + return result + + def get_schema(self, table_id, database=None): + if database is None: + database = self.dataset_id + db_value = self.instance.database(database) + table_schema = table.Table(table_id, db_value).schema + + t_schema = [] + for item in table_schema: + field_name = item.name + + if item.type_.code == TypeCode.ARRAY: + field_type = "array<{}>".format(item.type_.array_element_type.code.name) + elif item.type_.code == TypeCode.BYTES: + field_type = "binary" + elif item.type_.code == TypeCode.NUMERIC: + field_type = "decimal" + else: + field_type = item.type_.code.name + + final_item = (field_name, field_type) + + t_schema.append(final_item) + + return ibis.schema(t_schema) + + def _execute(self, stmt, results=True, query_parameters=None): + + spanner_client = spanner.Client() + instance_id = self.instance_id + instance = spanner_client.instance(instance_id) + database_id = self.dataset_id + database_1 = instance.database(database_id) + + with database_1.snapshot() as snapshot: + data_qry = pandas_df.to_pandas(snapshot, stmt, query_parameters) + return data_qry + + def database(self, name=None): + if name is None and self.dataset is None: + raise ValueError( + "Unable to determine Cloud Spanner dataset. Call " + "client.database('my_dataset') or set_database('my_dataset') " + "to assign your client a dataset." + ) + return self.database_class(name or self.dataset, self) + + def set_database(self, name): + self.data_instance, self.dataset = self._parse_instance_and_dataset(name) + + def dataset(self, database): + spanner_client = spanner.Client() + instance = spanner_client.instance(self.data_instance) + database = instance.database(database) + + def exists_database(self, name): + return self.instance.database(name).exists() + + diff --git a/third_party/ibis/ibis_cloud_spanner/compiler.py b/third_party/ibis/ibis_cloud_spanner/compiler.py new file mode 100644 index 000000000..2f0044a6d --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/compiler.py @@ -0,0 +1,57 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import ibis.expr.operations as ops +from ibis.backends.bigquery import compiler as bigquery_compiler + + +def build_ast(expr, context): + builder = bigquery_compiler.BigQueryQueryBuilder(expr, context=context) + return builder.get_result() + + +def to_sql(expr, context): + query_ast = build_ast(expr, context) + compiled = query_ast.compile() + return compiled + + +def _array_index(translator, expr): + # SAFE_OFFSET returns NULL if out of bounds + return "{}[OFFSET({})]".format(*map(translator.translate, expr.op().args)) + + +def _translate_pattern(translator, pattern): + # add 'r' to string literals to indicate to Cloud Spanner this is a raw string + return "r" * isinstance(pattern.op(), ops.Literal) + translator.translate(pattern) + + +def _regex_extract(translator, expr): + arg, pattern, index = expr.op().args + regex = _translate_pattern(translator, pattern) + result = "REGEXP_EXTRACT({}, {})".format(translator.translate(arg), regex) + return result + + +_operation_registry = bigquery_compiler._operation_registry.copy() +_operation_registry.update( + {ops.RegexExtract: _regex_extract, ops.ArrayIndex: _array_index,} +) + + +compiles = bigquery_compiler.BigQueryExprTranslator.compiles +rewrites = bigquery_compiler.BigQueryExprTranslator.rewrites + + +dialect = bigquery_compiler.BigQueryDialect diff --git a/third_party/ibis/ibis_cloud_spanner/datatypes.py b/third_party/ibis/ibis_cloud_spanner/datatypes.py new file mode 100644 index 000000000..4f1c2e093 --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/datatypes.py @@ -0,0 +1,89 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from multipledispatch import Dispatcher + +import ibis.expr.datatypes as dt + + +class TypeTranslationContext: + """A tag class to allow alteration of the way a particular type is + translated.""" + + __slots__ = () + + +ibis_type_to_cloud_spanner_type = Dispatcher("ibis_type_to_cloud_spanner_type") + + +@ibis_type_to_cloud_spanner_type.register(str) +def trans_string_default(datatype): + return ibis_type_to_cloud_spanner_type(dt.dtype(datatype)) + + +@ibis_type_to_cloud_spanner_type.register(dt.DataType) +def trans_default(t): + return ibis_type_to_cloud_spanner_type(t, TypeTranslationContext()) + + +@ibis_type_to_cloud_spanner_type.register(str, TypeTranslationContext) +def trans_string_context(datatype, context): + return ibis_type_to_cloud_spanner_type(dt.dtype(datatype), context) + + +@ibis_type_to_cloud_spanner_type.register(dt.Floating, TypeTranslationContext) +def trans_float64(t, context): + return "FLOAT64" + + +@ibis_type_to_cloud_spanner_type.register(dt.Integer, TypeTranslationContext) +def trans_integer(t, context): + return "INT64" + + +@ibis_type_to_cloud_spanner_type.register(dt.Array, TypeTranslationContext) +def trans_array(t, context): + return "ARRAY<{}>".format(ibis_type_to_cloud_spanner_type(t.value_type, context)) + + +@ibis_type_to_cloud_spanner_type.register(dt.Date, TypeTranslationContext) +def trans_date(t, context): + return "DATE" + + +@ibis_type_to_cloud_spanner_type.register(dt.Timestamp, TypeTranslationContext) +def trans_timestamp(t, context): + return "TIMESTAMP" + + +@ibis_type_to_cloud_spanner_type.register(dt.DataType, TypeTranslationContext) +def trans_type(t, context): + return str(t).upper() + + +@ibis_type_to_cloud_spanner_type.register(dt.UInt64, TypeTranslationContext) +def trans_lossy_integer(t, context): + raise TypeError( + "Conversion from uint64 to Cloud Spanner integer type (int64) is lossy" + ) + + +@ibis_type_to_cloud_spanner_type.register(dt.Decimal, TypeTranslationContext) +def trans_numeric(t, context): + if (t.precision, t.scale) != (38, 9): + raise TypeError( + "Cloud Spanner only supports decimal types with precision of 38 and " + "scale of 9" + ) + return "NUMERIC" diff --git a/third_party/ibis/ibis_cloud_spanner/table.py b/third_party/ibis/ibis_cloud_spanner/table.py new file mode 100644 index 000000000..68d8649eb --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/table.py @@ -0,0 +1,113 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""User friendly container for Cloud Spanner Table.""" + +from google.cloud.exceptions import NotFound + +from google.cloud.spanner_v1 import Type +from google.cloud.spanner_v1 import TypeCode + + +_EXISTS_TEMPLATE = """ +SELECT EXISTS( + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_NAME = @table_id +) +""" +_GET_SCHEMA_TEMPLATE = "SELECT * FROM {} LIMIT 0" + + +class Table(object): + """Representation of a Cloud Spanner Table. + :type table_id: str + :param table_id: The ID of the table. + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: The database that owns the table. + """ + + def __init__(self, table_id, database): + self._table_id = table_id + self._database = database + + # Calculated properties. + self._schema = None + + @property + def table_id(self): + """The ID of the table used in SQL. + :rtype: str + :returns: The table ID. + """ + return self._table_id + + def exists(self): + """Test whether this table exists. + :rtype: bool + :returns: True if the table exists, else false. + """ + with self._database.snapshot() as snapshot: + return self._exists(snapshot) + + def _exists(self, snapshot): + """Query to check that the table exists. + :type snapshot: :class:`~google.cloud.spanner_v1.snapshot.Snapshot` + :param snapshot: snapshot to use for database queries + :rtype: bool + :returns: True if the table exists, else false. + """ + results = snapshot.execute_sql( + _EXISTS_TEMPLATE, + params={"table_id": self.table_id}, + param_types={"table_id": Type(code=TypeCode.STRING)}, + ) + return next(iter(results))[0] + + @property + def schema(self): + """The schema of this table. + :rtype: list of :class:`~google.cloud.spanner_v1.types.StructType.Field` + :returns: The table schema. + """ + if self._schema is None: + with self._database.snapshot() as snapshot: + self._schema = self._get_schema(snapshot) + return self._schema + + def _get_schema(self, snapshot): + """Get the schema of this table. + :type snapshot: :class:`~google.cloud.spanner_v1.snapshot.Snapshot` + :param snapshot: snapshot to use for database queries + :rtype: list of :class:`~google.cloud.spanner_v1.types.StructType.Field` + :returns: The table schema. + """ + query = _GET_SCHEMA_TEMPLATE.format(self.table_id) + results = snapshot.execute_sql(query) + # Start iterating to force the schema to download. + try: + next(iter(results)) + except StopIteration: + pass + return list(results.fields) + + def reload(self): + """Reload this table. + Refresh any configured schema into :attr:`schema`. + :raises NotFound: if the table does not exist + """ + with self._database.snapshot() as snapshot: + if not self._exists(snapshot): + raise NotFound("table '{}' does not exist".format(self.table_id)) + self._schema = self._get_schema(snapshot) diff --git a/third_party/ibis/ibis_cloud_spanner/tests/__init__.py b/third_party/ibis/ibis_cloud_spanner/tests/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/third_party/ibis/ibis_cloud_spanner/tests/conftest.py b/third_party/ibis/ibis_cloud_spanner/tests/conftest.py new file mode 100644 index 000000000..052e6f55e --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/tests/conftest.py @@ -0,0 +1,146 @@ +# Copyright 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import random +import pathlib + +from google.cloud import spanner_v1 +import pytest +from third_party.ibis.ibis_cloud_spanner.api import connect + + +DATA_DIR = pathlib.Path(__file__).parent + +RANDOM_MAX = 0xFFFFFFFF +INSTANCE_ID_TEMPLATE = "data-validation-tool-{timestamp}" +DATABASE_ID_TEMPLATE = "db_{timestamp}_{randint}" + + +def load_sql(filename): + lines = [] + with open(DATA_DIR / filename) as sql_file: + for line in sql_file: + if line.startswith("--"): + continue + lines.append(line) + return [ + statement.strip() + for statement in "".join(lines).split(";") + if statement.strip() + ] + + +def insert_rows(transaction): + dml_statements = load_sql("dml.sql") + for dml in dml_statements: + transaction.execute_update(dml) + + +def insert_rows2(transaction): + dml_statements = load_sql("dml2.sql") + for dml in dml_statements: + transaction.execute_update(dml) + + +@pytest.fixture(scope="session") +def spanner_client(): + return spanner_v1.Client() + + +@pytest.fixture(scope="session") +def instance_id(spanner_client): + config_name = "{}/instanceConfigs/regional-us-central1".format( + spanner_client.project_name + ) + instance_id = INSTANCE_ID_TEMPLATE.format( + timestamp=datetime.datetime.utcnow().strftime("%Y%m%d-%H%M%S") + ) + instance = spanner_client.instance( + instance_id, + configuration_name=config_name, + display_name="Test for Data Validation Tool", + node_count=1, + ) + operation = instance.create() + operation.result() + yield instance_id + instance.delete() + + +@pytest.fixture(scope="session") +def database_id(spanner_client, instance_id): + database_id = DATABASE_ID_TEMPLATE.format( + timestamp=datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S"), + randint=random.randint(0, RANDOM_MAX), + ) + ddl_statements = load_sql("ddl.sql") + instance = spanner_client.instance(instance_id) + database = instance.database(database_id, ddl_statements=ddl_statements) + operation = database.create() + operation.result() + database.run_in_transaction(insert_rows) + yield database_id + database.drop() + + +@pytest.fixture(scope="session") +def database_id2(spanner_client, instance_id): + database_id = DATABASE_ID_TEMPLATE.format( + timestamp=datetime.datetime.utcnow().strftime("%Y%m%d_%H%M%S"), + randint=random.randint(0, RANDOM_MAX), + ) + ddl_statements = load_sql("ddl2.sql") + instance = spanner_client.instance(instance_id) + database = instance.database(database_id, ddl_statements=ddl_statements) + operation = database.create() + operation.result() + database.run_in_transaction(insert_rows2) + yield database_id + database.drop() + + +@pytest.fixture(scope="session") +def client(instance_id, database_id): + return connect(instance_id, database_id) + + +@pytest.fixture(scope="session") +def client2(instance_id, database_id): + return connect(instance_id, database_id) + + +@pytest.fixture(scope="session") +def alltypes(client): + return client.table("functional_alltypes") + + +@pytest.fixture(scope="session") +def df(alltypes): + return alltypes.execute() + + +@pytest.fixture(scope="session") +def students(client): + return client.table("students_pointer") + + +@pytest.fixture(scope="session") +def students_df(students): + return students.execute() + + +@pytest.fixture(scope="session") +def array_table(client): + return client.table("array_table") diff --git a/third_party/ibis/ibis_cloud_spanner/tests/ddl.sql b/third_party/ibis/ibis_cloud_spanner/tests/ddl.sql new file mode 100644 index 000000000..a6be8533e --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/tests/ddl.sql @@ -0,0 +1,54 @@ +-- Copyright 2021 Google LLC +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +CREATE TABLE students_pointer +( + id INT64, + name STRING(30), + division INT64, + marks INT64, + exam STRING(30), + overall_pointer FLOAT64, + date_of_exam TIMESTAMP +) +PRIMARY KEY (id); + +CREATE TABLE functional_alltypes +( + id INT64, + bigint_col INT64, + bool_col BOOL, + date DATE, + date_string_col STRING(MAX), + double_col NUMERIC, + float_col NUMERIC, + index INT64, + int_col INT64, + month INT64, + smallint_col INT64, + string_col STRING(MAX), + timestamp_col TIMESTAMP, + tinyint_col INT64, + Unnamed0 INT64, + year INT64 +) +PRIMARY KEY (id); + +CREATE TABLE array_table +( + string_col ARRAY, + int_col ARRAY, + id INT64, +) +PRIMARY KEY (id); diff --git a/third_party/ibis/ibis_cloud_spanner/tests/ddl2.sql b/third_party/ibis/ibis_cloud_spanner/tests/ddl2.sql new file mode 100644 index 000000000..ca90d1dea --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/tests/ddl2.sql @@ -0,0 +1,20 @@ +-- Copyright 2021 Google LLC +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +CREATE TABLE awards +( + id INT64, + award_name STRING(20) +) +PRIMARY KEY (id); diff --git a/third_party/ibis/ibis_cloud_spanner/tests/dml.sql b/third_party/ibis/ibis_cloud_spanner/tests/dml.sql new file mode 100644 index 000000000..5849f7dac --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/tests/dml.sql @@ -0,0 +1,78 @@ +-- Copyright 2021 Google LLC +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +INSERT INTO students_pointer + (id,name,division,marks,exam,overall_pointer,date_of_exam) +VALUES(101, 'Ross', 12, 500, 'Biology', 9.8, '2002-02-10 15:30:00+00'); + +INSERT INTO students_pointer + (id,name,division,marks,exam,overall_pointer,date_of_exam) +VALUES(102, 'Rachel', 14, 460, 'Chemistry', 9.9, '2018-04-22'); + +INSERT INTO students_pointer + (id,name,division,marks,exam,overall_pointer,date_of_exam) +VALUES(103, 'Chandler', 12, 480, 'Biology', 8.2, '2016-04-14'); + +INSERT INTO students_pointer + (id,name,division,marks,exam,overall_pointer,date_of_exam) +VALUES(104, 'Monica', 12, 390, 'Maths', 9.2, '2019-04-29'); + +INSERT INTO students_pointer + (id,name,division,marks,exam,overall_pointer,date_of_exam) +VALUES(105, 'Joey', 16, 410, 'Maths', 9.7, '2019-06-21'); + +INSERT INTO students_pointer + (id,name,division,marks,exam,overall_pointer,date_of_exam) +VALUES(106, 'Phoebe', 10, 490, 'Chemistry', 9.6, '2019-02-09'); + + +INSERT INTO functional_alltypes + (id ,bigint_col ,bool_col ,date ,date_string_col ,double_col ,float_col ,index ,int_col ,month ,smallint_col ,string_col ,timestamp_col ,tinyint_col ,Unnamed0 ,year ) +VALUES + (1, 10001, TRUE, '2016-02-09', '01/01/2001', 2.5, 12.16, 101, 21, 4, 16, 'David', '2002-02-10 15:30:00+00', 6, 99, 2010); + +INSERT INTO functional_alltypes + (id ,bigint_col ,bool_col ,date ,date_string_col ,double_col ,float_col ,index ,int_col ,month ,smallint_col ,string_col ,timestamp_col ,tinyint_col ,Unnamed0 ,year ) +VALUES + (2, 10002, FALSE, '2016-10-10', '02/02/2002', 2.6, 13.16, 102, 22, 5, 18, 'Ryan', '2009-02-12 10:06:00+00', 7, 98, 2012); + +INSERT INTO functional_alltypes + (id ,bigint_col ,bool_col ,date ,date_string_col ,double_col ,float_col ,index ,int_col ,month ,smallint_col ,string_col ,timestamp_col ,tinyint_col ,Unnamed0 ,year ) +VALUES + (3, 10003, TRUE, '2018-02-09', '03/03/2003', 9.5, 44.16, 201, 41, 6, 56, 'Steve', '2010-06-10 12:12:00+00', 12, 66, 2006); + +INSERT INTO functional_alltypes + (id ,bigint_col ,bool_col ,date ,date_string_col ,double_col ,float_col ,index ,int_col ,month ,smallint_col ,string_col ,timestamp_col ,tinyint_col ,Unnamed0 ,year ) +VALUES + (4, 10004, TRUE, '2018-10-10', '04/04/2004', 9.6, 45.16, 202, 42, 9, 58, 'Chandler', '2014-06-12 10:04:00+00', 14, 69, 2009); + +INSERT INTO functional_alltypes + (id ,bigint_col ,bool_col ,date ,date_string_col ,double_col ,float_col ,index ,int_col ,month ,smallint_col ,string_col ,timestamp_col ,tinyint_col ,Unnamed0 ,year ) +VALUES + (5, 10005, FALSE, '2020-06-12', '05/05/2005', 6.6, 66.12, 401, 62, 12, 98, 'Rose', '2018-02-10 10:06:00+00', 16, 96, 2012); + +INSERT INTO functional_alltypes + (id ,bigint_col ,bool_col ,date ,date_string_col ,double_col ,float_col ,index ,int_col ,month ,smallint_col ,string_col ,timestamp_col ,tinyint_col ,Unnamed0 ,year ) +VALUES + (6, 10006, TRUE, '2020-12-12', '06/06/2006', 6.9, 66.19, 402, 69, 14, 99, 'Rachel', '2019-04-12 12:09:00+00', 18, 99, 2014); + +INSERT INTO array_table + (id,string_col,int_col) +VALUES + (1, ['Peter','David'], [11,12]); + +INSERT INTO array_table + (id,string_col,int_col) +VALUES + (2, ['Raj','Dev','Neil'], [1,2,3]); diff --git a/third_party/ibis/ibis_cloud_spanner/tests/dml2.sql b/third_party/ibis/ibis_cloud_spanner/tests/dml2.sql new file mode 100644 index 000000000..c5a097562 --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/tests/dml2.sql @@ -0,0 +1,23 @@ +-- Copyright 2021 Google LLC +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +INSERT INTO awards + (id,award_name) +VALUES + (101, 'LOTUS'); + +INSERT INTO awards + (id,award_name) +VALUES + (102, 'ROSE'); diff --git a/third_party/ibis/ibis_cloud_spanner/tests/test_client.py b/third_party/ibis/ibis_cloud_spanner/tests/test_client.py new file mode 100644 index 000000000..32ca523df --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/tests/test_client.py @@ -0,0 +1,497 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import numpy as np +import pandas as pd +import pandas.util.testing as tm +import pytest +import pytz + +import ibis +import ibis.expr.datatypes as dt +import ibis.expr.types as ir + +from third_party.ibis.ibis_cloud_spanner import api as cs_compile + +pytestmark = pytest.mark.cloud_spanner + +from third_party.ibis.ibis_cloud_spanner.tests.conftest import connect + + +def test_table(alltypes): + assert isinstance(alltypes, ir.TableExpr) + + +def test_column_execute(alltypes, df): + col_name = "float_col" + expr = alltypes[col_name] + result = expr.execute() + expected = df[col_name] + tm.assert_series_equal( + (result.sort_values(col_name).reset_index(drop=True)).iloc[:, 0], + expected.sort_values().reset_index(drop=True), + ) + + +def test_literal_execute(client): + expected = "1234" + expr = ibis.literal(expected) + result = (client.execute(expr)).iloc[0]["tmp"] + assert result == expected + + +def test_simple_aggregate_execute(alltypes, df): + col_name = "float_col" + expr = alltypes[col_name].sum() + result = expr.execute() + expected = df[col_name].sum() + final_result = result.iloc[0]["sum"] + assert final_result == expected + + +def test_list_tables(client): + tables = client.list_tables(like="functional_alltypes") + assert set(tables) == {"functional_alltypes"} + + +def test_current_database(client, database_id): + assert client.current_database.name == database_id + assert client.current_database.name == client.dataset_id + assert client.current_database.tables == client.list_tables() + + +def test_database(client): + database = client.database(client.dataset_id) + assert database.list_tables() == client.list_tables() + + +def test_compile_toplevel(): + t = ibis.table([("foo", "double")], name="t0") + + expr = t.foo.sum() + result = cs_compile.compile(expr) + + expected = """\ +SELECT sum(`foo`) AS `sum` +FROM t0""" # noqa + assert str(result) == expected + + +def test_count_distinct_with_filter(alltypes): + expr = alltypes.float_col.nunique(where=alltypes.float_col.cast("int64") > 1) + result = expr.execute() + result = result.iloc[:, 0] + result = result.iloc[0] + + expected = alltypes.float_col.execute() + expected = expected[expected.astype("int64") > 1].nunique() + expected = expected.iloc[0] + assert result == expected + + +@pytest.mark.parametrize("type", ["date", dt.date]) +def test_cast_string_to_date(alltypes, df, type): + import toolz + + string_col = alltypes.date_string_col + month, day, year = toolz.take(3, string_col.split("/")) + + expr = ibis.literal("-").join([year, month, day]) + expr = expr.cast(type) + + result = ( + expr.execute() + .iloc[:, 0] + .astype("datetime64[ns]") + .sort_values() + .reset_index(drop=True) + .rename("date_string_col") + ) + expected = ( + pd.to_datetime(df.date_string_col) + .dt.normalize() + .sort_values() + .reset_index(drop=True) + ) + tm.assert_series_equal(result, expected) + + +def test_subquery_scalar_params(alltypes): + t = alltypes + param = ibis.param("timestamp").name("my_param") + expr = ( + t[["float_col", "timestamp_col", "int_col", "string_col"]][ + lambda t: t.timestamp_col < param + ] + .groupby("string_col") + .aggregate(foo=lambda t: t.float_col.sum()) + .foo.count() + ) + result = cs_compile.compile(expr, params={param: "20140101"}) + expected = """\ +SELECT count(`foo`) AS `count` +FROM ( + SELECT `string_col`, sum(`float_col`) AS `foo` + FROM ( + SELECT `float_col`, `timestamp_col`, `int_col`, `string_col` + FROM functional_alltypes + WHERE `timestamp_col` < @my_param + ) t1 + GROUP BY 1 +) t0""" + assert result == expected + + +def test_scalar_param_string(alltypes, df): + param = ibis.param("string") + expr = alltypes[alltypes.string_col == param] + + string_value = "David" + result = ( + expr.execute(params={param: string_value}) + .sort_values("id") + .reset_index(drop=True) + ) + expected = ( + df.loc[df.string_col == string_value].sort_values("id").reset_index(drop=True) + ) + tm.assert_frame_equal(result, expected) + + +def test_scalar_param_int64(alltypes, df): + param = ibis.param("int64") + expr = alltypes[alltypes.int_col == param] + + int64_value = 22 + result = ( + expr.execute(params={param: int64_value}) + .sort_values("id") + .reset_index(drop=True) + ) + expected = ( + df.loc[df.int_col == int64_value].sort_values("id").reset_index(drop=True) + ) + tm.assert_frame_equal(result, expected) + + +def test_scalar_param_double(alltypes, df): + param = ibis.param("double") + expr = alltypes[alltypes.double_col == param] + + double_value = 2.5 + result = ( + expr.execute(params={param: double_value}) + .sort_values("id") + .reset_index(drop=True) + ) + expected = ( + df.loc[df.double_col == double_value].sort_values("id").reset_index(drop=True) + ) + tm.assert_frame_equal(result, expected) + + +def test_scalar_param_boolean(alltypes, df): + param = ibis.param("boolean") + expr = alltypes[(alltypes.bool_col == param)] + + bool_value = True + result = ( + expr.execute(params={param: bool_value}) + .sort_values("id") + .reset_index(drop=True) + ) + expected = ( + df.loc[df.bool_col == bool_value].sort_values("id").reset_index(drop=True) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "timestamp_value", ["2019-04-12 12:09:00+00:00"], +) +def test_scalar_param_timestamp(alltypes, df, timestamp_value): + param = ibis.param("timestamp") + expr = (alltypes[alltypes.timestamp_col <= param]).select(["timestamp_col"]) + + result = ( + expr.execute(params={param: timestamp_value}) + .sort_values("timestamp_col") + .reset_index(drop=True) + ) + value = pd.Timestamp(timestamp_value) + expected = ( + df.loc[df.timestamp_col <= value, ["timestamp_col"]] + .sort_values("timestamp_col") + .reset_index(drop=True) + ) + tm.assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "date_value", + ["2009-02-12", datetime.date(2009, 2, 12), datetime.datetime(2009, 2, 12)], +) +def test_scalar_param_date(alltypes, df, date_value): + param = ibis.param("date") + expr = alltypes[alltypes.timestamp_col.cast("date") <= param] + + result = ( + expr.execute(params={param: date_value}) + .sort_values("timestamp_col") + .reset_index(drop=True) + ) + value = pd.Timestamp(date_value) + value = pd.to_datetime(value).tz_localize("UTC") + expected = ( + df.loc[df.timestamp_col.dt.normalize() <= value] + .sort_values("timestamp_col") + .reset_index(drop=True) + ) + tm.assert_frame_equal(result, expected) + + +def test_raw_sql(client): + assert (client.raw_sql("SELECT 1")).iloc[0][0] == 1 + + +def test_scalar_param_scope(alltypes): + t = alltypes + param = ibis.param("timestamp") + mut = t.mutate(param=param).compile(params={param: "2017-01-01"}) + assert ( + mut + == """\ +SELECT *, @param AS `param` +FROM functional_alltypes""" + ) + + +def test_column_names(alltypes): + assert "bigint_col" in alltypes.columns + assert "string_col" in alltypes.columns + + +def test_column_names_in_schema(alltypes): + assert "int_col" in alltypes.schema() + + +def test_exists_table(client): + assert client.exists_table("functional_alltypes") + assert not client.exists_table("footable") + + +def test_exists_database(client, database_id): + assert client.exists_database(database_id) + assert not client.exists_database("foodataset") + + +def test_set_database(client2, database_id2): + client2.set_database(database_id2) + tables = client2.list_tables() + assert "awards" in tables + + +def test_exists_table_different_project(client): + name = "functional_alltypes" + assert client.exists_table(name) + assert not client.exists_table("foobar") + + +def test_repeated_project_name(instance_id, database_id): + con = connect(instance_id, database_id) + assert "functional_alltypes" in con.list_tables() + + +def test_large_timestamp(client): + huge_timestamp = datetime.datetime(2012, 10, 10, 10, 10, 10, 154117) + expr = ibis.timestamp("2012-10-10 10:10:10.154117") + result = client.execute(expr) + + huge_timestamp = (pd.to_datetime(huge_timestamp).tz_localize("UTC")).date() + result = (result["tmp"][0]).date() + assert result == huge_timestamp + + +def test_string_to_timestamp(client): + timestamp = pd.Timestamp( + datetime.datetime(year=2017, month=2, day=6), tz=pytz.timezone("UTC") + ) + expr = ibis.literal("2017-02-06").to_timestamp("%F") + result = client.execute(expr) + result = result.iloc[:, 0][0] + result = result.date() + timestamp = timestamp.date() + assert result == timestamp + + timestamp_tz = pd.Timestamp( + datetime.datetime(year=2017, month=2, day=6, hour=5), tz=pytz.timezone("UTC"), + ) + expr_tz = ibis.literal("2017-02-06").to_timestamp("%F", "America/New_York") + result_tz = client.execute(expr_tz) + result_tz = result_tz.iloc[:, 0][0] + result_tz = result_tz.date() + timestamp_tz = timestamp_tz.date() + assert result_tz == timestamp_tz + + +def test_client_sql_query(client): + expr = client.get_data_using_query("select * from functional_alltypes limit 20") + result = expr + expected = client.table("functional_alltypes").head(20).execute() + tm.assert_frame_equal(result, expected) + + +def test_prevent_rewrite(alltypes): + t = alltypes + expr = ( + t.groupby(t.string_col) + .aggregate(collected_double=t.double_col.collect()) + .pipe(ibis.prevent_rewrite) + .filter(lambda t: t.string_col != "wat") + ) + result = cs_compile.compile(expr) + expected = """\ +SELECT * +FROM ( + SELECT `string_col`, ARRAY_AGG(`double_col`) AS `collected_double` + FROM functional_alltypes + GROUP BY 1 +) t0 +WHERE `string_col` != 'wat'""" + assert result == expected + + +@pytest.mark.parametrize( + ("case", "dtype"), + [ + (datetime.date(2017, 1, 1), dt.date), + (pd.Timestamp("2017-01-01"), dt.date), + ("2017-01-01", dt.date), + (datetime.datetime(2017, 1, 1, 4, 55, 59), dt.timestamp), + ("2017-01-01 04:55:59", dt.timestamp), + (pd.Timestamp("2017-01-01 04:55:59"), dt.timestamp), + ], +) +def test_day_of_week(client, case, dtype): + date_var = ibis.literal(case, type=dtype) + expr_index = date_var.day_of_week.index() + result = client.execute(expr_index) + result = result["tmp"][0] + assert result == 6 + + expr_name = date_var.day_of_week.full_name() + result = client.execute(expr_name) + result = result["tmp"][0] + assert result == "Sunday" + + +def test_boolean_reducers(alltypes): + b = alltypes.bool_col + bool_avg = b.mean().execute() + bool_avg = bool_avg.iloc[:, 0] + bool_avg = bool_avg[0] + assert type(bool_avg) == np.float64 + + bool_sum = b.sum().execute() + bool_sum = bool_sum.iloc[:, 0] + bool_sum = bool_sum[0] + assert type(bool_sum) == np.int64 + + +def test_students_table_schema(students): + assert students.schema() == ibis.schema( + [ + ("id", dt.int64), + ("name", dt.string), + ("division", dt.int64), + ("marks", dt.int64), + ("exam", dt.string), + ("overall_pointer", dt.float64), + ("date_of_exam", dt.timestamp), + ] + ) + + +def test_numeric_sum(students): + t = students + expr = t.overall_pointer.sum() + result = expr.execute() + result = (result.iloc[:, 0])[0] + assert isinstance(result, np.float64) + + +def test_boolean_casting(alltypes): + t = alltypes + expr = t.groupby(k=t.string_col.nullif("1") == "9").count() + result = expr.execute().set_index("k") + count = result["count"] + assert count.loc[False] == 6 + + +def test_approx_median(alltypes): + m = alltypes.month + expected = m.execute().median() + expected = expected[0] + assert expected == 7.5 + + +def test_struct_field_access(array_table): + expr = array_table.string_col + result = expr.execute() + result = result.iloc[:, 0] + expected = pd.Series( + [["Peter", "David"], ["Raj", "Dev", "Neil"]], name="string_col", + ) + + tm.assert_series_equal(result, expected) + + +def test_array_index(array_table): + expr = array_table.string_col[1] + result = expr.execute() + result = result.iloc[:, 0] + expected = pd.Series(["David", "Dev",], name="tmp",) + tm.assert_series_equal(result, expected) + + +def test_array_concat(array_table): + c = array_table.string_col + expr = c + c + result = expr.execute() + result = result.iloc[:, 0] + expected = pd.Series( + [ + ["Peter", "David", "Peter", "David"], + ["Raj", "Dev", "Neil", "Raj", "Dev", "Neil"], + ], + name="tmp", + ) + tm.assert_series_equal(result, expected) + + +def test_array_length(array_table): + expr = array_table.string_col.length() + result = expr.execute() + result = result.iloc[:, 0] + expected = pd.Series([2, 3], name="tmp") + tm.assert_series_equal(result, expected) + + +def test_scalar_param_array(alltypes, df, client): + expr = alltypes.sort_by("id").limit(1).double_col.collect() + result = client.get_data_using_query(cs_compile.compile(expr)) + result = result["tmp"][0] + expected = [df.sort_values("id").double_col.iat[0]] + assert result == expected diff --git a/third_party/ibis/ibis_cloud_spanner/tests/test_compiler.py b/third_party/ibis/ibis_cloud_spanner/tests/test_compiler.py new file mode 100644 index 000000000..8e8e9d25e --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/tests/test_compiler.py @@ -0,0 +1,474 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import pandas as pd +import pytest + +import ibis +from third_party.ibis.ibis_cloud_spanner import api as cs_compile +import third_party.ibis.ibis_cloud_spanner as cs +import ibis.expr.datatypes as dt +import ibis.expr.operations as ops +from ibis.expr.types import TableExpr + +pytestmark = pytest.mark.cloud_spanner + + +def test_timestamp_accepts_date_literals(alltypes): + date_string = "2009-03-01" + param = ibis.param(dt.timestamp).name("param_0") + expr = alltypes.mutate(param=param) + params = {param: date_string} + result = expr.compile(params=params) + expected = f"""\ +SELECT *, @param AS `param` +FROM functional_alltypes""" + assert result == expected + + +@pytest.mark.parametrize( + ("distinct", "expected_keyword"), [(True, "DISTINCT"), (False, "ALL")] +) +def test_union(alltypes, distinct, expected_keyword): + expr = alltypes.union(alltypes, distinct=distinct) + result = cs_compile.compile(expr) + expected = f"""\ +SELECT * +FROM functional_alltypes +UNION {expected_keyword} +SELECT * +FROM functional_alltypes""" + assert result == expected + + +def test_ieee_divide(alltypes): + expr = alltypes.double_col / 0 + result = cs_compile.compile(expr) + expected = f"""\ +SELECT IEEE_DIVIDE(`double_col`, 0) AS `tmp` +FROM functional_alltypes""" + assert result == expected + + +def test_identical_to(alltypes): + t = alltypes + pred = t.string_col.identical_to("a") & t.date_string_col.identical_to("b") + expr = t[pred] + result = cs_compile.compile(expr) + expected = f"""\ +SELECT * +FROM functional_alltypes +WHERE (((`string_col` IS NULL) AND ('a' IS NULL)) OR (`string_col` = 'a')) AND + (((`date_string_col` IS NULL) AND ('b' IS NULL)) OR (`date_string_col` = 'b'))""" # noqa: E501 + assert result == expected + + +@pytest.mark.parametrize("timezone", [None, "America/New_York"]) +def test_to_timestamp(alltypes, timezone): + expr = alltypes.date_string_col.to_timestamp("%F", timezone) + result = cs_compile.compile(expr) + if timezone: + expected = f"""\ +SELECT PARSE_TIMESTAMP('%F', `date_string_col`, 'America/New_York') AS `tmp` +FROM functional_alltypes""" + else: + expected = f"""\ +SELECT PARSE_TIMESTAMP('%F', `date_string_col`) AS `tmp` +FROM functional_alltypes""" + assert result == expected + + +@pytest.mark.parametrize( + ("case", "expected", "dtype"), + [ + (datetime.date(2017, 1, 1), "DATE '2017-01-01'", dt.date), + (pd.Timestamp("2017-01-01"), "DATE '2017-01-01'", dt.date,), + ("2017-01-01", "DATE '2017-01-01'", dt.date), + ( + datetime.datetime(2017, 1, 1, 4, 55, 59), + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + ), + ("2017-01-01 04:55:59", "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp,), + ( + pd.Timestamp("2017-01-01 04:55:59"), + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + ), + ], +) +def test_literal_date(case, expected, dtype): + expr = ibis.literal(case, type=dtype).year() + result = cs_compile.compile(expr) + assert result == f"SELECT EXTRACT(year from {expected}) AS `tmp`" + + +@pytest.mark.parametrize( + ("case", "expected", "dtype", "strftime_func"), + [ + (datetime.date(2017, 1, 1), "DATE '2017-01-01'", dt.date, "FORMAT_DATE",), + (pd.Timestamp("2017-01-01"), "DATE '2017-01-01'", dt.date, "FORMAT_DATE",), + ("2017-01-01", "DATE '2017-01-01'", dt.date, "FORMAT_DATE",), + ( + datetime.datetime(2017, 1, 1, 4, 55, 59), + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + "FORMAT_TIMESTAMP", + ), + ( + "2017-01-01 04:55:59", + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + "FORMAT_TIMESTAMP", + ), + ( + pd.Timestamp("2017-01-01 04:55:59"), + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + "FORMAT_TIMESTAMP", + ), + ], +) +def test_day_of_week(case, expected, dtype, strftime_func): + date_var = ibis.literal(case, type=dtype) + expr_index = date_var.day_of_week.index() + result = cs_compile.compile(expr_index) + assert result == f"SELECT MOD(EXTRACT(DAYOFWEEK FROM {expected}) + 5, 7) AS `tmp`" + + expr_name = date_var.day_of_week.full_name() + result = cs_compile.compile(expr_name) + if strftime_func == "FORMAT_TIMESTAMP": + assert result == f"SELECT {strftime_func}('%A', {expected}, 'UTC') AS `tmp`" + else: + assert result == f"SELECT {strftime_func}('%A', {expected}) AS `tmp`" + + +@pytest.mark.parametrize( + ("case", "expected", "dtype"), + [ + ( + datetime.datetime(2017, 1, 1, 4, 55, 59), + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + ), + ("2017-01-01 04:55:59", "TIMESTAMP '2017-01-01 04:55:59'", dt.timestamp,), + ( + pd.Timestamp("2017-01-01 04:55:59"), + "TIMESTAMP '2017-01-01 04:55:59'", + dt.timestamp, + ), + (datetime.time(4, 55, 59), "TIME '04:55:59'", dt.time), + ("04:55:59", "TIME '04:55:59'", dt.time), + ], +) +def test_literal_timestamp_or_time(case, expected, dtype): + expr = ibis.literal(case, type=dtype).hour() + result = cs_compile.compile(expr) + assert result == f"SELECT EXTRACT(hour from {expected}) AS `tmp`" + + +def test_window_function(alltypes): + t = alltypes + w1 = ibis.window( + preceding=1, following=0, group_by="year", order_by="timestamp_col" + ) + expr = t.mutate(win_avg=t.float_col.mean().over(w1)) + result = cs_compile.compile(expr) + expected = f"""\ +SELECT *, + avg(`float_col`) OVER (PARTITION BY `year` ORDER BY `timestamp_col` ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) AS `win_avg` +FROM functional_alltypes""" # noqa: E501 + assert result == expected + + w2 = ibis.window( + preceding=0, following=2, group_by="year", order_by="timestamp_col" + ) + expr = t.mutate(win_avg=t.float_col.mean().over(w2)) + result = cs_compile.compile(expr) + expected = f"""\ +SELECT *, + avg(`float_col`) OVER (PARTITION BY `year` ORDER BY `timestamp_col` ROWS BETWEEN CURRENT ROW AND 2 FOLLOWING) AS `win_avg` +FROM functional_alltypes""" # noqa: E501 + assert result == expected + + w3 = ibis.window(preceding=(4, 2), group_by="year", order_by="timestamp_col") + expr = t.mutate(win_avg=t.float_col.mean().over(w3)) + result = cs_compile.compile(expr) + expected = f"""\ +SELECT *, + avg(`float_col`) OVER (PARTITION BY `year` ORDER BY `timestamp_col` ROWS BETWEEN 4 PRECEDING AND 2 PRECEDING) AS `win_avg` +FROM functional_alltypes""" # noqa: E501 + assert result == expected + + +@pytest.mark.parametrize( + ("distinct1", "distinct2", "expected1", "expected2"), + [ + (True, True, "UNION DISTINCT", "UNION DISTINCT"), + (True, False, "UNION DISTINCT", "UNION ALL"), + (False, True, "UNION ALL", "UNION DISTINCT"), + (False, False, "UNION ALL", "UNION ALL"), + ], +) +def test_union_cte(alltypes, distinct1, distinct2, expected1, expected2): + t = alltypes + expr1 = t.group_by(t.string_col).aggregate(metric=t.double_col.sum()) + expr2 = expr1.view() + expr3 = expr1.view() + expr = expr1.union(expr2, distinct=distinct1).union(expr3, distinct=distinct2) + result = cs_compile.compile(expr) + expected = f"""\ +WITH t0 AS ( + SELECT `string_col`, sum(`double_col`) AS `metric` + FROM functional_alltypes + GROUP BY 1 +) +SELECT * +FROM t0 +{expected1} +SELECT `string_col`, sum(`double_col`) AS `metric` +FROM functional_alltypes +GROUP BY 1 +{expected2} +SELECT `string_col`, sum(`double_col`) AS `metric` +FROM functional_alltypes +GROUP BY 1""" + assert result == expected + + +def test_projection_fusion_only_peeks_at_immediate_parent(): + schema = [ + ("file_date", "timestamp"), + ("PARTITIONTIME", "date"), + ("val", "int64"), + ] + table = ibis.table(schema, name="unbound_table") + table = table[table.PARTITIONTIME < ibis.date("2017-01-01")] + table = table.mutate(file_date=table.file_date.cast("date")) + table = table[table.file_date < ibis.date("2017-01-01")] + table = table.mutate(XYZ=table.val * 2) + expr = table.join(table.view())[table] + result = cs_compile.compile(expr) + expected = """\ +WITH t0 AS ( + SELECT * + FROM unbound_table + WHERE `PARTITIONTIME` < DATE '2017-01-01' +), +t1 AS ( + SELECT CAST(`file_date` AS DATE) AS `file_date`, `PARTITIONTIME`, `val` + FROM t0 +), +t2 AS ( + SELECT t1.* + FROM t1 + WHERE t1.`file_date` < DATE '2017-01-01' +), +t3 AS ( + SELECT *, `val` * 2 AS `XYZ` + FROM t2 +) +SELECT t3.* +FROM t3 + INNER JOIN t3 t4""" + assert result == expected + + +def test_bool_reducers(alltypes): + b = alltypes.bool_col + expr = b.mean() + result = cs_compile.compile(expr) + expected = f"""\ +SELECT avg(CAST(`bool_col` AS INT64)) AS `mean` +FROM functional_alltypes""" + assert result == expected + + expr2 = b.sum() + result = cs_compile.compile(expr2) + expected = f"""\ +SELECT sum(CAST(`bool_col` AS INT64)) AS `sum` +FROM functional_alltypes""" + assert result == expected + + +def test_bool_reducers_where(alltypes): + b = alltypes.bool_col + m = alltypes.month + expr = b.mean(where=m > 6) + result = cs_compile.compile(expr) + expected = f"""\ +SELECT avg(CASE WHEN `month` > 6 THEN CAST(`bool_col` AS INT64) ELSE NULL END) AS `mean` +FROM functional_alltypes""" # noqa: E501 + assert result == expected + + expr2 = b.sum(where=((m > 6) & (m < 10))) + result = cs_compile.compile(expr2) + expected = f"""\ +SELECT sum(CASE WHEN (`month` > 6) AND (`month` < 10) THEN CAST(`bool_col` AS INT64) ELSE NULL END) AS `sum` +FROM functional_alltypes""" # noqa: E501 + assert result == expected + + +def test_approx_nunique(alltypes): + d = alltypes.double_col + expr = d.approx_nunique() + result = cs_compile.compile(expr) + expected = f"""\ +SELECT APPROX_COUNT_DISTINCT(`double_col`) AS `approx_nunique` +FROM functional_alltypes""" + assert result == expected + + b = alltypes.bool_col + m = alltypes.month + expr2 = b.approx_nunique(where=m > 6) + result = cs_compile.compile(expr2) + expected = f"""\ +SELECT APPROX_COUNT_DISTINCT(CASE WHEN `month` > 6 THEN `bool_col` ELSE NULL END) AS `approx_nunique` +FROM functional_alltypes""" # noqa: E501 + assert result == expected + + +def test_approx_median(alltypes): + d = alltypes.double_col + expr = d.approx_median() + result = cs_compile.compile(expr) + expected = f"""\ +SELECT APPROX_QUANTILES(`double_col`, 2)[OFFSET(1)] AS `approx_median` +FROM functional_alltypes""" + assert result == expected + + m = alltypes.month + expr2 = d.approx_median(where=m > 6) + result = cs_compile.compile(expr2) + expected = f"""\ +SELECT APPROX_QUANTILES(CASE WHEN `month` > 6 THEN `double_col` ELSE NULL END, 2)[OFFSET(1)] AS `approx_median` +FROM functional_alltypes""" # noqa: E501 + assert result == expected + + +@pytest.mark.parametrize( + ("unit", "expected_unit", "expected_func"), + [ + ("Y", "YEAR", "TIMESTAMP"), + ("Q", "QUARTER", "TIMESTAMP"), + ("M", "MONTH", "TIMESTAMP"), + ("W", "WEEK", "TIMESTAMP"), + ("D", "DAY", "TIMESTAMP"), + ("h", "HOUR", "TIMESTAMP"), + ("m", "MINUTE", "TIMESTAMP"), + ("s", "SECOND", "TIMESTAMP"), + ("ms", "MILLISECOND", "TIMESTAMP"), + ("us", "MICROSECOND", "TIMESTAMP"), + ("Y", "YEAR", "DATE"), + ("Q", "QUARTER", "DATE"), + ("M", "MONTH", "DATE"), + ("W", "WEEK", "DATE"), + ("D", "DAY", "DATE"), + ("h", "HOUR", "TIME"), + ("m", "MINUTE", "TIME"), + ("s", "SECOND", "TIME"), + ("ms", "MILLISECOND", "TIME"), + ("us", "MICROSECOND", "TIME"), + ], +) +def test_temporal_truncate(unit, expected_unit, expected_func): + t = ibis.table([("a", getattr(dt, expected_func.lower()))], name="t") + expr = t.a.truncate(unit) + result = cs_compile.compile(expr) + expected = f"""\ +SELECT {expected_func}_TRUNC(`a`, {expected_unit}) AS `tmp` +FROM t""" + assert result == expected + + +@pytest.mark.parametrize("kind", ["date", "time"]) +def test_extract_temporal_from_timestamp(kind): + t = ibis.table([("ts", dt.timestamp)], name="t") + expr = getattr(t.ts, kind)() + result = cs_compile.compile(expr) + expected = f"""\ +SELECT {kind.upper()}(`ts`) AS `tmp` +FROM t""" + assert result == expected + + +def test_now(): + expr = ibis.now() + result = cs_compile.compile(expr) + expected = "SELECT CURRENT_TIMESTAMP() AS `tmp`" + assert result == expected + + +def test_bucket(): + t = ibis.table([("value", "double")], name="t") + buckets = [0, 1, 3] + expr = t.value.bucket(buckets).name("foo") + result = cs_compile.compile(expr) + expected = """\ +SELECT + CASE + WHEN (`value` >= 0) AND (`value` < 1) THEN 0 + WHEN (`value` >= 1) AND (`value` <= 3) THEN 1 + ELSE CAST(NULL AS INT64) + END AS `tmp` +FROM t""" + assert result == expected + + +@pytest.mark.parametrize( + ("kind", "begin", "end", "expected"), + [ + ("preceding", None, 1, "UNBOUNDED PRECEDING AND 1 PRECEDING"), + ("following", 1, None, "1 FOLLOWING AND UNBOUNDED FOLLOWING"), + ], +) +def test_window_unbounded(kind, begin, end, expected): + t = ibis.table([("a", "int64")], name="t") + kwargs = {kind: (begin, end)} + expr = t.a.sum().over(ibis.window(**kwargs)) + result = cs_compile.compile(expr) + assert ( + result + == f"""\ +SELECT sum(`a`) OVER (ROWS BETWEEN {expected}) AS `tmp` +FROM t""" + ) + + +def test_large_compile(): + """ + Tests that compiling a large expression tree finishes + within a reasonable amount of time + """ + num_columns = 20 + num_joins = 7 + + class MockCloudSpannerClient(cs_compile.CloudSpannerClient): + def __init__(self): + pass + + names = [f"col_{i}" for i in range(num_columns)] + schema = ibis.Schema(names, ["string"] * num_columns) + ibis_client = MockCloudSpannerClient() + table = TableExpr(ops.SQLQueryResult("select * from t", schema, ibis_client)) + for _ in range(num_joins): + table = table.mutate(dummy=ibis.literal("")) + table = table.left_join(table, ["dummy"])[[table]] + + start = datetime.datetime.now() + cs_compile.compile(table) + delta = datetime.datetime.now() - start + assert delta.total_seconds() < 60 diff --git a/third_party/ibis/ibis_cloud_spanner/tests/test_datatypes.py b/third_party/ibis/ibis_cloud_spanner/tests/test_datatypes.py new file mode 100644 index 000000000..f5490367f --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/tests/test_datatypes.py @@ -0,0 +1,61 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from multipledispatch.conflict import ambiguities +from pytest import param + +import ibis.expr.datatypes as dt +from third_party.ibis.ibis_cloud_spanner.datatypes import ( + TypeTranslationContext, + ibis_type_to_cloud_spanner_type, +) + +pytestmark = pytest.mark.cloud_spanner + + +def test_no_ambiguities(): + ambs = ambiguities(ibis_type_to_cloud_spanner_type.funcs) + assert not ambs + + +@pytest.mark.parametrize( + ("datatype", "expected"), + [ + (dt.float32, "FLOAT64"), + (dt.float64, "FLOAT64"), + (dt.uint8, "INT64"), + (dt.uint16, "INT64"), + (dt.uint32, "INT64"), + (dt.int8, "INT64"), + (dt.int16, "INT64"), + (dt.int32, "INT64"), + (dt.int64, "INT64"), + (dt.string, "STRING"), + (dt.Array(dt.int64), "ARRAY"), + (dt.Array(dt.string), "ARRAY"), + (dt.date, "DATE"), + (dt.timestamp, "TIMESTAMP"), + param(dt.Timestamp(timezone="US/Eastern"), "TIMESTAMP",), + ], +) +def test_simple(datatype, expected): + context = TypeTranslationContext() + assert ibis_type_to_cloud_spanner_type(datatype, context) == expected + + +@pytest.mark.parametrize("datatype", [dt.uint64, dt.Decimal(8, 3)]) +def test_simple_failure_mode(datatype): + with pytest.raises(TypeError): + ibis_type_to_cloud_spanner_type(datatype) diff --git a/third_party/ibis/ibis_cloud_spanner/to_pandas.py b/third_party/ibis/ibis_cloud_spanner/to_pandas.py new file mode 100644 index 000000000..902f88334 --- /dev/null +++ b/third_party/ibis/ibis_cloud_spanner/to_pandas.py @@ -0,0 +1,67 @@ +# Copyright 2021 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pandas import DataFrame + + +class pandas_df: + def to_pandas(snapshot, sql, query_parameters): + + if query_parameters: + param = {} + param_type = {} + for i in query_parameters: + param.update(i["params"]) + param_type.update(i["param_types"]) + + data_qry = snapshot.execute_sql(sql, params=param, param_types=param_type) + + else: + data_qry = snapshot.execute_sql(sql) + + data = [] + for row in data_qry: + data.append(row) + + columns_dict = {} + + for item in data_qry.fields: + columns_dict[item.name] = item.type_.code.name + + # Creating list of columns to be mapped with the data + column_list = [k for k, v in columns_dict.items()] + + # Creating pandas dataframe from data and columns_list + df = DataFrame(data, columns=column_list) + + # Dictionary to map spanner datatype to a pandas compatible datatype + SPANNER_TO_PANDAS_DTYPE = { + "INT64": "int64", + "STRING": "object", + "BOOL": "bool", + "BYTES": "object", + "ARRAY": "object", + "DATE": "datetime64[ns, UTC]", + "FLOAT64": "float64", + "NUMERIC": "object", + "TIMESTAMP": "datetime64[ns, UTC]", + } + + for k, v in columns_dict.items(): + try: + df[k] = df[k].astype(SPANNER_TO_PANDAS_DTYPE[v]) + except KeyError: + print("Spanner Datatype is not present in datatype mapping dictionary") + + return df