From 94d10c3ab9bbe05bfee532dbbb6328a4d00cd171 Mon Sep 17 00:00:00 2001 From: destefy Date: Thu, 11 Jul 2024 14:04:30 -0400 Subject: [PATCH 1/6] use pydantic BaseSettings --- centml/sdk/config.py | 13 +++++++------ requirements.txt | 1 + 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/centml/sdk/config.py b/centml/sdk/config.py index a277ae9..b4724cb 100644 --- a/centml/sdk/config.py +++ b/centml/sdk/config.py @@ -1,11 +1,12 @@ import os +from pydantic_settings import BaseSettings -class Config: - centml_web_url = "https://main.d1tz9z8hgabab9.amplifyapp.com/" - centml_config_dir = os.getenv("CENTML_CONFIG_PATH", default=os.path.expanduser("~/.centml")) - centml_cred_file = centml_config_dir + "/" + os.getenv("CENTML_CRED_FILE", default="credential") +class Config(BaseSettings): + centml_web_url: str = "https://main.d1tz9z8hgabab9.amplifyapp.com/" + centml_config_dir: str = os.getenv("CENTML_CONFIG_PATH", default=os.path.expanduser("~/.centml")) + centml_cred_file: str = centml_config_dir + "/" + os.getenv("CENTML_CRED_FILE", default="credential") - platformapi_url = "https://api.centml.org" + platformapi_url: str = "https://api.centml.org" - firebase_api_key = "AIzaSyBXSNjruNdtypqUt_CPhB8QNl8Djfh5RXI" + firebase_api_key: str = "AIzaSyBXSNjruNdtypqUt_CPhB8QNl8Djfh5RXI" diff --git a/requirements.txt b/requirements.txt index 7ae42c9..fb3af44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ torch>=2.0.0 fastapi>=0.103.0 uvicorn>=0.23.0 python-multipart>=0.0.6 +pydantic-settings==2.0.* Requests==2.32.2 tabulate>=0.9.0 pyjwt>=2.8.0 From 3d3426e3db537b0ae0b965ab19b41f1f65dd67cf Mon Sep 17 00:00:00 2001 From: destefy Date: Thu, 11 Jul 2024 15:24:30 -0400 Subject: [PATCH 2/6] use Config instance instead of class --- centml/cli/login.py | 7 ++++--- centml/sdk/api.py | 4 ++-- centml/sdk/auth.py | 16 ++++++++-------- centml/sdk/config.py | 2 ++ tests/test_backend.py | 4 ++-- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/centml/cli/login.py b/centml/cli/login.py index 5396a1c..8502dfa 100644 --- a/centml/cli/login.py +++ b/centml/cli/login.py @@ -1,6 +1,7 @@ import click -from centml.sdk import auth, config +from centml.sdk import auth +from centml.sdk.config import settings @click.command(help="Login to CentML") @@ -10,7 +11,7 @@ def login(token_file): auth.store_centml_cred(token_file) if auth.load_centml_cred(): - click.echo(f"Authenticating with credentials from {config.Config.centml_cred_file}\n") + click.echo(f"Authenticating with credentials from {settings.centml_cred_file}\n") click.echo("Login successful") else: click.echo("Login with CentML authentication token") @@ -18,7 +19,7 @@ def login(token_file): choice = click.confirm("Do you want to download the token?") if choice: - click.launch(f"{config.Config.centml_web_url}?isCliAuthenticated=true") + click.launch(f"{settings.centml_web_url}?isCliAuthenticated=true") else: click.echo("Login unsuccessful") diff --git a/centml/sdk/api.py b/centml/sdk/api.py index 72142af..3e776b0 100644 --- a/centml/sdk/api.py +++ b/centml/sdk/api.py @@ -3,13 +3,13 @@ from platform_api_client.models.deployment_status import DeploymentStatus from centml.sdk import auth -from centml.sdk.config import Config +from centml.sdk.config import settings from centml.sdk.utils import client_certs @contextlib.contextmanager def get_api(): - configuration = platform_api_client.Configuration(host=Config.platformapi_url, access_token=auth.get_centml_token()) + configuration = platform_api_client.Configuration(host=settings.platformapi_url, access_token=auth.get_centml_token()) with platform_api_client.ApiClient(configuration) as api_client: api_instance = platform_api_client.EXTERNALApi(api_client) diff --git a/centml/sdk/auth.py b/centml/sdk/auth.py index 53c4939..857100e 100644 --- a/centml/sdk/auth.py +++ b/centml/sdk/auth.py @@ -5,11 +5,11 @@ import requests import jwt -from centml.sdk.config import Config +from centml.sdk.config import settings def refresh_centml_token(refresh_token): - api_key = Config.firebase_api_key + api_key = settings.firebase_api_key cred = requests.post( f"https://securetoken.googleapis.com/v1/token?key={api_key}", @@ -18,7 +18,7 @@ def refresh_centml_token(refresh_token): timeout=3, ).json() - with open(Config.centml_cred_file, 'w') as f: + with open(settings.centml_cred_file, 'w') as f: json.dump(cred, f) return cred @@ -27,7 +27,7 @@ def refresh_centml_token(refresh_token): def store_centml_cred(token_file): try: with open(token_file, 'r') as f: - os.makedirs(Config.centml_config_dir, exist_ok=True) + os.makedirs(settings.centml_config_dir, exist_ok=True) refresh_token = json.load(f)["refreshToken"] refresh_centml_token(refresh_token) @@ -38,8 +38,8 @@ def store_centml_cred(token_file): def load_centml_cred(): cred = None - if os.path.exists(Config.centml_cred_file): - with open(Config.centml_cred_file, 'r') as f: + if os.path.exists(settings.centml_cred_file): + with open(settings.centml_cred_file, 'r') as f: cred = json.load(f) return cred @@ -60,5 +60,5 @@ def get_centml_token(): def remove_centml_cred(): - if os.path.exists(Config.centml_cred_file): - os.remove(Config.centml_cred_file) + if os.path.exists(settings.centml_cred_file): + os.remove(settings.centml_cred_file) diff --git a/centml/sdk/config.py b/centml/sdk/config.py index b4724cb..50bce49 100644 --- a/centml/sdk/config.py +++ b/centml/sdk/config.py @@ -10,3 +10,5 @@ class Config(BaseSettings): platformapi_url: str = "https://api.centml.org" firebase_api_key: str = "AIzaSyBXSNjruNdtypqUt_CPhB8QNl8Djfh5RXI" + +settings = Config() \ No newline at end of file diff --git a/tests/test_backend.py b/tests/test_backend.py index 19034c7..c35bd17 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -132,7 +132,7 @@ def test_invalid_status(self, mock_requests): mock_requests.get.assert_called_once() self.assertIn("Status check: request failed, exception from server", str(context.exception)) - @patch("centml.compiler.config.Config.COMPILING_SLEEP_TIME", new=0) + @patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0) @patch("centml.compiler.backend.Runner._compile_model") @patch("centml.compiler.backend.requests") def test_max_tries(self, mock_requests, mock_compile): @@ -148,7 +148,7 @@ def test_max_tries(self, mock_requests, mock_compile): self.assertEqual(mock_compile.call_count, config_instance.MAX_RETRIES + 1) self.assertIn("Waiting for status: compilation failed too many times", str(context.exception)) - @patch("centml.compiler.config.Config.COMPILING_SLEEP_TIME", new=0) + @patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0) @patch("centml.compiler.backend.requests") def test_wait_on_compilation(self, mock_requests): COMPILATION_STEPS = 10 From 83262249fd481e60f2fe96cbb5940e0b48b8e4f2 Mon Sep 17 00:00:00 2001 From: destefy Date: Tue, 16 Jul 2024 16:57:13 -0400 Subject: [PATCH 3/6] use pydantic base class --- centml/cli/login.py | 4 ++-- centml/compiler/backend.py | 6 +++--- centml/compiler/config.py | 14 +++++++------- centml/compiler/server.py | 2 +- centml/sdk/api.py | 4 +++- centml/sdk/auth.py | 14 +++++++------- centml/sdk/config.py | 14 ++++++++------ 7 files changed, 31 insertions(+), 27 deletions(-) diff --git a/centml/cli/login.py b/centml/cli/login.py index 8502dfa..e3cfdb7 100644 --- a/centml/cli/login.py +++ b/centml/cli/login.py @@ -11,7 +11,7 @@ def login(token_file): auth.store_centml_cred(token_file) if auth.load_centml_cred(): - click.echo(f"Authenticating with credentials from {settings.centml_cred_file}\n") + click.echo(f"Authenticating with credentials from {settings.CENTML_CRED_FILE_PATH}\n") click.echo("Login successful") else: click.echo("Login with CentML authentication token") @@ -19,7 +19,7 @@ def login(token_file): choice = click.confirm("Do you want to download the token?") if choice: - click.launch(f"{settings.centml_web_url}?isCliAuthenticated=true") + click.launch(f"{settings.CENTML_WEB_URL}?isCliAuthenticated=true") else: click.echo("Login unsuccessful") diff --git a/centml/compiler/backend.py b/centml/compiler/backend.py index 7ecd3ad..be77545 100644 --- a/centml/compiler/backend.py +++ b/centml/compiler/backend.py @@ -80,7 +80,7 @@ def _get_model_id(self) -> str: def _download_model(self, model_id: str): download_response = requests.get( - url=f"{config_instance.SERVER_URL}/download/{model_id}", timeout=config_instance.TIMEOUT + url=f"{config_instance.CENTML_SERVER_URL}/download/{model_id}", timeout=config_instance.TIMEOUT ) if download_response.status_code != HTTPStatus.OK: raise Exception( @@ -104,7 +104,7 @@ def _compile_model(self, model_id: str): with open(self.serialized_model_path, 'rb') as model_file, open(self.serialized_input_path, 'rb') as input_file: compile_response = requests.post( - url=f"{config_instance.SERVER_URL}/submit/{model_id}", + url=f"{config_instance.CENTML_SERVER_URL}/submit/{model_id}", files={"model": model_file, "inputs": input_file}, timeout=config_instance.TIMEOUT, ) @@ -119,7 +119,7 @@ def _wait_for_status(self, model_id: str) -> bool: while True: # get server compilation status status_response = requests.get( - f"{config_instance.SERVER_URL}/status/{model_id}", timeout=config_instance.TIMEOUT + f"{config_instance.CENTML_SERVER_URL}/status/{model_id}", timeout=config_instance.TIMEOUT ) if status_response.status_code != HTTPStatus.OK: raise Exception( diff --git a/centml/compiler/config.py b/centml/compiler/config.py index f858f29..efd49c8 100644 --- a/centml/compiler/config.py +++ b/centml/compiler/config.py @@ -1,5 +1,6 @@ import os from enum import Enum +from pydantic_settings import BaseSettings class CompilationStatus(Enum): @@ -8,20 +9,19 @@ class CompilationStatus(Enum): DONE = "DONE" -class Config: +class Config(BaseSettings): TIMEOUT: int = 10 MAX_RETRIES: int = 3 COMPILING_SLEEP_TIME: int = 15 - CACHE_PATH: str = os.getenv("CENTML_CACHE_DIR", default=os.path.expanduser("~/.cache/centml")) + CENTML_CACHE_DIR: str = "~/.cache/centml" + BACKEND_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "backend") + SERVER_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "server") - SERVER_URL: str = os.getenv("CENTML_SERVER_URL", default="http://0.0.0.0:8090") - - BACKEND_BASE_PATH: str = os.path.join(CACHE_PATH, "backend") - SERVER_BASE_PATH: str = os.path.join(CACHE_PATH, "server") + CENTML_SERVER_URL: str = "http://0.0.0.0:8090" # Use a constant path since torch.save uses the given file name in it's zipfile. - # Thus, a different filename would result in a different hash. + # Using a different filename would result in a different hash. SERIALIZED_MODEL_FILE: str = "serialized_model.zip" SERIALIZED_INPUT_FILE: str = "serialized_input.zip" PICKLE_PROTOCOL: int = 4 diff --git a/centml/compiler/server.py b/centml/compiler/server.py index 569e6f6..4985192 100644 --- a/centml/compiler/server.py +++ b/centml/compiler/server.py @@ -110,7 +110,7 @@ async def download_handler(model_id: str): def run(): - parsed = urlparse(config_instance.SERVER_URL) + parsed = urlparse(config_instance.CENTML_SERVER_URL) uvicorn.run(app, host=parsed.hostname, port=parsed.port) diff --git a/centml/sdk/api.py b/centml/sdk/api.py index 3e776b0..02281cb 100644 --- a/centml/sdk/api.py +++ b/centml/sdk/api.py @@ -9,7 +9,9 @@ @contextlib.contextmanager def get_api(): - configuration = platform_api_client.Configuration(host=settings.platformapi_url, access_token=auth.get_centml_token()) + configuration = platform_api_client.Configuration( + host=settings.PLATFORM_API_URL, access_token=auth.get_centml_token() + ) with platform_api_client.ApiClient(configuration) as api_client: api_instance = platform_api_client.EXTERNALApi(api_client) diff --git a/centml/sdk/auth.py b/centml/sdk/auth.py index 857100e..65dc834 100644 --- a/centml/sdk/auth.py +++ b/centml/sdk/auth.py @@ -9,7 +9,7 @@ def refresh_centml_token(refresh_token): - api_key = settings.firebase_api_key + api_key = settings.FIREBASE_API_KEY cred = requests.post( f"https://securetoken.googleapis.com/v1/token?key={api_key}", @@ -18,7 +18,7 @@ def refresh_centml_token(refresh_token): timeout=3, ).json() - with open(settings.centml_cred_file, 'w') as f: + with open(settings.CENTML_CRED_FILE_PATH, 'w') as f: json.dump(cred, f) return cred @@ -27,7 +27,7 @@ def refresh_centml_token(refresh_token): def store_centml_cred(token_file): try: with open(token_file, 'r') as f: - os.makedirs(settings.centml_config_dir, exist_ok=True) + os.makedirs(settings.CENTML_CONFIG_PATH, exist_ok=True) refresh_token = json.load(f)["refreshToken"] refresh_centml_token(refresh_token) @@ -38,8 +38,8 @@ def store_centml_cred(token_file): def load_centml_cred(): cred = None - if os.path.exists(settings.centml_cred_file): - with open(settings.centml_cred_file, 'r') as f: + if os.path.exists(settings.CENTML_CRED_FILE_PATH): + with open(settings.CENTML_CRED_FILE_PATH, 'r') as f: cred = json.load(f) return cred @@ -60,5 +60,5 @@ def get_centml_token(): def remove_centml_cred(): - if os.path.exists(settings.centml_cred_file): - os.remove(settings.centml_cred_file) + if os.path.exists(settings.CENTML_CRED_FILE_PATH): + os.remove(settings.CENTML_CRED_FILE_PATH) diff --git a/centml/sdk/config.py b/centml/sdk/config.py index 50bce49..409edc2 100644 --- a/centml/sdk/config.py +++ b/centml/sdk/config.py @@ -3,12 +3,14 @@ class Config(BaseSettings): - centml_web_url: str = "https://main.d1tz9z8hgabab9.amplifyapp.com/" - centml_config_dir: str = os.getenv("CENTML_CONFIG_PATH", default=os.path.expanduser("~/.centml")) - centml_cred_file: str = centml_config_dir + "/" + os.getenv("CENTML_CRED_FILE", default="credential") + CENTML_WEB_URL: str = "https://main.d1tz9z8hgabab9.amplifyapp.com/" + CENTML_CONFIG_PATH: str = os.path.expanduser("~/.centml") + CENTML_CRED_FILE: str = "credential" + CENTML_CRED_FILE_PATH: str = CENTML_CONFIG_PATH + "/" + CENTML_CRED_FILE - platformapi_url: str = "https://api.centml.org" + PLATFORM_API_URL: str = "https://api.centml.org" - firebase_api_key: str = "AIzaSyBXSNjruNdtypqUt_CPhB8QNl8Djfh5RXI" + FIREBASE_API_KEY: str = "AIzaSyBXSNjruNdtypqUt_CPhB8QNl8Djfh5RXI" -settings = Config() \ No newline at end of file + +settings = Config() From 4cb9d108e83baddac4c202e6cd6f72503cd99e5f Mon Sep 17 00:00:00 2001 From: destefy Date: Tue, 16 Jul 2024 16:59:33 -0400 Subject: [PATCH 4/6] rename config_instance to settings --- centml/compiler/backend.py | 24 ++++++++++++------------ centml/compiler/config.py | 2 +- centml/compiler/server.py | 12 ++++++------ centml/compiler/utils.py | 12 ++++++------ tests/test_backend.py | 4 ++-- 5 files changed, 27 insertions(+), 27 deletions(-) diff --git a/centml/compiler/backend.py b/centml/compiler/backend.py index be77545..37482c2 100644 --- a/centml/compiler/backend.py +++ b/centml/compiler/backend.py @@ -11,7 +11,7 @@ import requests import torch from torch.fx import GraphModule -from centml.compiler.config import config_instance, CompilationStatus +from centml.compiler.config import settings, CompilationStatus from centml.compiler.utils import get_backend_compiled_forward_path @@ -54,13 +54,13 @@ def inputs(self): def _serialize_model_and_inputs(self): self.serialized_model_dir = TemporaryDirectory() # pylint: disable=consider-using-with - self.serialized_model_path = os.path.join(self.serialized_model_dir.name, config_instance.SERIALIZED_MODEL_FILE) - self.serialized_input_path = os.path.join(self.serialized_model_dir.name, config_instance.SERIALIZED_INPUT_FILE) + self.serialized_model_path = os.path.join(self.serialized_model_dir.name, settings.SERIALIZED_MODEL_FILE) + self.serialized_input_path = os.path.join(self.serialized_model_dir.name, settings.SERIALIZED_INPUT_FILE) # torch.save saves a zip file full of pickled files with the model's states. try: - torch.save(self.module, self.serialized_model_path, pickle_protocol=config_instance.PICKLE_PROTOCOL) - torch.save(self.inputs, self.serialized_input_path, pickle_protocol=config_instance.PICKLE_PROTOCOL) + torch.save(self.module, self.serialized_model_path, pickle_protocol=settings.PICKLE_PROTOCOL) + torch.save(self.inputs, self.serialized_input_path, pickle_protocol=settings.PICKLE_PROTOCOL) except Exception as e: raise Exception(f"Failed to save module or inputs with torch.save: {e}") from e @@ -71,7 +71,7 @@ def _get_model_id(self) -> str: sha_hash = hashlib.sha256() with open(self.serialized_model_path, "rb") as serialized_model_file: # Read in chunks to not load too much into memory - for block in iter(lambda: serialized_model_file.read(config_instance.HASH_CHUNK_SIZE), b""): + for block in iter(lambda: serialized_model_file.read(settings.HASH_CHUNK_SIZE), b""): sha_hash.update(block) model_id = sha_hash.hexdigest() @@ -80,7 +80,7 @@ def _get_model_id(self) -> str: def _download_model(self, model_id: str): download_response = requests.get( - url=f"{config_instance.CENTML_SERVER_URL}/download/{model_id}", timeout=config_instance.TIMEOUT + url=f"{settings.CENTML_SERVER_URL}/download/{model_id}", timeout=settings.TIMEOUT ) if download_response.status_code != HTTPStatus.OK: raise Exception( @@ -104,9 +104,9 @@ def _compile_model(self, model_id: str): with open(self.serialized_model_path, 'rb') as model_file, open(self.serialized_input_path, 'rb') as input_file: compile_response = requests.post( - url=f"{config_instance.CENTML_SERVER_URL}/submit/{model_id}", + url=f"{settings.CENTML_SERVER_URL}/submit/{model_id}", files={"model": model_file, "inputs": input_file}, - timeout=config_instance.TIMEOUT, + timeout=settings.TIMEOUT, ) if compile_response.status_code != HTTPStatus.OK: @@ -119,7 +119,7 @@ def _wait_for_status(self, model_id: str) -> bool: while True: # get server compilation status status_response = requests.get( - f"{config_instance.CENTML_SERVER_URL}/status/{model_id}", timeout=config_instance.TIMEOUT + f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.TIMEOUT ) if status_response.status_code != HTTPStatus.OK: raise Exception( @@ -138,10 +138,10 @@ def _wait_for_status(self, model_id: str) -> bool: else: tries += 1 - if tries > config_instance.MAX_RETRIES: + if tries > settings.MAX_RETRIES: raise Exception("Waiting for status: compilation failed too many times.\n") - time.sleep(config_instance.COMPILING_SLEEP_TIME) + time.sleep(settings.COMPILING_SLEEP_TIME) def remote_compilation(self): self._serialize_model_and_inputs() diff --git a/centml/compiler/config.py b/centml/compiler/config.py index efd49c8..5a462b1 100644 --- a/centml/compiler/config.py +++ b/centml/compiler/config.py @@ -32,4 +32,4 @@ class Config(BaseSettings): MINIMUM_GZIP_SIZE: int = 1000 -config_instance = Config() +settings = Config() diff --git a/centml/compiler/server.py b/centml/compiler/server.py index 4985192..86e08d2 100644 --- a/centml/compiler/server.py +++ b/centml/compiler/server.py @@ -10,15 +10,15 @@ from fastapi.middleware.gzip import GZipMiddleware from centml.compiler.server_compilation import hidet_backend_server from centml.compiler.utils import dir_cleanup -from centml.compiler.config import config_instance, CompilationStatus +from centml.compiler.config import settings, CompilationStatus from centml.compiler.utils import get_server_compiled_forward_path app = FastAPI() -app.add_middleware(GZipMiddleware, minimum_size=config_instance.MINIMUM_GZIP_SIZE) # type: ignore +app.add_middleware(GZipMiddleware, minimum_size=settings.MINIMUM_GZIP_SIZE) # type: ignore def get_status(model_id: str): - if not os.path.isdir(os.path.join(config_instance.SERVER_BASE_PATH, model_id)): + if not os.path.isdir(os.path.join(settings.SERVER_BASE_PATH, model_id)): return CompilationStatus.NOT_FOUND if not os.path.isfile(get_server_compiled_forward_path(model_id)): @@ -50,7 +50,7 @@ def background_compile(model_id: str, tfx_graph, example_inputs): # To avoid this, we write to a tmp file and rename it to the correct path after saving. save_path = get_server_compiled_forward_path(model_id) tmp_path = save_path + ".tmp" - torch.save(compiled_graph_module, tmp_path, pickle_protocol=config_instance.PICKLE_PROTOCOL) + torch.save(compiled_graph_module, tmp_path, pickle_protocol=settings.PICKLE_PROTOCOL) os.rename(tmp_path, save_path) except Exception as e: logging.getLogger(__name__).exception(f"Saving graph module failed: {e}") @@ -93,7 +93,7 @@ async def compile_model_handler(model_id: str, model: UploadFile, inputs: Upload return Response(status_code=200) # This effectively sets the model's status to COMPILING - os.makedirs(os.path.join(config_instance.SERVER_BASE_PATH, model_id)) + os.makedirs(os.path.join(settings.SERVER_BASE_PATH, model_id)) tfx_graph, example_inputs = read_upload_files(model_id, model, inputs) @@ -110,7 +110,7 @@ async def download_handler(model_id: str): def run(): - parsed = urlparse(config_instance.CENTML_SERVER_URL) + parsed = urlparse(settings.CENTML_SERVER_URL) uvicorn.run(app, host=parsed.hostname, port=parsed.port) diff --git a/centml/compiler/utils.py b/centml/compiler/utils.py index a8706ea..f4380f4 100644 --- a/centml/compiler/utils.py +++ b/centml/compiler/utils.py @@ -1,21 +1,21 @@ import os import shutil -from centml.compiler.config import config_instance +from centml.compiler.config import settings def get_backend_compiled_forward_path(model_id: str): - os.makedirs(os.path.join(config_instance.BACKEND_BASE_PATH, model_id), exist_ok=True) - return os.path.join(config_instance.BACKEND_BASE_PATH, model_id, "compilation_return.pkl") + os.makedirs(os.path.join(settings.BACKEND_BASE_PATH, model_id), exist_ok=True) + return os.path.join(settings.BACKEND_BASE_PATH, model_id, "compilation_return.pkl") def get_server_compiled_forward_path(model_id: str): - os.makedirs(os.path.join(config_instance.SERVER_BASE_PATH, model_id), exist_ok=True) - return os.path.join(config_instance.SERVER_BASE_PATH, model_id, "compilation_return.pkl") + os.makedirs(os.path.join(settings.SERVER_BASE_PATH, model_id), exist_ok=True) + return os.path.join(settings.SERVER_BASE_PATH, model_id, "compilation_return.pkl") # This function will delete the storage_path/{model_id} directory def dir_cleanup(model_id: str): - dir_path = os.path.join(config_instance.SERVER_BASE_PATH, model_id) + dir_path = os.path.join(settings.SERVER_BASE_PATH, model_id) if not os.path.exists(dir_path): return # Directory does not exist, return diff --git a/tests/test_backend.py b/tests/test_backend.py index c35bd17..aa7e6e3 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -6,7 +6,7 @@ from parameterized import parameterized_class from torch.fx import GraphModule from centml.compiler.backend import Runner -from centml.compiler.config import CompilationStatus, config_instance +from centml.compiler.config import CompilationStatus, settings from .test_helpers import MODEL_SUITE @@ -145,7 +145,7 @@ def test_max_tries(self, mock_requests, mock_compile): with self.assertRaises(Exception) as context: self.runner._wait_for_status(model_id) - self.assertEqual(mock_compile.call_count, config_instance.MAX_RETRIES + 1) + self.assertEqual(mock_compile.call_count, settings.MAX_RETRIES + 1) self.assertIn("Waiting for status: compilation failed too many times", str(context.exception)) @patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0) From 11380f2331e4876d431e9407656c073223b446bb Mon Sep 17 00:00:00 2001 From: destefy Date: Tue, 16 Jul 2024 17:15:32 -0400 Subject: [PATCH 5/6] expand path --- centml/compiler/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/centml/compiler/config.py b/centml/compiler/config.py index 5a462b1..42d336a 100644 --- a/centml/compiler/config.py +++ b/centml/compiler/config.py @@ -14,7 +14,7 @@ class Config(BaseSettings): MAX_RETRIES: int = 3 COMPILING_SLEEP_TIME: int = 15 - CENTML_CACHE_DIR: str = "~/.cache/centml" + CENTML_CACHE_DIR: str = os.path.expanduser("~/.cache/centml") BACKEND_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "backend") SERVER_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "server") From 4cb57287df907694dc76a90f98568b9b67b77ef9 Mon Sep 17 00:00:00 2001 From: destefy Date: Tue, 16 Jul 2024 17:16:01 -0400 Subject: [PATCH 6/6] lint --- centml/compiler/backend.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/centml/compiler/backend.py b/centml/compiler/backend.py index 37482c2..3398c77 100644 --- a/centml/compiler/backend.py +++ b/centml/compiler/backend.py @@ -118,9 +118,7 @@ def _wait_for_status(self, model_id: str) -> bool: tries = 0 while True: # get server compilation status - status_response = requests.get( - f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.TIMEOUT - ) + status_response = requests.get(f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.TIMEOUT) if status_response.status_code != HTTPStatus.OK: raise Exception( f"Status check: request failed, exception from server:\n{status_response.json().get('detail')}"