In [None]:
# syft absolute
import syft as sy

In [None]:
node = sy.orchestra.launch(name="test-domain-1", port="auto", dev_mode=True, reset=True)

# Admin

In [None]:
domain_client = node.login(email="info@openmined.org", password="changethis")

In [None]:
# syft absolute
from syft.client.api import APIRegistry

In [None]:
APIRegistry.get_all_api()

In [None]:
PRIVATE_KEY = ""

In [None]:
SERVICE_ACCOUNT = {
    "type": "service_account",
    "project_id": "project-enigma-415021",
    "private_key_id": "0bd7cdd831f456f905fa98ad570740948bf7b7b9",
    "private_key": PRIVATE_KEY,
    "client_email": "vertex-test@project-enigma-415021.iam.gserviceaccount.com",
    "client_id": "113559790781665979367",
    "auth_uri": "https://accounts.google.com/o/oauth2/auth",
    "token_uri": "https://oauth2.googleapis.com/token",
    "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
    "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/vertex-test%40project-enigma-415021.iam.gserviceaccount.com",
    "universe_domain": "googleapis.com",
}

In [None]:
@sy.mock_api_endpoint()
def mock_run_vertex(
    context,
    prompt: str,
    max_tokens: int = 50,
    temperature: float = 0.1,
    top_p: float = 1.0,
    top_k: int = 1,
    raw_response: bool = False,
) -> str:
    if raw_response:
        return {"prediction": "You get back a raw result"}
    else:
        return {"prediction": "You get back a result"}

In [None]:
# !uv pip install google-cloud-aiplatform

In [None]:
@sy.private_api_endpoint(
    settings={"SERVICE_ACCOUNT": SERVICE_ACCOUNT},
)
def private_run_vertex(
    context,
    prompt: str,
    max_tokens: int = 50,
    temperature: float = 0.1,
    top_p: float = 1.0,
    top_k: int = 1,
    raw_response: bool = False,
) -> str:
    # third party
    from google.cloud import aiplatform
    from google.oauth2 import service_account

    try:
        credentials = service_account.Credentials.from_service_account_info(
            context.settings["SERVICE_ACCOUNT"]
        )

        PROJECT_ID = "project-enigma-415021"
        REGION = "us-west1"
        ENDPOINT_ID = "3213239169291649024"
        aip_endpoint_name = (
            f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{ENDPOINT_ID}"
        )
        endpoint_vllm = aiplatform.Endpoint(aip_endpoint_name, credentials=credentials)
        default_kwargs = {
            "prompt": prompt,
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "top_k": top_k,
            "raw_response": raw_response,
        }
        instances = [
            default_kwargs,
        ]
        response = endpoint_vllm.predict(instances=instances)
        prediction = response.predictions[0]
    except Exception:
        prediction = "Error: Please try again?"
    return {"prediction": prediction}

In [None]:
new_endpoint = sy.TwinAPIEndpoint(
    path="vertex.run",
    private_function=private_run_vertex,
    mock_function=mock_run_vertex,
    description="Run vertex model",
)
new_endpoint

In [None]:
# use to delete if you want to add again (there is a seperate update API)
domain_client.api.services.api.delete(endpoint_path="vertex.run")

In [None]:
domain_client.api.services.api.add(endpoint=new_endpoint)

## Create Data Scientist

In [None]:
domain_client.register(
    name="Jimmy Doe",
    email="jimmy@caltech.edu",
    password="abc123",
    password_verify="abc123",
    institution="Caltech",
    website="https://www.caltech.edu/",
)

# Data Scientist

In [None]:
users = domain_client.api.services.user.search(name="Jimmy Doe")
user = users[0]
user.mock_execution_permission

In [None]:
user.update(mock_execution_permission=True)

In [None]:
users = domain_client.api.services.user.search(name="Jimmy Doe")
user = users[0]
user.mock_execution_permission

In [None]:
jimmy_client = node.login(email="jimmy@caltech.edu", password="abc123")

In [None]:
jimmy_client.api.services.vertex.run.mock(prompt="test", raw_response=True)

## Create Input Policy

In [None]:
input_policy = sy.MixedInputPolicy(
    func=jimmy_client.api.services.vertex.run,
    prompt=str,
    max_tokens=int,
    temperature=float,
    top_p=float,
    top_k=int,
    raw_response=bool,
)
input_policy

In [None]:
# assert False

## Create Output Policy

