In [None]:
# import os
# os.environ["ORCHESTRA_DEPLOYMENT_TYPE"] = "remote"
# os.environ["DEV_MODE"] = "True"
# os.environ["TEST_EXTERNAL_REGISTRY"] = "k3d-registry.localhost:5800"

#### Helpers

In [None]:
# stdlib
import time


class TimeoutError(Exception):
    pass


class Timeout:
    def __init__(self, timeout_duration):
        if timeout_duration > 60:
            raise ValueError("Timeout duration cannot exceed 60 seconds.")
        self.timeout_duration = timeout_duration

    def run_with_timeout(self, condition_func, *args, **kwargs):
        start_time = time.time()
        result = None

        while True:
            elapsed_time = time.time() - start_time
            if elapsed_time > self.timeout_duration:
                raise TimeoutError(
                    f"Function execution exceeded {self.timeout_duration} seconds."
                )

            # Check if the condition is met
            try:
                if condition_func():
                    print("Condition met, exiting early.")
                    break
            except Exception as e:
                print(f"Exception in target function: {e}")
                break  # Exit the loop if an exception occurs in the function
            time.sleep(1)

        return result

### Import lib

In [None]:
# stdlib
import os

environment = os.environ.get("ORCHESTRA_DEPLOYMENT_TYPE", "python")
environment

In [None]:
num_workers = int(os.environ.get("NUM_TEST_WORKERS", 1))

In [None]:
# stdlib

# syft absolute
import syft as sy

In [None]:
# third party
# run email server
from helpers import EmailServer
from helpers import SMTPTestServer

email_server = EmailServer()
email_server.reset_emails()
smtp_server = SMTPTestServer(email_server)
smtp_server.start()

In [None]:
server = sy.orchestra.launch(
    name="bigquery-high",
    dev_mode=True,
    server_side_type="high",
    port="8080",
    n_consumers=num_workers,  # How many workers to be spawned
    create_producer=True,  # Can produce more workers
)

In [None]:
ROOT_EMAIL = "admin@bigquery.org"
ROOT_PASSWORD = "bqpw"

In [None]:
high_client = sy.login(
    url="http://localhost:8080", email=ROOT_EMAIL, password=ROOT_PASSWORD
)

In [None]:
high_client.worker_pools

In [None]:
default_worker_pool = high_client.worker_pools.get_by_name("default-pool")
default_worker_pool

### Scale Worker pool

##### Scale up

In [None]:
# Scale to 1
if environment == "remote":
    high_client.api.worker_pool.scale(
        number=num_workers, pool_name=default_worker_pool.name
    )

In [None]:
high_client.api.services.worker_pool[0]

In [None]:
# Scale up workers
if environment == "remote":
    scale_up_result = high_client.api.worker_pool.scale(
        number=5, pool_name=default_worker_pool.name
    )
    if environment == "remote":
        assert scale_up_result, scale_up_result

        assert (
            high_client.api.services.worker_pool[default_worker_pool.name].max_count
            == 5
        )

##### Scale down

In [None]:
# Scale down workers, this gracefully shutdowns the consumers
if environment == "remote":
    scale_down_result = high_client.api.worker_pool.scale(
        number=num_workers, pool_name=default_worker_pool.name
    )
    assert scale_down_result, scale_down_result

In [None]:
if environment == "remote":

    def has_worker_scaled_down():
        return (
            high_client.api.worker_pool[default_worker_pool.name].max_count
            == num_workers
        )

    worker_scale_timeout = Timeout(timeout_duration=20)
    worker_scale_timeout.run_with_timeout(has_worker_scaled_down)

In [None]:
if environment == "remote":
    assert (
        high_client.api.services.worker_pool[default_worker_pool.name].max_count
        == num_workers
    )

#### Delete Worker Pool

In [None]:
pool_delete_result = high_client.api.services.worker_pool.delete(
    pool_name=default_worker_pool.name
)
pool_delete_result

In [None]:
with sy.raises(KeyError):
    _ = high_client.api.services.worker_pool[default_worker_pool.name]

#### Re-launch the default worker pool

In [None]:
default_worker_image = default_worker_pool.image

In [None]:
launch_result = high_client.api.services.worker_pool.launch(
    pool_name=default_worker_pool.name,
    image_uid=default_worker_image.id,
    num_workers=num_workers,
)

In [None]:
assert high_client.api.services.worker_pool[default_worker_pool.name]
assert (
    high_client.api.services.worker_pool[default_worker_pool.name].max_count
    == num_workers
)

In [None]:
smtp_server.stop()

In [None]:
server.land()