Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add row_count on Table #1074

Merged
merged 8 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python-sdk/src/astro/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any

from attr import define, field, fields_dict
from sqlalchemy import Column, MetaData
from sqlalchemy import Column, MetaData, func, select

from astro.airflow.datasets import Dataset
from astro.databases import create_database
Expand Down Expand Up @@ -129,10 +129,12 @@ def name(self, value: str) -> None:
@property
def row_count(self) -> Any:
"""
pankajastro marked this conversation as resolved.
Show resolved Hide resolved
Return the row count of table
Return the row count of table.
"""
# TODO: Implement this property
return 0
db = create_database(self.conn_id)
tb = db.get_sqla_table(table=self)
pankajastro marked this conversation as resolved.
Show resolved Hide resolved
query = select(func.count("*")).select_from(tb)
return db.run_sql(query).scalar()

def to_json(self):
return {
Expand Down
55 changes: 42 additions & 13 deletions python-sdk/tests/extractors/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from astro.sql import AppendOperator, MergeOperator
from astro.sql.operators.load_file import LoadFileOperator
from astro.table import Metadata, Table
from tests.utils.airflow import create_context

TEST_FILE_LOCATION = "gs://astro-sdk/workspace/sample_pattern"
TEST_TABLE = "test-table"
Expand Down Expand Up @@ -98,12 +99,24 @@ def test_append_op_extract_on_complete():
Test extractor ``extract_on_complete`` get called and collect lineage for append operator
"""
task_id = "append_table"
src = Table(conn_id="bigquery", metadata=Metadata(schema="astro"))
target = Table(conn_id="bigquery", metadata=Metadata(schema="astro"))

src_table = LoadFileOperator(
task_id="load_file",
input_file=File(path="gs://astro-sdk/workspace/sample_pattern", filetype=FileType.CSV),
output_table=Table(conn_id="gcp_conn"),
).execute({})

target_table = LoadFileOperator(
task_id="load_file",
input_file=File(path="gs://astro-sdk/workspace/sample_pattern", filetype=FileType.CSV),
output_table=Table(conn_id="gcp_conn"),
).execute({})

op = AppendOperator(
source_table=src,
target_table=target,
source_table=src_table,
target_table=target_table,
)

tzinfo = pendulum.timezone("UTC")
execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo)
task_instance = TaskInstance(task=op, run_id=execution_date)
Expand All @@ -114,8 +127,7 @@ def test_append_op_extract_on_complete():

task_meta = python_sdk_extractor.extract_on_complete(task_instance)
assert task_meta.name == f"adhoc_airflow.{task_id}"

assert task_meta.inputs[0].name == f"astronomer-dag-authoring.astro.{src.name}"
assert task_meta.inputs[0].name == f"astronomer-dag-authoring.astronomer-dag-authoring.{src_table.name}"
assert task_meta.inputs[0].namespace == "bigquery"
assert task_meta.inputs[0].facets is not None
assert len(task_meta.job_facets) > 0
Expand All @@ -128,11 +140,20 @@ def test_merge_op_extract_on_complete():
Test extractor ``extract_on_complete`` get called and collect lineage for merge operator
"""
task_id = "merge"
src = Table(conn_id="bigquery", metadata=Metadata(schema="astro"))
target = Table(conn_id="bigquery", metadata=Metadata(schema="astro"))
src_table = LoadFileOperator(
task_id="load_file",
input_file=File(path="gs://astro-sdk/workspace/sample_pattern", filetype=FileType.CSV),
output_table=Table(conn_id="gcp_conn", metadata=Metadata(schema="astro")),
).execute({})

target_table = LoadFileOperator(
task_id="load_file",
input_file=File(path="gs://astro-sdk/workspace/sample_pattern", filetype=FileType.CSV),
output_table=Table(conn_id="gcp_conn", metadata=Metadata(schema="astro")),
).execute({})
op = MergeOperator(
source_table=src,
target_table=target,
source_table=src_table,
target_table=target_table,
target_conflict_columns=["id"],
columns=["id", "name"],
if_conflicts="update",
Expand All @@ -147,7 +168,7 @@ def test_merge_op_extract_on_complete():

task_meta = python_sdk_extractor.extract_on_complete(task_instance)
assert task_meta.name == f"adhoc_airflow.{task_id}"
assert task_meta.inputs[0].name == f"astronomer-dag-authoring.astro.{src.name}"
assert task_meta.inputs[0].name == f"astronomer-dag-authoring.astro.{src_table.name}"
assert task_meta.inputs[0].namespace == "bigquery"
assert task_meta.inputs[0].facets is not None
assert len(task_meta.job_facets) > 0
Expand All @@ -161,8 +182,15 @@ def test_python_sdk_transform_extract_on_complete():
operator's metadata that needs to be extracted as per OpenLineage
for TransformOperator.
"""
imdb_table = (Table(name="imdb", conn_id="sqlite_default"),)
output_table = Table(name="test_name", conn_id="sqlite_default")
imdb_table = LoadFileOperator(
task_id="load_file",
input_file=File(
path="https://raw.githubusercontent.com/astronomer/astro-sdk/main/tests/data/imdb_v2.csv"
),
output_table=Table(conn_id="gcp_conn", metadata=Metadata(schema="astro")),
).execute({})

output_table = Table(name="test_name", conn_id="gcp_conn", metadata=Metadata(schema="astro"))
task_id = "top_five_animations"

@aql.transform
Expand All @@ -171,6 +199,7 @@ def top_five_animations(input_table: Table) -> str:

task = top_five_animations(input_table=imdb_table, output_table=output_table)

task.operator.execute(context=create_context(task.operator))
tzinfo = pendulum.timezone("UTC")
execution_date = timezone.datetime(2022, 1, 1, 1, 0, 0, tzinfo=tzinfo)
task_instance = TaskInstance(task=task.operator, run_id=execution_date)
Expand Down
20 changes: 19 additions & 1 deletion python-sdk/tests/utils/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import pytest

from astro.files import File
from astro.sql import LoadFileOperator
from astro.sql.operators.transform import TransformOperator
from astro.table import BaseTable, Table
from astro.table import BaseTable, Metadata, Table
from astro.utils.table import find_first_table


Expand Down Expand Up @@ -89,3 +91,19 @@ def test_find_first_table(kwargs, return_type):
@mock.patch("airflow.models.xcom_arg.PlainXComArg.resolve", return_value=Table())
def test_find_first_table_with_xcom_arg(xcom_arg_resolve, kwargs, return_type):
assert isinstance(find_first_table(context={}, **kwargs), return_type)


@pytest.mark.integration
def test_row_count():
"""
Load file in bigquery and test the row count of bigquery table
"""
imdb_table = LoadFileOperator(
task_id="load_file",
pankajastro marked this conversation as resolved.
Show resolved Hide resolved
input_file=File(
path="https://raw.githubusercontent.com/astronomer/astro-sdk/main/tests/data/imdb_v2.csv"
),
output_table=Table(conn_id="gcp_conn", metadata=Metadata(schema="astro")),
).execute({})

assert imdb_table.row_count > 0