In [None]:
# third party

# syft absolute
import syft as sy
from syft import test_settings

In [None]:
server_low = sy.orchestra.launch(
    name="bigquery-low",
    server_side_type="low",
    dev_mode=True,
    reset=True,
    local_db=True,
    n_consumers=1,
    create_producer=True,
)

server_high = sy.orchestra.launch(
    name="bigquery-high",
    server_side_type="high",
    dev_mode=True,
    reset=True,
    local_db=True,
    n_consumers=1,
    create_producer=True,
)

# Login and launch worker Pools

In [None]:
low_client = server_low.login(email="info@openmined.org", password="changethis")

In [None]:
high_client = server_high.login(email="info@openmined.org", password="changethis")

In [None]:
assert len(high_client.worker_pools.get_all()) == 1

In [None]:
def launch_worker_pool(client, pool_name):
    if pool_name not in [x.name for x in client.worker_pools]:
        external_registry = test_settings.get("external_registry", default="docker.io")
        worker_docker_tag = f"openmined/bigquery:{sy.__version__}"
        external_registry, worker_docker_tag
        result = client.api.services.worker_image.submit(
            worker_config=sy.PrebuiltWorkerConfig(
                tag=f"{external_registry}/{worker_docker_tag}"
            )
        )
        worker_image = client.images.get_all()[1]
        worker_image
        result = client.api.services.image_registry.add(external_registry)
        result = client.api.services.worker_pool.launch(
            pool_name=pool_name,
            image_uid=worker_image.id,
            num_workers=1,
        )
        result
        return result
    else:
        print("Already exists")

In [None]:
pool_name = "bigquery-pool"

In [None]:
launch_worker_pool(high_client, pool_name)

In [None]:
launch_worker_pool(low_client, pool_name)

In [None]:
# result = high_client.worker_pools.scale(number=5, pool_name=worker_pool_name)
# result

In [None]:
assert len(high_client.worker_pools.get_all()) == 2
assert len(low_client.worker_pools.get_all()) == 2

In [None]:
base_worker_image = high_client.images.get_all()[0]
base_worker_image

# Register DS

In [None]:
low_client.register(
    email="data_scientist@openmined.org",
    password="verysecurepassword",
    password_verify="verysecurepassword",
    name="John Doe",
)

In [None]:
high_client.settings.allow_guest_signup(enable=False)

In [None]:
assert len(low_client.api.services.user.get_all()) == 2

In [None]:
# worker_dockerfile = f"""
# FROM {str(base_worker_image.image_identifier)}

# RUN uv pip install db-dtypes google-cloud-bigquery

# """.strip()
# worker_dockerfile

In [None]:
# docker_tag = str(base_worker_image.image_identifier).replace(
#     "backend", "worker-bigquery"
# )
# docker_tag

# Twin endpoints

In [None]:
@sy.api_endpoint_method(
    settings={
        "credentials": test_settings.gce_service_account.to_dict(),
        "region": test_settings.gce_region,
        "project_id": test_settings.gce_project_id,
    }
)
def private_query_function(
    context,
    sql_query: str,
) -> str:
    # third party
    from google.cloud import bigquery  # noqa: F811
    from google.oauth2 import service_account

    # syft absolute
    from syft import SyftException

    # Auth for Bigquer based on the workload identity
    credentials = service_account.Credentials.from_service_account_info(
        context.settings["credentials"]
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location=context.settings["region"],
    )

    # third party

    # Auth for Bigquer based on the workload identity
    try:
        rows = client.query_and_wait(
            sql_query,
            project=context.settings["project_id"],
        )

        if rows.total_rows > 1_000_000:
            raise SyftException(
                public_message="Please only write queries that gather aggregate statistics"
            )

        return rows.to_dataframe()
    except Exception as e:
        # We MUST handle the errors that we want to be visible to the data owners.
        # Any exception not catched is visible only to the data owner.
        # not a bigquery exception
        if not hasattr(e, "_errors"):
            output = f"got exception e: {type(e)} {str(e)}"
            raise SyftException(
                public_message=f"An error occured executing the API call {output}"
            )

        if e._errors[0]["reason"] in [
            "badRequest",
            "blocked",
            "duplicate",
            "invalidQuery",
            "invalid",
            "jobBackendError",
            "jobInternalError",
            "notFound",
            "notImplemented",
            "rateLimitExceeded",
            "resourceInUse",
            "resourcesExceeded",
            "tableUnavailable",
            "timeout",
        ]:
            raise SyftException(
                public_message="Error occured during the call: "
                + e._errors[0]["message"]
            )
        else:
            raise SyftException(
                public_message="An error occured executing the API call, please contact the domain owner."
            )