In [None]:
class RateLimiter(sy.CustomOutputPolicy):
    n_calls: int = 0
    downloadable_output_args: list[str] = []
    state: dict = {}

    def __init__(self, n_calls=1, downloadable_output_args: list[str] = None):
        self.downloadable_output_args = (
            downloadable_output_args if downloadable_output_args is not None else []
        )
        self.n_calls = n_calls
        self.state = {"counts": 0}

    def public_state(self):
        return self.state["counts"]

    def update_policy(self, context, outputs):
        self.state["counts"] += 1

    def apply_to_output(self, context, outputs, update_policy=True):
        if hasattr(outputs, "syft_action_data"):
            outputs = outputs.syft_action_data
        output_dict = {}
        if self.state["counts"] < self.n_calls:
            for output_arg in self.downloadable_output_args:
                output_dict[output_arg] = outputs[output_arg]
            if update_policy:
                self.update_policy(context, outputs)
        else:
            return "You've hit the rate limit. Please contact the administrator."

        output_dict["calls_remaining"] = self.n_calls - self.state["counts"]
        return output_dict

    def _is_valid(self, context):
        return self.state["counts"] < self.n_calls

In [None]:
@sy.syft_function(
    input_policy=input_policy,
    output_policy=RateLimiter(n_calls=3, downloadable_output_args=["prediction"]),
)
def my_vertex_func(
    func,
    prompt: str,
    max_tokens: int = 50,
    temperature: float = 0.1,
    top_p: float = 1.0,
    top_k: int = 1,
    raw_response: bool = False,
):
    return func(
        prompt=prompt,
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        raw_response=raw_response,
    )

In [None]:
my_vertex_func

In [None]:
# @sy.syft_function()
#     input_policy=input_policy,
#     output_policy=RateLimiter(n_calls=3, downloadable_output_args=["prediction"]),
# )
# def my_vertex_func(
#     func,
#     prompt: str,
#     max_tokens: int,
#     temperature: float,
#     top_p: float,
#     top_k: int,
#     raw_response: bool
# ):
#     return func(
#         prompt=prompt,
#         max_tokens=max_tokens,
#         temperature=temperature,
#         top_p=top_p,
#         top_k=top_k,
#         raw_response=raw_response
#     )

In [None]:
# my_vertex_func(func=jane_client.api.services.vertex.run, prompt="my prompt")

In [None]:
# def test_func(*args, **kwargs):
#     print(kwargs)
#     return {"prediction": f"Test API: {kwargs['prompt']}"}

In [None]:
# my_vertex_func(func=test_func, prompt="my prompt")

In [None]:
# assert False

In [None]:
new_project = sy.Project(
    name="Vertex Model Access",
    description="Hi, I want to use this model 3 times",
    members=[jimmy_client],
)

new_project.create_code_request(my_vertex_func, jimmy_client)

## Admin approves

In [None]:
domain_client.requests

In [None]:
request = domain_client.requests[-1]
request

In [None]:
request.approve()

## Data Scientist runs

In [None]:
jimmy_client.refresh()

In [None]:
jimmy_client.code.my_vertex_func

In [None]:
result = jimmy_client.code.my_vertex_func(
    func=jimmy_client.api.services.vertex.run,
    prompt="Who are you now?",
    max_tokens=50,
    temperature=0.1,
    top_p=1.0,
    top_k=1,
    raw_response=False,
)
result

In [None]:
result = jimmy_client.code.my_vertex_func(
    func=jimmy_client.api.services.vertex.run,
    prompt="Who are you now?",
    max_tokens=50,
    temperature=0.1,
    top_p=1.0,
    top_k=1,
    raw_response=False,
)
result

In [None]:
result = jimmy_client.code.my_vertex_func(
    func=jimmy_client.api.services.vertex.run,
    prompt="Who are you now?",
    max_tokens=50,
    temperature=0.1,
    top_p=1.0,
    top_k=1,
    raw_response=False,
)
result

In [None]:
result = jimmy_client.code.my_vertex_func(
    func=jimmy_client.api.services.vertex.run,
    prompt="Who are you now?",
    max_tokens=50,
    temperature=0.1,
    top_p=1.0,
    top_k=1,
    raw_response=False,
)
result

## Advanced

In [None]:
assert False

In [None]:
# Allow api endpoint code to generate a policy and code submission object

