In [None]:
# stdlib
import os

# syft absolute
import syft as sy
from syft.util.test_helpers.email_helpers import Timeout

In [None]:
environment = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", "python")
high_port = os.environ.get("CLUSTER_HTTP_PORT_HIGH", "9081")
low_port = os.environ.get("CLUSTER_HTTP_PORT_LOW", "9083")
print(environment, high_port, low_port)

In [None]:
num_workers = int(os.environ.get("NUM_TEST_WORKERS", 1))

ROOT_EMAIL = "admin@bigquery.org"
ROOT_PASSWORD = "bqpw"

### Launch server & login

In [None]:
server_low = sy.orchestra.launch(
    name="bigquery-low",
    server_side_type="low",
    dev_mode=True,
    n_consumers=1,
    create_producer=True,
    port=low_port,
)

In [None]:
low_client = sy.login(
    url=f"http://localhost:{low_port}", email=ROOT_EMAIL, password=ROOT_PASSWORD
)

In [None]:
assert len(low_client.worker_pools.get_all()) == 1

In [None]:
default_worker_pool = low_client.worker_pools.get_by_name("default-pool")
default_worker_pool

### Scale Worker pool

##### Scale up

In [None]:
# Scale to 1
if environment == "remote":
    low_client.api.worker_pool.scale(
        number=num_workers, pool_name=default_worker_pool.name
    )

In [None]:
low_client.api.services.worker_pool[0]

In [None]:
# Scale up workers
if environment == "remote":
    scale_up_result = low_client.api.worker_pool.scale(
        number=5, pool_name=default_worker_pool.name
    )
    if environment == "remote":
        assert scale_up_result, scale_up_result

        assert (
            low_client.api.services.worker_pool[default_worker_pool.name].max_count == 5
        )

##### Scale down

In [None]:
# Scale down workers, this gracefully shutdowns the consumers
if environment == "remote":
    scale_down_result = low_client.api.worker_pool.scale(
        number=num_workers, pool_name=default_worker_pool.name
    )
    assert scale_down_result, scale_down_result

In [None]:
if environment == "remote":

    def has_worker_scaled_down():
        return (
            low_client.api.worker_pool[default_worker_pool.name].max_count
            == num_workers
        )

    worker_scale_timeout = Timeout(timeout_duration=20)
    worker_scale_timeout.run_with_timeout(has_worker_scaled_down)

In [None]:
if environment == "remote":
    assert (
        low_client.api.services.worker_pool[default_worker_pool.name].max_count
        == num_workers
    )

#### Delete Worker Pool

In [None]:
pool_delete_result = low_client.api.services.worker_pool.delete(
    pool_name=default_worker_pool.name
)
pool_delete_result

In [None]:
with sy.raises(KeyError):
    _ = low_client.api.services.worker_pool[default_worker_pool.name]

#### Re-launch the default worker pool

In [None]:
default_worker_image = default_worker_pool.image

In [None]:
launch_result = low_client.api.services.worker_pool.launch(
    pool_name=default_worker_pool.name,
    image_uid=default_worker_image.id,
    num_workers=num_workers,
)

In [None]:
assert low_client.api.services.worker_pool[default_worker_pool.name]
assert (
    low_client.api.services.worker_pool[default_worker_pool.name].max_count
    == num_workers
)

In [None]:
if environment != "remote":
    server_low.land()