In [None]:
# Define any helper methods for our rate limiter
def is_within_rate_limit(context):
    """Rate limiter for custom API calls made by users."""
    # stdlib
    import datetime

    state = context.state
    settings = context.settings
    email = context.user.email

    current_time = datetime.datetime.now()
    calls_last_min = [
        1 if (current_time - call_time).seconds < 60 else 0
        for call_time in state[email]
    ]

    return sum(calls_last_min) < settings["CALLS_PER_MIN"]

In [None]:
# Define a mock endpoint that the researchers can use for testing


@sy.api_endpoint_method(
    settings={
        "credentials": test_settings.gce_service_account.to_dict(),
        "region": test_settings.gce_region,
        "project_id": test_settings.gce_project_id,
        "CALLS_PER_MIN": 10,
    },
    helper_functions=[is_within_rate_limit],
)
def mock_query_function(
    context,
    sql_query: str,
) -> str:
    # stdlib
    import datetime

    # third party
    from google.cloud import bigquery  # noqa: F811
    from google.oauth2 import service_account

    # syft absolute
    from syft import SyftException

    # Auth for Bigquer based on the workload identity
    credentials = service_account.Credentials.from_service_account_info(
        context.settings["credentials"]
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location=context.settings["region"],
    )

    # Store a dict with the calltimes for each user, via the email.
    if context.user.email not in context.state.keys():
        context.state[context.user.email] = []

    if not context.code.is_within_rate_limit(context):
        raise SyftException(
            public_message="Rate limit of calls per minute has been reached."
        )

    try:
        context.state[context.user.email].append(datetime.datetime.now())

        rows = client.query_and_wait(
            sql_query,
            project=context.settings["project_id"],
        )

        if rows.total_rows > 1_000_000:
            raise SyftException(
                public_message="Please only write queries that gather aggregate statistics"
            )

        return rows.to_dataframe()

    except Exception as e:
        # not a bigquery exception
        if not hasattr(e, "_errors"):
            output = f"got exception e: {type(e)} {str(e)}"
            raise SyftException(
                public_message=f"An error occured executing the API call {output}"
            )

        # Treat all errors that we would like to be forwarded to the data scientists
        # By default, any exception is only visible to the data owner.

        if e._errors[0]["reason"] in [
            "badRequest",
            "blocked",
            "duplicate",
            "invalidQuery",
            "invalid",
            "jobBackendError",
            "jobInternalError",
            "notFound",
            "notImplemented",
            "rateLimitExceeded",
            "resourceInUse",
            "resourcesExceeded",
            "tableUnavailable",
            "timeout",
        ]:
            raise SyftException(
                public_message="Error occured during the call: "
                + e._errors[0]["message"]
            )
        else:
            raise SyftException(
                public_message="An error occured executing the API call, please contact the domain owner."
            )

In [None]:
new_endpoint = sy.TwinAPIEndpoint(
    path="bigquery.test_query",
    description="This endpoint allows to query Bigquery storage via SQL queries.",
    private_function=private_query_function,
    mock_function=mock_query_function,
    worker_pool=pool_name,
)

high_client.custom_api.add(endpoint=new_endpoint)

In [None]:
# Here, we update the endpoint to timeout after 100s (rather the default of 60s)
high_client.api.services.api.update(
    endpoint_path="bigquery.test_query", endpoint_timeout=120
)

In [None]:
high_client.api.services.api.update(
    endpoint_path="bigquery.test_query", hide_mock_definition=True
)

In [None]:
# Test mock version
result = high_client.api.services.bigquery.test_query.mock(
    sql_query=f"SELECT * FROM {test_settings.dataset_1}.{test_settings.table_1} LIMIT 10"
)
result

