Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions centml/cli/login.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import click

from centml.sdk import auth, config
from centml.sdk import auth
from centml.sdk.config import settings


@click.command(help="Login to CentML")
Expand All @@ -10,15 +11,15 @@ def login(token_file):
auth.store_centml_cred(token_file)

if auth.load_centml_cred():
click.echo(f"Authenticating with credentials from {config.Config.centml_cred_file}\n")
click.echo(f"Authenticating with credentials from {settings.CENTML_CRED_FILE_PATH}\n")
click.echo("Login successful")
else:
click.echo("Login with CentML authentication token")
click.echo("Usage: centml login TOKEN_FILE\n")
choice = click.confirm("Do you want to download the token?")

if choice:
click.launch(f"{config.Config.centml_web_url}?isCliAuthenticated=true")
click.launch(f"{settings.CENTML_WEB_URL}?isCliAuthenticated=true")
else:
click.echo("Login unsuccessful")

Expand Down
26 changes: 12 additions & 14 deletions centml/compiler/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import requests
import torch
from torch.fx import GraphModule
from centml.compiler.config import config_instance, CompilationStatus
from centml.compiler.config import settings, CompilationStatus
from centml.compiler.utils import get_backend_compiled_forward_path


Expand Down Expand Up @@ -54,13 +54,13 @@ def inputs(self):

def _serialize_model_and_inputs(self):
self.serialized_model_dir = TemporaryDirectory() # pylint: disable=consider-using-with
self.serialized_model_path = os.path.join(self.serialized_model_dir.name, config_instance.SERIALIZED_MODEL_FILE)
self.serialized_input_path = os.path.join(self.serialized_model_dir.name, config_instance.SERIALIZED_INPUT_FILE)
self.serialized_model_path = os.path.join(self.serialized_model_dir.name, settings.SERIALIZED_MODEL_FILE)
self.serialized_input_path = os.path.join(self.serialized_model_dir.name, settings.SERIALIZED_INPUT_FILE)

# torch.save saves a zip file full of pickled files with the model's states.
try:
torch.save(self.module, self.serialized_model_path, pickle_protocol=config_instance.PICKLE_PROTOCOL)
torch.save(self.inputs, self.serialized_input_path, pickle_protocol=config_instance.PICKLE_PROTOCOL)
torch.save(self.module, self.serialized_model_path, pickle_protocol=settings.PICKLE_PROTOCOL)
torch.save(self.inputs, self.serialized_input_path, pickle_protocol=settings.PICKLE_PROTOCOL)
except Exception as e:
raise Exception(f"Failed to save module or inputs with torch.save: {e}") from e

Expand All @@ -71,7 +71,7 @@ def _get_model_id(self) -> str:
sha_hash = hashlib.sha256()
with open(self.serialized_model_path, "rb") as serialized_model_file:
# Read in chunks to not load too much into memory
for block in iter(lambda: serialized_model_file.read(config_instance.HASH_CHUNK_SIZE), b""):
for block in iter(lambda: serialized_model_file.read(settings.HASH_CHUNK_SIZE), b""):
sha_hash.update(block)

model_id = sha_hash.hexdigest()
Expand All @@ -80,7 +80,7 @@ def _get_model_id(self) -> str:

