# Syncing helpers

In [1]:
# stdlib

# syft absolute
import syft as sy
from syft.client.domain_client import DomainClient
from syft.client.syncing import compare_clients
from syft.service.code.user_code import UserCode
from syft.service.job.job_stash import Job
from syft.service.job.job_stash import JobStatus
from syft.service.request.request import Request
from syft.service.sync.diff_state import ObjectDiffBatch


def is_request_to_sync(batch: ObjectDiffBatch) -> bool:
    # True if this is a new low-side request
    # TODO add condition for sql requests/usercodes
    low_request = batch.root.low_obj
    return (
        isinstance(low_request, Request)
        and batch.status == "NEW"
        and "autosync" in low_request.tags
    )


def is_job_to_sync(batch: ObjectDiffBatch):
    # True if this is a new high-side job that is either COMPLETED or ERRORED
    if batch.status != "NEW":
        return False
    if not isinstance(batch.root.high_obj, Job):
        return False
    job = batch.root.high_obj
    return job.status in (JobStatus.ERRORED, JobStatus.COMPLETED)


def sync_new_requests(
    client_low: DomainClient,
    client_high: DomainClient,
) -> dict[sy.UID, sy.SyftSuccess | sy.SyftError] | sy.SyftError:
    sync_request_results = {}
    diff = compare_clients(
        from_client=client_low, to_client=client_high, include_types=["request"]
    )
    if isinstance(diff, sy.SyftError):
        print(diff)
        return sync_request_results
    for batch in diff.batches:
        if is_request_to_sync(batch):
            request_id = batch.root.low_obj.id
            w = batch.resolve()
            result = w.click_sync()
            sync_request_results[request_id] = result
    return sync_request_results


def execute_requests(
    client_high: DomainClient, request_ids: list[sy.UID]
) -> dict[sy.UID, Job]:
    jobs_by_request_id = {}
    for request_id in request_ids:
        request = client_high.requests.get_by_uid(request_id)
        if not isinstance(request, Request):
            continue

        code = request.code
        if not isinstance(code, UserCode):
            continue

        func_name = request.code.service_func_name
        api_func = getattr(client_high.code, func_name, None)
        if api_func is None:
            continue

        job = api_func(blocking=False)
        jobs_by_request_id[request_id] = job

    return jobs_by_request_id


def sync_and_execute_new_requests(
    client_low: DomainClient, client_high: DomainClient
) -> None:
    sync_results = sync_new_requests(client_low, client_high)
    if isinstance(sync_results, sy.SyftError):
        print(sync_results)
        return

    request_ids = [
        uid for uid, res in sync_results.items() if isinstance(res, sy.SyftSuccess)
    ]
    print(f"Synced {len(request_ids)} new requests")

    jobs_by_request = execute_requests(client_high, request_ids)
    print(f"Started {len(jobs_by_request)} new jobs")


def sync_finished_jobs(
    client_low: DomainClient,
    client_high: DomainClient,
) -> dict[sy.UID, sy.SyftError | sy.SyftSuccess] | sy.SyftError:
    sync_job_results = {}
    diff = compare_clients(
        from_client=client_high, to_client=client_low, include_types=["job"]
    )
    if isinstance(diff, sy.SyftError):
        print(diff)
        return diff

    for batch in diff.batches:
        if is_job_to_sync(batch):
            batch_id = batch.root.high_obj.id
            w = batch.resolve()
            share_result = w.click_share_all_private_data()
            if isinstance(share_result, sy.SyftError):
                sync_job_results[batch_id] = share_result
                continue
            sync_result = w.click_sync()
            sync_job_results[batch_id] = sync_result

    print(f"Sharing {len(sync_job_results)} new results")
    return sync_job_results


def auto_sync(client_low: DomainClient, client_high: DomainClient) -> None:
    print("Starting auto sync")
    sync_and_execute_new_requests(client_low, client_high)
    sync_finished_jobs(client_low, client_high)
    print("Finished auto sync")

# Create Nodes

In [2]:
# third party
from google.oauth2 import service_account

In [3]:
low_side = sy.orchestra.launch(
    name="auto-sync-low",
    node_side_type="low",
    local_db=True,
    reset=True,
    n_consumers=1,
    create_producer=True,
    dev_mode=True,
)

