diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py index 098b0bd786..014d82d3cc 100644 --- a/google/cloud/spanner_dbapi/__init__.py +++ b/google/cloud/spanner_dbapi/__init__.py @@ -93,7 +93,7 @@ def connect( if not database.exists(): raise ValueError("database '%s' does not exist." % database_id) - return Connection(database) + return Connection(instance, database) __all__ = [ diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 869586e363..cbd4b6bec3 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -4,141 +4,98 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from collections import namedtuple - -from google.cloud import spanner_v1 as spanner +"""DB-API driver Connection implementation for Google Cloud Spanner. -from .cursor import Cursor -from .exceptions import InterfaceError + See + https://www.python.org/dev/peps/pep-0249/#connection-objects +""" -ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) +from collections import namedtuple +from weakref import WeakSet +from .cursor import Cursor +from .exceptions import InterfaceError, Warning +from .enums import AutocommitDMLModes, TransactionModes -class Connection: - def __init__(self, db_handle): - self._dbhandle = db_handle - self._ddl_statements = [] - self.is_closed = False +ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) - def cursor(self): - self._raise_if_closed() - return Cursor(self) +class Connection(object): + """Representation of a connection to a Cloud Spanner database. - def _raise_if_closed(self): - """Raise an exception if this connection is closed. + You most likely don't need to instantiate `Connection` objects + directly, use the `connect` module function instead. - Helper to check the connection state before - running a SQL/DDL/DML query. + :type instance: :class:`~google.cloud.spanner_v1.instance.Instance` + :param instance: Cloud Spanner instance to connect to. - :raises: :class:`InterfaceError` if this connection is closed. - """ - if self.is_closed: - raise InterfaceError("connection is already closed") + :type database: :class:`~google.cloud.spanner_v1.database.Database` + :param database: Cloud Spanner database to connect to. + """ - def __handle_update_ddl(self, ddl_statements): - """ - Run the list of Data Definition Language (DDL) statements on the underlying - database. Each DDL statement MUST NOT contain a semicolon. - Args: - ddl_statements: a list of DDL statements, each without a semicolon. - Returns: - google.api_core.operation.Operation.result() - """ - self._raise_if_closed() - # Synchronously wait on the operation's completion. - return self._dbhandle.update_ddl(ddl_statements).result() - - def read_snapshot(self): - self._raise_if_closed() - return self._dbhandle.snapshot() - - def in_transaction(self, fn, *args, **kwargs): - self._raise_if_closed() - return self._dbhandle.run_in_transaction(fn, *args, **kwargs) - - def append_ddl_statement(self, ddl_statement): - self._raise_if_closed() - self._ddl_statements.append(ddl_statement) - - def run_prior_DDL_statements(self): - self._raise_if_closed() - - if not self._ddl_statements: - return - - ddl_statements = self._ddl_statements - self._ddl_statements = [] - - return self.__handle_update_ddl(ddl_statements) - - def list_tables(self): - return self.run_sql_in_snapshot( - """ - SELECT - t.table_name - FROM - information_schema.tables AS t - WHERE - t.table_catalog = '' and t.table_schema = '' - """ + def __init__(self, instance, database): + self.instance = instance + self.database = database + self.autocommit = True + self.read_only = False + self.transaction_mode = ( + TransactionModes.READ_ONLY + if self.read_only + else TransactionModes.READ_WRITE ) + self.autocommit_dml_mode = AutocommitDMLModes.TRANSACTIONAL + self.timeout_secs = 0 + self.read_timestamp = None + self.commit_timestamp = None + self._is_closed = False + self._inside_transaction = not self.autocommit + self._transaction_started = False + self._cursors = WeakSet() + self.read_only_staleness = {} + + @property + def is_closed(self): + return self._is_closed + + @property + def inside_transaction(self): + return self._inside_transaction + + @property + def transaction_started(self): + return self._transaction_started - def run_sql_in_snapshot(self, sql, params=None, param_types=None): - # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions - # hence this method exists to circumvent that limit. - self.run_prior_DDL_statements() - - with self._dbhandle.snapshot() as snapshot: - res = snapshot.execute_sql( - sql, params=params, param_types=param_types - ) - return list(res) - - def get_table_column_schema(self, table_name): - rows = self.run_sql_in_snapshot( - """SELECT - COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - TABLE_SCHEMA = '' - AND - TABLE_NAME = @table_name""", - params={"table_name": table_name}, - param_types={"table_name": spanner.param_types.STRING}, - ) + def cursor(self): + """Returns cursor for current connection""" + if self._is_closed: + raise InterfaceError("connection is already closed") - column_details = {} - for column_name, is_nullable, spanner_type in rows: - column_details[column_name] = ColumnDetails( - null_ok=is_nullable == "YES", spanner_type=spanner_type - ) - return column_details + return Cursor(self) def close(self): """Close this connection. The connection will be unusable from this point forward. """ - self.rollback() - self.__dbhandle = None - self.is_closed = True + self._is_closed = True def commit(self): - self._raise_if_closed() - - self.run_prior_DDL_statements() + """Commit all the pending transactions.""" + raise Warning( + "Cloud Spanner DB API always works in `autocommit` mode." + "See https://github.com/googleapis/python-spanner-django#transaction-management-isnt-supported" + ) def rollback(self): - self._raise_if_closed() - - # TODO: to be added. + """Rollback all the pending transactions.""" + raise Warning( + "Cloud Spanner DB API always works in `autocommit` mode." + "See https://github.com/googleapis/python-spanner-django#transaction-management-isnt-supported" + ) def __enter__(self): return self - def __exit__(self, etype, value, traceback): - self.commit() + def __exit__(self, exc_type, exc_value, traceback): self.close() diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 10e5184ed2..05b0091ade 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -12,6 +12,7 @@ InternalServerError, InvalidArgument, ) + from google.cloud.spanner_v1 import param_types from .exceptions import ( @@ -34,7 +35,6 @@ _UNSET_COUNT = -1 - # This table maps spanner_types to Spanner's data type sizes as per # https://cloud.google.com/spanner/docs/data-types#allowable-types # It is used to map `display_size` to a known type for Cursor.description @@ -53,58 +53,110 @@ } -class Cursor: - """ - Database cursor to manage the context of a fetch operation. +class Cursor(object): + """Database cursor to manage the context of a fetch operation. - :type connection: :class:`spanner_dbapi.connection.Connection` - :param connection: Parent connection object for this Cursor. + :type connection: :class:`~google.cloud.spanner_dbapi.connection.Connection` + :param connection: A DB-API connection to Google Cloud Spanner. """ def __init__(self, connection): - self._itr = None - self._res = None - self._row_count = _UNSET_COUNT self._connection = connection self._is_closed = False - - # the number of rows to fetch at a time with fetchmany() + self._stream = None + self._itr = None + self._row_count = _UNSET_COUNT self.arraysize = 1 + self._ddl_statements = [] - def execute(self, sql, args=None): + @property + def is_closed(self): + """The cursor close indicator. + :rtype: bool + :returns: True if the cursor or the parent connection is closed, + otherwise False. + """ + return self._is_closed or self._connection.is_closed + + @property + def connection(self): + return self._connection + + @property + def rowcount(self): + """The number of rows produced by the last `.execute()`.""" + return self._row_count + + @property + def lastrowid(self): + return None + + @property + def description(self): + """Read-only attribute containing a sequence of the following items: + - ``name`` + - ``type_code`` + - ``display_size`` + - ``internal_size`` + - ``precision`` + - ``scale`` + - ``null_ok`` """ - Abstracts and implements execute SQL statements on Cloud Spanner. - Args: - sql: A SQL statement - *args: variadic argument list - **kwargs: key worded arguments - Returns: - None + if not (self._stream and self._stream.metadata): + return None + + row_type = self._stream.metadata.row_type + columns = [] + for field in row_type.fields: + columns.append( + ColumnInfo( + name=field.name, + type_code=field.type.code, + # Size of the SQL type of the column. + display_size=code_to_display_size.get(field.type.code), + # Client perceived size of the column. + internal_size=field.ByteSize(), + ) + ) + return tuple(columns) + + def close(self): + """Closes this Cursor, making it unusable from this point forward.""" + self._connection = None + self._is_closed = True + + def callproc(self, procname, args=None): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def execute(self, sql, args=None): + """Prepares and executes a Spanner database operation. + + :type sql: str + :param sql: A SQL query statement. + + :type args: list + :param args: Additional parameters to supplement the SQL query. """ self._raise_if_closed() if not self._connection: raise ProgrammingError("Cursor is not connected to the database") - self._res = None + self._stream = None # Classify whether this is a read-only SQL statement. try: classification = classify_stmt(sql) if classification == STMT_DDL: - self._connection.append_ddl_statement(sql) - return - - # For every other operation, we've got to ensure that - # any prior DDL statements were run. - self._run_prior_DDL_statements() - - if classification == STMT_NON_UPDATING: - self.__handle_DQL(sql, args or None) + self._ddl_statements.append() + self._run_ddl_statements(sql) + elif classification == STMT_NON_UPDATING: + self._handle_dql(sql, args or None) elif classification == STMT_INSERT: - self.__handle_insert(sql, args or None) + self._handle_insert(sql, args or None) else: - self.__handle_update(sql, args or None) + self._handle_update(sql, args or None) except (AlreadyExists, FailedPrecondition) as e: raise IntegrityError(e.details if hasattr(e, "details") else e) except InvalidArgument as e: @@ -112,8 +164,109 @@ def execute(self, sql, args=None): except InternalServerError as e: raise OperationalError(e.details if hasattr(e, "details") else e) - def __handle_update(self, sql, params): - self._connection.in_transaction(self.__do_execute_update, sql, params) + def executemany(self, operation, seq_of_params): + """Execute the given SQL with every parameters set + from the given sequence of parameters. + + :type operation: str + :param operation: SQL code to execute. + + :type seq_of_params: list + :param seq_of_params: Sequence of additional parameters to run + the query with. + """ + self._raise_if_closed() + if not self._connection: + raise ProgrammingError("Cursor is not connected to the database") + + for params in seq_of_params: + self.execute(operation, params) + + def fetchone(self): + """Fetch the next row of a query result set, returning a single + sequence, or None when no more data is available.""" + self._raise_if_closed() + try: + return next(self) + except StopIteration: + return None + + def fetchmany(self, size=None): + """Fetch the next set of rows of a query result, returning a sequence + of sequences. An empty sequence is returned when no more rows are available. + + :type size: int + :param size: (Optional) The maximum number of results to fetch. + + :raises InterfaceError: + if the previous call to .execute*() did not produce any result set + or if no call was issued yet. + """ + self._raise_if_closed() + + if size is None: + size = self.arraysize + + items = [] + for i in range(size): + try: + items.append(tuple(next(self))) + except StopIteration: + break + + return items + + def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as + a sequence of sequences. + """ + self._raise_if_closed() + + return list(iter(self)) + + def nextset(self): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def setinputsizes(self, sizes): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def setoutputsize(self, size, column=None): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def __iter__(self): + if self._itr is None: + raise ProgrammingError("no results to return") + return self._itr + + def __next__(self): + if self._itr is None: + raise ProgrammingError("no results to return") + return next(self._itr) + + def __enter__(self): + return self + + def __exit__(self, etype, value, traceback): + self.close() + + def _raise_if_closed(self): + """Raise an exception if this cursor is closed. + Helper to check this cursor's state before running a + SQL/DDL/DML query. If the parent connection is + already closed it also raises an error. + :raises: :class:`InterfaceError` if this cursor is closed. + """ + if self._is_closed: + raise InterfaceError("Cursor and/or connection is already closed.") + + def _handle_update(self, sql, params): + self._raise_if_closed() + self._connection.database.run_in_transaction( + self.__do_execute_update, sql, params + ) def __do_execute_update(self, transaction, sql, params, param_types=None): sql = ensure_where_clause(sql) @@ -128,7 +281,7 @@ def __do_execute_update(self, transaction, sql, params, param_types=None): return res - def __handle_insert(self, sql, params): + def _handle_insert(self, sql, params): parts = parse_insert(sql, params) # The split between the two styles exists because: @@ -147,37 +300,38 @@ def __handle_insert(self, sql, params): if parts.get("homogenous"): # The common case of multiple values being passed in # non-complex pyformat args and need to be uploaded in one RPC. - return self._connection.in_transaction( - self.__do_execute_insert_homogenous, parts + return self._connection.database.run_in_transaction( + self._do_execute_insert_homogenous, parts ) else: # All the other cases that are esoteric and need # transaction.execute_sql sql_params_list = parts.get("sql_params_list") - return self._connection.in_transaction( - self.__do_execute_insert_heterogenous, sql_params_list + return self._connection.database.run_in_transaction( + self._do_execute_insert_heterogenous, sql_params_list ) - def __do_execute_insert_heterogenous(self, transaction, sql_params_list): + def _do_execute_insert_heterogenous(self, transaction, sql_params_list): for sql, params in sql_params_list: sql, params = sql_pyformat_args_to_spanner(sql, params) param_types = get_param_types(params) res = transaction.execute_sql( sql, params=params, param_types=param_types ) - # TODO: File a bug with Cloud Spanner and the Python client maintainers - # about a lost commit when res isn't read from. - _ = list(res) - def __do_execute_insert_homogenous(self, transaction, parts): + def _do_execute_insert_homogenous(self, transaction, parts): # Perform an insert in one shot. - table = parts.get("table") - columns = parts.get("columns") - values = parts.get("values") + table, columns, values = ( + parts.get("table"), + parts.get("columns"), + parts.get("values"), + ) + return transaction.insert(table, columns, values) - def __handle_DQL(self, sql, params): - with self._connection.read_snapshot() as snapshot: + def _handle_dql(self, sql, params): + self._raise_if_closed() + with self._connection.database.snapshot() as snapshot: # Reference # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql sql, params = sql_pyformat_args_to_spanner(sql, params) @@ -196,158 +350,22 @@ def __handle_DQL(self, sql, params): # are for .fetchone() with those that would result in # many items returns a RuntimeError if .fetchone() is # invoked and vice versa. - self._res = res + self._stream = res # Read the first element so that StreamedResult can # return the metadata after a DQL statement. See issue #155. - self._itr = PeekIterator(self._res) + self._itr = PeekIterator(self._stream) # Unfortunately, Spanner doesn't seem to send back # information about the number of rows available. self._row_count = _UNSET_COUNT - def __enter__(self): - return self - - def __exit__(self, etype, value, traceback): - self.__clear() - - def __clear(self): - self._connection = None - - @property - def description(self): - if not (self._res and self._res.metadata): - return None - - row_type = self._res.metadata.row_type - columns = [] - for field in row_type.fields: - columns.append( - Column( - name=field.name, - type_code=field.type.code, - # Size of the SQL type of the column. - display_size=code_to_display_size.get(field.type.code), - # Client perceived size of the column. - internal_size=field.ByteSize(), - ) - ) - return tuple(columns) - - @property - def rowcount(self): - return self._row_count - - @property - def is_closed(self): - """The cursor close indicator. - - :rtype: :class:`bool` - :returns: True if this cursor or it's parent connection is closed, False - otherwise. - """ - return self._is_closed or self._connection.is_closed - - def _raise_if_closed(self): - """Raise an exception if this cursor is closed. - - Helper to check this cursor's state before running a - SQL/DDL/DML query. If the parent connection is - already closed it also raises an error. - - :raises: :class:`InterfaceError` if this cursor is closed. - """ - if self.is_closed: - raise InterfaceError("cursor is already closed") - - def close(self): - """Close this cursor. - - The cursor will be unusable from this point forward. - """ - self.__clear() - self._is_closed = True - - def executemany(self, operation, seq_of_params): - if not self._connection: - raise ProgrammingError("Cursor is not connected to the database") - - for params in seq_of_params: - self.execute(operation, params) - - def __next__(self): - if self._itr is None: - raise ProgrammingError("no results to return") - return next(self._itr) - - def __iter__(self): - if self._itr is None: - raise ProgrammingError("no results to return") - return self._itr - - def fetchone(self): + def _run_ddl_statements(self, sql): self._raise_if_closed() + return self._connection.database.update_ddl(sql).result() - try: - return next(self) - except StopIteration: - return None - def fetchall(self): - self._raise_if_closed() +class ColumnInfo: + """Row column description object.""" - return list(self.__iter__()) - - def fetchmany(self, size=None): - """ - Fetch the next set of rows of a query result, returning a sequence of sequences. - An empty sequence is returned when no more rows are available. - - Args: - size: optional integer to determine the maximum number of results to fetch. - - - Raises: - Error if the previous call to .execute*() did not produce any result set - or if no call was issued yet. - """ - self._raise_if_closed() - - if size is None: - size = self.arraysize - - items = [] - for i in range(size): - try: - items.append(tuple(self.__next__())) - except StopIteration: - break - - return items - - @property - def lastrowid(self): - return None - - def setinputsizes(sizes): - raise ProgrammingError("Unimplemented") - - def setoutputsize(size, column=None): - raise ProgrammingError("Unimplemented") - - def _run_prior_DDL_statements(self): - return self._connection.run_prior_DDL_statements() - - def list_tables(self): - return self._connection.list_tables() - - def run_sql_in_snapshot(self, sql): - return self._connection.run_sql_in_snapshot(sql) - - def get_table_column_schema(self, table_name): - return self._connection.get_table_column_schema(table_name) - - -class Column: def __init__( self, name, @@ -366,48 +384,41 @@ def __init__( self.scale = scale self.null_ok = null_ok + self.fields = ( + self.name, + self.type_code, + self.display_size, + self.internal_size, + self.precision, + self.scale, + self.null_ok, + ) + def __repr__(self): return self.__str__() def __getitem__(self, index): - if index == 0: - return self.name - elif index == 1: - return self.type_code - elif index == 2: - return self.display_size - elif index == 3: - return self.internal_size - elif index == 4: - return self.precision - elif index == 5: - return self.scale - elif index == 6: - return self.null_ok + return self.fields[index] def __str__(self): - rstr = ", ".join( - [ - field - for field in [ + str_repr = ", ".join( + filter( + lambda part: part is not None, + [ "name='%s'" % self.name, "type_code=%d" % self.type_code, - None - if not self.display_size - else "display_size=%d" % self.display_size, - None - if not self.internal_size - else "internal_size=%d" % self.internal_size, - None - if not self.precision - else "precision='%s'" % self.precision, - None if not self.scale else "scale='%s'" % self.scale, - None - if not self.null_ok - else "null_ok='%s'" % self.null_ok, - ] - if field - ] + "display_size=%d" % self.display_size + if self.display_size + else None, + "internal_size=%d" % self.internal_size + if self.internal_size + else None, + "precision='%s'" % self.precision + if self.precision + else None, + "scale='%s'" % self.scale if self.scale else None, + "null_ok='%s'" % self.null_ok if self.null_ok else None, + ], + ) ) - - return "Column(%s)" % rstr + return "ColumnInfo(%s)" % str_repr diff --git a/google/cloud/spanner_dbapi/enums.py b/google/cloud/spanner_dbapi/enums.py new file mode 100644 index 0000000000..350ecba0d8 --- /dev/null +++ b/google/cloud/spanner_dbapi/enums.py @@ -0,0 +1,19 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""DBAPI enum types.""" + +import enum + + +class TransactionModes(enum.IntEnum): + READ_ONLY = 0 + READ_WRITE = 1 + + +class AutocommitDMLModes(enum.IntEnum): + TRANSACTIONAL = 0 + PARTITIONED_NON_ATOMIC = 1 diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py index e7cd3f361f..bf5bb0ec2e 100644 --- a/tests/spanner_dbapi/test_connection.py +++ b/tests/spanner_dbapi/test_connection.py @@ -7,26 +7,113 @@ """Connection() class unit tests.""" import unittest -from unittest import mock -from google.cloud.spanner_dbapi import connect, InterfaceError +# import google.cloud.spanner_dbapi.exceptions as dbapi_exceptions + +from google.cloud.spanner_dbapi import Connection, InterfaceError, Warning +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.instance import Instance class TestConnection(unittest.TestCase): + instance_name = "instance-name" + database_name = "database-name" + + def _make_connection(self): + # we don't need real Client object to test the constructor + instance = Instance(self.instance_name, client=None) + database = instance.database(self.database_name) + return Connection(instance, database) + + def test_ctor(self): + connection = self._make_connection() + + self.assertIsInstance(connection.instance, Instance) + self.assertEqual(connection.instance.instance_id, self.instance_name) + + self.assertIsInstance(connection.database, Database) + self.assertEqual(connection.database.database_id, self.database_name) + + self.assertFalse(connection._is_closed) + def test_close(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - self.assertFalse(connection.is_closed) + connection = self._make_connection() + + self.assertFalse(connection._is_closed) connection.close() - self.assertTrue(connection.is_closed) + self.assertTrue(connection._is_closed) with self.assertRaises(InterfaceError): connection.cursor() + + def test_transaction_management_warnings(self): + connection = self._make_connection() + + with self.assertRaises(Warning): + connection.commit() + + with self.assertRaises(Warning): + connection.rollback() + + def test_connection_close_check_if_open(self): + connection = self._make_connection() + + connection.cursor() + self.assertFalse(connection._is_closed) + + def test_is_closed(self): + connection = self._make_connection() + + self.assertEqual(connection._is_closed, connection.is_closed) + connection.close() + self.assertEqual(connection._is_closed, connection.is_closed) + + def test_inside_transaction(self): + connection = self._make_connection() + + self.assertEqual( + connection._inside_transaction, connection.inside_transaction, + ) + + def test_transaction_started(self): + connection = self._make_connection() + + self.assertEqual( + connection.transaction_started, connection._transaction_started, + ) + + def test_cursor(self): + from google.cloud.spanner_dbapi.cursor import Cursor + + connection = self._make_connection() + cursor = connection.cursor() + + self.assertIsInstance(cursor, Cursor) + self.assertEqual(connection, cursor._connection) + + def test_commit(self): + connection = self._make_connection() + + with self.assertRaises(Warning): + connection.commit() + + def test_rollback(self): + connection = self._make_connection() + + with self.assertRaises(Warning): + connection.rollback() + + def test_context_success(self): + connection = self._make_connection() + + with connection as conn: + conn.cursor() + self.assertTrue(connection._is_closed) + + def test_context_error(self): + connection = self._make_connection() + + with self.assertRaises(Exception): + with connection: + raise Exception + self.assertTrue(connection._is_closed) diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py index 722bbbcb8a..ad19202a51 100644 --- a/tests/spanner_dbapi/test_cursor.py +++ b/tests/spanner_dbapi/test_cursor.py @@ -9,11 +9,11 @@ import unittest from unittest import mock -from google.cloud.spanner_dbapi import connect, InterfaceError - class TestCursor(unittest.TestCase): - def test_close(self): + def _make_cursor(self): + from google.cloud.spanner_dbapi import connect + with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, @@ -24,16 +24,11 @@ def test_close(self): ): connection = connect("test-instance", "test-database") - cursor = connection.cursor() - self.assertFalse(cursor.is_closed) + return connection.cursor() - cursor.close() - - self.assertTrue(cursor.is_closed) - with self.assertRaises(InterfaceError): - cursor.execute("SELECT * FROM database") + def test_close(self): + from google.cloud.spanner_dbapi import connect, InterfaceError - def test_connection_closed(self): with mock.patch( "google.cloud.spanner_v1.instance.Instance.exists", return_value=True, @@ -47,8 +42,195 @@ def test_connection_closed(self): cursor = connection.cursor() self.assertFalse(cursor.is_closed) - connection.close() + cursor.close() self.assertTrue(cursor.is_closed) with self.assertRaises(InterfaceError): cursor.execute("SELECT * FROM database") + + def test_connection(self): + cursor = self._make_cursor() + + self.assertEqual(cursor.connection, cursor._connection) + + cursor._connection = "changed-connection" + self.assertEqual(cursor.connection, cursor._connection) + + def test_description_if_not_stream(self): + cursor = self._make_cursor() + cursor._stream = None + + self.assertIsNone(cursor.description) + + def test_rowcount(self): + cursor = self._make_cursor() + + self.assertEqual(cursor.rowcount, cursor._row_count) + + cursor._row_count = 52 + self.assertEqual(cursor.rowcount, cursor._row_count) + + def test_lastrowid(self): + cursor = self._make_cursor() + + self.assertIsNone(cursor.lastrowid) + + def test_callproc(self): + from google.cloud.spanner_dbapi import InterfaceError + + cursor = self._make_cursor() + + self.assertIsNone(cursor.callproc("procname")) + + cursor.close() + with self.assertRaises(InterfaceError): + cursor.callproc("procname") + + def test_nextset(self): + from google.cloud.spanner_dbapi import InterfaceError + + cursor = self._make_cursor() + self.assertIsNone(cursor.nextset()) + cursor.close() + with self.assertRaises(InterfaceError): + cursor.nextset() + + def test_setinputsizes(self): + from google.cloud.spanner_dbapi import InterfaceError + + cursor = self._make_cursor() + self.assertIsNone(cursor.setinputsizes("sizes")) + cursor.close() + with self.assertRaises(InterfaceError): + cursor.setinputsizes("sizes") + + def test_setoutputsize(self): + from google.cloud.spanner_dbapi import InterfaceError + + cursor = self._make_cursor() + self.assertIsNone(cursor.setoutputsize("size")) + cursor.close() + with self.assertRaises(InterfaceError): + cursor.setoutputsize("size") + + def test_execute_without_connection(self): + from google.cloud.spanner_dbapi import ProgrammingError + + cursor = self._make_cursor() + cursor._connection = None + + with self.assertRaises(ProgrammingError): + cursor.execute('SELECT * FROM table1 WHERE "col1" = @a1') + + def test_executemany_without_connection(self): + from google.cloud.spanner_dbapi import ProgrammingError + + cursor = self._make_cursor() + cursor._connection = None + + with self.assertRaises(ProgrammingError): + cursor.executemany( + """SELECT * FROM table1 WHERE "col1" = @a1""", () + ) + + def test_executemany_on_closed_cursor(self): + from google.cloud.spanner_dbapi import InterfaceError + + cursor = self._make_cursor() + cursor.close() + + with self.assertRaises(InterfaceError): + cursor.executemany( + """SELECT * FROM table1 WHERE "col1" = @a1""", () + ) + + def test_executemany(self): + operation = """SELECT * FROM table1 WHERE "col1" = @a1""" + params_seq = ((1,), (2,)) + + cursor = self._make_cursor() + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.execute" + ) as execute_mock: + cursor.executemany(operation, params_seq) + + execute_mock.assert_has_calls( + (mock.call(operation, (1,)), mock.call(operation, (2,))) + ) + + def test_context_success(self): + cursor = self._make_cursor() + + with cursor as c: + c.nextset() + self.assertTrue(cursor._is_closed) + + def test_context_error(self): + cursor = self._make_cursor() + + with self.assertRaises(Exception): + with cursor: + raise Exception + self.assertTrue(cursor._is_closed) + + +class TestColumns(unittest.TestCase): + def test_ctor(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + name = "col-name" + type_code = 8 + display_size = 5 + internal_size = 10 + precision = 3 + scale = None + null_ok = False + + cols = ColumnInfo( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ) + + self.assertEqual(cols.name, name) + self.assertEqual(cols.type_code, type_code) + self.assertEqual(cols.display_size, display_size) + self.assertEqual(cols.internal_size, internal_size) + self.assertEqual(cols.precision, precision) + self.assertEqual(cols.scale, scale) + self.assertEqual(cols.null_ok, null_ok) + self.assertEqual( + cols.fields, + ( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ), + ) + + def test___get_item__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + fields = ("col-name", 8, 5, 10, 3, None, False) + cols = ColumnInfo(*fields) + + for i in range(0, 7): + self.assertEqual(cols[i], fields[i]) + + def test___str__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) + + self.assertEqual( + str(cols), + "ColumnInfo(name='col-name', type_code=8, internal_size=10, precision='3')", + )