def _download_model(self, model_id: str):
download_response = requests.get(
url=f"{config_instance.SERVER_URL}/download/{model_id}", timeout=config_instance.TIMEOUT
url=f"{settings.CENTML_SERVER_URL}/download/{model_id}", timeout=settings.TIMEOUT
)
if download_response.status_code != HTTPStatus.OK:
raise Exception(
Expand All @@ -104,9 +104,9 @@ def _compile_model(self, model_id: str):

with open(self.serialized_model_path, 'rb') as model_file, open(self.serialized_input_path, 'rb') as input_file:
compile_response = requests.post(
url=f"{config_instance.SERVER_URL}/submit/{model_id}",
url=f"{settings.CENTML_SERVER_URL}/submit/{model_id}",
files={"model": model_file, "inputs": input_file},
timeout=config_instance.TIMEOUT,
timeout=settings.TIMEOUT,
)

if compile_response.status_code != HTTPStatus.OK:
Expand All @@ -118,9 +118,7 @@ def _wait_for_status(self, model_id: str) -> bool:
tries = 0
while True:
# get server compilation status
status_response = requests.get(
f"{config_instance.SERVER_URL}/status/{model_id}", timeout=config_instance.TIMEOUT
)
status_response = requests.get(f"{settings.CENTML_SERVER_URL}/status/{model_id}", timeout=settings.TIMEOUT)
if status_response.status_code != HTTPStatus.OK:
raise Exception(
f"Status check: request failed, exception from server:\n{status_response.json().get('detail')}"
Expand All @@ -138,10 +136,10 @@ def _wait_for_status(self, model_id: str) -> bool:
else:
tries += 1

if tries > config_instance.MAX_RETRIES:
if tries > settings.MAX_RETRIES:
raise Exception("Waiting for status: compilation failed too many times.\n")

time.sleep(config_instance.COMPILING_SLEEP_TIME)
time.sleep(settings.COMPILING_SLEEP_TIME)

def remote_compilation(self):
self._serialize_model_and_inputs()
Expand Down
16 changes: 8 additions & 8 deletions centml/compiler/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from enum import Enum
from pydantic_settings import BaseSettings


class CompilationStatus(Enum):
Expand All @@ -8,20 +9,19 @@ class CompilationStatus(Enum):
DONE = "DONE"


class Config:
class Config(BaseSettings):
TIMEOUT: int = 10
MAX_RETRIES: int = 3
COMPILING_SLEEP_TIME: int = 15

CACHE_PATH: str = os.getenv("CENTML_CACHE_DIR", default=os.path.expanduser("~/.cache/centml"))
CENTML_CACHE_DIR: str = os.path.expanduser("~/.cache/centml")
BACKEND_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "backend")
SERVER_BASE_PATH: str = os.path.join(CENTML_CACHE_DIR, "server")

SERVER_URL: str = os.getenv("CENTML_SERVER_URL", default="http://0.0.0.0:8090")

BACKEND_BASE_PATH: str = os.path.join(CACHE_PATH, "backend")
SERVER_BASE_PATH: str = os.path.join(CACHE_PATH, "server")
CENTML_SERVER_URL: str = "http://0.0.0.0:8090"

# Use a constant path since torch.save uses the given file name in it's zipfile.
# Thus, a different filename would result in a different hash.
# Using a different filename would result in a different hash.
SERIALIZED_MODEL_FILE: str = "serialized_model.zip"
SERIALIZED_INPUT_FILE: str = "serialized_input.zip"
PICKLE_PROTOCOL: int = 4
Expand All @@ -32,4 +32,4 @@ class Config:
MINIMUM_GZIP_SIZE: int = 1000


config_instance = Config()
settings = Config()
12 changes: 6 additions & 6 deletions centml/compiler/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from fastapi.middleware.gzip import GZipMiddleware
from centml.compiler.server_compilation import hidet_backend_server
from centml.compiler.utils import dir_cleanup
from centml.compiler.config import config_instance, CompilationStatus
from centml.compiler.config import settings, CompilationStatus
from centml.compiler.utils import get_server_compiled_forward_path

app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=config_instance.MINIMUM_GZIP_SIZE) # type: ignore
app.add_middleware(GZipMiddleware, minimum_size=settings.MINIMUM_GZIP_SIZE) # type: ignore


def get_status(model_id: str):
if not os.path.isdir(os.path.join(config_instance.SERVER_BASE_PATH, model_id)):
if not os.path.isdir(os.path.join(settings.SERVER_BASE_PATH, model_id)):
return CompilationStatus.NOT_FOUND