In [None]:
@sy.api_endpoint(
    path="bigquery.schema",
    description="This endpoint allows for visualising the metadata of tables available in BigQuery.",
    settings={
        "credentials": test_settings.gce_service_account.to_dict(),
        "region": test_settings.gce_region,
        "project_id": test_settings.gce_project_id,
        "dataset_1": test_settings.dataset_1,
        "table_1": test_settings.table_1,
        "table_2": test_settings.table_2,
        "CALLS_PER_MIN": 5,
    },
    helper_functions=[
        is_within_rate_limit
    ],  # Adds ratelimit as this is also a method available to data scientists
    worker_pool=pool_name,
)
def schema_function(
    context,
) -> str:
    # stdlib
    import datetime

    # third party
    from google.cloud import bigquery  # noqa: F811
    from google.oauth2 import service_account
    import pandas as pd

    # syft absolute
    from syft import SyftException

    # Auth for Bigquer based on the workload identity
    credentials = service_account.Credentials.from_service_account_info(
        context.settings["credentials"]
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location=context.settings["region"],
    )

    if context.user.email not in context.state.keys():
        context.state[context.user.email] = []

    if not context.code.is_within_rate_limit(context):
        raise SyftException(
            public_message="Rate limit of calls per minute has been reached."
        )

    try:
        context.state[context.user.email].append(datetime.datetime.now())

        # Formats the data schema in a data frame format
        # Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames

        data_schema = []
        for table_id in [
            f"{context.settings["dataset_1"]}.{context.settings["table_1"]}",
            f"{context.settings["dataset_1"]}.{context.settings["table_2"]}",
        ]:
            table = client.get_table(table_id)
            for schema in table.schema:
                data_schema.append(
                    {
                        "project": str(table.project),
                        "dataset_id": str(table.dataset_id),
                        "table_id": str(table.table_id),
                        "schema_name": str(schema.name),
                        "schema_field": str(schema.field_type),
                        "description": str(table.description),
                        "num_rows": str(table.num_rows),
                    }
                )
        return pd.DataFrame(data_schema)

    except Exception as e:
        # not a bigquery exception
        if not hasattr(e, "_errors"):
            output = f"got exception e: {type(e)} {str(e)}"
            raise SyftException(
                public_message=f"An error occured executing the API call {output}"
            )

        # Should add appropriate error handling for what should be exposed to the data scientists.
        raise SyftException(
            public_message="An error occured executing the API call, please contact the domain owner."
        )

In [None]:
high_client.custom_api.add(endpoint=schema_function)

In [None]:
high_client.api.services.bigquery.schema()

In [None]:
@sy.api_endpoint(
    path="bigquery.submit_query",
    description="API endpoint that allows you to submit SQL queries to run on the private data.",
    worker_pool=pool_name,
    settings={"worker": pool_name},
)
def submit_query(
    context,
    func_name: str,
    query: str,
) -> str:
    # stdlib
    import hashlib

    # syft absolute
    import syft as sy

    hash_object = hashlib.new("sha256")

    hash_object.update(context.user.email.encode("utf-8"))
    func_name = func_name + "_" + hash_object.hexdigest()[:6]

    @sy.syft_function(
        name=func_name,
        input_policy=sy.MixedInputPolicy(
            endpoint=sy.Constant(
                val=context.admin_client.api.services.bigquery.test_query
            ),
            query=sy.Constant(val=query),
            client=context.admin_client,
        ),
        worker_pool_name=context.settings["worker"],
    )
    def execute_query(query: str, endpoint):
        res = endpoint(sql_query=query)
        return res

    request = context.user_client.code.request_code_execution(execute_query)
    context.admin_client.requests.set_tags(request, ["autosync"])

    return (
        f"Query submitted {request}. Use `client.code.{func_name}()` to run your query"
    )

In [None]:
high_client.custom_api.add(endpoint=submit_query)

In [None]:
high_client.api.services.api.update(
    endpoint_path="bigquery.submit_query", hide_mock_definition=True
)

In [None]:
high_client.custom_api.api_endpoints()

In [None]:
assert len(high_client.custom_api.api_endpoints()) == 3

In [None]:
assert (
    high_client.api.services.bigquery.test_query
    and high_client.api.services.bigquery.submit_query
)

