In [None]:
# stdlib
import os

# third party
# set to use the live APIs
# os.environ["TEST_BIGQUERY_APIS_LIVE"] = "True"
import pandas as pd

# syft absolute
import syft as sy
from syft import test_settings
from syft.client.syncing import compare_clients
from syft.util.test_helpers.apis import make_schema
from syft.util.test_helpers.apis import make_submit_query
from syft.util.test_helpers.apis import make_test_query
from syft.util.test_helpers.email_helpers import get_email_server

In [None]:
ADMIN_EMAIL, ADMIN_PW = "admin2@bigquery.org", "bqpw2"
ROOT_EMAIL, ROOT_PW = "admin@bigquery.org", "bqpw"
environment = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", "python")
high_port = os.environ.get("CLUSTER_HTTP_PORT_HIGH", "9081")
low_port = os.environ.get("CLUSTER_HTTP_PORT_LOW", "9083")
print(environment, high_port, low_port)

# Launch server and login

In [None]:
server_low = sy.orchestra.launch(
    name="bigquery-low",
    server_side_type="low",
    dev_mode=True,
    n_consumers=1,
    create_producer=True,
    port=low_port,
)

server_high = sy.orchestra.launch(
    name="bigquery-high",
    server_side_type="high",
    dev_mode=True,
    n_consumers=1,
    create_producer=True,
    port=high_port,
)

### Email Server

In [None]:
email_server, smtp_server = get_email_server()

In [None]:
smtp_server.controller

In [None]:
low_client = sy.login(
    url=f"http://localhost:{low_port}", email=ADMIN_EMAIL, password=ADMIN_PW
)
high_client = sy.login(
    url=f"http://localhost:{high_port}", email=ADMIN_EMAIL, password=ADMIN_PW
)

In [None]:
assert len(high_client.worker_pools.get_all()) == 2
assert len(low_client.worker_pools.get_all()) == 2

In [None]:
this_worker_pool_name = "bigquery-pool"

# Load database information from test_settings

In [None]:
dataset_1 = test_settings.get("dataset_1", default="dataset_1")
dataset_2 = test_settings.get("dataset_2", default="dataset_2")
table_1 = test_settings.get("table_1", default="table_1")
table_2 = test_settings.get("table_2", default="table_2")
table_2_col_id = test_settings.get("table_2_col_id", default="table_id")
table_2_col_score = test_settings.get("table_2_col_score", default="colname")

# Create and test different endpoints

----

### Create `biquery.schema` endpoint

In [None]:
schema_function = make_schema(
    settings={
        "calls_per_min": 5,
    },
    worker_pool_name=this_worker_pool_name,
)

In [None]:
high_client.custom_api.add(endpoint=schema_function)

In [None]:
result = high_client.api.services.bigquery.schema()

In [None]:
assert len(result) == 23

TODO: Note that when we do not create a job, the type of result is `syft.service.action.pandas.PandasDataFrameObject` and not pandas but the `.get()` method will get you the expected answer

In [None]:
# syft absolute
from syft.service.action.pandas import PandasDataFrameObject

# assert isinstance(result, pd.DataFrame)
assert isinstance(result, PandasDataFrameObject)
assert isinstance(result.get(), pd.DataFrame)

____

### Create `biquery.test_query` endpoint

In [None]:
mock_func = make_test_query(
    settings={
        "rate_limiter_enabled": True,
        "calls_per_min": 10,
    }
)

In [None]:
private_func = make_test_query(
    settings={
        "rate_limiter_enabled": False,
    }
)

In [None]:
new_endpoint = sy.TwinAPIEndpoint(
    path="bigquery.test_query",
    description="This endpoint allows to query Bigquery storage via SQL queries.",
    private_function=private_func,
    mock_function=mock_func,
    worker_pool_name=this_worker_pool_name,
)

high_client.custom_api.add(endpoint=new_endpoint)

#### Some features for updating endpoint

In [None]:
# Here, we update the endpoint to timeout after 100s (rather the default of 60s)
high_client.api.services.api.update(
    endpoint_path="bigquery.test_query", endpoint_timeout=120
)

In [None]:
high_client.api.services.api.update(
    endpoint_path="bigquery.test_query", hide_mock_definition=True
)

#### Test the `bigquery.test_query` endpoint

In [None]:
# Test mock version
result = high_client.api.services.bigquery.test_query.mock(
    sql_query=f"SELECT * FROM {dataset_1}.{table_1} LIMIT 10"
)
result

In [None]:
assert len(result) == 10

In [None]:
# Test mock version for wrong queries
with sy.raises(
    sy.SyftException(public_message="*must be qualified with a dataset*"), show=True
):
    high_client.api.services.bigquery.test_query.mock(
        sql_query="SELECT * FROM invalid_table LIMIT 1"
    )

In [None]:
# Test private version
result = high_client.api.services.bigquery.test_query.private(
    sql_query=f"SELECT * FROM {dataset_1}.{table_1} LIMIT 12"
)
result

In [None]:
assert len(result) == 12

____

### Create `submit_query` endpoint

In [None]:
submit_query_function = make_submit_query(
    settings={}, worker_pool_name=this_worker_pool_name
)

In [None]:
high_client.custom_api.add(endpoint=submit_query_function)

In [None]:
high_client.api.services.api.update(
    endpoint_path="bigquery.submit_query", hide_mock_definition=True
)

In [None]:
# Testing submit query
result = high_client.api.services.bigquery.submit_query(
    func_name="my_func",
    query=f"SELECT * FROM {dataset_1}.{table_1} LIMIT 2",
)

In [None]:
assert "Query submitted" in result
result

In [None]:
job = high_client.code.my_func(blocking=False)

In [None]:
res = job.wait().get()
assert len(res) == 2
assert isinstance(res, pd.DataFrame)

# Test endpoints

In [None]:
high_client.custom_api.api_endpoints()

In [None]:
assert len(high_client.custom_api.api_endpoints()) == 3

In [None]:
assert (
    high_client.api.services.bigquery.test_query
    and high_client.api.services.bigquery.submit_query
)

# Syncing

In [None]:
diff = compare_clients(
    from_client=high_client, to_client=low_client, hide_usercode=False
)

In [None]:
widget = diff.resolve()

In [None]:
widget._share_all()
widget._sync_all()

In [None]:
assert len(low_client.jobs.get_all()) == 0

In [None]:
assert len(low_client.custom_api.api_endpoints()) == 3

In [None]:
assert len(high_client.custom_api.api_endpoints()) == 3

# Test emails

In [None]:
# email_server_low.get_emails_for_user(user_email="info@openmined.org")
assert len(email_server.get_emails_for_user(user_email=ADMIN_EMAIL)) == 1
assert len(email_server.get_emails_for_user(user_email=ROOT_EMAIL)) == 1

In [None]:
assert (
    "Job Failed"
    in email_server.get_emails_for_user(user_email=ADMIN_EMAIL)[0].email_content
)

In [None]:
assert (
    "A new request has been submitted and requires your attention"
    in email_server.get_emails_for_user(user_email=ROOT_EMAIL)[0].email_content
)

# Clean up

In [None]:
if environment != "remote":
    server_high.land()
    server_low.land()
smtp_server.stop()