high_side = sy.orchestra.launch(
    name="high-side",
    node_side_type="high",
    local_db=True,
    reset=True,
    n_consumers=4,
    create_producer=True,
    dev_mode=True,
)

Staging Protocol Changes...
Document Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/579f2ebaf61545e4bead94c215ea3f88/db/579f2ebaf61545e4bead94c215ea3f88.sqlite
Action Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/579f2ebaf61545e4bead94c215ea3f88/db/579f2ebaf61545e4bead94c215ea3f88.sqlite
Creating default worker image with tag='local-dev'
Setting up worker poolname=default-pool workers=1 image_uid=b5fa6320676a4ba78a4dc18fd1abd9ac in_memory=True
Created default worker pool.
Data Migrated to latest version !!!
Staging Protocol Changes...
Document Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/083dfc0ecd744d17ad21a36a6477565e/db/083dfc0ecd744d17ad21a36a6477565e.sqlite
Action Store's SQLite DB path: /var/folders/pn/f6xkq7mx683g5jkyt91gqyzw0000gn/T/syft/083dfc0ecd744d17ad21a36a6477565e/db/083dfc0ecd744d17ad21a36a6477565e.sqlite
Creating default worker image with tag='local-dev'
Setting up worker poo

In [4]:
client_high = high_side.login(email="info@openmined.org", password="changethis")
client_low = low_side.login(email="info@openmined.org", password="changethis")
client_low.register(
    email="newuser@openmined.org", name="John Doe", password="pw", password_verify="pw"
)
client_low_ds = low_side.login(email="newuser@openmined.org", password="pw")

Logged into <high-side: High side Domain> as <info@openmined.org>


Logged into <auto-sync-low: Low side Domain> as <info@openmined.org>


Logged into <auto-sync-low: Low side Domain> as <newuser@openmined.org>


# Create Query enpoints

6. we are not limiting the result in size (IMPLEMENT)

In [5]:
# stdlib
import json

with open("./credentials.json") as f:
    BQ_CREDENTIALS = json.loads(f.read())

In [6]:
# Mock API


@sy.api_endpoint_method(settings={})
def mock_query_function(
    context,
    sql_query: str,
) -> str:
    # third party
    import numpy as np
    import pandas as pd

    # syft absolute
    from syft.service.response import SyftError

    # Set the seed for reproducibility
    np.random.seed(42)
    try:
        # Generate mock data
        data = {
            "Name": [f"Name_{i}" for i in range(1, 11)],
            "Age": np.random.randint(20, 50, size=10),
            "Email": [f"name_{i}@example.com" for i in range(1, 11)],
            "JoinDate": pd.date_range(start="2023-01-01", periods=10, freq="M")
            .strftime("%Y-%m-%d")
            .tolist(),
            "Salary": np.random.randint(40000, 120000, size=10),
        }

        # Create DataFrame
        return pd.DataFrame(data)
    except Exception:
        return SyftError(
            message="Ops! Something went wrong. please, contact your admin"
        )


# Private API
@sy.api_endpoint_method(settings=BQ_CREDENTIALS)
def private_query_function(
    context,
    sql_query: str,
) -> str:
    # third party

    # third party
    from google.cloud import bigquery

    # syft absolute
    from syft.service.response import SyftError

    # Client query
    credentials = service_account.Credentials.from_service_account_info(
        context.settings
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location="us-west1",
    )
    # Generate mock data
    rows = client.query_and_wait(
        sql_query,
        project="reddit-testing-415005",
    )
    if rows.total_rows > 40000:
        return SyftError(
            message="Please only write queries that gather aggregate statistics"
        )
    # Create DataFrame
    res = rows.to_dataframe()
    return res


# Create new Twin API using bigquery-pool as a worker pool
new_endpoint = sy.TwinAPIEndpoint(
    path="reddit.query",
    description="Ask SQL Queries using our BQ",
    private_function=private_query_function,
    mock_function=mock_query_function,
)

client_high.custom_api.add(endpoint=new_endpoint)

In [7]:
if False:
    client_high.api.services.reddit.query.private(
        sql_query="SELECT * from data_10gb.comments LIMIT 40"
    ).head()

# Sync TwinAPI to LowSide

