diff --git a/pyproject.toml b/pyproject.toml index 495086681..464c64ee6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,9 +75,12 @@ dev = [ version = {attr = "xpk.core.config.__version__"} [tool.setuptools] -packages = ["xpk", "xpk.parser", "xpk.core", "xpk.commands", "xpk.api", "xpk.templates", "xpk.utils", "xpk.core.blueprint", "xpk.core.remote_state", "xpk.core.workload_decorators"] package-dir = {"" = "src"} -package-data = {"xpk.api" = ["storage_crd.yaml"], "xpk.templates" = ["storage.yaml"]} +packages = { find = { where = ["src"] } } + +[tool.setuptools.package-data] +"xpk" = ["templates/*"] +"xpk.api" = ["*.yaml"] [tool.pyink] # Formatting configuration to follow Google style-guide. diff --git a/src/xpk/commands/cluster.py b/src/xpk/commands/cluster.py index ef2715061..f1a02bda0 100644 --- a/src/xpk/commands/cluster.py +++ b/src/xpk/commands/cluster.py @@ -78,7 +78,7 @@ from . import cluster_gcluster from .common import set_cluster_command, validate_sub_slicing_system from jinja2 import Environment, FileSystemLoader -from ..utils.templates import TEMPLATE_PATH +from ..utils.templates import get_templates_absolute_path import shutil import os @@ -434,7 +434,9 @@ def cluster_cacheimage(args) -> None: system.accelerator_type ].accelerator_label - template_env = Environment(loader=FileSystemLoader(TEMPLATE_PATH)) + template_env = Environment( + loader=FileSystemLoader(searchpath=get_templates_absolute_path()) + ) cluster_preheat_yaml = template_env.get_template(CLUSTER_PREHEAT_JINJA_FILE) rendered_yaml = cluster_preheat_yaml.render( cachekey=args.cache_key, diff --git a/src/xpk/core/kueue_manager.py b/src/xpk/core/kueue_manager.py index 582e27519..921350db6 100644 --- a/src/xpk/core/kueue_manager.py +++ b/src/xpk/core/kueue_manager.py @@ -39,7 +39,7 @@ ) from ..utils.file import write_tmp_file from ..utils.console import xpk_print, xpk_exit -from ..utils.templates import TEMPLATE_PATH +from ..utils.templates import TEMPLATE_PATH, get_templates_absolute_path WAIT_FOR_KUEUE_TIMEOUT = "10m" CLUSTER_QUEUE_NAME = "cluster-queue" @@ -82,7 +82,12 @@ def __init__( template_path=TEMPLATE_PATH, ): self.kueue_version = kueue_version - self.template_env = Environment(loader=FileSystemLoader(template_path)) + + self.template_env = Environment( + loader=FileSystemLoader( + searchpath=get_templates_absolute_path(template_path) + ) + ) def install_or_upgrade( self, diff --git a/src/xpk/utils/templates.py b/src/xpk/utils/templates.py index 11e9cba24..eca9d2063 100644 --- a/src/xpk/utils/templates.py +++ b/src/xpk/utils/templates.py @@ -18,7 +18,7 @@ import ruamel.yaml -TEMPLATE_PATH = "src/xpk/templates/" +TEMPLATE_PATH = "templates" yaml = ruamel.yaml.YAML() @@ -28,3 +28,16 @@ def load(path: str) -> dict: with open(template_path, "r", encoding="utf-8") as file: data: dict = yaml.load(file) return data + + +def get_templates_absolute_path(templates_path: str = TEMPLATE_PATH) -> str: + """ + Return the absolute path to the templates folder + + Args: + templates_path: The path to the templates folder relative to the src/xpk directory + """ + current_file_path = os.path.abspath(__file__) + current_dir = os.path.dirname(current_file_path) + xpk_package_dir = os.path.dirname(current_dir) + return os.path.join(xpk_package_dir, templates_path)