diff --git a/examples/async_optimization_example.py b/examples/async_optimization_example.py index 9455d93a..56536112 100644 --- a/examples/async_optimization_example.py +++ b/examples/async_optimization_example.py @@ -9,6 +9,12 @@ - Coordinate multiple async trace operations """ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import asyncio import time import random @@ -367,4 +373,4 @@ async def main(): if __name__ == "__main__": # Run the async main function - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/examples/battleship.py b/examples/battleship.py index a18d99c7..8ce5c4ab 100644 --- a/examples/battleship.py +++ b/examples/battleship.py @@ -1,5 +1,10 @@ import random import copy +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) def create_battleship_board(width, height): diff --git a/examples/bbh/run_prompt_bigbench_dspy.py b/examples/bbh/run_prompt_bigbench_dspy.py index 7b6a3e98..3e63fa76 100644 --- a/examples/bbh/run_prompt_bigbench_dspy.py +++ b/examples/bbh/run_prompt_bigbench_dspy.py @@ -1,5 +1,9 @@ +import os import sys +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import dspy from datasets import load_dataset from dspy.evaluate import Evaluate diff --git a/examples/bbh/run_prompt_bigbench_trace.py b/examples/bbh/run_prompt_bigbench_trace.py index 23564649..994d09f6 100644 --- a/examples/bbh/run_prompt_bigbench_trace.py +++ b/examples/bbh/run_prompt_bigbench_trace.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto.trace.nodes import node, GRAPH, ParameterNode from textwrap import dedent from opto.optimizers import OptoPrime diff --git a/examples/greeting.py b/examples/greeting.py index 280bceb7..ccfabe8b 100644 --- a/examples/greeting.py +++ b/examples/greeting.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace from opto.trace import node, bundle, model, ExecutionError from opto.optimizers import OptoPrime diff --git a/examples/gsm8k_trainer_example.py b/examples/gsm8k_trainer_example.py index 9b234d14..156ed904 100644 --- a/examples/gsm8k_trainer_example.py +++ b/examples/gsm8k_trainer_example.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import datasets import numpy as np from opto import trainer diff --git a/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py b/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py index c5689a97..f9a0e87a 100644 --- a/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py +++ b/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import re from opto.trace.nodes import node, GRAPH, ParameterNode from textwrap import dedent diff --git a/examples/priority_search_example.py b/examples/priority_search_example.py index b63e04a7..eb62f8d3 100644 --- a/examples/priority_search_example.py +++ b/examples/priority_search_example.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import datasets import numpy as np from opto import trace diff --git a/examples/priority_search_on_convex_fn.py b/examples/priority_search_on_convex_fn.py index 8122dde7..a48f06c8 100644 --- a/examples/priority_search_on_convex_fn.py +++ b/examples/priority_search_on_convex_fn.py @@ -1,5 +1,10 @@ -import re +import os import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + +import re import string import numpy as np from opto.trace.utils import dedent @@ -258,4 +263,4 @@ def get_feedback(self, query: str, response: str, reference=None, **kwargs) -> T num_proposals=4, memory_update_frequency=2, optimizer_kwargs={'objective':"You have a task of guessing two numbers. You should make sure your guess minimizes y.", 'memory_size': 10} -) \ No newline at end of file +) diff --git a/examples/priority_search_on_convex_fn_BENCH.py b/examples/priority_search_on_convex_fn_BENCH.py index 8f1a974e..c85017ea 100644 --- a/examples/priority_search_on_convex_fn_BENCH.py +++ b/examples/priority_search_on_convex_fn_BENCH.py @@ -1,5 +1,10 @@ -import re +import os import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + +import re import string import numpy as np import time @@ -215,4 +220,4 @@ def run_algorithm_comparison(): if __name__ == "__main__": - results = run_algorithm_comparison() \ No newline at end of file + results = run_algorithm_comparison() diff --git a/examples/search_algo_example.py b/examples/search_algo_example.py index 5b04d11c..64561f54 100644 --- a/examples/search_algo_example.py +++ b/examples/search_algo_example.py @@ -1,5 +1,10 @@ # Standard library imports import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import time import argparse from typing import Any, Tuple @@ -348,4 +353,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/textgrad_examples/evals/textgrad_prompt_optimization.py b/examples/textgrad_examples/evals/textgrad_prompt_optimization.py index e87b3b9e..c61ca80c 100644 --- a/examples/textgrad_examples/evals/textgrad_prompt_optimization.py +++ b/examples/textgrad_examples/evals/textgrad_prompt_optimization.py @@ -1,5 +1,11 @@ # This script applies Trace to optimize the workflow in TextGrad's prompt_optimization.py. +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace from opto.optimizers import OptoPrime, TextGrad import time @@ -217,4 +223,4 @@ def concat(*items): import os os.makedirs("textgrad_figures", exist_ok=True) with open(f"./textgrad_figures/results_{args.task}_{args.test_engine}_{args.algo}_{args.seed}.json", "w") as f: - json.dump(results, f) \ No newline at end of file + json.dump(results, f) diff --git a/examples/textgrad_examples/evals/textgrad_solution_optimization.py b/examples/textgrad_examples/evals/textgrad_solution_optimization.py index abd2cfc5..d5f38458 100644 --- a/examples/textgrad_examples/evals/textgrad_solution_optimization.py +++ b/examples/textgrad_examples/evals/textgrad_solution_optimization.py @@ -1,5 +1,11 @@ # This script applies Trace to optimize the workflow in TextGrad's solution_optimization.py. +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace from opto.optimizers import OptoPrime, TextGrad, OptoPrimeMulti diff --git a/examples/train_model.py b/examples/train_model.py index 23361c74..53841638 100644 --- a/examples/train_model.py +++ b/examples/train_model.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import datasets import numpy as np from opto import trainer diff --git a/examples/train_single_node.py b/examples/train_single_node.py index 13a903fc..515318bb 100644 --- a/examples/train_single_node.py +++ b/examples/train_single_node.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace, trainer def main(): diff --git a/examples/train_single_node_multi_optimizers.py.py b/examples/train_single_node_multi_optimizers.py.py index 6bbcb7b2..0e47c925 100644 --- a/examples/train_single_node_multi_optimizers.py.py +++ b/examples/train_single_node_multi_optimizers.py.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace, trainer from opto.optimizers.optoprime_v2 import OptimizerPromptSymbolSet diff --git a/examples/virtualhome.py b/examples/virtualhome.py index ef392569..4aab560b 100644 --- a/examples/virtualhome.py +++ b/examples/virtualhome.py @@ -5,6 +5,10 @@ import re from collections import defaultdict from difflib import SequenceMatcher +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) import autogen diff --git a/opto/features/priority_search/search_template.py b/opto/features/priority_search/search_template.py index ec244f74..0a1c336f 100644 --- a/opto/features/priority_search/search_template.py +++ b/opto/features/priority_search/search_template.py @@ -71,15 +71,25 @@ def check_optimizer_parameters(optimizer: Optimizer, agent: trace.Module): def save_train_config(function): """ Decorator to save the inputs of a class method. """ - def wrapper(self, **kwargs): + + def wrapper(self, *args, **kwargs): + # Backward-compat: some wrappers call ``train(guide, train_dataset, ...)`` positionally. + # Normalize them into keyword args before forwarding. + if args: + if len(args) >= 1 and "guide" not in kwargs: + kwargs["guide"] = args[0] + if len(args) >= 2 and "train_dataset" not in kwargs: + kwargs["train_dataset"] = args[1] + _kwargs = kwargs.copy() - del _kwargs['train_dataset'] # remove train_dataset from the saved kwargs - if _kwargs.get('validate_dataset') is not None: - del _kwargs['validate_dataset'] # remove validate_dataset from the saved kwargs - if _kwargs.get('test_dataset') is not None: - del _kwargs['test_dataset'] # remove test_dataset from the saved kwargs - setattr(self, f'_train_last_kwargs', _kwargs) + _kwargs.pop("train_dataset", None) # remove train_dataset from the saved kwargs + if _kwargs.get("validate_dataset") is not None: + _kwargs.pop("validate_dataset", None) # remove validate_dataset from the saved kwargs + if _kwargs.get("test_dataset") is not None: + _kwargs.pop("test_dataset", None) # remove test_dataset from the saved kwargs + setattr(self, "_train_last_kwargs", _kwargs) return function(self, **kwargs) + return wrapper class SearchTemplate(Trainer): diff --git a/tests/unit_tests/test_logger_override.py b/tests/unit_tests/test_logger_override.py new file mode 100644 index 00000000..c7cc5876 --- /dev/null +++ b/tests/unit_tests/test_logger_override.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from opto.features.priority_search.search_template import save_train_config + + +class _DummyTrainer: + @save_train_config + def train(self, *, guide, train_dataset, batch_size=0, validate_dataset=None, test_dataset=None): + return { + "guide": guide, + "train_dataset": train_dataset, + "batch_size": batch_size, + "validate_dataset": validate_dataset, + "test_dataset": test_dataset, + } + + +def test_save_train_config_accepts_positional_guide_and_dataset(): + trainer = _DummyTrainer() + train_dataset = {"inputs": [], "infos": []} + + result = trainer.train("guide", train_dataset, batch_size=1) + + assert result["guide"] == "guide" + assert result["train_dataset"] is train_dataset + assert result["batch_size"] == 1 + assert trainer._train_last_kwargs == { + "guide": "guide", + "batch_size": 1, + }