In [8]:
widget = compare_clients(from_client=client_high, to_client=client_low).resolve()
widget.click_sync(0)

Decision: Syncing 1 objects


# Create Function factory

In [9]:
@sy.api_endpoint(path="reddit.submit_query")
def submit_query(
    context,
    func_name: str,
    query: str,
) -> str:
    # syft absolute
    import syft as sy

    if not func_name.isalpha():
        return sy.SyftError(
            message="Please only use alphabetic characters for your func_name"
        )

    @sy.syft_function(
        name=func_name,
        input_policy=sy.MixedInputPolicy(
            endpoint=sy.Constant(val=context.admin_client.api.services.reddit.query),
            query=sy.Constant(val=query),
            client=context.admin_client,
        ),
    )
    def execute_query(query: str, endpoint):
        res = endpoint.private(sql_query=query)
        return res

    request = context.user_client.code.request_code_execution(execute_query)
    if isinstance(request, sy.SyftError):
        return request
    context.admin_client.requests.set_tags(request, ["autosync"])

    return (
        f"Query submitted {request}, use `client.code.{func_name}()` to run your query"
    )

In [10]:
client_low.api.services.api.add(endpoint=submit_query)

# Submit request

In [11]:
submit_res = client_low_ds.api.services.reddit.submit_query(
    func_name="myquery", query="SELECT * from data_10gb.comments LIMIT 40"
)

Logged into <auto-sync-low: Low side Domain > as GUEST
Logged into <auto-sync-low: Low side Domain > as GUEST


In [12]:
client_low.requests[0].tags

['autosync']

In [13]:
# client_low_ds.code.myquery()

# Run Autosync

In [18]:
# stdlib
import time

# sync every 5 seconds

for _ in range(5):
    try:
        auto_sync(client_low, client_high)
    except Exception as e:
        print(e)
    time.sleep(5)

Starting auto sync


Synced 0 new requests
Started 0 new jobs


Decision: Syncing 5 objects
Sharing 1 new results
Finished auto sync
Starting auto sync


Synced 0 new requests
Started 0 new jobs


Sharing 0 new results
Finished auto sync
Starting auto sync


Synced 0 new requests
Started 0 new jobs


Sharing 0 new results
Finished auto sync
Starting auto sync


Synced 0 new requests
Started 0 new jobs


Sharing 0 new results
Finished auto sync


# Run function as DS

In [19]:
res = client_low_ds.code.myquery()

In [21]:
res.get().head()

Unnamed: 0,id,post_id,parent_id,created_at,last_modified_at,body,author_id,gilded,score,upvote_ratio,deleted,collapsed_in_crowd_control,spam,subreddit_id,permalink
0,t1_jsrssaa,t3_3eq9p3r,t1_j0bm0qn,2020-02-05 13:15:44+00:00,NaT,WASHINGTON (AP) — The federal government groun...,t2_31y14bfh,False,3,0.65,False,False,False,t5_7i2tp,/r/t5_7i2tp/comments/eq9p3r/comment/jsrssaa
1,t1_z014wyn,t3_mtoy3vi,,2020-02-05 13:15:44+00:00,NaT,"He was indicted on 16 felony charges, includin...",t2_iemo2ikg,False,8,1.0,False,False,False,t5_xg19m,/r/t5_xg19m/comments/mtoy3vi/comment/z014wyn
2,t1_8ttp66l,t3_is0dk32,,2020-02-05 13:15:44+00:00,NaT,,t2_csenfqwl,False,6,1.0,False,False,False,t5_unjsw,/r/t5_unjsw/comments/is0dk32/comment/8ttp66l
3,t1_qhuklsm,t3_7ajgpje,,2020-02-05 13:15:44+00:00,NaT,These nachos are so sinful; it's hard to stop ...,t2_2ztp96r7,False,7,0.69,False,False,False,t5_91cqb,/r/t5_91cqb/comments/7ajgpje/comment/qhuklsm
4,t1_8nkh2zb,t3_oygwavx,t1_0mzt6bq,2020-02-05 13:15:44+00:00,NaT,"When we last checked in with Charles Platkin, ...",t2_o79jr0e0,False,5,1.0,True,False,False,t5_y71mw,/r/t5_y71mw/comments/oygwavx/comment/8nkh2zb
