In [None]:
# syft absolute
import syft as sy

In [None]:
# server = sy.orchestra.launch(
#     name="bigquery-high",
#     dev_mode=False,
#     server_side_type="high",
#     port="8080",
#     n_consumers=1, # How many workers to be spawned
#     create_producer=True # Can produce more workers
# )

In [None]:
high_client = sy.login(
    url="http://localhost:8080", email="info@openmined.org", password="changethis"
)

In [None]:
# add your values to secrets.json in this folder
secrets = sy.get_nb_secrets(
    {
        "service_account_bigquery_private": {},
        "service_account_bigquery_mock": {},
        "region_bigquery": "",
        "project_id": "",
        "dataset_1": "dataset1",
        "table_1": "table1",
        "table_2": "table2",
    }
)

In [None]:
high_client.worker_pools

In [None]:
# !pip list | grep bigquery

In [None]:
# !pip install db-dtypes google-cloud-bigquery

In [None]:
# third party
from google.cloud import bigquery

In [None]:
@sy.api_endpoint_method(
    settings={
        "credentials": secrets["service_account_bigquery_private"],
        "region": secrets["region_bigquery"],
        "project_id": secrets["project_id"],
    }
)
def private_query_function(
    context,
    sql_query: str,
) -> str:
    # third party
    from google.cloud import bigquery
    from google.oauth2 import service_account

    # syft absolute
    from syft.service.response import SyftError

    # Auth for Bigquer based on the workload identity
    credentials = service_account.Credentials.from_service_account_info(
        context.settings["credentials"]
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location=context.settings["region"],
    )

    try:
        rows = client.query_and_wait(
            sql_query,
            project=context.settings["project_id"],
        )

        if rows.total_rows > 1_000_000:
            return SyftError(
                message="Please only write queries that gather aggregate statistics"
            )

        return rows.to_dataframe()
    except Exception as e:
        # We MUST handle the errors that we want to be visible to the data owners.
        # Any exception not catched is visible only to the data owner.
        # not a bigquery exception
        if not hasattr(e, "_errors"):
            output = f"got exception e: {type(e)} {str(e)}"
            return SyftError(
                message=f"An error occured executing the API call {output}"
            )
            # return SyftError(message="An error occured executing the API call, please contact the domain owner.")

        if e._errors[0]["reason"] in [
            "badRequest",
            "blocked",
            "duplicate",
            "invalidQuery",
            "invalid",
            "jobBackendError",
            "jobInternalError",
            "notFound",
            "notImplemented",
            "rateLimitExceeded",
            "resourceInUse",
            "resourcesExceeded",
            "tableUnavailable",
            "timeout",
        ]:
            return SyftError(
                message="Error occured during the call: " + e._errors[0]["message"]
            )
        else:
            return SyftError(
                message="An error occured executing the API call, please contact the domain owner."
            )

In [None]:
# Define any helper methods for our rate limiter


def is_within_rate_limit(context):
    """Rate limiter for custom API calls made by users."""
    # stdlib
    import datetime

    state = context.state
    settings = context.settings
    email = context.user.email

    current_time = datetime.datetime.now()
    calls_last_min = [
        1 if (current_time - call_time).seconds < 60 else 0
        for call_time in state[email]
    ]

    return sum(calls_last_min) < settings["CALLS_PER_MIN"]


# Define a mock endpoint that the researchers can use for testing


