diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 1f6ae0b96b..10e9361140 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -6,6 +6,8 @@ """DB-API Connection for the Google Cloud Spanner.""" +import warnings + from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import spanner_v1 as spanner @@ -15,19 +17,115 @@ from google.cloud.spanner_dbapi.version import PY_VERSION -class Connection(object): - """DB-API Connection to a Google Cloud Spanner database. +AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" + + +class Connection: + """Representation of a DB-API connection to a Cloud Spanner database. + + You most likely don't need to instantiate `Connection` objects + directly, use the `connect` module function instead. + + :type instance: :class:`~google.cloud.spanner_v1.instance.Instance` + :param instance: Cloud Spanner instance to connect to. :type database: :class:`~google.cloud.spanner_v1.database.Database` :param database: The database to which the connection is linked. """ def __init__(self, instance, database): - self.instance = instance - self.database = database + self._instance = instance + self._database = database + self._ddl_statements = [] + + self._transaction = None + self._session = None + self.is_closed = False + self._autocommit = False + + @property + def autocommit(self): + """Autocommit mode flag for this connection. + + :rtype: bool + :returns: Autocommit mode flag value. + """ + return self._autocommit + + @autocommit.setter + def autocommit(self, value): + """Change this connection autocommit mode. + + :type value: bool + :param value: New autocommit mode state. + """ + if value and not self._autocommit: + self.commit() + + self._autocommit = value + + @property + def database(self): + """Database to which this connection relates. + + :rtype: :class:`~google.cloud.spanner_v1.database.Database` + :returns: The related database object. + """ + return self._database + + @property + def instance(self): + """Instance to which this connection relates. + + :rtype: :class:`~google.cloud.spanner_v1.instance.Instance` + :returns: The related instance object. + """ + return self._instance + + def _session_checkout(self): + """Get a Cloud Spanner session from the pool. + + If there is already a session associated with + this connection, it'll be used instead. + + :rtype: :class:`google.cloud.spanner_v1.session.Session` + :returns: Cloud Spanner session object ready to use. + """ + if not self._session: + self._session = self.database._pool.get() + + return self._session + + def _release_session(self): + """Release the currently used Spanner session. + + The session will be returned into the sessions pool. + """ + self.database._pool.put(self._session) + self._session = None + + def transaction_checkout(self): + """Get a Cloud Spanner transaction. - self.ddl_statements = [] + Begin a new transaction, if there is no transaction in + this connection yet. Return the begun one otherwise. + + The method is non operational in autocommit mode. + + :rtype: :class:`google.cloud.spanner_v1.transaction.Transaction` + :returns: A Cloud Spanner transaction object, ready to use. + """ + if not self.autocommit: + if ( + not self._transaction + or self._transaction.committed + or self._transaction.rolled_back + ): + self._transaction = self._session_checkout().transaction() + self._transaction.begin() + + return self._transaction def _raise_if_closed(self): """Helper to check the connection state before running a query. @@ -41,20 +139,33 @@ def _raise_if_closed(self): def close(self): """Closes this connection. - The connection will be unusable from this point forward. + The connection will be unusable from this point forward. If the + connection has an active transaction, it will be rolled back. """ - self.database = None + if ( + self._transaction + and not self._transaction.committed + and not self._transaction.rolled_back + ): + self._transaction.rollback() + self.is_closed = True def commit(self): """Commits any pending transaction to the database.""" - self._raise_if_closed() - - self.run_prior_DDL_statements() + if self._autocommit: + warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + elif self._transaction: + self._transaction.commit() + self._release_session() def rollback(self): - """A no-op, raising an error if the connection is closed.""" - self._raise_if_closed() + """Rollback all the pending transactions.""" + if self._autocommit: + warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) + elif self._transaction: + self._transaction.rollback() + self._release_session() def cursor(self): """Factory to create a DB-API Cursor.""" @@ -65,9 +176,9 @@ def cursor(self): def run_prior_DDL_statements(self): self._raise_if_closed() - if self.ddl_statements: - ddl_statements = self.ddl_statements - self.ddl_statements = [] + if self._ddl_statements: + ddl_statements = self._ddl_statements + self._ddl_statements = [] return self.database.update_ddl(ddl_statements).result() @@ -80,15 +191,20 @@ def __exit__(self, etype, value, traceback): def connect( - instance_id, database_id, project=None, credentials=None, user_agent=None + instance_id, + database_id, + project=None, + credentials=None, + pool=None, + user_agent=None, ): """Creates a connection to a Google Cloud Spanner database. :type instance_id: str - :param instance_id: ID of the instance to connect to. + :param instance_id: The ID of the instance to connect to. :type database_id: str - :param database_id: The name of the database to connect to. + :param database_id: The ID of the database to connect to. :type project: str :param project: (Optional) The ID of the project which owns the @@ -102,8 +218,13 @@ def connect( attempt to ascertain the credentials from the environment. + :type pool: Concrete subclass of + :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`. + :param pool: (Optional). Session pool to be used by database. + :type user_agent: str - :param user_agent: (Optional) Prefix to the user agent header. + :param user_agent: (Optional) User agent to be used with this connection's + requests. :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` :returns: Connection object associated with the given Google Cloud Spanner @@ -125,7 +246,7 @@ def connect( if not instance.exists(): raise ValueError("instance '%s' does not exist." % instance_id) - database = instance.database(database_id, pool=spanner.pool.BurstyPool()) + database = instance.database(database_id, pool=pool) if not database.exists(): raise ValueError("database '%s' does not exist." % database_id) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index f6ba16d216..6997752a42 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -26,8 +26,7 @@ from google.cloud.spanner_dbapi import parse_utils from google.cloud.spanner_dbapi.parse_utils import get_param_types - -from .utils import PeekIterator +from google.cloud.spanner_dbapi.utils import PeekIterator _UNSET_COUNT = -1 @@ -150,7 +149,7 @@ def execute(self, sql, args=None): try: classification = parse_utils.classify_stmt(sql) if classification == parse_utils.STMT_DDL: - self.connection.ddl_statements.append(sql) + self.connection._ddl_statements.append(sql) return # For every other operation, we've got to ensure that @@ -158,6 +157,19 @@ def execute(self, sql, args=None): # self._run_prior_DDL_statements() self.connection.run_prior_DDL_statements() + if not self.connection.autocommit: + transaction = self.connection.transaction_checkout() + + sql, params = parse_utils.sql_pyformat_args_to_spanner( + sql, args + ) + + self._result_set = transaction.execute_sql( + sql, params, param_types=get_param_types(params) + ) + self._itr = PeekIterator(self._result_set) + return + if classification == parse_utils.STMT_NON_UPDATING: self._handle_DQL(sql, args or None) elif classification == parse_utils.STMT_INSERT: diff --git a/tests/spanner_dbapi/test_connect.py b/tests/spanner_dbapi/test_connect.py new file mode 100644 index 0000000000..fb4d89c373 --- /dev/null +++ b/tests/spanner_dbapi/test_connect.py @@ -0,0 +1,135 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""connect() module function unit tests.""" + +import unittest +from unittest import mock + +import google.auth.credentials +from google.api_core.gapic_v1.client_info import ClientInfo +from google.cloud.spanner_dbapi import connect, Connection +from google.cloud.spanner_v1.pool import FixedSizePool + + +def _make_credentials(): + class _CredentialsWithScopes( + google.auth.credentials.Credentials, google.auth.credentials.Scoped + ): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class Test_connect(unittest.TestCase): + def test_connect(self): + PROJECT = "test-project" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + CLIENT_INFO = ClientInfo(user_agent=USER_AGENT) + + with mock.patch( + "google.cloud.spanner_dbapi.spanner_v1.Client" + ) as client_mock: + with mock.patch( + "google.cloud.spanner_dbapi.google_client_info", + return_value=CLIENT_INFO, + ) as client_info_mock: + + connection = connect( + "test-instance", + "test-database", + PROJECT, + CREDENTIALS, + user_agent=USER_AGENT, + ) + + self.assertIsInstance(connection, Connection) + client_info_mock.assert_called_once_with(USER_AGENT) + + client_mock.assert_called_once_with( + project=PROJECT, + credentials=CREDENTIALS, + client_info=CLIENT_INFO, + ) + + def test_instance_not_found(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=False, + ) as exists_mock: + + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + exists_mock.assert_called_once_with() + + def test_database_not_found(self): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=False, + ) as exists_mock: + + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + exists_mock.assert_called_once_with() + + def test_connect_instance_id(self): + INSTANCE = "test-instance" + + with mock.patch( + "google.cloud.spanner_v1.client.Client.instance" + ) as instance_mock: + connection = connect(INSTANCE, "test-database") + + instance_mock.assert_called_once_with(INSTANCE) + + self.assertIsInstance(connection, Connection) + + def test_connect_database_id(self): + DATABASE = "test-database" + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connection = connect("test-instance", DATABASE) + + database_mock.assert_called_once_with(DATABASE, pool=mock.ANY) + + self.assertIsInstance(connection, Connection) + + def test_default_sessions_pool(self): + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertIsNotNone(connection.database._pool) + + def test_sessions_pool(self): + database_id = "test-database" + pool = FixedSizePool() + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connect("test-instance", database_id, pool=pool) + database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/spanner_dbapi/test_connection.py b/tests/spanner_dbapi/test_connection.py new file mode 100644 index 0000000000..24260de12e --- /dev/null +++ b/tests/spanner_dbapi/test_connection.py @@ -0,0 +1,79 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Connection() class unit tests.""" + +import unittest +from unittest import mock + +# import google.cloud.spanner_dbapi.exceptions as dbapi_exceptions + +from google.cloud.spanner_dbapi import Connection, InterfaceError +from google.cloud.spanner_dbapi.connection import AUTOCOMMIT_MODE_WARNING +from google.cloud.spanner_v1.database import Database +from google.cloud.spanner_v1.instance import Instance + + +class TestConnection(unittest.TestCase): + instance_name = "instance-name" + database_name = "database-name" + + def _make_connection(self): + # we don't need real Client object to test the constructor + instance = Instance(self.instance_name, client=None) + database = instance.database(self.database_name) + return Connection(instance, database) + + def test_ctor(self): + connection = self._make_connection() + + self.assertIsInstance(connection.instance, Instance) + self.assertEqual(connection.instance.instance_id, self.instance_name) + + self.assertIsInstance(connection.database, Database) + self.assertEqual(connection.database.database_id, self.database_name) + + self.assertFalse(connection.is_closed) + + def test_close(self): + connection = self._make_connection() + + self.assertFalse(connection.is_closed) + connection.close() + self.assertTrue(connection.is_closed) + + with self.assertRaises(InterfaceError): + connection.cursor() + + @mock.patch("warnings.warn") + def test_transaction_autocommit_warnings(self, warn_mock): + connection = self._make_connection() + connection.autocommit = True + + connection.commit() + warn_mock.assert_called_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) + connection.rollback() + warn_mock.assert_called_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) + + def test_database_property(self): + connection = self._make_connection() + self.assertIsInstance(connection.database, Database) + self.assertEqual(connection.database, connection._database) + + with self.assertRaises(AttributeError): + connection.database = None + + def test_instance_property(self): + connection = self._make_connection() + self.assertIsInstance(connection.instance, Instance) + self.assertEqual(connection.instance, connection._instance) + + with self.assertRaises(AttributeError): + connection.instance = None diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 5710ba6ce6..f3ee345e15 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -5,18 +5,291 @@ # https://developers.google.com/open-source/licenses/bsd import unittest +import os +from google.api_core import exceptions -class TestSpannerDjangoDBAPI(unittest.TestCase): - def setUp(self): - # TODO: Implement this method - pass +from google.cloud.spanner import Client +from google.cloud.spanner import BurstyPool +from google.cloud.spanner_dbapi.connection import Connection + +from test_utils.retry import RetryErrors +from test_utils.system import unique_resource_id + + +CREATE_INSTANCE = ( + os.getenv("GOOGLE_CLOUD_TESTS_CREATE_SPANNER_INSTANCE") is not None +) +USE_EMULATOR = os.getenv("SPANNER_EMULATOR_HOST") is not None + +if CREATE_INSTANCE: + INSTANCE_ID = "google-cloud" + unique_resource_id("-") +else: + INSTANCE_ID = os.environ.get( + "GOOGLE_CLOUD_TESTS_SPANNER_INSTANCE", "google-cloud-python-systest" + ) +EXISTING_INSTANCES = [] + +DDL_STATEMENTS = ( + """CREATE TABLE contacts ( + contact_id INT64, + first_name STRING(1024), + last_name STRING(1024), + email STRING(1024) + ) + PRIMARY KEY (contact_id)""", +) + + +class Config(object): + """Run-time configuration to be modified at set-up. + + This is a mutable stand-in to allow test set-up to modify + global state. + """ + + CLIENT = None + INSTANCE_CONFIG = None + INSTANCE = None + + +def _list_instances(): + return list(Config.CLIENT.list_instances()) + + +def setUpModule(): + if USE_EMULATOR: + from google.auth.credentials import AnonymousCredentials + + emulator_project = os.getenv("GCLOUD_PROJECT", "emulator-test-project") + Config.CLIENT = Client( + project=emulator_project, credentials=AnonymousCredentials() + ) + else: + Config.CLIENT = Client() + + retry = RetryErrors(exceptions.ServiceUnavailable) + + configs = list(retry(Config.CLIENT.list_instance_configs)()) + + instances = retry(_list_instances)() + EXISTING_INSTANCES[:] = instances + + if CREATE_INSTANCE: + if not USE_EMULATOR: + # Defend against back-end returning configs for regions we aren't + # actually allowed to use. + configs = [config for config in configs if "-us-" in config.name] + + if not configs: + raise ValueError("List instance configs failed in module set up.") + + Config.INSTANCE_CONFIG = configs[0] + config_name = configs[0].name + + Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID, config_name) + created_op = Config.INSTANCE.create() + created_op.result(30) # block until completion + else: + Config.INSTANCE = Config.CLIENT.instance(INSTANCE_ID) + Config.INSTANCE.reload() + + +def tearDownModule(): + """Delete the test instance, if it was created.""" + if CREATE_INSTANCE: + Config.INSTANCE.delete() + + +class TestTransactionsManagement(unittest.TestCase): + """Transactions management support tests.""" + + DATABASE_NAME = "db-api-transactions-management" + + @classmethod + def setUpClass(cls): + """Create a test database.""" + cls._db = Config.INSTANCE.database( + cls.DATABASE_NAME, + ddl_statements=DDL_STATEMENTS, + pool=BurstyPool(labels={"testcase": "database_api"}), + ) + cls._db.create().result(30) # raises on failure / timeout. + + @classmethod + def tearDownClass(cls): + """Delete the test database.""" + cls._db.drop() def tearDown(self): - # TODO: Implement this method - pass + """Clear the test table after every test.""" + self._db.run_in_transaction(clear_table) + + def test_commit(self): + """Test committing a transaction with several statements.""" + want_row = ( + 1, + "updated-first-name", + "last-name", + "test.email_updated@domen.ru", + ) + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + # execute several DML statements within one transaction + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + cursor.execute( + """ +UPDATE contacts +SET email = 'test.email_updated@domen.ru' +WHERE email = 'test.email@domen.ru' +""" + ) + conn.commit() + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_rollback(self): + """Test rollbacking a transaction with several statements.""" + want_row = (2, "first-name", "last-name", "test.email@domen.ru") + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + conn.commit() + + # execute several DMLs with one transaction + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + cursor.execute( + """ +UPDATE contacts +SET email = 'test.email_updated@domen.ru' +WHERE email = 'test.email@domen.ru' +""" + ) + conn.rollback() + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_autocommit_mode_change(self): + """Test auto committing a transaction on `autocommit` mode change.""" + want_row = ( + 2, + "updated-first-name", + "last-name", + "test.email@domen.ru", + ) + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (2, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + conn.autocommit = True + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + + def test_rollback_on_connection_closing(self): + """ + When closing a connection all the pending transactions + must be rollbacked. Testing if it's working this way. + """ + want_row = (1, "first-name", "last-name", "test.email@domen.ru") + # connect to the test database + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + cursor.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@domen.ru') + """ + ) + conn.commit() + + cursor.execute( + """ +UPDATE contacts +SET first_name = 'updated-first-name' +WHERE first_name = 'first-name' +""" + ) + conn.close() + + # connect again, as the previous connection is no-op after closing + conn = Connection(Config.INSTANCE, self._db) + cursor = conn.cursor() + + # read the resulting data from the database + cursor.execute("SELECT * FROM contacts") + got_rows = cursor.fetchall() + conn.commit() + + self.assertEqual(got_rows, [want_row]) + + cursor.close() + conn.close() + - def test_api(self): - # An dummy stub to avoid `exit code 5` errors - # TODO: Replace this with an actual system test method - self.assertTrue(True) +def clear_table(transaction): + """Clear the test table.""" + transaction.execute_update("DELETE FROM contacts WHERE true") diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 2cbd6ac1ed..d545472c57 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -7,6 +7,7 @@ """Cloud Spanner DB-API Connection class unit tests.""" import unittest +import warnings from unittest import mock @@ -33,6 +34,100 @@ def _get_client_info(self): return ClientInfo(user_agent=self.USER_AGENT) + def _make_connection(self): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_v1.instance import Instance + + # We don't need a real Client object to test the constructor + instance = Instance(self.INSTANCE, client=None) + database = instance.database(self.DATABASE) + return Connection(instance, database) + + def test_property_autocommit_setter(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = True + mock_commit.assert_called_once_with() + self.assertEqual(connection._autocommit, True) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = False + mock_commit.assert_not_called() + self.assertEqual(connection._autocommit, False) + + def test_property_database(self): + from google.cloud.spanner_v1.database import Database + + connection = self._make_connection() + self.assertIsInstance(connection.database, Database) + self.assertEqual(connection.database, connection._database) + + def test_property_instance(self): + from google.cloud.spanner_v1.instance import Instance + + connection = self._make_connection() + self.assertIsInstance(connection.instance, Instance) + self.assertEqual(connection.instance, connection._instance) + + def test__session_checkout(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch( + "google.cloud.spanner_v1.database.Database", + ) as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.get = mock.MagicMock( + return_value="db_session_pool" + ) + connection = Connection(self.INSTANCE, mock_database) + + connection._session_checkout() + mock_database._pool.get.assert_called_once_with() + self.assertEqual(connection._session, "db_session_pool") + + connection._session = "db_session" + connection._session_checkout() + self.assertEqual(connection._session, "db_session") + + def test__release_session(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch( + "google.cloud.spanner_v1.database.Database", + ) as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.put = mock.MagicMock() + connection = Connection(self.INSTANCE, mock_database) + connection._session = "session" + + connection._release_session() + mock_database._pool.put.assert_called_once_with("session") + self.assertIsNone(connection._session) + + def test_transaction_checkout(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + connection._session_checkout = mock_checkout = mock.MagicMock( + autospec=True + ) + connection.transaction_checkout() + mock_checkout.assert_called_once_with() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + self.assertEqual(connection.transaction_checkout(), mock_transaction) + + connection._autocommit = True + self.assertIsNone(connection.transaction_checkout()) + def test_close(self): from google.cloud.spanner_dbapi import connect, InterfaceError @@ -53,32 +148,73 @@ def test_close(self): with self.assertRaises(InterfaceError): connection.cursor() - def test_commit(self): - from google.cloud.spanner_dbapi import Connection, InterfaceError + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + mock_transaction.rollback = mock_rollback = mock.MagicMock() + connection.close() + mock_rollback.assert_called_once_with() + + @mock.patch.object(warnings, "warn") + def test_commit(self, mock_warn): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import ( + AUTOCOMMIT_MODE_WARNING, + ) connection = Connection(self.INSTANCE, self.DATABASE) with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection.run_prior_DDL_statements" - ) as run_ddl_mock: + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: connection.commit() - run_ddl_mock.assert_called_once_with() + mock_release.assert_not_called() - connection.is_closed = True + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.commit = mock_commit = mock.MagicMock() - with self.assertRaises(InterfaceError): + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: connection.commit() + mock_commit.assert_called_once_with() + mock_release.assert_called_once_with() - def test_rollback(self): + connection._autocommit = True + connection.commit() + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) + + @mock.patch.object(warnings, "warn") + def test_rollback(self, mock_warn): from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import ( + AUTOCOMMIT_MODE_WARNING, + ) connection = Connection(self.INSTANCE, self.DATABASE) with mock.patch( - "google.cloud.spanner_dbapi.connection.Connection._raise_if_closed" - ) as check_closed_mock: + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.rollback() + mock_release.assert_not_called() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.rollback = mock_rollback = mock.MagicMock() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: connection.rollback() - check_closed_mock.assert_called_once_with() + mock_rollback.assert_called_once_with() + mock_release.assert_called_once_with() + + connection._autocommit = True + connection.rollback() + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) def test_run_prior_DDL_statements(self): from google.cloud.spanner_dbapi import Connection, InterfaceError @@ -92,7 +228,7 @@ def test_run_prior_DDL_statements(self): mock_database.update_ddl.assert_not_called() ddl = ["ddl"] - connection.ddl_statements = ddl + connection._ddl_statements = ddl connection.run_prior_DDL_statements() mock_database.update_ddl.assert_called_once_with(ddl) @@ -151,3 +287,32 @@ def test_connect_database_not_found(self): ): with self.assertRaises(ValueError): connect("test-instance", "test-database") + + def test_default_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertIsNotNone(connection.database._pool) + + def test_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.pool import FixedSizePool + + database_id = "test-database" + pool = FixedSizePool() + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connect("test-instance", database_id, pool=pool) + database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 337a645736..09288df94e 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -122,6 +122,18 @@ def test_execute_attribute_error(self): with self.assertRaises(AttributeError): cursor.execute(sql="") + def test_execute_autocommit_off(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.connection._autocommit = False + cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) + + cursor.execute("sql") + self.assertIsInstance(cursor._result_set, mock.MagicMock) + self.assertIsInstance(cursor._itr, PeekIterator) + def test_execute_statement(self): from google.cloud.spanner_dbapi import parse_utils @@ -135,7 +147,7 @@ def test_execute_statement(self): sql = "sql" cursor.execute(sql=sql) mock_classify_stmt.assert_called_once_with(sql) - self.assertEqual(cursor.connection.ddl_statements, [sql]) + self.assertEqual(cursor.connection._ddl_statements, [sql]) with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", @@ -145,6 +157,7 @@ def test_execute_statement(self): "google.cloud.spanner_dbapi.cursor.Cursor._handle_DQL", return_value=parse_utils.STMT_NON_UPDATING, ) as mock_handle_ddl: + connection.autocommit = True sql = "sql" cursor.execute(sql=sql) mock_handle_ddl.assert_called_once_with(sql, None) @@ -163,6 +176,18 @@ def test_execute_statement(self): connection, sql, None ) + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value="other_statement", + ): + cursor.connection._database = mock_db = mock.MagicMock() + mock_db.run_in_transaction = mock_run_in = mock.MagicMock() + sql = "sql" + cursor.execute(sql=sql) + mock_run_in.assert_called_once_with( + cursor._do_execute_update, sql, None + ) + def test_execute_integrity_error(self): from google.api_core import exceptions from google.cloud.spanner_dbapi.exceptions import IntegrityError