In [1]:
import syft as sy


In [2]:
node = sy.orchestra.launch(
    name="reddit_h",
    node_side_type="high",
    dev_mode=True,
    reset=True,
    local_db=True,
    n_consumers=1,
    create_producer=True,
)

Staging Protocol Changes...
Creating default worker image with tag='local-dev'
Building default worker image with tag=local-dev
Setting up worker poolname=default-pool workers=1 image_uid=512703c6ea404905b2b721d21e3cd285 in_memory=True
Created default worker pool.
Data Migrated to latest version !!!


In [3]:
domain_client = node.login(email="info@openmined.org", password="changethis")

Logged into <reddit_h: High side Domain> as <info@openmined.org>


In [4]:
# TODO credentials
bq_credentials = {}

In [5]:
# Mock Behavior
@sy.mock_api_endpoint(
    # settings={}
    settings=bq_credentials,
)
def mock_query_function(
    context, # Variable used to track user session, user information, user activities and settings
    sql_query: str,  
) -> str:
    # third party
    from google.cloud import bigquery
    from google.oauth2 import service_account

    credentials = service_account.Credentials.from_service_account_info(
        context.settings
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location="us-west1",
    )

    rows = client.query_and_wait(
        sql_query,
        project=context.settings['project_id'],
    )

    # Replacing private values to mocked ones.
    result = rows.to_dataframe()
    result['int64_field_0'] = 0
    result['id'] = "Private"
    result['name'] = "Private"
    result['subscribers_count'] = 0
    result['permalink'] = "Private"
    result['nsfw'] = "NaN"
    result['spam'] = False
    return result
    
    
    

# Private Behavior
@sy.private_api_endpoint(
    # settings={}
    settings=bq_credentials,
)
def private_query_function(
    context,
    sql_query: str,  
) -> str:
    # third party
    from google.cloud import bigquery
    from google.oauth2 import service_account
    import pandas as pd
    print("test")
    credentials = service_account.Credentials.from_service_account_info(
        context.settings
    )
    scoped_credentials = credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    client = bigquery.Client(
        credentials=scoped_credentials,
        location="us-west1",
    )
    print(client)
    
    rows = client.query_and_wait(
        sql_query,
        project=context.settings["project_id"],
    )
    print(rows) 

    return rows.to_dataframe()

In [6]:
mock_query_function

```python
class PublicAPIEndpoint:
  id: str = c44649b415734ed5a4b524ca0c9562e7

```

In [7]:
from typing import Any


class RepeatedCallPolicy(sy.CustomOutputPolicy):
    n_calls: int = 0
    state: dict[Any, Any] = {}

    def __init__(self, n_calls=1):
        self.n_calls = n_calls
        self.state = {"counts": 0}

    def public_state(self):
        return self.state["counts"]

    def apply_output(self, context, outputs):
        if hasattr(outputs, "syft_action_data"):
            outputs = outputs.syft_action_data
        if self.state["counts"] < self.n_calls:
            self.state["counts"] += 1
        else:
            return None

        return outputs
    
    def _is_valid(self, context):
        return self.state['counts'] < self.n_calls

In [8]:
from syft.service.policy.policy import ExactMatch, SingleExecutionExactOutput

# just for testing
input_policy = ExactMatch()
output_policy = RepeatedCallPolicy(n_calls=1)


In [9]:
from syft.service.policy.policy import SubmitUserPolicy


new_endpoint = sy.TwinAPIEndpoint(
    path="reddit.query",
    private_function=private_query_function,
    mock_function=mock_query_function,
    description="Lorem ipsum dolor sit amet lorem adipiscing elit …",
    input_policy_type=ExactMatch,
    input_policy_init_kwargs={},
    output_policy_type=SubmitUserPolicy.from_obj(output_policy),
    output_policy_init_kwargs=output_policy.init_kwargs,
)


In [10]:
# from syft.service.api.api import TwinAPIEndpoint


# new_endpoint.to(TwinAPIEndpoint)

In [11]:
domain_client.api.services.api.add(endpoint=new_endpoint)

syft.service.api.api.CreateTwinAPIEndpoint
Domain: reddit_h - 4b600cd95dcc42a79be83b5fdfedb75b - domain

Services:
APIService, ActionService, BlobStorageService, CodeHistoryService, DataSubjectMemberService, DataSubjectService, DatasetService, EnclaveService, JobService, LogService, MetadataService, MigrateStateService, NetworkService, NotificationService, NotifierService, OutputService, PolicyService, ProjectService, QueueService, RequestService, SettingsService, SyftImageRegistryService, SyftWorkerImageService, SyftWorkerPoolService, SyncService, UserCodeService, UserCodeStatusService, UserService, WorkerService
Domain: reddit_h - 4b600cd95dcc42a79be83b5fdfedb75b - domain

Services:
APIService, ActionService, BlobStorageService, CodeHistoryService, DataSubjectMemberService, DataSubjectService, DatasetService, EnclaveService, JobService, LogService, MetadataService, MigrateStateService, NetworkService, NotificationService, NotifierService, OutputService, PolicyService, ProjectService,

In [12]:
domain_client.refresh()

In [13]:
domain_client.api.services.reddit.query(sql_query="SELECT *  FROM test_1gb.subreddits LIMIT 100")

SELECT *  FROM test_1gb.subreddits LIMIT 100
POLICY CHECK: True
test
<google.cloud.bigquery.client.Client object at 0x7a6ebfb2d3d0>
<google.cloud.bigquery.table.RowIterator object at 0x7a6ebf975ed0>
SyftSuccess: Endpoint STATE successfully updated. False


Unnamed: 0,int64_field_0,id,name,subscribers_count,permalink,nsfw,spam
0,4,t5_via1x,/r/mylittlepony,4323081,/r//r/mylittlepony,,False
1,5,t5_cv9gn,/r/polyamory,2425929,/r//r/polyamory,,False
2,10,t5_8p2tq,/r/Catholicism,4062607,/r//r/Catholicism,,False
3,16,t5_8fcro,/r/cordcutters,7543226,/r//r/cordcutters,,False
4,17,t5_td5of,/r/stevenuniverse,2692168,/r//r/stevenuniverse,,False
...,...,...,...,...,...,...,...
95,305,t5_jgydw,/r/cannabis,7703201,/r//r/cannabis,,False
96,311,t5_3mfau,/r/marvelmemes,4288492,/r//r/marvelmemes,,False
97,317,t5_ub3c8,/r/ghibli,6029127,/r//r/ghibli,,False
98,319,t5_fbgo3,/r/birdsarentreal,3416317,/r//r/birdsarentreal,,False


In [14]:
domain_client.api.services.reddit.query(sql_query="SELECT *  FROM test_1gb.subreddits LIMIT 100")

SELECT *  FROM test_1gb.subreddits LIMIT 100
POLICY CHECK: False
