In [None]:
# stdlib
import os
import sys

# Get the current script's directory
current_dir = os.path.dirname(os.path.abspath("."))

# Get the parent directory (one level up)
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))

# Add the parent directory to the system path
sys.path.insert(0, current_dir)

# set to use the live APIs
# import os
# os.environ["TEST_BIGQUERY_APIS_LIVE"] = "True"
# third party
from apis import make_schema
from apis import make_submit_query
from apis import make_test_query

In [None]:
# syft absolute
import syft as sy
from syft import test_settings

In [None]:
server_low = sy.orchestra.launch(
    name="bigquery-low",
    server_side_type="low",
    dev_mode=True,
    n_consumers=1,
    create_producer=True,
    port="auto",
)

server_high = sy.orchestra.launch(
    name="bigquery-high",
    server_side_type="high",
    dev_mode=True,
    n_consumers=1,
    create_producer=True,
    port="auto",
)

In [None]:
low_client = server_low.login(email="info@openmined.org", password="changethis")
high_client = server_high.login(email="info@openmined.org", password="changethis")

In [None]:
assert len(high_client.worker_pools.get_all()) == 2
assert len(low_client.worker_pools.get_all()) == 2

In [None]:
this_worker_pool_name = "bigquery-pool"

In [None]:
# !pip list | grep bigquery

In [None]:
# !pip install db-dtypes google-cloud-bigquery

# Twin endpoints

In [None]:
mock_func = make_test_query(
    settings={
        "rate_limiter_enabled": True,
        "calls_per_min": 10,
    }
)

In [None]:
private_func = make_test_query(
    settings={
        "rate_limiter_enabled": False,
    }
)

In [None]:
new_endpoint = sy.TwinAPIEndpoint(
    path="bigquery.test_query",
    description="This endpoint allows to query Bigquery storage via SQL queries.",
    private_function=private_func,
    mock_function=mock_func,
    worker_pool=this_worker_pool_name,
)

high_client.custom_api.add(endpoint=new_endpoint)

In [None]:
# Here, we update the endpoint to timeout after 100s (rather the default of 60s)
high_client.api.services.api.update(
    endpoint_path="bigquery.test_query", endpoint_timeout=120
)

In [None]:
high_client.api.services.api.update(
    endpoint_path="bigquery.test_query", hide_mock_definition=True
)

In [None]:
schema_function = make_schema(
    settings={
        "calls_per_min": 5,
    },
    worker_pool=this_worker_pool_name,
)

In [None]:
high_client.custom_api.add(endpoint=schema_function)
high_client.refresh()

In [None]:
dataset_1 = test_settings.get("dataset_1", default="dataset_1")
dataset_2 = test_settings.get("dataset_2", default="dataset_2")
table_1 = test_settings.get("table_1", default="table_1")
table_2 = test_settings.get("table_2", default="table_2")
table_2_col_id = test_settings.get("table_2_col_id", default="table_id")
table_2_col_score = test_settings.get("table_2_col_score", default="colname")

In [None]:
# Test mock version
result = high_client.api.services.bigquery.test_query.mock(
    sql_query=f"SELECT * FROM {dataset_1}.{table_1} LIMIT 10"
)
result

In [None]:
high_client.api.services.bigquery.schema()

In [None]:
submit_query_function = make_submit_query(
    settings={}, worker_pool=this_worker_pool_name
)

In [None]:
high_client.custom_api.add(endpoint=submit_query_function)

In [None]:
high_client.api.services.api.update(
    endpoint_path="bigquery.submit_query", hide_mock_definition=True
)

In [None]:
high_client.custom_api.api_endpoints()

In [None]:
assert len(high_client.custom_api.api_endpoints()) == 3

In [None]:
assert (
    high_client.api.services.bigquery.test_query
    and high_client.api.services.bigquery.submit_query
)

In [None]:
# Test mock version
result = high_client.api.services.bigquery.test_query.mock(
    sql_query=f"SELECT * FROM {dataset_1}.{table_1} LIMIT 10"
)
result

In [None]:
# Bug with the new Error PR: message printed multiple times. TODO clean up the duplicate exception messages.

# Test mock version for wrong queries
with sy.raises(
    sy.SyftException(public_message="*must be qualified with a dataset*"), show=True
):
    high_client.api.services.bigquery.test_query.mock(
        sql_query="SELECT * FROM invalid_table LIMIT 1"
    )

In [None]:
# Test private version
result = high_client.api.services.bigquery.test_query.private(
    sql_query=f"SELECT * FROM {dataset_1}.{table_1} LIMIT 10"
)
result

In [None]:
# Testing submit query
result = high_client.api.services.bigquery.submit_query(
    func_name="my_func",
    query=f"SELECT * FROM {dataset_1}.{table_1} LIMIT 1",
)

In [None]:
assert "Query submitted" in result
result

In [None]:
job = high_client.code.my_func_d6d975(blocking=False)

In [None]:
job.result

In [None]:
job.wait()

In [None]:
# syft absolute
from syft.client.syncing import compare_clients
from syft.service.job.job_stash import Job
from syft.service.job.job_stash import JobStatus

In [None]:
def is_job_to_sync(batch):
    if batch.status != "NEW":
        return False
    if not isinstance(batch.root.high_obj, Job):
        return False
    job = batch.root.high_obj
    return job.status in (JobStatus.ERRORED, JobStatus.COMPLETED)

In [None]:
def sync_new_objects(
    from_client, to_client, dry_run: bool = True, private_data: bool = False
):
    sim = "Simulating " if dry_run else ""
    priv = "WITH PRIVATE DATA" if private_data else ""
    print(f"{sim}Syncing from {from_client.name} to {to_client.name} {priv}")
    changes = []
    diff = compare_clients(
        from_client=from_client, to_client=to_client, hide_usercode=False
    )
    if isinstance(diff, sy.SyftError):
        return diff

    for batch in diff.batches:
        try:
            if is_job_to_sync(batch) or batch.status == "NEW":
                w = batch.resolve(build_state=False)
                if private_data:
                    w.click_share_all_private_data()
                if not dry_run:
                    w.click_sync()
                change_text = f"Synced {batch.status} {batch.root_type.__name__}"
                if not dry_run:
                    changes.append(change_text)
                else:
                    print(f"Would have run: {change_text}")
        except Exception as e:
            print("sync_new_objects", e)
            raise e
    return changes

In [None]:
result = sync_new_objects(high_client, low_client)
result

In [None]:
result = sync_new_objects(high_client, low_client, dry_run=False)
result

In [None]:
assert [
    "Synced NEW UserCode",
    "Synced NEW Request",
    "Synced NEW Job",
    "Synced NEW TwinAPIEndpoint",
    "Synced NEW TwinAPIEndpoint",
    "Synced NEW TwinAPIEndpoint",
] == result

In [None]:
# widget = sy.sync(from_client=high_client, to_client=low_client, hide_usercode=False)

In [None]:
# # TODO: ignore private function from high side in diff
# widget

In [None]:
# widget.click_sync(0)
# widget.click_sync(1)
# widget.click_sync(2)

In [None]:
# Some internal helper methods

# widget._share_all()
# widget._sync_all()

In [None]:
server_high.land()
server_low.land()