Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions airflow/providers/google/cloud/example_dags/example_bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,22 @@
from airflow.utils.dates import days_ago

GCP_PROJECT_ID = getenv('GCP_PROJECT_ID', 'example-project')
CBT_INSTANCE_ID = getenv('CBT_INSTANCE_ID', 'some-instance-id')
CBT_INSTANCE_DISPLAY_NAME = getenv('CBT_INSTANCE_DISPLAY_NAME', 'Human-readable name')
CBT_INSTANCE_ID = getenv('GCP_BIG_TABLE_INSTANCE_ID', 'some-instance-id')
CBT_INSTANCE_DISPLAY_NAME = getenv('GCP_BIG_TABLE_INSTANCE_DISPLAY_NAME', 'Human-readable name')
CBT_INSTANCE_DISPLAY_NAME_UPDATED = getenv(
"CBT_INSTANCE_DISPLAY_NAME_UPDATED", "Human-readable name - updated"
"GCP_BIG_TABLE_INSTANCE_DISPLAY_NAME_UPDATED", f"{CBT_INSTANCE_DISPLAY_NAME} - updated"
)
CBT_INSTANCE_TYPE = getenv('CBT_INSTANCE_TYPE', '2')
CBT_INSTANCE_TYPE_PROD = getenv('CBT_INSTANCE_TYPE_PROD', '1')
CBT_INSTANCE_LABELS = getenv('CBT_INSTANCE_LABELS', '{}')
CBT_INSTANCE_LABELS_UPDATED = getenv('CBT_INSTANCE_LABELS', '{"env": "prod"}')
CBT_CLUSTER_ID = getenv('CBT_CLUSTER_ID', 'some-cluster-id')
CBT_CLUSTER_ZONE = getenv('CBT_CLUSTER_ZONE', 'europe-west1-b')
CBT_CLUSTER_NODES = getenv('CBT_CLUSTER_NODES', '3')
CBT_CLUSTER_NODES_UPDATED = getenv('CBT_CLUSTER_NODES_UPDATED', '5')
CBT_CLUSTER_STORAGE_TYPE = getenv('CBT_CLUSTER_STORAGE_TYPE', '2')
CBT_TABLE_ID = getenv('CBT_TABLE_ID', 'some-table-id')
CBT_POKE_INTERVAL = getenv('CBT_POKE_INTERVAL', '60')
CBT_INSTANCE_TYPE = getenv('GCP_BIG_TABLE_INSTANCE_TYPE', '2')
CBT_INSTANCE_TYPE_PROD = getenv('GCP_BIG_TABLE_INSTANCE_TYPE_PROD', '1')
CBT_INSTANCE_LABELS = getenv('GCP_BIG_TABLE_INSTANCE_LABELS', '{}')
CBT_INSTANCE_LABELS_UPDATED = getenv('GCP_BIG_TABLE_INSTANCE_LABELS_UPDATED', '{"env": "prod"}')
CBT_CLUSTER_ID = getenv('GCP_BIG_TABLE_CLUSTER_ID', 'some-cluster-id')
CBT_CLUSTER_ZONE = getenv('GCP_BIG_TABLE_CLUSTER_ZONE', 'europe-west1-b')
CBT_CLUSTER_NODES = getenv('GCP_BIG_TABLE_CLUSTER_NODES', '3')
CBT_CLUSTER_NODES_UPDATED = getenv('GCP_BIG_TABLE_CLUSTER_NODES_UPDATED', '5')
CBT_CLUSTER_STORAGE_TYPE = getenv('GCP_BIG_TABLE_CLUSTER_STORAGE_TYPE', '2')
CBT_TABLE_ID = getenv('GCP_BIG_TABLE_TABLE_ID', 'some-table-id')
CBT_POKE_INTERVAL = getenv('GCP_BIG_TABLE_POKE_INTERVAL', '60')