if not os.path.isfile(get_server_compiled_forward_path(model_id)):
Expand Down Expand Up @@ -50,7 +50,7 @@ def background_compile(model_id: str, tfx_graph, example_inputs):
# To avoid this, we write to a tmp file and rename it to the correct path after saving.
save_path = get_server_compiled_forward_path(model_id)
tmp_path = save_path + ".tmp"
torch.save(compiled_graph_module, tmp_path, pickle_protocol=config_instance.PICKLE_PROTOCOL)
torch.save(compiled_graph_module, tmp_path, pickle_protocol=settings.PICKLE_PROTOCOL)
os.rename(tmp_path, save_path)
except Exception as e:
logging.getLogger(__name__).exception(f"Saving graph module failed: {e}")
Expand Down Expand Up @@ -93,7 +93,7 @@ async def compile_model_handler(model_id: str, model: UploadFile, inputs: Upload
return Response(status_code=200)

# This effectively sets the model's status to COMPILING
os.makedirs(os.path.join(config_instance.SERVER_BASE_PATH, model_id))
os.makedirs(os.path.join(settings.SERVER_BASE_PATH, model_id))

tfx_graph, example_inputs = read_upload_files(model_id, model, inputs)

Expand All @@ -110,7 +110,7 @@ async def download_handler(model_id: str):


def run():
parsed = urlparse(config_instance.SERVER_URL)
parsed = urlparse(settings.CENTML_SERVER_URL)
uvicorn.run(app, host=parsed.hostname, port=parsed.port)


Expand Down
12 changes: 6 additions & 6 deletions centml/compiler/utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import os
import shutil
from centml.compiler.config import config_instance
from centml.compiler.config import settings


def get_backend_compiled_forward_path(model_id: str):
os.makedirs(os.path.join(config_instance.BACKEND_BASE_PATH, model_id), exist_ok=True)
return os.path.join(config_instance.BACKEND_BASE_PATH, model_id, "compilation_return.pkl")
os.makedirs(os.path.join(settings.BACKEND_BASE_PATH, model_id), exist_ok=True)
return os.path.join(settings.BACKEND_BASE_PATH, model_id, "compilation_return.pkl")


def get_server_compiled_forward_path(model_id: str):
os.makedirs(os.path.join(config_instance.SERVER_BASE_PATH, model_id), exist_ok=True)
return os.path.join(config_instance.SERVER_BASE_PATH, model_id, "compilation_return.pkl")
os.makedirs(os.path.join(settings.SERVER_BASE_PATH, model_id), exist_ok=True)
return os.path.join(settings.SERVER_BASE_PATH, model_id, "compilation_return.pkl")


# This function will delete the storage_path/{model_id} directory
def dir_cleanup(model_id: str):
dir_path = os.path.join(config_instance.SERVER_BASE_PATH, model_id)
dir_path = os.path.join(settings.SERVER_BASE_PATH, model_id)
if not os.path.exists(dir_path):
return # Directory does not exist, return

Expand Down
6 changes: 4 additions & 2 deletions centml/sdk/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from platform_api_client.models.deployment_status import DeploymentStatus

from centml.sdk import auth
from centml.sdk.config import Config
from centml.sdk.config import settings
from centml.sdk.utils import client_certs


@contextlib.contextmanager
def get_api():
configuration = platform_api_client.Configuration(host=Config.platformapi_url, access_token=auth.get_centml_token())
configuration = platform_api_client.Configuration(
host=settings.PLATFORM_API_URL, access_token=auth.get_centml_token()
)

with platform_api_client.ApiClient(configuration) as api_client:
api_instance = platform_api_client.EXTERNALApi(api_client)
Expand Down
16 changes: 8 additions & 8 deletions centml/sdk/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import requests
import jwt

from centml.sdk.config import Config
from centml.sdk.config import settings


def refresh_centml_token(refresh_token):
api_key = Config.firebase_api_key
api_key = settings.FIREBASE_API_KEY

cred = requests.post(
f"https://securetoken.googleapis.com/v1/token?key={api_key}",
Expand All @@ -18,7 +18,7 @@ def refresh_centml_token(refresh_token):
timeout=3,
).json()

with open(Config.centml_cred_file, 'w') as f:
with open(settings.CENTML_CRED_FILE_PATH, 'w') as f:
json.dump(cred, f)

return cred
Expand All @@ -27,7 +27,7 @@ def refresh_centml_token(refresh_token):
def store_centml_cred(token_file):
try:
with open(token_file, 'r') as f:
os.makedirs(Config.centml_config_dir, exist_ok=True)
os.makedirs(settings.CENTML_CONFIG_PATH, exist_ok=True)
refresh_token = json.load(f)["refreshToken"]

refresh_centml_token(refresh_token)
Expand All @@ -38,8 +38,8 @@ def store_centml_cred(token_file):
def load_centml_cred():
cred = None

if os.path.exists(Config.centml_cred_file):
with open(Config.centml_cred_file, 'r') as f:
if os.path.exists(settings.CENTML_CRED_FILE_PATH):
with open(settings.CENTML_CRED_FILE_PATH, 'r') as f:
cred = json.load(f)

return cred
Expand All @@ -60,5 +60,5 @@ def get_centml_token():


def remove_centml_cred():
if os.path.exists(Config.centml_cred_file):
os.remove(Config.centml_cred_file)
if os.path.exists(settings.CENTML_CRED_FILE_PATH):
os.remove(settings.CENTML_CRED_FILE_PATH)
17 changes: 11 additions & 6 deletions centml/sdk/config.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import os
from pydantic_settings import BaseSettings


class Config:
centml_web_url = "https://main.d1tz9z8hgabab9.amplifyapp.com/"
centml_config_dir = os.getenv("CENTML_CONFIG_PATH", default=os.path.expanduser("~/.centml"))
centml_cred_file = centml_config_dir + "/" + os.getenv("CENTML_CRED_FILE", default="credential")
class Config(BaseSettings):
CENTML_WEB_URL: str = "https://main.d1tz9z8hgabab9.amplifyapp.com/"
CENTML_CONFIG_PATH: str = os.path.expanduser("~/.centml")
CENTML_CRED_FILE: str = "credential"
CENTML_CRED_FILE_PATH: str = CENTML_CONFIG_PATH + "/" + CENTML_CRED_FILE

platformapi_url = "https://api.centml.org"
PLATFORM_API_URL: str = "https://api.centml.org"

firebase_api_key = "AIzaSyBXSNjruNdtypqUt_CPhB8QNl8Djfh5RXI"
FIREBASE_API_KEY: str = "AIzaSyBXSNjruNdtypqUt_CPhB8QNl8Djfh5RXI"


settings = Config()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ torch>=2.0.0
fastapi>=0.103.0
uvicorn>=0.23.0
python-multipart>=0.0.6
pydantic-settings==2.0.*
Requests==2.32.2
tabulate>=0.9.0
pyjwt>=2.8.0
Expand Down
8 changes: 4 additions & 4 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from parameterized import parameterized_class
from torch.fx import GraphModule
from centml.compiler.backend import Runner
from centml.compiler.config import CompilationStatus, config_instance
from centml.compiler.config import CompilationStatus, settings
from .test_helpers import MODEL_SUITE


Expand Down Expand Up @@ -132,7 +132,7 @@ def test_invalid_status(self, mock_requests):
mock_requests.get.assert_called_once()
self.assertIn("Status check: request failed, exception from server", str(context.exception))

@patch("centml.compiler.config.Config.COMPILING_SLEEP_TIME", new=0)
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
@patch("centml.compiler.backend.Runner._compile_model")
@patch("centml.compiler.backend.requests")
def test_max_tries(self, mock_requests, mock_compile):
Expand All @@ -145,10 +145,10 @@ def test_max_tries(self, mock_requests, mock_compile):
with self.assertRaises(Exception) as context:
self.runner._wait_for_status(model_id)

self.assertEqual(mock_compile.call_count, config_instance.MAX_RETRIES + 1)
self.assertEqual(mock_compile.call_count, settings.MAX_RETRIES + 1)
self.assertIn("Waiting for status: compilation failed too many times", str(context.exception))

@patch("centml.compiler.config.Config.COMPILING_SLEEP_TIME", new=0)
@patch("centml.compiler.config.settings.COMPILING_SLEEP_TIME", new=0)
@patch("centml.compiler.backend.requests")
def test_wait_on_compilation(self, mock_requests):
COMPILATION_STEPS = 10
Expand Down