In [None]:
@sy.private_api_endpoint()
def private_user_function_creator(
    context,
    api_func,
    n_calls: int,
    name: str,
) -> str:
    # syft absolute
    import syft as sy

    # create input policy
    input_policy = sy.MixedInputPolicy(
        func=api_func,
        prompt=str,
        max_tokens=int,
        temperature=float,
        top_p=float,
        top_k=int,
        raw_response=bool,
    )

    class RateLimiter(sy.CustomOutputPolicy):
        n_calls: int = 0
        downloadable_output_args: list[str] = []
        state: dict = {}

        def __init__(self, n_calls=1, downloadable_output_args: list[str] = None):
            self.downloadable_output_args = (
                downloadable_output_args if downloadable_output_args is not None else []
            )
            self.n_calls = n_calls
            self.state = {"counts": 0}

        def public_state(self):
            return self.state["counts"]

        def update_policy(self, context, outputs):
            self.state["counts"] += 1

        def apply_to_output(self, context, outputs, update_policy=True):
            if hasattr(outputs, "syft_action_data"):
                outputs = outputs.syft_action_data
            output_dict = {}
            if self.state["counts"] < self.n_calls:
                for output_arg in self.downloadable_output_args:
                    output_dict[output_arg] = outputs[output_arg]
                if update_policy:
                    self.update_policy(context, outputs)
            else:
                return "You've hit the rate limit. Please contact the administrator."

            output_dict["calls_remaining"] = self.n_calls - self.state["counts"]
            return output_dict

        def _is_valid(self, context):
            return self.state["counts"] < self.n_calls

    @sy.syft_function(
        input_policy=input_policy,
        output_policy=RateLimiter(
            n_calls=n_calls, downloadable_output_args=["prediction"]
        ),
    )
    def my_vertex_func(
        func,
        prompt: str,
        max_tokens: int = 50,
        temperature: float = 0.1,
        top_p: float = 1.0,
        top_k: int = 1,
        raw_response: bool = False,
    ):
        return func(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            raw_response=raw_response,
        )

    my_vertex_func.__name__ = name
    return my_vertex_func

In [None]:
@sy.mock_api_endpoint()
def mock_user_function_creator(
    context,
    api_func,
    n_calls: int,
    name: str,
) -> str:
    # syft absolute
    import syft as sy

    # create input policy
    input_policy = sy.MixedInputPolicy(
        func=api_func,
        prompt=str,
        max_tokens=int,
        temperature=float,
        top_p=float,
        top_k=int,
        raw_response=bool,
    )

    class RateLimiter(sy.CustomOutputPolicy):
        n_calls: int = 0
        downloadable_output_args: list[str] = []
        state: dict = {}

        def __init__(self, n_calls=1, downloadable_output_args: list[str] = None):
            self.downloadable_output_args = (
                downloadable_output_args if downloadable_output_args is not None else []
            )
            self.n_calls = n_calls
            self.state = {"counts": 0}

        def public_state(self):
            return self.state["counts"]

        def update_policy(self, context, outputs):
            self.state["counts"] += 1

        def apply_to_output(self, context, outputs, update_policy=True):
            if hasattr(outputs, "syft_action_data"):
                outputs = outputs.syft_action_data
            output_dict = {}
            if self.state["counts"] < self.n_calls:
                for output_arg in self.downloadable_output_args:
                    output_dict[output_arg] = outputs[output_arg]
                if update_policy:
                    self.update_policy(context, outputs)
            else:
                return "You've hit the rate limit. Please contact the administrator."

            output_dict["calls_remaining"] = self.n_calls - self.state["counts"]
            return output_dict

        def _is_valid(self, context):
            return self.state["counts"] < self.n_calls

    @sy.syft_function(
        input_policy=input_policy,
        output_policy=RateLimiter(
            n_calls=n_calls, downloadable_output_args=["prediction"]
        ),
    )
    def my_vertex_func(
        func,
        prompt: str,
        max_tokens: int = 50,
        temperature: float = 0.1,
        top_p: float = 1.0,
        top_k: int = 1,
        raw_response: bool = False,
    ):
        return func(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            raw_response=raw_response,
        )

    my_vertex_func.__name__ = name
    return my_vertex_func

In [None]:
creator_endpoint = sy.TwinAPIEndpoint(
    path="vertex.create_code_request",
    private_function=private_user_function_creator,
    mock_function=mock_user_function_creator,
    description="Create a vertex code request",
)
creator_endpoint

In [None]:
response = domain_client.api.services.api.delete(
    endpoint_path="vertex.create_code_request"
)
response

In [None]:
response = domain_client.api.services.api.add(endpoint=creator_endpoint)
response

In [None]:
domain_client.refresh()

In [None]:
# we need to make the RemoteFunction serializable andchange what comes back to the user side from the Admin API
# or add a context.submit as user

In [None]:
domain_client.api.services.vertex.create_code_request(
    api_func=domain_client.api.services.vertex.run, n_calls=2, name="myfunc"
)

In [None]:
jimmy_client.refresh()

In [None]:
# jimmy_client.api.services.vertex.create_code_request(api_func=jimmy_client.api.services.vertex.run, n_calls=2, name="myfunc")

In [None]:
# new_project = sy.Project(
#     name="Vertex Model Access",
#     description="Hi, I want to use this model 3 times",
#     members=[jane_client],
# )

# new_project.create_code_request(code_obj, jane_client)