with models.DAG(
Expand All @@ -93,8 +93,8 @@
instance_display_name=CBT_INSTANCE_DISPLAY_NAME,
instance_type=int(CBT_INSTANCE_TYPE),
instance_labels=json.loads(CBT_INSTANCE_LABELS),
cluster_nodes=int(CBT_CLUSTER_NODES),
cluster_storage_type=CBT_CLUSTER_STORAGE_TYPE,
cluster_nodes=None,
cluster_storage_type=int(CBT_CLUSTER_STORAGE_TYPE),
task_id='create_instance_task',
)
create_instance_task2 = BigtableCreateInstanceOperator(
Expand Down
9 changes: 8 additions & 1 deletion airflow/providers/google/cloud/hooks/bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,14 @@ def create_instance(
instance_labels,
)

clusters = [instance.cluster(main_cluster_id, main_cluster_zone, cluster_nodes, cluster_storage_type)]
cluster_kwargs = dict(
cluster_id=main_cluster_id,
location_id=main_cluster_zone,
default_storage_type=cluster_storage_type,
)
if instance_type != enums.Instance.Type.DEVELOPMENT and cluster_nodes:
cluster_kwargs["serve_nodes"] = cluster_nodes
clusters = [instance.cluster(**cluster_kwargs)]
if replica_cluster_id and replica_cluster_zone:
warnings.warn(
"The replica_cluster_id and replica_cluster_zone parameter have been deprecated."
Expand Down
58 changes: 55 additions & 3 deletions tests/providers/google/cloud/hooks/test_bigtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_create_instance(self, get_client, instance_create, mock_project_id):
@mock.patch('google.cloud.bigtable.instance.Instance.cluster')
@mock.patch('google.cloud.bigtable.instance.Instance.create')
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
def test_create_instance_with_one_replica_cluster(
def test_create_instance_with_one_replica_cluster_production(
self, get_client, instance_create, cluster, mock_project_id
):
operation = mock.Mock()
Expand All @@ -325,10 +325,57 @@ def test_create_instance_with_one_replica_cluster(
cluster_nodes=1,
cluster_storage_type=enums.StorageType.SSD,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
instance_type=enums.Instance.Type.PRODUCTION,
)
cluster.assert_has_calls(
[
unittest.mock.call(CBT_CLUSTER, CBT_ZONE, 1, enums.StorageType.SSD),
unittest.mock.call(
cluster_id=CBT_CLUSTER,
location_id=CBT_ZONE,
serve_nodes=1,
default_storage_type=enums.StorageType.SSD,
),
unittest.mock.call(
CBT_REPLICA_CLUSTER_ID, CBT_REPLICA_CLUSTER_ZONE, 1, enums.StorageType.SSD
),
],
any_order=True,
)
get_client.assert_called_once_with(project_id='example-project')
instance_create.assert_called_once_with(clusters=mock.ANY)
assert res.instance_id == 'instance'

@mock.patch(
'airflow.providers.google.common.hooks.base_google.GoogleBaseHook.project_id',
new_callable=PropertyMock,
return_value=GCP_PROJECT_ID_HOOK_UNIT_TEST,
)
@mock.patch('google.cloud.bigtable.instance.Instance.cluster')
@mock.patch('google.cloud.bigtable.instance.Instance.create')
@mock.patch('airflow.providers.google.cloud.hooks.bigtable.BigtableHook._get_client')
def test_create_instance_with_one_replica_cluster_development(
self, get_client, instance_create, cluster, mock_project_id
):
operation = mock.Mock()
operation.result_return_value = Instance(instance_id=CBT_INSTANCE, client=get_client)
instance_create.return_value = operation

res = self.bigtable_hook_default_project_id.create_instance(
instance_id=CBT_INSTANCE,
main_cluster_id=CBT_CLUSTER,
main_cluster_zone=CBT_ZONE,
replica_cluster_id=CBT_REPLICA_CLUSTER_ID,
replica_cluster_zone=CBT_REPLICA_CLUSTER_ZONE,
cluster_nodes=1,
cluster_storage_type=enums.StorageType.SSD,
project_id=GCP_PROJECT_ID_HOOK_UNIT_TEST,
instance_type=enums.Instance.Type.DEVELOPMENT,
)
cluster.assert_has_calls(
[
unittest.mock.call(
cluster_id=CBT_CLUSTER, location_id=CBT_ZONE, default_storage_type=enums.StorageType.SSD
),
unittest.mock.call(
CBT_REPLICA_CLUSTER_ID, CBT_REPLICA_CLUSTER_ZONE, 1, enums.StorageType.SSD
),
Expand Down Expand Up @@ -365,7 +412,12 @@ def test_create_instance_with_multiple_replica_clusters(
)
cluster.assert_has_calls(
[
unittest.mock.call(CBT_CLUSTER, CBT_ZONE, 1, enums.StorageType.SSD),
unittest.mock.call(
cluster_id=CBT_CLUSTER,
location_id=CBT_ZONE,
serve_nodes=1,
default_storage_type=enums.StorageType.SSD,
),
unittest.mock.call('replica-1', 'us-west1-a', 1, enums.StorageType.SSD),
unittest.mock.call('replica-2', 'us-central1-f', 1, enums.StorageType.SSD),
unittest.mock.call('replica-3', 'us-east1-d', 1, enums.StorageType.SSD),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os

import pytest

from airflow.providers.google.cloud.example_dags.example_bigtable import CBT_INSTANCE_ID, GCP_PROJECT_ID
from tests.providers.google.cloud.utils.gcp_authenticator import GCP_BIGTABLE_KEY
from tests.test_utils.gcp_system_helpers import CLOUD_DAG_FOLDER, GoogleSystemTest, provide_gcp_context

GCP_PROJECT_ID = os.environ.get('GCP_PROJECT_ID', 'example-project')
CBT_INSTANCE = os.environ.get('CBT_INSTANCE_ID', 'testinstance')


@pytest.mark.backend("mysql", "postgres")
@pytest.mark.credential_file(GCP_BIGTABLE_KEY)
Expand All @@ -45,7 +42,7 @@ def tearDown(self):
'--verbosity=none',
'instances',
'delete',
CBT_INSTANCE,
CBT_INSTANCE_ID,
],
key=GCP_BIGTABLE_KEY,
)
Expand Down