diff --git a/src/backend/main.py b/src/backend/main.py index 9bda30d3..c72b0a6e 100644 --- a/src/backend/main.py +++ b/src/backend/main.py @@ -13,6 +13,7 @@ from backend.server.scheduler_manage import SchedulerManage from backend.server.server_args import parse_args from backend.server.static_config import get_model_list, get_node_join_command +from common.file_util import get_project_root from parallax_utils.ascii_anime import display_parallax_run from parallax_utils.logging_config import get_logger, set_log_level @@ -113,7 +114,7 @@ async def openai_v1_chat_completions(raw_request: Request): # Disable caching for index.html @app.get("/") async def serve_index(): - response = FileResponse("src/frontend/dist/index.html") + response = FileResponse(str(get_project_root()) + "/src/frontend/dist/index.html") # Disable cache response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0" response.headers["Pragma"] = "no-cache" @@ -122,7 +123,11 @@ async def serve_index(): # mount the frontend -app.mount("/", StaticFiles(directory="src/frontend/dist", html=True), name="static") +app.mount( + "/", + StaticFiles(directory=str(get_project_root() / "src" / "frontend" / "dist"), html=True), + name="static", +) if __name__ == "__main__": args = parse_args() diff --git a/src/backend/server/server_args.py b/src/backend/server/server_args.py index 3117ef12..27207c7c 100644 --- a/src/backend/server/server_args.py +++ b/src/backend/server/server_args.py @@ -21,7 +21,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--dht-prefix", type=str, default="gradient", help="Prefix for DHT keys") # Scheduler configuration - parser.add_argument("--port", type=int, default=5000, help="Port to listen on") + parser.add_argument("--port", type=int, default=3001, help="Port to listen on") parser.add_argument( "--log-level", type=str, diff --git a/src/common/file_util.py b/src/common/file_util.py new file mode 100644 index 00000000..3efb24f8 --- /dev/null +++ b/src/common/file_util.py @@ -0,0 +1,14 @@ +from pathlib import Path + + +def get_project_root(): + """Get the project root directory.""" + # Search for the project root by looking for pyproject.toml in parent directories + current_dir = Path(__file__).parent + while current_dir != current_dir.parent: + if (current_dir / "pyproject.toml").exists(): + return current_dir + current_dir = current_dir.parent + + # If not found, fallback to current working directory + return Path.cwd() diff --git a/src/parallax/cli.py b/src/parallax/cli.py index 51ee06ef..64ce2c67 100644 --- a/src/parallax/cli.py +++ b/src/parallax/cli.py @@ -12,8 +12,8 @@ import signal import subprocess import sys -from pathlib import Path +from common.file_util import get_project_root from common.static_config import get_relay_params from parallax_utils.logging_config import get_logger @@ -29,19 +29,6 @@ def check_python_version(): sys.exit(1) -def get_project_root(): - """Get the project root directory.""" - # Search for the project root by looking for pyproject.toml in parent directories - current_dir = Path(__file__).parent - while current_dir != current_dir.parent: - if (current_dir / "pyproject.toml").exists(): - return current_dir - current_dir = current_dir.parent - - # If not found, fallback to current working directory - return Path.cwd() - - def _flag_present(args_list: list[str], flag_names: list[str]) -> bool: """Return True if any of the given flags is present in args_list. diff --git a/src/parallax_utils/ascii_anime.py b/src/parallax_utils/ascii_anime.py index 38ffbffe..38eae149 100755 --- a/src/parallax_utils/ascii_anime.py +++ b/src/parallax_utils/ascii_anime.py @@ -2,6 +2,8 @@ import math import os +from common.file_util import get_project_root + class HexColorPrinter: COLOR_MAP = { @@ -198,7 +200,7 @@ def display_ascii_animation_join(animation_data, model_name): def display_parallax_run(): - file_path = "./src/parallax_utils/anime/parallax_run.json" + file_path = str(get_project_root()) + "/src/parallax_utils/anime/parallax_run.json" try: with open(file_path, "r") as f: animation_data = json.load(f) @@ -212,7 +214,7 @@ def display_parallax_run(): def display_parallax_join(model_name): - file_path = "./src/parallax_utils/anime/parallax_join.json" + file_path = str(get_project_root()) + "/src/parallax_utils/anime/parallax_join.json" try: with open(file_path, "r") as f: animation_data = json.load(f)