Skip to content

Commit

Permalink
[AIRFLOW-6896] AzureCosmosDBHook: Move DB call out of __init__ (#7520)
Browse files Browse the repository at this point in the history
* [AIRFLOW-6896] AzureCosmosDBHook: Move DB call out of __init__

* Fix tests

* fixup

* Fix test
  • Loading branch information
kaxil committed Feb 24, 2020
1 parent 4bec1cc commit f0e2421
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 27 deletions.
29 changes: 16 additions & 13 deletions airflow/providers/microsoft/azure/hooks/azure_cosmos.py
Expand Up @@ -25,7 +25,7 @@
"""
import uuid

import azure.cosmos.cosmos_client as cosmos_client
from azure.cosmos.cosmos_client import CosmosClient
from azure.cosmos.errors import HTTPFailure

from airflow.exceptions import AirflowBadRequest
Expand All @@ -46,28 +46,30 @@ class AzureCosmosDBHook(BaseHook):

def __init__(self, azure_cosmos_conn_id='azure_cosmos_default'):
self.conn_id = azure_cosmos_conn_id
self.connection = self.get_connection(self.conn_id)
self.extras = self.connection.extra_dejson
self._conn = None

self.endpoint_uri = self.connection.login
self.master_key = self.connection.password
self.default_database_name = self.extras.get('database_name')
self.default_collection_name = self.extras.get('collection_name')
self.cosmos_client = None
self.default_database_name = None
self.default_collection_name = None

def get_conn(self):
"""
Return a cosmos db client.
"""
if self.cosmos_client is not None:
return self.cosmos_client
if not self._conn:
conn = self.get_connection(self.conn_id)
extras = conn.extra_dejson
endpoint_uri = conn.login
master_key = conn.password

# Initialize the Python Azure Cosmos DB client
self.cosmos_client = cosmos_client.CosmosClient(self.endpoint_uri, {'masterKey': self.master_key})
self.default_database_name = extras.get('database_name')
self.default_collection_name = extras.get('collection_name')

return self.cosmos_client
# Initialize the Python Azure Cosmos DB client
self._conn = CosmosClient(endpoint_uri, {'masterKey': master_key})
return self._conn

def __get_database_name(self, database_name=None):
self.get_conn()
db_name = database_name
if db_name is None:
db_name = self.default_database_name
Expand All @@ -78,6 +80,7 @@ def __get_database_name(self, database_name=None):
return db_name

def __get_collection_name(self, collection_name=None):
self.get_conn()
coll_name = collection_name
if coll_name is None:
coll_name = self.default_collection_name
Expand Down
33 changes: 20 additions & 13 deletions tests/providers/microsoft/azure/hooks/test_azure_cosmos.py
Expand Up @@ -24,6 +24,7 @@
import uuid

import mock
from azure.cosmos.cosmos_client import CosmosClient

from airflow.exceptions import AirflowException
from airflow.models import Connection
Expand Down Expand Up @@ -53,25 +54,31 @@ def setUp(self):
)
)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient', autospec=True)
def test_client(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
self.assertIsNone(hook._conn)
self.assertIsInstance(hook.get_conn(), CosmosClient)

@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_create_database(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.create_database(self.test_database_name)
expected_calls = [mock.call().CreateDatabase({'id': self.test_database_name})]
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_create_database_exception(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
self.assertRaises(AirflowException, hook.create_database, None)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_create_container_exception(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
self.assertRaises(AirflowException, hook.create_collection, None)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_create_container(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.create_collection(self.test_collection_name, self.test_database_name)
Expand All @@ -81,7 +88,7 @@ def test_create_container(self, mock_cosmos):
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_create_container_default(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.create_collection(self.test_collection_name)
Expand All @@ -91,7 +98,7 @@ def test_create_container_default(self, mock_cosmos):
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_upsert_document_default(self, mock_cosmos):
test_id = str(uuid.uuid4())
mock_cosmos.return_value.CreateItem.return_value = {'id': test_id}
Expand All @@ -105,7 +112,7 @@ def test_upsert_document_default(self, mock_cosmos):
logging.getLogger().info(returned_item)
self.assertEqual(returned_item['id'], test_id)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_upsert_document(self, mock_cosmos):
test_id = str(uuid.uuid4())
mock_cosmos.return_value.CreateItem.return_value = {'id': test_id}
Expand All @@ -125,7 +132,7 @@ def test_upsert_document(self, mock_cosmos):
logging.getLogger().info(returned_item)
self.assertEqual(returned_item['id'], test_id)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_insert_documents(self, mock_cosmos):
test_id1 = str(uuid.uuid4())
test_id2 = str(uuid.uuid4())
Expand All @@ -149,17 +156,17 @@ def test_insert_documents(self, mock_cosmos):
{'data': 'data3', 'id': test_id3})]
logging.getLogger().info(returned_item)
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)
mock_cosmos.assert_has_calls(expected_calls, any_order=True)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_delete_database(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.delete_database(self.test_database_name)
expected_calls = [mock.call().DeleteDatabase('dbs/test_database_name')]
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_delete_database_exception(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
self.assertRaises(AirflowException, hook.delete_database, None)
Expand All @@ -169,15 +176,15 @@ def test_delete_container_exception(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
self.assertRaises(AirflowException, hook.delete_collection, None)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_delete_container(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.delete_collection(self.test_collection_name, self.test_database_name)
expected_calls = [mock.call().DeleteContainer('dbs/test_database_name/colls/test_collection_name')]
mock_cosmos.assert_any_call(self.test_end_point, {'masterKey': self.test_master_key})
mock_cosmos.assert_has_calls(expected_calls)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_delete_container_default(self, mock_cosmos):
hook = AzureCosmosDBHook(azure_cosmos_conn_id='azure_cosmos_test_key_id')
hook.delete_collection(self.test_collection_name)
Expand Down
Expand Up @@ -49,7 +49,7 @@ def setUp(self):
)
)

@mock.patch('azure.cosmos.cosmos_client.CosmosClient')
@mock.patch('airflow.providers.microsoft.azure.hooks.azure_cosmos.CosmosClient')
def test_insert_document(self, cosmos_mock):
test_id = str(uuid.uuid4())
cosmos_mock.return_value.CreateItem.return_value = {'id': test_id}
Expand Down

0 comments on commit f0e2421

Please sign in to comment.