From 946d587138fac78e62c758f6a161516d2e44578c Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Tue, 31 Mar 2026 10:16:33 -0700 Subject: [PATCH 1/4] Adding high level memory --- configs/examples/task_segments.json | 26 ++ src/opentau/scripts/pi_mem_data_generator.py | 353 +++++++++++++++++++ 2 files changed, 379 insertions(+) create mode 100644 configs/examples/task_segments.json create mode 100644 src/opentau/scripts/pi_mem_data_generator.py diff --git a/configs/examples/task_segments.json b/configs/examples/task_segments.json new file mode 100644 index 00000000..069c36bf --- /dev/null +++ b/configs/examples/task_segments.json @@ -0,0 +1,26 @@ +[ + { + "time": 12.0, + "subtask": "pick up the cup and place in the tray", + "success": true, + "prompt": "Pick up a blue bottle and a cup and place it in the tray." + }, + { + "time": 20, + "subtask": "pick up the bottle and place in tray", + "success": false, + "prompt": "Pick up a blue bottle and a cup and place it in the tray." + }, + { + "time": 28, + "subtask": "pick up the bottle and place in tray", + "success": true, + "prompt": "Pick up a blue bottle and a cup and place it in the tray." + }, + { + "time": 34, + "subtask": "reset", + "success": true, + "prompt": "Pick up a blue bottle and a cup and place it in the tray." + } +] diff --git a/src/opentau/scripts/pi_mem_data_generator.py b/src/opentau/scripts/pi_mem_data_generator.py new file mode 100644 index 00000000..4b97a3b3 --- /dev/null +++ b/src/opentau/scripts/pi_mem_data_generator.py @@ -0,0 +1,353 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Enrich a JSON array with OpenAI chat completions (one call per field per object). + +Top-level JSON must be a **list of objects**. For each object, ``subtask`` and the +indicator are read from that object: use ``indicator`` if present, otherwise +``success``. Those keys are prepended to every user message and are not sent as +separate per-field API calls. Replies go under ``memory`` as +``{ field_name: model_reply, ... }``. + +Examples:: + + export OPENAI_API_KEY=sk-... + python -m opentau.scripts.pi_mem_data_generator task_segments.json + + # Pass the key explicitly (overrides .env / environment): + python -m opentau.scripts.pi_mem_data_generator task_segments.json --api-key sk-... + + # Verify .env is found and OPENAI_API_KEY is loaded: + python -m opentau.scripts.pi_mem_data_generator --check-env +""" + +from __future__ import annotations + +import argparse +import json +import logging +import os +import sys +import time +from pathlib import Path +from typing import Any + +try: + from dotenv import load_dotenv +except ImportError: + load_dotenv = None # type: ignore[misc, assignment] + +from openai import OpenAI + +logger = logging.getLogger(__name__) + + +def _apply_env_file_lines(path: Path, *, override: bool) -> None: + """Minimal ``.env`` reader when ``python-dotenv`` is not installed.""" + try: + text = path.read_text(encoding="utf-8") + except OSError as e: + logger.warning("Could not read %s: %s", path, e) + return + for line in text.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("export "): + line = line[7:].strip() + if "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + value = value.strip().strip('"').strip("'") + if not key: + continue + if override or key not in os.environ: + os.environ[key] = value + + +def _resolve_dotenv_path() -> Path | None: + """First ``.env`` file found walking up from this script (repo root typically).""" + script = Path(__file__).resolve() + for d in (script.parent, *script.parents): + candidate = d / ".env" + if candidate.is_file(): + return candidate + return None + + +def _load_env_file() -> Path | None: + """Load ``.env`` walking up from this script; ``override=True`` so ``.env`` beats shell. + + Returns the path to the loaded ``.env`` if found in the walk, else ``None``. + """ + path = _resolve_dotenv_path() + if path is not None: + if load_dotenv: + load_dotenv(path, override=True) + else: + _apply_env_file_lines(path, override=True) + logger.debug("Loaded environment file %s", path) + return path + if load_dotenv: + load_dotenv(override=True) + return None + + +def _mask_api_key_preview(key: str) -> str: + """Safe one-line description (never print the full key).""" + n = len(key) + if n <= 12: + return f"length {n} (too short to show prefix/suffix safely)" + return f"prefix {key[:8]}… suffix …{key[-4:]} (length {n})" + + +def _run_check_env(*, api_key_override: str | None) -> int: + """Print whether ``.env`` and ``OPENAI_API_KEY`` are visible after the same load as a normal run.""" + dotenv_path = _resolve_dotenv_path() + if dotenv_path is not None: + print(f"OK: found .env at {dotenv_path}") + else: + print("No .env file in any parent directory of the script (only shell env applies).") + + if api_key_override: + os.environ["OPENAI_API_KEY"] = api_key_override.strip() + + key = _normalize_openai_api_key() + if key: + print(f"OK: OPENAI_API_KEY is set — {_mask_api_key_preview(key)}") + return 0 + + print( + "FAIL: OPENAI_API_KEY missing or empty after load. " + "Use one line in .env: OPENAI_API_KEY=sk-... (no spaces around =).", + ) + return 1 + + +def _normalize_openai_api_key() -> str | None: + raw = os.environ.get("OPENAI_API_KEY") + if raw is None: + return None + key = raw.strip().strip('"').strip("'") + if not key: + return None + os.environ["OPENAI_API_KEY"] = key + return key + + +SYSTEM_PROMPT = """\ +You are the memory module of a robotic manipulation system. You receive a log \ +of subtasks that have ALREADY been executed and must produce a compact \ +plain-text summary. + +Critical rules: +- ONLY mention actions that appear in the log below. If an action is not in \ +the log, it has NOT happened — do NOT mention it, do NOT infer it, do NOT \ +speculate about it. You have zero knowledge beyond the log entries provided. +- Write simple, plain sentences. No bullet points, no numbered lists, \ +no labels, no markdown, no structured formatting. +- If the same action failed earlier but succeeded later in the log, just \ +mention the success. Drop the resolved failure. +- If a failure is the last entry for that action in the log, mention it. +- Merge completed actions into short phrases where possible. +- Omit timestamps. +- Keep it under 50 words.\ +""" + +USER_PROMPT_TEMPLATE = """\ +Here is the complete log of actions executed so far: +{subtask_log} + +Write a plain-text summary covering ONLY the actions listed above. \ +Do not mention or infer any action that is not in this log.\ +""" + + +def _call_openai( + client: OpenAI, + *, + model: str, + system_prompt: str | None, + user_content: str, +) -> str: + messages: list[dict[str, str]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": user_content}) + resp = client.chat.completions.create(model=model, messages=messages) + choice = resp.choices[0].message + text = choice.content + if not text: + raise RuntimeError("OpenAI returned empty message content") + return text.strip() + + +def _build_subtask_log(data: list[dict[str, Any]], up_to: int) -> str: + """Format subtasks 0..up_to (inclusive) as a numbered list for the prompt.""" + lines: list[str] = [] + for idx in range(up_to + 1): + item = data[idx] + subtask = item.get("subtask", "unknown") + success = item.get("success") + outcome = "SUCCESS" if success else "FAILED" if success is False else "UNKNOWN" + t = item.get("time") + time_str = f" (t={t}s)" if t is not None else "" + lines.append(f" {idx + 1}. [{outcome}]{time_str} {subtask}") + return "\n".join(lines) + + +def _enrich_list( + data: list[Any], + *, + client: OpenAI, + model: str, + system_prompt: str, + output_key: str, + skip_existing: bool, + delay_s: float, +) -> None: + """For each item, build a cumulative subtask log and request a memory summary.""" + for i, item in enumerate(data): + if not isinstance(item, dict): + raise ValueError(f"Item at index {i} must be an object, got {type(item).__name__}") + if skip_existing and output_key in item and item[output_key]: + logger.info("Skipping index %d (existing %s)", i, output_key) + continue + + subtask_log = _build_subtask_log(data, up_to=i) + user_content = USER_PROMPT_TEMPLATE.format(subtask_log=subtask_log) + logger.info("Calling API for item %d (subtask: %s)", i, item.get("subtask")) + item[output_key] = _call_openai( + client, model=model, system_prompt=system_prompt, user_content=user_content + ) + if delay_s > 0: + time.sleep(delay_s) + + +def main(argv: list[str] | None = None) -> int: + _load_env_file() + + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + p.add_argument( + "--check-env", + action="store_true", + help="Load .env like a normal run, print whether OPENAI_API_KEY is set (masked), then exit.", + ) + p.add_argument( + "json_path", + type=Path, + nargs="?", + default=None, + help="Path to JSON file to read and update in place", + ) + p.add_argument( + "--system-prompt", + type=str, + default=None, + help="Optional system message for the chat completion.", + ) + p.add_argument( + "--model", + type=str, + default=os.environ.get("OPENAI_MODEL", "gpt-4o-mini"), + help="Chat model (default: gpt-4o-mini or OPENAI_MODEL env).", + ) + p.add_argument( + "--output-key", + type=str, + default="up_to_date_memory", + help="Key for per-field replies: a dict mapping each source field name to model text.", + ) + p.add_argument( + "--skip-existing", + action="store_true", + help="Do not call API if output-key already set (non-empty).", + ) + p.add_argument( + "--delay", + type=float, + default=0.0, + help="Seconds to sleep between API calls (rate limiting).", + ) + p.add_argument( + "--api-key", + type=str, + default=None, + help="OpenAI API key (default: OPENAI_API_KEY from .env or environment).", + ) + p.add_argument("-v", "--verbose", action="store_true", help="DEBUG logging") + args = p.parse_args(argv) + + logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) + + if args.check_env: + return _run_check_env(api_key_override=args.api_key) + + if args.json_path is None: + p.error("json_path is required (unless using --check-env)") + + if args.api_key: + os.environ["OPENAI_API_KEY"] = args.api_key.strip() + + api_key = _normalize_openai_api_key() + if not api_key: + logger.error( + "Missing OPENAI_API_KEY. Put it in repo .env, export it, or pass --api-key. " + "Run with --check-env to diagnose.", + ) + return 1 + + path = args.json_path.resolve() + if not path.is_file(): + logger.error("Not a file: %s", path) + return 1 + + text = path.read_text(encoding="utf-8") + try: + data = json.loads(text) + except json.JSONDecodeError as e: + logger.error("Invalid JSON: %s", e) + return 1 + + client = OpenAI(api_key=api_key) + if args.verbose: + tail = api_key[-4:] if len(api_key) >= 4 else "****" + logger.debug("OpenAI client using API key ending in …%s (length %d)", tail, len(api_key)) + + try: + if not isinstance(data, list): + logger.error("Top-level JSON must be a list of objects") + return 1 + system = args.system_prompt if args.system_prompt else SYSTEM_PROMPT + _enrich_list( + data, + client=client, + model=args.model, + system_prompt=system, + output_key=args.output_key, + skip_existing=args.skip_existing, + delay_s=args.delay, + ) + except Exception as e: + logger.exception("OpenAI or processing failed: %s", e) + return 1 + + path.write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8") + logger.info("Wrote %s", path) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) From fefdb80437493769616d7a1b9725aa7671d6bc38 Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Tue, 31 Mar 2026 13:08:28 -0700 Subject: [PATCH 2/4] Passing prev_memeory and high level prompt to gpt --- src/opentau/scripts/pi_mem_data_generator.py | 40 +++++++++----------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/src/opentau/scripts/pi_mem_data_generator.py b/src/opentau/scripts/pi_mem_data_generator.py index 4b97a3b3..fe5eed2b 100644 --- a/src/opentau/scripts/pi_mem_data_generator.py +++ b/src/opentau/scripts/pi_mem_data_generator.py @@ -25,9 +25,6 @@ export OPENAI_API_KEY=sk-... python -m opentau.scripts.pi_mem_data_generator task_segments.json - # Pass the key explicitly (overrides .env / environment): - python -m opentau.scripts.pi_mem_data_generator task_segments.json --api-key sk-... - # Verify .env is found and OPENAI_API_KEY is loaded: python -m opentau.scripts.pi_mem_data_generator --check-env """ @@ -113,7 +110,7 @@ def _mask_api_key_preview(key: str) -> str: return f"prefix {key[:8]}… suffix …{key[-4:]} (length {n})" -def _run_check_env(*, api_key_override: str | None) -> int: +def _run_check_env() -> int: """Print whether ``.env`` and ``OPENAI_API_KEY`` are visible after the same load as a normal run.""" dotenv_path = _resolve_dotenv_path() if dotenv_path is not None: @@ -121,9 +118,6 @@ def _run_check_env(*, api_key_override: str | None) -> int: else: print("No .env file in any parent directory of the script (only shell env applies).") - if api_key_override: - os.environ["OPENAI_API_KEY"] = api_key_override.strip() - key = _normalize_openai_api_key() if key: print(f"OK: OPENAI_API_KEY is set — {_mask_api_key_preview(key)}") @@ -167,6 +161,10 @@ def _normalize_openai_api_key() -> str | None: """ USER_PROMPT_TEMPLATE = """\ +Previous memory: {prev_memory} + +Task prompt: {task_prompt} + Here is the complete log of actions executed so far: {subtask_log} @@ -219,19 +217,27 @@ def _enrich_list( delay_s: float, ) -> None: """For each item, build a cumulative subtask log and request a memory summary.""" + prev_memory = "(none)" for i, item in enumerate(data): if not isinstance(item, dict): raise ValueError(f"Item at index {i} must be an object, got {type(item).__name__}") if skip_existing and output_key in item and item[output_key]: logger.info("Skipping index %d (existing %s)", i, output_key) + prev_memory = item[output_key] continue subtask_log = _build_subtask_log(data, up_to=i) - user_content = USER_PROMPT_TEMPLATE.format(subtask_log=subtask_log) + task_prompt = item.get("prompt", "(not provided)") + user_content = USER_PROMPT_TEMPLATE.format( + prev_memory=prev_memory, + task_prompt=task_prompt, + subtask_log=subtask_log, + ) logger.info("Calling API for item %d (subtask: %s)", i, item.get("subtask")) - item[output_key] = _call_openai( + prev_memory = _call_openai( client, model=model, system_prompt=system_prompt, user_content=user_content ) + item[output_key] = prev_memory if delay_s > 0: time.sleep(delay_s) @@ -267,7 +273,7 @@ def main(argv: list[str] | None = None) -> int: p.add_argument( "--output-key", type=str, - default="up_to_date_memory", + default="memory", help="Key for per-field replies: a dict mapping each source field name to model text.", ) p.add_argument( @@ -281,31 +287,21 @@ def main(argv: list[str] | None = None) -> int: default=0.0, help="Seconds to sleep between API calls (rate limiting).", ) - p.add_argument( - "--api-key", - type=str, - default=None, - help="OpenAI API key (default: OPENAI_API_KEY from .env or environment).", - ) p.add_argument("-v", "--verbose", action="store_true", help="DEBUG logging") args = p.parse_args(argv) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) if args.check_env: - return _run_check_env(api_key_override=args.api_key) + return _run_check_env() if args.json_path is None: p.error("json_path is required (unless using --check-env)") - if args.api_key: - os.environ["OPENAI_API_KEY"] = args.api_key.strip() - api_key = _normalize_openai_api_key() if not api_key: logger.error( - "Missing OPENAI_API_KEY. Put it in repo .env, export it, or pass --api-key. " - "Run with --check-env to diagnose.", + "Missing OPENAI_API_KEY. Put it in repo .env or export it. Run with --check-env to diagnose.", ) return 1 From 7249f410a1845ff670c6290ccdb8f843b6e87f57 Mon Sep 17 00:00:00 2001 From: akshay18iitg Date: Tue, 31 Mar 2026 13:28:46 -0700 Subject: [PATCH 3/4] Passing prev_memeory and high level prompt to gpt --- src/opentau/scripts/pi_mem_data_generator.py | 32 +------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/src/opentau/scripts/pi_mem_data_generator.py b/src/opentau/scripts/pi_mem_data_generator.py index fe5eed2b..359858a7 100644 --- a/src/opentau/scripts/pi_mem_data_generator.py +++ b/src/opentau/scripts/pi_mem_data_generator.py @@ -40,40 +40,12 @@ from pathlib import Path from typing import Any -try: - from dotenv import load_dotenv -except ImportError: - load_dotenv = None # type: ignore[misc, assignment] - +from dotenv import load_dotenv from openai import OpenAI logger = logging.getLogger(__name__) -def _apply_env_file_lines(path: Path, *, override: bool) -> None: - """Minimal ``.env`` reader when ``python-dotenv`` is not installed.""" - try: - text = path.read_text(encoding="utf-8") - except OSError as e: - logger.warning("Could not read %s: %s", path, e) - return - for line in text.splitlines(): - line = line.strip() - if not line or line.startswith("#"): - continue - if line.startswith("export "): - line = line[7:].strip() - if "=" not in line: - continue - key, _, value = line.partition("=") - key = key.strip() - value = value.strip().strip('"').strip("'") - if not key: - continue - if override or key not in os.environ: - os.environ[key] = value - - def _resolve_dotenv_path() -> Path | None: """First ``.env`` file found walking up from this script (repo root typically).""" script = Path(__file__).resolve() @@ -93,8 +65,6 @@ def _load_env_file() -> Path | None: if path is not None: if load_dotenv: load_dotenv(path, override=True) - else: - _apply_env_file_lines(path, override=True) logger.debug("Loaded environment file %s", path) return path if load_dotenv: From 3aca44bdf307fe5a073a0bf10ae1a9b6f88f05e2 Mon Sep 17 00:00:00 2001 From: Shuheng Liu Date: Thu, 2 Apr 2026 11:18:50 -0700 Subject: [PATCH 4/4] chore: cursor review --- src/opentau/scripts/pi_mem_data_generator.py | 86 +++------- tests/scripts/__init__.py | 13 ++ tests/scripts/test_pi_mem_data_generator.py | 168 +++++++++++++++++++ 3 files changed, 208 insertions(+), 59 deletions(-) create mode 100644 tests/scripts/__init__.py create mode 100644 tests/scripts/test_pi_mem_data_generator.py diff --git a/src/opentau/scripts/pi_mem_data_generator.py b/src/opentau/scripts/pi_mem_data_generator.py index 359858a7..66256112 100644 --- a/src/opentau/scripts/pi_mem_data_generator.py +++ b/src/opentau/scripts/pi_mem_data_generator.py @@ -12,21 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Enrich a JSON array with OpenAI chat completions (one call per field per object). +"""Iterate over a JSON array of task segments and generate cumulative memory summaries. -Top-level JSON must be a **list of objects**. For each object, ``subtask`` and the -indicator are read from that object: use ``indicator`` if present, otherwise -``success``. Those keys are prepended to every user message and are not sent as -separate per-field API calls. Replies go under ``memory`` as -``{ field_name: model_reply, ... }``. +Top-level JSON must be a **list of objects**. For each object the script builds a +cumulative subtask log (subtask names and ``success`` outcomes up to that point), +then calls OpenAI to produce a running plain-text memory summary. The summary is +written back to each object under the ``memory`` key (configurable via +``--output-key``). Examples:: export OPENAI_API_KEY=sk-... python -m opentau.scripts.pi_mem_data_generator task_segments.json - - # Verify .env is found and OPENAI_API_KEY is loaded: - python -m opentau.scripts.pi_mem_data_generator --check-env """ from __future__ import annotations @@ -36,12 +33,18 @@ import logging import os import sys +import tempfile import time from pathlib import Path from typing import Any -from dotenv import load_dotenv from openai import OpenAI +from openai.types.chat import ChatCompletionMessageParam + +try: + from dotenv import load_dotenv +except ImportError: + load_dotenv = None logger = logging.getLogger(__name__) @@ -49,7 +52,7 @@ def _resolve_dotenv_path() -> Path | None: """First ``.env`` file found walking up from this script (repo root typically).""" script = Path(__file__).resolve() - for d in (script.parent, *script.parents): + for d in script.parents: candidate = d / ".env" if candidate.is_file(): return candidate @@ -63,43 +66,15 @@ def _load_env_file() -> Path | None: """ path = _resolve_dotenv_path() if path is not None: - if load_dotenv: + if load_dotenv is not None: load_dotenv(path, override=True) logger.debug("Loaded environment file %s", path) return path - if load_dotenv: + if load_dotenv is not None: load_dotenv(override=True) return None -def _mask_api_key_preview(key: str) -> str: - """Safe one-line description (never print the full key).""" - n = len(key) - if n <= 12: - return f"length {n} (too short to show prefix/suffix safely)" - return f"prefix {key[:8]}… suffix …{key[-4:]} (length {n})" - - -def _run_check_env() -> int: - """Print whether ``.env`` and ``OPENAI_API_KEY`` are visible after the same load as a normal run.""" - dotenv_path = _resolve_dotenv_path() - if dotenv_path is not None: - print(f"OK: found .env at {dotenv_path}") - else: - print("No .env file in any parent directory of the script (only shell env applies).") - - key = _normalize_openai_api_key() - if key: - print(f"OK: OPENAI_API_KEY is set — {_mask_api_key_preview(key)}") - return 0 - - print( - "FAIL: OPENAI_API_KEY missing or empty after load. " - "Use one line in .env: OPENAI_API_KEY=sk-... (no spaces around =).", - ) - return 1 - - def _normalize_openai_api_key() -> str | None: raw = os.environ.get("OPENAI_API_KEY") if raw is None: @@ -150,7 +125,7 @@ def _call_openai( system_prompt: str | None, user_content: str, ) -> str: - messages: list[dict[str, str]] = [] + messages: list[ChatCompletionMessageParam] = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({"role": "user", "content": user_content}) @@ -216,16 +191,9 @@ def main(argv: list[str] | None = None) -> int: _load_env_file() p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) - p.add_argument( - "--check-env", - action="store_true", - help="Load .env like a normal run, print whether OPENAI_API_KEY is set (masked), then exit.", - ) p.add_argument( "json_path", type=Path, - nargs="?", - default=None, help="Path to JSON file to read and update in place", ) p.add_argument( @@ -262,17 +230,9 @@ def main(argv: list[str] | None = None) -> int: logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) - if args.check_env: - return _run_check_env() - - if args.json_path is None: - p.error("json_path is required (unless using --check-env)") - api_key = _normalize_openai_api_key() if not api_key: - logger.error( - "Missing OPENAI_API_KEY. Put it in repo .env or export it. Run with --check-env to diagnose.", - ) + logger.error("Missing OPENAI_API_KEY. Put it in repo .env or export it.") return 1 path = args.json_path.resolve() @@ -310,7 +270,15 @@ def main(argv: list[str] | None = None) -> int: logger.exception("OpenAI or processing failed: %s", e) return 1 - path.write_text(json.dumps(data, indent=2) + "\n", encoding="utf-8") + # Atomic os.replace (on Unix) to avoid partial JSON updates. + fd, tmp_path = tempfile.mkstemp(dir=path.parent, suffix=".tmp") + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + f.write(json.dumps(data, indent=2) + "\n") + os.replace(tmp_path, path) + except BaseException: + os.unlink(tmp_path) + raise logger.info("Wrote %s", path) return 0 diff --git a/tests/scripts/__init__.py b/tests/scripts/__init__.py new file mode 100644 index 00000000..787f750f --- /dev/null +++ b/tests/scripts/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/scripts/test_pi_mem_data_generator.py b/tests/scripts/test_pi_mem_data_generator.py new file mode 100644 index 00000000..084cce27 --- /dev/null +++ b/tests/scripts/test_pi_mem_data_generator.py @@ -0,0 +1,168 @@ +# Copyright 2026 Tensor Auto Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from opentau.scripts.pi_mem_data_generator import _build_subtask_log, _enrich_list, main + +SAMPLE_DATA = [ + {"time": 12.0, "subtask": "pick up cup", "success": True, "prompt": "Pick up a cup."}, + {"time": 20, "subtask": "pick up bottle", "success": False, "prompt": "Pick up a cup."}, + {"time": 28, "subtask": "pick up bottle", "success": True, "prompt": "Pick up a cup."}, +] + + +class TestBuildSubtaskLog: + def test_single_entry(self): + log = _build_subtask_log(SAMPLE_DATA, up_to=0) + assert log == " 1. [SUCCESS] (t=12.0s) pick up cup" + + def test_cumulative_entries(self): + log = _build_subtask_log(SAMPLE_DATA, up_to=2) + lines = log.split("\n") + assert len(lines) == 3 + assert "[SUCCESS]" in lines[0] + assert "[FAILED]" in lines[1] + assert "[SUCCESS]" in lines[2] + + def test_missing_success_shows_unknown(self): + data = [{"subtask": "test", "time": 1}] + log = _build_subtask_log(data, up_to=0) + assert "[UNKNOWN]" in log + + def test_missing_subtask_shows_unknown(self): + data = [{"success": True, "time": 5}] + log = _build_subtask_log(data, up_to=0) + assert "unknown" in log + + def test_missing_time_omits_timestamp(self): + data = [{"subtask": "reset", "success": True}] + log = _build_subtask_log(data, up_to=0) + assert "(t=" not in log + assert "reset" in log + + +def _make_mock_client(replies: list[str]) -> MagicMock: + """Build a mock OpenAI client that returns ``replies`` in order.""" + client = MagicMock() + responses = [] + for text in replies: + msg = SimpleNamespace(content=text) + resp = SimpleNamespace(choices=[SimpleNamespace(message=msg)]) + responses.append(resp) + client.chat.completions.create.side_effect = responses + return client + + +class TestEnrichList: + def test_basic_enrichment(self): + data = [dict(d) for d in SAMPLE_DATA] + client = _make_mock_client(["mem0", "mem1", "mem2"]) + _enrich_list( + data, + client=client, + model="test-model", + system_prompt="sys", + output_key="memory", + skip_existing=False, + delay_s=0, + ) + assert data[0]["memory"] == "mem0" + assert data[1]["memory"] == "mem1" + assert data[2]["memory"] == "mem2" + assert client.chat.completions.create.call_count == 3 + + def test_skip_existing(self): + data = [dict(d) for d in SAMPLE_DATA] + data[0]["memory"] = "already-set" + client = _make_mock_client(["mem1", "mem2"]) + _enrich_list( + data, + client=client, + model="m", + system_prompt="s", + output_key="memory", + skip_existing=True, + delay_s=0, + ) + assert data[0]["memory"] == "already-set" + assert data[1]["memory"] == "mem1" + assert client.chat.completions.create.call_count == 2 + + def test_previous_memory_propagation(self): + """The reply from step N is passed as prev_memory in step N+1's prompt.""" + data = [dict(d) for d in SAMPLE_DATA[:2]] + client = _make_mock_client(["first-summary", "second-summary"]) + _enrich_list( + data, + client=client, + model="m", + system_prompt="s", + output_key="memory", + skip_existing=False, + delay_s=0, + ) + second_call_kwargs = client.chat.completions.create.call_args_list[1] + user_msg = second_call_kwargs.kwargs["messages"][-1]["content"] + assert "first-summary" in user_msg + + def test_non_dict_item_raises(self): + with pytest.raises(ValueError, match="must be an object"): + _enrich_list( + ["not-a-dict"], + client=MagicMock(), + model="m", + system_prompt="s", + output_key="memory", + skip_existing=False, + delay_s=0, + ) + + +class TestMainMissingApiKey: + def test_exits_with_error_when_key_missing(self, tmp_path, monkeypatch): + input_file = tmp_path / "data.json" + input_file.write_text(json.dumps(SAMPLE_DATA), encoding="utf-8") + monkeypatch.delenv("OPENAI_API_KEY", raising=False) + with patch("opentau.scripts.pi_mem_data_generator._load_env_file"): + rc = main([str(input_file)]) + assert rc == 1 + + +class TestMainAtomicWrite: + def test_write_is_atomic(self, tmp_path, monkeypatch): + """After a successful run the file is updated and no .tmp files remain.""" + input_file = tmp_path / "data.json" + input_file.write_text(json.dumps(SAMPLE_DATA), encoding="utf-8") + + monkeypatch.setenv("OPENAI_API_KEY", "test") # gitleaks:allow + mock_client = _make_mock_client(["m0", "m1", "m2"]) + + with ( + patch("opentau.scripts.pi_mem_data_generator._load_env_file"), + patch("opentau.scripts.pi_mem_data_generator.OpenAI", return_value=mock_client), + ): + rc = main([str(input_file)]) + + assert rc == 0 + result = json.loads(input_file.read_text(encoding="utf-8")) + assert result[0]["memory"] == "m0" + tmp_files = list(tmp_path.glob("*.tmp")) + assert tmp_files == []