Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from backend.server.scheduler_manage import SchedulerManage
from backend.server.server_args import parse_args
from backend.server.static_config import get_model_list, get_node_join_command
from common.file_util import get_project_root
from parallax_utils.ascii_anime import display_parallax_run
from parallax_utils.logging_config import get_logger, set_log_level

Expand Down Expand Up @@ -113,7 +114,7 @@ async def openai_v1_chat_completions(raw_request: Request):
# Disable caching for index.html
@app.get("/")
async def serve_index():
response = FileResponse("src/frontend/dist/index.html")
response = FileResponse(str(get_project_root()) + "/src/frontend/dist/index.html")
# Disable cache
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate, max-age=0"
response.headers["Pragma"] = "no-cache"
Expand All @@ -122,7 +123,11 @@ async def serve_index():


# mount the frontend
app.mount("/", StaticFiles(directory="src/frontend/dist", html=True), name="static")
app.mount(
"/",
StaticFiles(directory=str(get_project_root() / "src" / "frontend" / "dist"), html=True),
name="static",
)

if __name__ == "__main__":
args = parse_args()
Expand Down
2 changes: 1 addition & 1 deletion src/backend/server/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--dht-prefix", type=str, default="gradient", help="Prefix for DHT keys")

# Scheduler configuration
parser.add_argument("--port", type=int, default=5000, help="Port to listen on")
parser.add_argument("--port", type=int, default=3001, help="Port to listen on")
parser.add_argument(
"--log-level",
type=str,
Expand Down
14 changes: 14 additions & 0 deletions src/common/file_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pathlib import Path


def get_project_root():
"""Get the project root directory."""
# Search for the project root by looking for pyproject.toml in parent directories
current_dir = Path(__file__).parent
while current_dir != current_dir.parent:
if (current_dir / "pyproject.toml").exists():
return current_dir
current_dir = current_dir.parent

# If not found, fallback to current working directory
return Path.cwd()
15 changes: 1 addition & 14 deletions src/parallax/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import signal
import subprocess
import sys
from pathlib import Path

from common.file_util import get_project_root
from common.static_config import get_relay_params
from parallax_utils.logging_config import get_logger

Expand All @@ -29,19 +29,6 @@ def check_python_version():
sys.exit(1)


def get_project_root():
"""Get the project root directory."""
# Search for the project root by looking for pyproject.toml in parent directories
current_dir = Path(__file__).parent
while current_dir != current_dir.parent:
if (current_dir / "pyproject.toml").exists():
return current_dir
current_dir = current_dir.parent

# If not found, fallback to current working directory
return Path.cwd()


def _flag_present(args_list: list[str], flag_names: list[str]) -> bool:
"""Return True if any of the given flags is present in args_list.

Expand Down
6 changes: 4 additions & 2 deletions src/parallax_utils/ascii_anime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import math
import os

from common.file_util import get_project_root


class HexColorPrinter:
COLOR_MAP = {
Expand Down Expand Up @@ -198,7 +200,7 @@ def display_ascii_animation_join(animation_data, model_name):


def display_parallax_run():
file_path = "./src/parallax_utils/anime/parallax_run.json"
file_path = str(get_project_root()) + "/src/parallax_utils/anime/parallax_run.json"
try:
with open(file_path, "r") as f:
animation_data = json.load(f)
Expand All @@ -212,7 +214,7 @@ def display_parallax_run():


def display_parallax_join(model_name):
file_path = "./src/parallax_utils/anime/parallax_join.json"
file_path = str(get_project_root()) + "/src/parallax_utils/anime/parallax_join.json"
try:
with open(file_path, "r") as f:
animation_data = json.load(f)
Expand Down