@sy.api_endpoint_method(
    settings={
        "credentials": secrets["service_account_bigquery_private"],
        "region": secrets["region_bigquery"],
        "project_id": secrets["project_id"],
        "CALLS_PER_MIN": 10,
    },
    helper_functions=[is_within_rate_limit],
)
def mock_query_function(
    context,
    sql_query: str,
) -> str:
    # stdlib
    import datetime

    # third party
    from google.cloud import bigquery
    from google.oauth2 import service_account

    # syft absolute
    from syft.service.response import SyftError

    # Auth for Bigquer based on the workload identity
    credentials = service_account.Credentials.from_service_account_info(
        context.settings["credentials"]
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location=context.settings["region"],
    )

    # Store a dict with the calltimes for each user, via the email.
    if context.user.email not in context.state.keys():
        context.state[context.user.email] = []

    if not context.code.is_within_rate_limit(context):
        return SyftError(message="Rate limit of calls per minute has been reached.")

    try:
        context.state[context.user.email].append(datetime.datetime.now())

        rows = client.query_and_wait(
            sql_query,
            project=context.settings["project_id"],
        )

        if rows.total_rows > 1_000_000:
            return SyftError(
                message="Please only write queries that gather aggregate statistics"
            )

        return rows.to_dataframe()

    except Exception as e:
        # not a bigquery exception
        if not hasattr(e, "_errors"):
            output = f"got exception e: {type(e)} {str(e)}"
            return SyftError(
                message=f"An error occured executing the API call {output}"
            )
            # return SyftError(message="An error occured executing the API call, please contact the domain owner.")

        # Treat all errors that we would like to be forwarded to the data scientists
        # By default, any exception is only visible to the data owner.

        if e._errors[0]["reason"] in [
            "badRequest",
            "blocked",
            "duplicate",
            "invalidQuery",
            "invalid",
            "jobBackendError",
            "jobInternalError",
            "notFound",
            "notImplemented",
            "rateLimitExceeded",
            "resourceInUse",
            "resourcesExceeded",
            "tableUnavailable",
            "timeout",
        ]:
            return SyftError(
                message="Error occured during the call: " + e._errors[0]["message"]
            )
        else:
            return SyftError(
                message="An error occured executing the API call, please contact the domain owner."
            )

In [None]:
# Look up the worker pools and identify the name of the one that has the required packages
# After, bind the endpoint to that workerpool


high_client.worker_pools

In [None]:
new_endpoint = sy.TwinAPIEndpoint(
    path="bigquery.test_query",
    description="This endpoint allows to query Bigquery storage via SQL queries.",
    private_function=private_query_function,
    mock_function=mock_query_function,
    worker_pool="bigquery-pool",
)

high_client.custom_api.add(endpoint=new_endpoint)

In [None]:
# Here, we update the endpoint to timeout after 100s (rather the default of 60s)
high_client.api.services.api.update(
    endpoint_path="bigquery.test_query", endpoint_timeout=120
)

In [None]:
high_client.api.services.api.update(
    endpoint_path="bigquery.test_query", hide_mock_definition=True
)

In [None]:
# Test mock version
result = high_client.api.services.bigquery.test_query.mock(
    sql_query=f"SELECT * FROM {secrets['dataset_1']}.{secrets['table_1']} LIMIT 10"
)
result

In [None]:
@sy.api_endpoint(
    path="bigquery.schema",
    description="This endpoint allows for visualising the metadata of tables available in BigQuery.",
    settings={
        "credentials": secrets["service_account_bigquery_mock"],
        "region": secrets["region_bigquery"],
        "project_id": secrets["project_id"],
        "CALLS_PER_MIN": 5,
    },
    helper_functions=[
        is_within_rate_limit
    ],  # Adds ratelimit as this is also a method available to data scientists
    worker_pool="bigquery-pool",
)
def schema_function(
    context,
) -> str:
    # stdlib
    import datetime

    # third party
    from google.oauth2 import service_account
    import pandas as pd

    # syft absolute
    from syft.service.response import SyftError

    # Auth for Bigquer based on the workload identity
    credentials = service_account.Credentials.from_service_account_info(
        context.settings["credentials"]
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location=context.settings["region"],
    )

    if context.user.email not in context.state.keys():
        context.state[context.user.email] = []

    if not context.code.is_within_rate_limit(context):
        return SyftError(message="Rate limit of calls per minute has been reached.")

    try:
        context.state[context.user.email].append(datetime.datetime.now())

        # Formats the data schema in a data frame format
        # Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames

        data_schema = []
        for table_id in [
            f"{secrets['dataset_1']}.{secrets['table_1']}",
            f"{secrets['dataset_1']}.{secrets['table_2']}",
        ]:
            table = client.get_table(table_id)
            for schema in table.schema:
                data_schema.append(
                    {
                        "project": str(table.project),
                        "dataset_id": str(table.dataset_id),
                        "table_id": str(table.table_id),
                        "schema_name": str(schema.name),
                        "schema_field": str(schema.field_type),
                        "description": str(table.description),
                        "num_rows": str(table.num_rows),
                    }
                )
        return pd.DataFrame(data_schema)

    except Exception as e:
        # not a bigquery exception
        if not hasattr(e, "_errors"):
            output = f"got exception e: {type(e)} {str(e)}"
            return SyftError(
                message=f"An error occured executing the API call {output}"
            )
            # return SyftError(message="An error occured executing the API call, please contact the domain owner.")

        # Should add appropriate error handling for what should be exposed to the data scientists.
        return SyftError(
            message="An error occured executing the API call, please contact the domain owner."
        )


