diff --git a/servers/corpus/src/corpus.py b/servers/corpus/src/corpus.py index c87198b..24b351d 100644 --- a/servers/corpus/src/corpus.py +++ b/servers/corpus/src/corpus.py @@ -14,6 +14,48 @@ from tqdm import tqdm from ultrarag.server import UltraRAG_MCP_Server + +def _validate_path(user_path: str, allowed_base: Optional[str] = None) -> Path: + """Validate and sanitize file path to prevent path traversal attacks. + + Args: + user_path: User-provided file path + allowed_base: Optional base directory to restrict paths to + + Returns: + Resolved and validated Path object + + Raises: + ValueError: If path traversal is detected or path is invalid + """ + try: + # Resolve the path to absolute + safe_path = Path(user_path).resolve() + + # If allowed_base is provided, ensure path is within it + if allowed_base: + base_path = Path(allowed_base).resolve() + try: + # Check if safe_path is relative to base_path + safe_path.relative_to(base_path) + except ValueError: + raise ValueError( + f"Path traversal detected: '{user_path}' is outside allowed directory '{allowed_base}'" + ) + + # Additional safety: check for suspicious patterns + path_str = str(safe_path) + if ".." in path_str or path_str.startswith("/etc/") or path_str.startswith("/proc/"): + # Double check even after resolve + if ".." in str(Path(user_path)): + raise ValueError(f"Path traversal detected: '{user_path}' contains '..'") + + return safe_path + except (OSError, ValueError) as e: + if isinstance(e, ValueError): + raise + raise ValueError(f"Invalid path: {user_path}") from e + app = UltraRAG_MCP_Server("corpus") @@ -171,7 +213,15 @@ async def build_text_corpus( PMLIKE_EXT = [".pdf", ".xps", ".oxps", ".epub", ".mobi", ".fb2"] DOCX_EXT = [".docx"] - in_path = os.path.abspath(parse_file_path) + # Validate and sanitize path to prevent path traversal + try: + safe_path = _validate_path(parse_file_path) + in_path = str(safe_path) + except ValueError as e: + err_msg = f"Invalid file path: {e}" + app.logger.error(err_msg) + raise ToolError(err_msg) + if not os.path.exists(in_path): err_msg = f"Input path not found: {in_path}" app.logger.error(err_msg) @@ -224,6 +274,7 @@ def process_one_file(fp: str) -> None: app.logger.error(err_msg) raise ToolError(err_msg) try: + doc = None with suppress_stdout(): doc = pymupdf.open(fp) texts = [] @@ -235,6 +286,13 @@ def process_one_file(fp: str) -> None: content = "\n\n".join(texts) except Exception as e: app.logger.warning(f"PDF read failed: {fp} | {e}") + finally: + # Ensure PDF document is closed to prevent memory leaks + if doc is not None: + try: + doc.close() + except Exception: + pass else: warn_msg = f"Unsupported file type, skip: {fp}" app.logger.warning(warn_msg) @@ -291,13 +349,28 @@ async def build_image_corpus( app.logger.error(err_msg) raise ToolError(err_msg) - in_path = os.path.abspath(parse_file_path) + # Validate and sanitize path to prevent path traversal + try: + safe_path = _validate_path(parse_file_path) + in_path = str(safe_path) + except ValueError as e: + err_msg = f"Invalid file path: {e}" + app.logger.error(err_msg) + raise ToolError(err_msg) + if not os.path.exists(in_path): err_msg = f"Input path not found: {in_path}" app.logger.error(err_msg) raise ToolError(err_msg) - corpus_jsonl = os.path.abspath(image_corpus_save_path) + # Validate output path + try: + safe_output_path = _validate_path(image_corpus_save_path) + corpus_jsonl = str(safe_output_path) + except ValueError as e: + err_msg = f"Invalid output path: {e}" + app.logger.error(err_msg) + raise ToolError(err_msg) out_root = os.path.dirname(corpus_jsonl) or os.getcwd() base_img_dir = os.path.join(out_root, "image") os.makedirs(base_img_dir, exist_ok=True) @@ -329,6 +402,7 @@ async def build_image_corpus( out_img_dir = os.path.join(base_img_dir, stem) os.makedirs(out_img_dir, exist_ok=True) + doc = None try: with suppress_stdout(): doc = pymupdf.open(pdf_path) @@ -337,6 +411,9 @@ async def build_image_corpus( app.logger.warning(warn_msg) continue + if doc is None: + continue + if getattr(doc, "is_encrypted", False): try: doc.authenticate("") @@ -393,6 +470,13 @@ async def build_image_corpus( } ) gid += 1 + + # Ensure PDF document is closed to prevent memory leaks + if doc is not None: + try: + doc.close() + except Exception: + pass _save_jsonl(valid_rows, corpus_jsonl) info_msg = ( @@ -429,7 +513,15 @@ async def mineru_parse( app.logger.error(err_msg) raise ToolError(err_msg) - in_path = os.path.abspath(parse_file_path) + # Validate and sanitize path to prevent path traversal + try: + safe_path = _validate_path(parse_file_path) + in_path = str(safe_path) + except ValueError as e: + err_msg = f"Invalid file path: {e}" + app.logger.error(err_msg) + raise ToolError(err_msg) + if not os.path.exists(in_path): err_msg = f"Input path not found: {in_path}" app.logger.error(err_msg) diff --git a/servers/prompt/src/prompt.py b/servers/prompt/src/prompt.py index 6ba62c5..04d6853 100644 --- a/servers/prompt/src/prompt.py +++ b/servers/prompt/src/prompt.py @@ -4,6 +4,8 @@ from typing import Any, Dict, List, Optional, Union from jinja2 import Template +from jinja2.sandbox import SandboxedEnvironment +from markupsafe import escape from fastmcp.prompts import PromptMessage from ultrarag.server import UltraRAG_MCP_Server @@ -11,24 +13,86 @@ app = UltraRAG_MCP_Server("prompt") +# Create a sandboxed Jinja2 environment for security +_sandboxed_env = SandboxedEnvironment(autoescape=True) + + +def _validate_template_path(template_path: Union[str, Path]) -> Path: + """Validate template path to prevent path traversal. + + Args: + template_path: Path to template file + + Returns: + Validated Path object + + Raises: + ValueError: If path is invalid or contains traversal attempts + """ + path = Path(template_path) + + # Check for path traversal + if ".." in str(path): + raise ValueError(f"Path traversal detected in template path: {template_path}") + + # Resolve to absolute path + try: + resolved = path.resolve() + except (OSError, RuntimeError) as e: + raise ValueError(f"Invalid template path: {template_path}") from e + + return resolved + def load_prompt_template(template_path: Union[str, Path]) -> Template: - """Load Jinja2 template from file. + """Load Jinja2 template from file with security validation. Args: template_path: Path to template file Returns: - Jinja2 Template object + Jinja2 Template object from sandboxed environment Raises: FileNotFoundError: If template file doesn't exist + ValueError: If template path is invalid """ - if not os.path.exists(template_path): + # Validate path to prevent traversal + safe_path = _validate_template_path(template_path) + + if not safe_path.exists(): raise FileNotFoundError(f"Template file not found: {template_path}") - with open(template_path, "r", encoding="utf-8") as f: + + # Load template using sandboxed environment + with open(safe_path, "r", encoding="utf-8") as f: template_content = f.read() - return Template(template_content) + + # Use sandboxed environment to prevent code injection + return _sandboxed_env.from_string(template_content) + + +def _safe_render(template: Template, **kwargs: Any) -> str: + """Safely render a template with escaped user inputs. + + Args: + template: Jinja2 Template object + **kwargs: Template variables (will be escaped if strings) + + Returns: + Rendered template string + """ + # Escape all string inputs to prevent XSS and injection + safe_kwargs = {} + for key, value in kwargs.items(): + if isinstance(value, str): + safe_kwargs[key] = escape(value) + elif isinstance(value, list): + # Escape string items in lists + safe_kwargs[key] = [escape(str(item)) if isinstance(item, str) else item for item in value] + else: + safe_kwargs[key] = value + + return template.render(**safe_kwargs) @app.prompt(output="q_ls,template->prompt_ls") @@ -45,7 +109,7 @@ def qa_boxed(q_ls: List[str], template: Union[str, Path]) -> List[PromptMessage] template: Template = load_prompt_template(template) ret = [] for q in q_ls: - p = template.render(question=q) + p = _safe_render(template, question=q) ret.append(p) return ret @@ -71,7 +135,7 @@ def qa_boxed_multiple_choice( CHOICES: List[str] = list(string.ascii_uppercase) # A, B, ..., Z for q, choices in zip(q_ls, choices_ls): choices_text = "\n".join(f"{CHOICES[i]}: {c}" for i, c in enumerate(choices)) - p = template.render(question=q, choices=choices_text) + p = _safe_render(template, question=q, choices=choices_text) ret.append(p) return ret @@ -94,7 +158,7 @@ def qa_rag_boxed( ret = [] for q, psg in zip(q_ls, ret_psg): passage_text = "\n".join(psg) - p = template.render(question=q, documents=passage_text) + p = _safe_render(template, question=q, documents=passage_text) ret.append(p) return ret @@ -123,7 +187,7 @@ def qa_rag_boxed_multiple_choice( for q, psg, choices in zip(q_ls, ret_psg, choices_ls): passage_text = "\n".join(psg) choices_text = "\n".join(f"{CHOICES[i]}: {c}" for i, c in enumerate(choices)) - p = template.render(question=q, documents=passage_text, choices=choices_text) + p = _safe_render(template, question=q, documents=passage_text, choices=choices_text) ret.append(p) return ret @@ -148,7 +212,7 @@ def RankCoT_kr( ret = [] for q, psg in zip(q_ls, ret_psg): passage_text = "\n".join(psg) - p = template.render(question=q, documents=passage_text) + p = _safe_render(template, question=q, documents=passage_text) ret.append(p) return ret @@ -172,7 +236,7 @@ def RankCoT_qa( template: Template = load_prompt_template(template) ret = [] for q, cot in zip(q_ls, kr_ls): - p = template.render(question=q, CoT=cot) + p = _safe_render(template, question=q, CoT=cot) ret.append(p) return ret @@ -202,7 +266,7 @@ def ircot_next_prompt( continue passage_text = "" if psg is None else "\n".join(psg) ret.append( - template.render(documents=passage_text, question=q, cur_answer="") + _safe_render(template, documents=passage_text, question=q, cur_answer="") ) return ret # Multi turn @@ -228,7 +292,7 @@ def ircot_next_prompt( cur_answer = " ".join(all_cots).strip() q = memory_q_ls[0][i] ret.append( - template.render(documents=passage_text, question=q, cur_answer=cur_answer) + _safe_render(template, documents=passage_text, question=q, cur_answer=cur_answer) ) return ret @@ -252,7 +316,7 @@ def webnote_init_page( template: Template = load_prompt_template(template) all_prompts = [] for q, plan in zip(q_ls, plan_ls): - p = template.render(question=q, plan=plan) + p = _safe_render(template, question=q, plan=plan) all_prompts.append(p) return all_prompts @@ -274,7 +338,7 @@ def webnote_gen_plan( template: Template = load_prompt_template(template) all_prompts = [] for q in q_ls: - p = template.render(question=q) + p = _safe_render(template, question=q) all_prompts.append(p) return all_prompts @@ -300,7 +364,7 @@ def webnote_gen_subq( template: Template = load_prompt_template(template) all_prompts = [] for q, plan, page in zip(q_ls, plan_ls, page_ls): - p = template.render(question=q, plan=plan, page=page) + p = _safe_render(template, question=q, plan=plan, page=page) all_prompts.append(p) return all_prompts @@ -332,8 +396,8 @@ def webnote_fill_page( template: Template = load_prompt_template(template) all_prompts = [] for q, plan, page, subq, psg in zip(q_ls, plan_ls, page_ls, subq_ls, psg_ls): - p = template.render( - question=q, plan=plan, sub_question=subq, docs_text=psg, page=page + p = _safe_render( + template, question=q, plan=plan, sub_question=subq, docs_text=psg, page=page ) all_prompts.append(p) return all_prompts @@ -358,7 +422,7 @@ def webnote_gen_answer( template: Template = load_prompt_template(template) all_prompts = [] for q, page in zip(q_ls, page_ls): - p = template.render(page=page, question=q) + p = _safe_render(template, page=page, question=q) all_prompts.append(p) return all_prompts @@ -387,7 +451,7 @@ def search_r1_gen( passages = psg[:3] passage_text = "\n".join(passages) _pro = prompt.content.text - p = template.render(history=_pro, answer=ans, passages=passage_text) + p = _safe_render(template, history=_pro, answer=ans, passages=passage_text) ret.append(p) return ret @@ -416,7 +480,7 @@ def r1_searcher_gen( passages = psg[:5] passage_text = "\n".join(passages) _pro = prompt.content.text - p = template.render(history=_pro, answer=ans, passages=passage_text) + p = _safe_render(template, history=_pro, answer=ans, passages=passage_text) ret.append(p) return ret @@ -439,7 +503,7 @@ def search_o1_init( ret = [] for q in q_ls: - p = template.render(question=q) + p = _safe_render(template, question=q) ret.append(p) return ret @@ -482,7 +546,8 @@ def search_o1_reasoning_indocument( ] formatted_history_str = "\n\n".join(formatted_history_parts) - p = template.render( + p = _safe_render( + template, prev_reasoning=formatted_history_str, search_query=squery, document=passage_text, @@ -515,7 +580,7 @@ def search_o1_insert( template: Template = load_prompt_template(template) prompt_ls = [] for q in q_ls: - p = template.render(question=q) + p = _safe_render(template, question=q) prompt_ls.append(p) ret = [] @@ -559,7 +624,7 @@ def gen_subq( all_prompts = [] for q, psg in zip(q_ls, ret_psg): passage_text = "\n".join(psg) - p = template.render(question=q, documents=passage_text) + p = _safe_render(template, question=q, documents=passage_text) all_prompts.append(p) return all_prompts @@ -584,7 +649,7 @@ def check_passages( all_prompts = [] for q, psg in zip(q_ls, ret_psg): passage_text = "\n".join(psg) - p = template.render(question=q, documents=passage_text) + p = _safe_render(template, question=q, documents=passage_text) all_prompts.append(p) return all_prompts @@ -606,7 +671,7 @@ def evisrag_vqa( template: Template = load_prompt_template(template) ret = [] for q, psg in zip(q_ls, ret_psg): - p = template.render(question=q) + p = _safe_render(template, question=q) p = p.replace("", "" * len(psg)) ret.append(p) return ret @@ -902,7 +967,8 @@ def surveycpm_search( else: survey_str = _print_tasknote(survey, abbr=True) - p = template.render( + p = _safe_render( + template, user_query=instruction, current_outline=survey_str, current_instruction=f"You need to update {cursor}", @@ -933,7 +999,7 @@ def surveycpm_init_plan( ret = [] for instruction, retrieved_info in zip(instruction_ls, retrieved_info_ls): info = retrieved_info if retrieved_info != "" else "" - p = template.render(user_query=instruction, current_information=info) + p = _safe_render(template, user_query=instruction, current_information=info) ret.append(p) return ret @@ -972,7 +1038,8 @@ def surveycpm_write( ) info = retrieved_info if retrieved_info != "" else "" survey_str = _print_tasknote_hire(survey, last_detail=True) - p = template.render( + p = _safe_render( + template, user_query=instruction, current_survey=survey_str, current_instruction=f"You need to update {cursor}", @@ -1008,7 +1075,7 @@ def surveycpm_extend_plan( json.loads(survey_json) if survey_json and survey_json != "" else {} ) survey_str = _print_tasknote(survey, abbr=False) - p = template.render(user_query=instruction, current_survey=survey_str) + p = _safe_render(template, user_query=instruction, current_survey=survey_str) ret.append(p) return ret diff --git a/servers/retriever/src/index_backends/milvus_backend.py b/servers/retriever/src/index_backends/milvus_backend.py index 68e4507..3e96042 100644 --- a/servers/retriever/src/index_backends/milvus_backend.py +++ b/servers/retriever/src/index_backends/milvus_backend.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import re from pathlib import Path from typing import Any, List, Optional, Sequence @@ -62,6 +63,25 @@ def __init__( self.client = None + @staticmethod + def _validate_collection_name(name: str) -> bool: + """Validate collection name to prevent injection attacks. + + Args: + name: Collection name to validate + + Returns: + True if valid, False otherwise + """ + if not name or not isinstance(name, str): + return False + # Only allow alphanumeric characters, underscores, and hyphens + # Maximum length check + if len(name) > 255: + return False + # Pattern: alphanumeric, underscore, hyphen only + return bool(re.match(r'^[a-zA-Z0-9_-]+$', name)) + def _resolve_index_path(self, index_path: Optional[str]) -> str: """Resolve Milvus URI from config. @@ -115,7 +135,15 @@ def _ensure_collection( Raises: RuntimeError: If collection creation fails + ValueError: If collection name is invalid """ + # Validate collection name to prevent injection + if not self._validate_collection_name(collection_name): + raise ValueError( + f"[milvus] Invalid collection name: '{collection_name}'. " + "Collection names must contain only alphanumeric characters, underscores, and hyphens." + ) + client = self._client_connect() has_collection = client.has_collection(collection_name) @@ -208,6 +236,13 @@ def build_index( client = self._client_connect() target_collection = kwargs.get("collection_name", self.collection_name) + + # Validate collection name to prevent injection + if target_collection and not self._validate_collection_name(target_collection): + raise ValueError( + f"[milvus] Invalid collection name: '{target_collection}'. " + "Collection names must contain only alphanumeric characters, underscores, and hyphens." + ) passed_contents = kwargs.get("contents", None) passed_metadatas = kwargs.get("metadatas", None) @@ -293,6 +328,13 @@ def search( client = self._client_connect() target_collection = kwargs.get("collection_name", self.collection_name) + + # Validate collection name to prevent injection + if target_collection and not self._validate_collection_name(target_collection): + raise ValueError( + f"[milvus] Invalid collection name: '{target_collection}'. " + "Collection names must contain only alphanumeric characters, underscores, and hyphens." + ) query_embeddings = np.asarray(query_embeddings, dtype=np.float32, order="C") if query_embeddings.ndim != 2: diff --git a/ui/backend/app.py b/ui/backend/app.py index 76883ad..eb11985 100644 --- a/ui/backend/app.py +++ b/ui/backend/app.py @@ -1430,6 +1430,23 @@ def parse_ai_actions(content: str, context: Dict) -> list: if __name__ == "__main__": + import os + logging.basicConfig(level=logging.INFO) app = create_app() - app.run(host="0.0.0.0", port=5050, debug=True) + + # Security: Use environment variables to control debug mode and host + # Never enable debug mode in production or expose debugger to network + debug_mode = os.getenv("FLASK_DEBUG", "False").lower() == "true" + host = os.getenv("FLASK_HOST", "127.0.0.1" if debug_mode else "0.0.0.0") + port = int(os.getenv("FLASK_PORT", "5050")) + + # Additional safety: Never allow debug=True with host='0.0.0.0' + if debug_mode and host == "0.0.0.0": + app.logger.warning( + "Security warning: Debug mode should not be enabled with host='0.0.0.0'. " + "Using host='127.0.0.1' instead." + ) + host = "127.0.0.1" + + app.run(host=host, port=port, debug=debug_mode) diff --git a/ui/backend/pipeline_manager.py b/ui/backend/pipeline_manager.py index 13bdd09..a130e28 100644 --- a/ui/backend/pipeline_manager.py +++ b/ui/backend/pipeline_manager.py @@ -144,7 +144,8 @@ def _make_safe_collection_name(display_name: str) -> tuple[str, str]: base = _normalize_collection_name(display_name) if not base: - digest = hashlib.sha1(display_name.encode("utf-8")).hexdigest()[:8] + # Use SHA-256 instead of SHA-1 for security + digest = hashlib.sha256(display_name.encode("utf-8")).hexdigest()[:8] base = f"kb_{digest}" safe_name = base