In [None]:
# stdlib
import os

# syft absolute
import syft as sy
from syft.client.syncing import compare_states
from syft.client.syncing import resolve_single

In [None]:
BIGQUERY_CREDENTIALS_FILE = ""
URL_HIGH = ""
PASSWORD_HIGH = ""

URL_LOW = ""
PASSWORD_LOW = ""

In [None]:
# Local node + Client
# node_low = sy.orchestra.launch(
#     name="reddit_l",
#     node_side_type="low",
#     dev_mode=True,
#     reset=True,
#     local_db=True,
#     n_consumers=1,
#     create_producer=True,
# )

# node_high = sy.orchestra.launch(
#     name="reddit_h",
#     node_side_type="high",
#     dev_mode=True,
#     reset=True,
#     local_db=True,
#     n_consumers=1,
#     create_producer=True,
# )

# client_low = node_low.login(email="info@openmined.org", password="changethis")
# client_high = node_high.login(email="info@openmined.org", password="changethis")

In [None]:
client_high = sy.login(
    url=URL_HIGH,
    email="info@openmined.org",
    password=PASSWORD_HIGH,
)

client_low = sy.login(
    url=URL_LOW,
    email="info@openmined.org",
    password=PASSWORD_LOW,
)

In [None]:
client_low.register(
    email="newuser@openmined.org", name="John Doe", password="pw", password_verify="pw"
)

In [None]:
client_low_ds = sy.login(
    url=URL_LOW,
    email="newuser@openmined.org",
    password="pw"
)

# Setup twin api

In [None]:
import json

with open(BIGQUERY_CREDENTIALS_FILE, 'r') as f:
    bq_credentials = json.load(f)

In [None]:
# Mock Behavior


@sy.mock_api_endpoint(
    # settings={}
    settings=bq_credentials,
)
def mock_query_function(
    context,  # Variable used to track user session, user information, user activities and settings
    sql_query: str,
) -> str:
    # third party
    from google.cloud import bigquery
    from google.oauth2 import service_account

    credentials = service_account.Credentials.from_service_account_info(
        context.settings
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location="us-west1",
    )

    rows = client.query_and_wait(
        sql_query,
        project=context.settings["project_id"],
    )

    # Replacing private values to mocked ones.
    result = rows.to_dataframe()
    result["int64_field_0"] = 0
    result["id"] = "Private"
    result["name"] = "Private"
    result["subscribers_count"] = 0
    result["permalink"] = "Private"
    result["nsfw"] = "NaN"
    result["spam"] = False
    return result

In [None]:
# Private Behavior


@sy.private_api_endpoint(
    # settings={}
    settings=bq_credentials,
)
def private_query_function(
    context,
    sql_query: str,
) -> str:
    # third party
    from google.cloud import bigquery
    from google.oauth2 import service_account

    print("test")
    credentials = service_account.Credentials.from_service_account_info(
        context.settings
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location="us-west1",
    )
    print(client)

    rows = client.query_and_wait(
        sql_query,
        project=context.settings["project_id"],
    )
    print(rows)

    return rows.to_dataframe()

In [None]:
new_endpoint = sy.TwinAPIEndpoint(
    path="reddit.query",
    private_function=private_query_function,
    mock_function=mock_query_function,
    description="Lorem ipsum dolor sit amet lorem adipiscing elit …",
)
client_high.api.services.api.add(endpoint=new_endpoint)

In [None]:
client_high.refresh()

In [None]:
client_high.api.services.reddit.query.private(
    sql_query="SELECT *  FROM test_1gb.subreddits LIMIT 100"
)

In [None]:
twin_api_obj = client_high.api.services.api.api_endpoints()[0]

twin_api_obj

## Sync twin api

In [None]:
#CustomEndpointActionObject

In [None]:
high_state = client_high.get_sync_state()
low_state = client_low.get_sync_state()

high_state

In [None]:
diff = compare_states(high_state, low_state)

diff

In [None]:
widget = resolve_single(diff[0])

widget

In [None]:
widget.click_sync()

In [None]:
client_low.get_sync_state()

# Create request

## Use mock endpoint

In [None]:
client_low.refresh()
client_low_ds.refresh()

In [None]:
client_low_ds.api.services.reddit.query(
    sql_query="SELECT *  FROM test_1gb.subreddits LIMIT 100"
)

# Define code, project, request

In [None]:
@sy.syft_function_single_use(
    reddit_query=client_low_ds.api.services.reddit.query,
)
def my_research_pipeline(reddit_query):
    sql_query = "SELECT *  FROM test_1gb.subreddits LIMIT 100"
    return reddit_query(sql_query=sql_query)

In [None]:
new_project = sy.Project(
    name="Reddit Research Studies",
    description="Hi, I want to get information about your data.",
    members=[client_low_ds],
)

new_project.create_code_request(my_research_pipeline, client_low_ds)

In [None]:
# low_request = client_low.requests[-1]

In [None]:
# low_request

In [None]:
# low_code = low_request.code

In [None]:
# low_code

## Sync code request to high side

In [None]:
low_state = client_low.get_sync_state()
high_state = client_high.get_sync_state()

In [None]:
low_state

In [None]:
diff_state = compare_states(low_state, high_state)

In [None]:
diff_state

### Sync UserCode

In [None]:
code_diff = diff_state[0]
widget = resolve_single(code_diff)

widget

In [None]:
widget.click_sync()

### Sync Request

In [None]:
request_diff = diff_state[1]
widget = resolve_single(request_diff)

widget

In [None]:
widget.click_sync()

# High side: Run and sync back

## Run on high side

In [None]:
client_high.refresh()

In [None]:
client_high.requests

In [None]:
request = client_high.requests[0]

In [None]:
# request.code

In [None]:
res = client_high.code.my_research_pipeline(
    reddit_query=client_high.api.services.reddit.query
)

In [None]:
res

In [None]:
request.accept_by_depositing_result(res)

## Sync back to low side

In [None]:
low_state = client_low.get_sync_state()
high_state = client_high.get_sync_state()

In [None]:
high_state

In [None]:
diff_state_2 = compare_states(high_state, low_state)

In [None]:
diff_state_2

In [None]:
code_batch = diff_state_2[0]

widget = resolve_single(code_batch)
widget

In [None]:
widget.click_sync()

In [None]:
request_batch = diff_state_2[1]

widget = resolve_single(request_batch)
widget

In [None]:
widget.click_sync()

In [None]:
job_batch = diff_state_2[2]

widget = resolve_single(job_batch)
widget

In [None]:
widget.click_share_all_private_data()

In [None]:
widget.click_sync()

In [None]:
client_low.get_sync_state()

# Run on low side

In [None]:
client_low_ds.refresh()
client_low_ds.code.my_research_pipeline(
    reddit_query=client_low_ds.api.services.reddit.query
)