# First Login

In [None]:
# syft absolute
import syft as sy

In [None]:
node = sy.orchestra.launch(
    name="test-domain-1",
    port="auto",
    dev_mode=True,
    reset=True,
    n_consumers=1,
    create_producer=True,
)
domain_client = node.login(email="info@openmined.org", password="changethis")

# Modify Admin User/Password

In [None]:
user = domain_client.me
user

In [None]:
user.set_email(email="party@beach.edu")

In [None]:
user.set_password(new_password="test", confirm=False)

In [None]:
domain_client = node.login(email="info@openmined.org", password="changethis")
domain_client

In [None]:
domain_client = node.login(email="party@beach.edu", password="test")
domain_client

# Create Adapter to Vertex

In [None]:
PRIVATE_KEY = ""

In [None]:
SERVICE_ACCOUNT = {
    "type": "service_account",
    "project_id": "project-enigma-415021",
    "private_key_id": "0bd7cdd831f456f905fa98ad570740948bf7b7b9",
    "private_key": PRIVATE_KEY,
    "client_email": "vertex-test@project-enigma-415021.iam.gserviceaccount.com",
    "client_id": "113559790781665979367",
    "auth_uri": "https://accounts.google.com/o/oauth2/auth",
    "token_uri": "https://oauth2.googleapis.com/token",
    "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
    "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/vertex-test%40project-enigma-415021.iam.gserviceaccount.com",
    "universe_domain": "googleapis.com",
}

In [None]:
def rate_limit(user_email: str, api_state: dict, user_settings: dict):
    # stdlib
    import time

    # syft absolute
    import syft as sy

    def filter_history(call_history: dict, period_secs: int) -> None:
        now = time.time()
        filtered_call_history = {}
        for execution_time_str, call_item in call_history.items():
            execution_time = float(execution_time_str)

            diff = now - execution_time
            if diff > period_secs:
                # it has expired so trim
                pass
            else:
                filtered_call_history[execution_time] = call_item
        return filtered_call_history

    def add_call_history(call_history: dict | None) -> None:
        if call_history is None:
            call_history = {}
        now = time.time()
        call_history[str(now)] = True
        return call_history

    def is_valid(call_history: dict, n_calls: int, period_secs: int):
        if call_history is None:
            call_history = {}
        if len(call_history) >= n_calls:
            # syft absolute
            import syft as sy

            return sy.SyftError(
                message=f"You have hit the rate limit of {n_calls} calls in {period_secs} seconds."
            )
        return True

    # get user_state
    user_state = api_state.get(user_email, {})

    if user_email not in user_settings:
        valid = sy.SyftError(
            message=f"Email {user_email} is not allowed to use this API Endpoint."
        )
        return user_state, valid

    # get settings for user
    settings = user_settings.get(user_email, {})
    n_calls = settings.get("n_calls", 3)  # defaults
    period_secs = settings.get("period_secs", 60)  # defaults

    # filter old calls
    user_state = filter_history(call_history=user_state, period_secs=period_secs)

    # check if its still valid
    valid = is_valid(call_history=user_state, n_calls=n_calls, period_secs=period_secs)

    if valid:
        # record a call if we are allowed
        user_state = add_call_history(call_history=user_state)

    return user_state, valid

In [None]:
user_settings = {
    "paul@arrakis.net": {"n_calls": 1, "period_secs": 10},
    "info@openmined.org": {"n_calls": 1, "period_secs": 10},
    "party@beach.edu": {"n_calls": 1, "period_secs": 10},
}

In [None]:
@sy.api_endpoint(
    path="vertex.run",
    settings={"SERVICE_ACCOUNT": SERVICE_ACCOUNT, "user_settings": user_settings},
    helper_functions=[rate_limit],
)
def public_endpoint_method(
    context,
    prompt: str,
    max_tokens: int = 50,
    temperature: float = 0.1,
    top_p: float = 1.0,
    top_k: int = 1,
    raw_response: bool = False,
) -> str:
    # syft absolute
    import syft as sy

    # get helper function
    rate_limit = context.code.helper_functions["rate_limit"]
    user_settings = context.settings["user_settings"]
    user_email = context.user_view.email

    # apply rate limiter
    user_state, valid = rate_limit(
        user_email=user_email, api_state=context.state, user_settings=user_settings
    )

    # update state
    context.state[user_email] = user_state

    if context.user_view.role != sy.ServiceRole.ADMIN and not valid:
        # send back error message
        return valid

    # run code
    # third party
    from google.cloud import aiplatform
    from google.oauth2 import service_account

    try:
        credentials = service_account.Credentials.from_service_account_info(
            context.settings["SERVICE_ACCOUNT"]
        )

        PROJECT_ID = "project-enigma-415021"
        REGION = "us-west1"
        ENDPOINT_ID = "3213239169291649024"
        aip_endpoint_name = (
            f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{ENDPOINT_ID}"
        )
        endpoint_vllm = aiplatform.Endpoint(aip_endpoint_name, credentials=credentials)
        default_kwargs = {
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": top_k,
            "raw_response": raw_response,
        }
        instances = [
            default_kwargs,
        ]
        response = endpoint_vllm.predict(instances=instances)
        prediction = response.predictions[0]
    except Exception as e:
        prediction = f"Error: Please try again? {e}"
    return {"prediction": prediction}