In [None]:
# Test mock version
result = high_client.api.services.bigquery.test_query.mock(
    sql_query=f"SELECT * FROM {test_settings.dataset_1}.{test_settings.table_1} LIMIT 10"
)
assert len(result) == 10

In [None]:
# todo can we clean up the duplicate exception messages?

# Test mock version for wrong queries
with sy.raises(
    sy.SyftException(public_message="*must be qualified with a dataset*"), show=True
):
    high_client.api.services.bigquery.test_query.mock(
        sql_query="SELECT * FROM invalid_table LIMIT 1"
    )

In [None]:
# Test private version
result = high_client.api.services.bigquery.test_query.private(
    sql_query=f"SELECT * FROM {test_settings.dataset_1}.{test_settings.table_1} LIMIT 1"
)
result

assert len(result) == 1

In [None]:
widget = sy.sync(from_client=high_client, to_client=low_client)

In [None]:
widget.click_sync(0)
widget.click_sync(1)
widget.click_sync(2)

# Low side research

In [None]:
assert len(low_client.custom_api.api_endpoints()) == 3

In [None]:
result = low_client.api.services.bigquery.test_query.mock(
    sql_query="SELECT * from data_10gb.comments limit 10"
)
assert len(result) == 10

In [None]:
with sy.raises(sy.SyftException, show=True):
    low_client.api.services.bigquery.test_query.private(
        sql_query="SELECT * from data_10gb.comments limit 10"
    )

In [None]:
res = low_client.api.services.bigquery.schema()
# third party
import pandas as pd

assert isinstance(res.get(), pd.DataFrame)

In [None]:
FUNC_NAME = "large_sample"
LARGE_SAMPLE_QUERY = (
    f"SELECT * FROM {test_settings.dataset_2}.{test_settings.table_2} LIMIT 10000"
)

In [None]:
mock_res = low_client.api.services.bigquery.test_query(sql_query=LARGE_SAMPLE_QUERY)

In [None]:
submission = low_client.api.services.bigquery.submit_query(
    func_name=FUNC_NAME, query=LARGE_SAMPLE_QUERY
)

In [None]:
def extract_code_path(response):
    # stdlib
    import re

    pattern = r"client\.code\.(\w+)\(\)"
    match = re.search(pattern, str(response))
    if match:
        extracted_code = match.group(1)
        return extracted_code
    return None

In [None]:
# why are we randomizing things here?
func_name = extract_code_path(submission)

In [None]:
api_method = getattr(low_client.code, func_name, None)
api_method

In [None]:
# todo: this is very noisy, but it actually passes
with sy.raises(
    sy.SyftException(
        public_message="*Please wait for the admin to allow the execution of this code*"
    ),
    show=True,
):
    result = api_method(blocking=False)

# Sync, approve, sync

In [None]:
# todo: this is way too noisy
widget = sy.sync(from_client=low_client, to_client=high_client)

In [None]:
# this is not great
widget.click_sync(0)

In [None]:
assert len(high_client.code.get_all()) == 1

In [None]:
request = high_client.requests[0]

In [None]:
# syft absolute
from syft.service.code.user_code import UserCode
from syft.service.request.request import Request

In [None]:
def execute_request(client, request) -> dict:
    if not isinstance(request, Request):
        return "This is not a request"

    code = request.code
    if not isinstance(code, UserCode):
        return "No usercode found"

    func_name = request.code.service_func_name
    api_func = getattr(client.code, func_name, None)
    if api_func is None:
        return "Code name was not found on the client."

    job = api_func(blocking=False)
    return job

In [None]:
job = execute_request(high_client, request)

In [None]:
job

In [None]:
# third party
from sync_helpers import sync_finished_jobs

In [None]:
sync_finished_jobs(client_low=low_client, client_high=high_client)

In [None]:
low_client.requests

# DS: Execute

In [None]:
job = api_method(blocking=False)

In [None]:
res = job.wait().get()

In [None]:
assert isinstance(res, pd.DataFrame)

In [None]:
assert len(res) == 10000

In [None]:
FUNC_NAME

In [None]:
server_high.land()

In [None]:
# !echo "$worker_dockerfile" | docker build -t $docker_tag -q -

In [None]:
# !docker image ls | grep bigquery