In [None]:
# import os
# os.environ["ORCHESTRA_DEPLOYMENT_TYPE"] = "remote"
# os.environ["DEV_MODE"] = "True"
# os.environ["TEST_EXTERNAL_REGISTRY"] = "k3d-registry.localhost:5800"
# stdlib

In [None]:
# stdlib
import os

# third party
from helpers import Timeout
from helpers import get_email_server

# syft absolute
import syft as sy

In [None]:
environment = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", "python")

num_workers = int(os.environ.get("NUM_TEST_WORKERS", 1))

ROOT_EMAIL = "admin@bigquery.org"
ROOT_PASSWORD = "bqpw"
environment

### Launch server & login

In [None]:
server = sy.orchestra.launch(
    name="bigquery-high",
    dev_mode=True,
    server_side_type="high",
    port="8080",
    n_consumers=num_workers,  # How many workers to be spawned
    create_producer=True,  # Can produce more workers
)

In [None]:
email_server, smtp_server = get_email_server(reset=True)

In [None]:
high_client = sy.login(
    url="http://localhost:8080", email=ROOT_EMAIL, password=ROOT_PASSWORD
)

In [None]:
high_client.worker_pools

In [None]:
default_worker_pool = high_client.worker_pools.get_by_name("default-pool")
default_worker_pool

### Scale Worker pool

##### Scale up

In [None]:
# Scale to 1
if environment == "remote":
    high_client.api.worker_pool.scale(
        number=num_workers, pool_name=default_worker_pool.name
    )

In [None]:
high_client.api.services.worker_pool[0]

In [None]:
# Scale up workers
if environment == "remote":
    scale_up_result = high_client.api.worker_pool.scale(
        number=5, pool_name=default_worker_pool.name
    )
    if environment == "remote":
        assert scale_up_result, scale_up_result

        assert (
            high_client.api.services.worker_pool[default_worker_pool.name].max_count
            == 5
        )

##### Give workers some long-running jobs


In [None]:
@sy.syft_function_single_use(worker_pool_name=default_worker_pool.name)
def wait_1000_seconds_1():
    # stdlib
    import time

    time.sleep(1000)


@sy.syft_function_single_use(worker_pool_name=default_worker_pool.name)
def wait_1000_seconds_2():
    # stdlib
    import time

    time.sleep(1000)


@sy.syft_function_single_use(worker_pool_name=default_worker_pool.name)
def wait_1000_seconds_3():
    # stdlib
    import time

    time.sleep(1000)

In [None]:
jobs = []
high_client.code.request_code_execution(wait_1000_seconds_1)
high_client.code.request_code_execution(wait_1000_seconds_2)
high_client.code.request_code_execution(wait_1000_seconds_3)

assert len(list(high_client.requests)) == 3
for request in high_client.requests:
    request.approve()

In [None]:
jobs = []
jobs.append(high_client.code.wait_1000_seconds_1(blocking=False))
jobs.append(high_client.code.wait_1000_seconds_2(blocking=False))
jobs.append(high_client.code.wait_1000_seconds_3(blocking=False))


assert len(list(high_client.jobs)) == 3

In [None]:
# check that at least three workers have a job (since scaling down to 2)
# try 3 times with a 1 second sleep in case it takes time for the workers to accept the jobs
for _ in range(3):
    worker_to_job_map = {}
    syft_workers_ids = set()
    for job in high_client.jobs:
        if job.status == "processing":
            syft_workers_ids.add(job.worker.id)
            worker_to_job_map[job.worker.id] = job.id
    if len(syft_workers_ids) < 3:
        time.sleep(1)
    else:
        break
assert len(syft_workers_ids) >= 3

##### Scale down

In [None]:
# Scale down workers, this gracefully shutdowns the consumers
if environment == "remote":
    scale_down_result = high_client.api.worker_pool.scale(
        number=num_workers, pool_name=default_worker_pool.name
    )
    assert scale_down_result, scale_down_result

In [None]:
if environment == "remote":

    def has_worker_scaled_down():
        worker_count_condition = (
            high_client.api.worker_pool[default_worker_pool.name].max_count
            == num_workers
        )
        current_worker_ids = {
            worker.id
            for worker in high_client.api.services.worker_pool[
                default_worker_pool.name
            ].workers
        }
        job_status_condition = [
            job.status == "interrupted"
            for job in high_client.jobs
            if job.job_worker_id is not None
            and job.job_worker_id not in current_worker_ids
        ]

        jobs_on_old_workers_are_interrupted = all(job_status_condition)
        return worker_count_condition and jobs_on_old_workers_are_interrupted

    worker_scale_timeout = Timeout(timeout_duration=60)
    worker_scale_timeout.run_with_timeout(has_worker_scaled_down)

In [None]:
if environment == "remote":
    assert (
        high_client.api.services.worker_pool[default_worker_pool.name].max_count
        == num_workers
    )

#### Delete Worker Pool

In [None]:
pool_delete_result = high_client.api.services.worker_pool.delete(
    pool_name=default_worker_pool.name
)
pool_delete_result

In [None]:
with sy.raises(KeyError):
    _ = high_client.api.services.worker_pool[default_worker_pool.name]

In [None]:
# check that all jobs are interrupted
# should be the case since the entire pool was deleted and all jobs were previously assigned
assert all(job.status == "interrupted" for job in high_client.jobs)

#### Re-launch the default worker pool

In [None]:
default_worker_image = default_worker_pool.image

In [None]:
launch_result = high_client.api.services.worker_pool.launch(
    pool_name=default_worker_pool.name,
    image_uid=default_worker_image.id,
    num_workers=num_workers,
)

In [None]:
assert high_client.api.services.worker_pool[default_worker_pool.name]
assert (
    high_client.api.services.worker_pool[default_worker_pool.name].max_count
    == num_workers
)

In [None]:
smtp_server.stop()

In [None]:
server.land()