In [None]:
# use to delete if you want to add again (there is a seperate update API)
domain_client.api.services.api.delete(endpoint_path="vertex.run")

In [None]:
domain_client.api.services.api.add(endpoint=public_endpoint_method)

In [None]:
# test with user

In [None]:
domain_client.register(
    name="Paul Atreides",
    email="paul@arrakis.net",
    password="lisan-al-gaib",
    password_verify="lisan-al-gaib",
    institution="House Atreides",
    website="https://arrakis.net/",
)

In [None]:
paul_client = domain_client.login_as(email="paul@arrakis.net")

In [None]:
paul_client.api.services.vertex.run(prompt="hello")

In [None]:
paul_client.api.services.vertex.run(prompt="hello")

# Create Rate Limit Updater

In [None]:
@sy.api_endpoint_method()
def private_set_user_rate_limit(
    context, endpoint_path: str, email: str, n_calls: int, period_secs: int
):
    api_endpoint = context.admin_client.api.services.api.get(api_path=endpoint_path)
    settings = api_endpoint.mock_function.settings
    user_settings = {"n_calls": n_calls, "period_secs": period_secs}
    settings["user_settings"][email] = user_settings
    result = context.admin_client.api.services.api.set_settings(
        api_path=endpoint_path, settings=settings, both=True
    )
    return result

In [None]:
new_endpoint = sy.TwinAPIEndpoint(
    path="ratelimit.update",
    mock_function=None,
    private_function=private_set_user_rate_limit,
)

In [None]:
domain_client.api.services.api.delete(endpoint_path="ratelimit.update")

In [None]:
domain_client.api.services.api.add(endpoint=new_endpoint)

In [None]:
domain_client.refresh()

In [None]:
domain_client.api.services.ratelimit.update(
    endpoint_path="vertex.run", email="paul@arrakis.net", n_calls=0, period_secs=10
)

In [None]:
paul_client.api.services.vertex.run(prompt="hello")

# Create Settings Getter

In [None]:
@sy.api_endpoint_method()
def private_get_user_rate_limit(context, endpoint_path: str, email: str):
    api_endpoint = context.admin_client.api.services.api.get(api_path=endpoint_path)
    settings = api_endpoint.mock_function.settings
    if email in settings["user_settings"]:
        user_rate_limit = settings["user_settings"][email]
        return f"Rate limit for {email} is {user_rate_limit}"
    return f"Rate limit for {email} does not exist"

In [None]:
new_endpoint = sy.TwinAPIEndpoint(
    path="ratelimit.get",
    mock_function=None,
    private_function=private_get_user_rate_limit,
)

In [None]:
domain_client.api.services.api.delete(endpoint_path="ratelimit.get")

In [None]:
domain_client.api.services.api.add(endpoint=new_endpoint)

In [None]:
domain_client.refresh()

In [None]:
domain_client.api.services.ratelimit.get(
    endpoint_path="vertex.run", email="paul@arrakis.net"
)

# Create Reset State Endpoint

In [None]:
@sy.api_endpoint_method()
def private_reset_endpoint_state(context, endpoint_path: str):
    result = context.admin_client.api.services.api.set_state(
        api_path=endpoint_path, state={}, both=True
    )
    return result

In [None]:
new_endpoint = sy.TwinAPIEndpoint(
    path="state.reset",
    mock_function=None,
    private_function=private_reset_endpoint_state,
)

In [None]:
domain_client.api.services.api.delete(endpoint_path="state.reset")

In [None]:
domain_client.api.services.api.add(endpoint=new_endpoint)

In [None]:
domain_client.refresh()

In [None]:
domain_client.api.services.api.get(api_path="vertex.run").mock_function.state

In [None]:
domain_client.api.services.state.reset(endpoint_path="vertex.run")

In [None]:
# Cleanup local domain server
if node.node_type.value == "python":
    node.land()