In [None]:
# stdlib
from collections import Counter
import os

# third party
import pandas as pd

# syft absolute
import syft as sy
from syft.client.syncing import compare_clients
from syft.util.test_helpers.email_helpers import get_email_server
from syft.util.test_helpers.email_helpers import load_users
from syft.util.test_helpers.job_helpers import get_job_emails
from syft.util.test_helpers.job_helpers import get_request_for_job_info
from syft.util.test_helpers.job_helpers import load_jobs
from syft.util.test_helpers.job_helpers import save_jobs

In [None]:
# stdlib
ADMIN_EMAIL, ADMIN_PW = "admin2@bigquery.org", "bqpw2"
ROOT_EMAIL, ROOT_PASSWORD = "admin@bigquery.org", "bqpw"
environment = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", "python")
high_port = os.environ.get("CLUSTER_HTTP_PORT_HIGH", "9081")
low_port = os.environ.get("CLUSTER_HTTP_PORT_LOW", "9083")
num_jobs = int(os.environ.get("NUM_TEST_JOBS", 10))
print(environment, low_port)

# Launch server and login

In [None]:
server_low = sy.orchestra.launch(
    name="bigquery-low",
    server_side_type="low",
    dev_mode=True,
    n_consumers=1,
    create_producer=True,
    port=low_port,
)

server_high = sy.orchestra.launch(
    name="bigquery-high",
    server_side_type="high",
    dev_mode=True,
    n_consumers=1,
    create_producer=True,
    port=high_port,
)

In [None]:
email_server, smtp_server = get_email_server()

In [None]:
low_client = sy.login(
    url=f"http://localhost:{low_port}", email=ADMIN_EMAIL, password=ADMIN_PW
)
high_client = sy.login(
    url=f"http://localhost:{high_port}", email=ADMIN_EMAIL, password=ADMIN_PW
)

# Sync UserCode and Requests to High Side

In [None]:
widget = sy.sync(low_client, high_client)

# Ignore batches we dont want to sync

In [None]:
idxs_to_ignore = []

for idx in range(len(widget)):
    batch = widget[idx].obj_diff_batch
    request = batch.root.low_obj
    if request is not None and "broken" in request.code.service_func_name:
        idxs_to_ignore.append(idx)

for idx in idxs_to_ignore:
    widget[idx].deny_and_ignore("query is broken")

In [None]:
diffs = compare_clients(low_client, high_client)
# # check that only requests and usercode are in the diff
assert {diff.root_diff.obj_type.__qualname__ for diff in diffs.batches} == {
    "Request",
}

In [None]:
# widget._share_all()
widget._sync_all()

In [None]:
# syft absolute
from syft.service.request.request import RequestStatus

In [None]:
assert any(x.status == RequestStatus.REJECTED for x in low_client.requests)

# Check that request synced over to high side

In [None]:
len(high_client.code.get_all())

In [None]:
assert len(high_client.code.get_all()) == num_jobs

In [None]:
requests = high_client.requests.get_all_pending()
requests

In [None]:
users = load_users(low_client)
jobs_data = load_jobs(users, low_client)
all_requests = high_client.requests
submitted_jobs_data = [job for job in jobs_data if job.is_submitted]
n_emails_per_job_user = {
    k: len(v)
    for k, v in get_job_emails(submitted_jobs_data, high_client, email_server).items()
}

# Run or Deny

In [None]:
submitted_jobs_data_should_succeed = [
    j for j in submitted_jobs_data if j.should_succeed
]
submitted_jobs_data_should_fail = [
    j for j in submitted_jobs_data if not j.should_succeed
]

In [None]:
for job in submitted_jobs_data_should_succeed:
    request = get_request_for_job_info(all_requests, job)
    j = request.code(blocking=False)
    result = j.wait().get()
    assert isinstance(result, pd.DataFrame)
    job.admin_reviewed = True

In [None]:
for job in submitted_jobs_data_should_fail:
    request = get_request_for_job_info(all_requests, job)
    response = request.deny(
        reason=f"Your request {job.func_name} looks wrong, try again."
    )
    assert isinstance(response, sy.SyftSuccess)
    assert not job.should_succeed
    job.admin_reviewed = True

# Sync job result to low side

In [None]:
widget = sy.sync(from_client=high_client, to_client=low_client)

In [None]:
diffs = sy.compare_clients(high_client, low_client)
batch_root_strs = [x.root_diff.obj_type.__qualname__ for x in diffs.batches]

In [None]:
diffs = sy.compare_clients(high_client, low_client)
batch_root_strs = [x.root_diff.obj_type.__qualname__ for x in diffs.batches]
root_str_counts = Counter(batch_root_strs)
# for successful jobs, root diff should be job. Otherwise request
assert root_str_counts["Job"] == len(submitted_jobs_data_should_succeed)
assert root_str_counts["Request"] == len(submitted_jobs_data_should_fail)

In [None]:
widget._share_all()

In [None]:
widget._sync_all()

# Save state

In [None]:
save_jobs(jobs_data)

# Shutdown

In [None]:
if environment != "remote":
    server_high.land()
    server_low.land()
smtp_server.stop()