high_client.custom_api.add(endpoint=schema_function)
high_client.refresh()

In [None]:
@sy.api_endpoint(
    path="bigquery.submit_query",
    description="API endpoint that allows you to submit SQL queries to run on the private data.",
    worker_pool="bigquery-pool",
)
def submit_query(
    context,
    func_name: str,
    query: str,
) -> str:
    # stdlib
    import hashlib

    # syft absolute
    import syft as sy

    hash_object = hashlib.new("sha256")

    hash_object.update(context.user.email.encode("utf-8"))
    func_name = func_name + "_" + hash_object.hexdigest()[:6]

    @sy.syft_function(
        name=func_name,
        input_policy=sy.MixedInputPolicy(
            endpoint=sy.Constant(
                val=context.admin_client.api.services.bigquery.test_query
            ),
            query=sy.Constant(val=query),
            client=context.admin_client,
        ),
        worker_pool_name="bigquery-pool",
    )
    def execute_query(query: str, endpoint):
        res = endpoint(sql_query=query)
        return res

    request = context.user_client.code.request_code_execution(execute_query)
    if isinstance(request, sy.SyftError):
        return request
    context.admin_client.requests.set_tags(request, ["autosync"])

    return (
        f"Query submitted {request}. Use `client.code.{func_name}()` to run your query"
    )

In [None]:
high_client.custom_api.add(endpoint=submit_query)

In [None]:
high_client.api.services.api.update(
    endpoint_path="bigquery.submit_query", hide_mock_definition=True
)

In [None]:
high_client.custom_api.api_endpoints()

In [None]:
high_client.api.services.bigquery.test_query

In [None]:
high_client.api.services.bigquery.submit_query

In [None]:
# Test mock version
result = high_client.api.services.bigquery.test_query.mock(
    sql_query=f"SELECT * FROM {secrets['dataset_1']}.{secrets['table_1']} LIMIT 10"
)
result

In [None]:
# Test private version
result = high_client.api.services.bigquery.test_query.private(
    sql_query=f"SELECT * FROM {secrets['dataset_1']}.{secrets['table_1']} LIMIT 10"
)
result

In [None]:
# Test mock version for wrong queries
result = high_client.api.services.bigquery.test_query.mock(
    sql_query="SELECT * FROM invalid_table LIMIT 1"
)
result

In [None]:
# Test private version
result = high_client.api.services.bigquery.test_query.private(
    sql_query=f"SELECT * FROM {secrets['dataset_1']}.{secrets['table_1']} LIMIT 1"
)
result

In [None]:
# Inspect the context state on an endpoint
high_client.api.services.bigquery.test_query.mock.context.state

In [None]:
result = high_client.api.services.bigquery.submit_query(
    func_name="my_func",
    query=f"SELECT * FROM {secrets['dataset_1']}.{secrets['table_1']} LIMIT 1",
)

In [None]:
result