From 5bc5621695e11b90ae398fe8e4e1e80fe31e98a5 Mon Sep 17 00:00:00 2001 From: Awais Date: Fri, 6 Feb 2026 17:41:17 +0500 Subject: [PATCH 1/4] m1: add smoke guards for examples --- examples/async_optimization_example.py | 8 +++++++- examples/battleship.py | 5 +++++ examples/bbh/run_prompt_bigbench_dspy.py | 4 ++++ examples/bbh/run_prompt_bigbench_trace.py | 6 ++++++ examples/greeting.py | 6 ++++++ examples/gsm8k_trainer_example.py | 6 ++++++ examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py | 6 ++++++ examples/priority_search_example.py | 6 ++++++ examples/priority_search_on_convex_fn.py | 9 +++++++-- examples/priority_search_on_convex_fn_BENCH.py | 9 +++++++-- examples/search_algo_example.py | 7 ++++++- .../evals/textgrad_prompt_optimization.py | 8 +++++++- .../evals/textgrad_solution_optimization.py | 6 ++++++ examples/train_model.py | 6 ++++++ examples/train_single_node.py | 6 ++++++ examples/train_single_node_multi_optimizers.py.py | 6 ++++++ examples/virtualhome.py | 4 ++++ 17 files changed, 101 insertions(+), 7 deletions(-) diff --git a/examples/async_optimization_example.py b/examples/async_optimization_example.py index 9455d93a..56536112 100644 --- a/examples/async_optimization_example.py +++ b/examples/async_optimization_example.py @@ -9,6 +9,12 @@ - Coordinate multiple async trace operations """ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import asyncio import time import random @@ -367,4 +373,4 @@ async def main(): if __name__ == "__main__": # Run the async main function - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/examples/battleship.py b/examples/battleship.py index a18d99c7..8ce5c4ab 100644 --- a/examples/battleship.py +++ b/examples/battleship.py @@ -1,5 +1,10 @@ import random import copy +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) def create_battleship_board(width, height): diff --git a/examples/bbh/run_prompt_bigbench_dspy.py b/examples/bbh/run_prompt_bigbench_dspy.py index 7b6a3e98..3e63fa76 100644 --- a/examples/bbh/run_prompt_bigbench_dspy.py +++ b/examples/bbh/run_prompt_bigbench_dspy.py @@ -1,5 +1,9 @@ +import os import sys +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import dspy from datasets import load_dataset from dspy.evaluate import Evaluate diff --git a/examples/bbh/run_prompt_bigbench_trace.py b/examples/bbh/run_prompt_bigbench_trace.py index 23564649..994d09f6 100644 --- a/examples/bbh/run_prompt_bigbench_trace.py +++ b/examples/bbh/run_prompt_bigbench_trace.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto.trace.nodes import node, GRAPH, ParameterNode from textwrap import dedent from opto.optimizers import OptoPrime diff --git a/examples/greeting.py b/examples/greeting.py index 280bceb7..ccfabe8b 100644 --- a/examples/greeting.py +++ b/examples/greeting.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace from opto.trace import node, bundle, model, ExecutionError from opto.optimizers import OptoPrime diff --git a/examples/gsm8k_trainer_example.py b/examples/gsm8k_trainer_example.py index 9b234d14..156ed904 100644 --- a/examples/gsm8k_trainer_example.py +++ b/examples/gsm8k_trainer_example.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import datasets import numpy as np from opto import trainer diff --git a/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py b/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py index c5689a97..f9a0e87a 100644 --- a/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py +++ b/examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import re from opto.trace.nodes import node, GRAPH, ParameterNode from textwrap import dedent diff --git a/examples/priority_search_example.py b/examples/priority_search_example.py index b63e04a7..eb62f8d3 100644 --- a/examples/priority_search_example.py +++ b/examples/priority_search_example.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import datasets import numpy as np from opto import trace diff --git a/examples/priority_search_on_convex_fn.py b/examples/priority_search_on_convex_fn.py index 8122dde7..a48f06c8 100644 --- a/examples/priority_search_on_convex_fn.py +++ b/examples/priority_search_on_convex_fn.py @@ -1,5 +1,10 @@ -import re +import os import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + +import re import string import numpy as np from opto.trace.utils import dedent @@ -258,4 +263,4 @@ def get_feedback(self, query: str, response: str, reference=None, **kwargs) -> T num_proposals=4, memory_update_frequency=2, optimizer_kwargs={'objective':"You have a task of guessing two numbers. You should make sure your guess minimizes y.", 'memory_size': 10} -) \ No newline at end of file +) diff --git a/examples/priority_search_on_convex_fn_BENCH.py b/examples/priority_search_on_convex_fn_BENCH.py index 8f1a974e..c85017ea 100644 --- a/examples/priority_search_on_convex_fn_BENCH.py +++ b/examples/priority_search_on_convex_fn_BENCH.py @@ -1,5 +1,10 @@ -import re +import os import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + +import re import string import numpy as np import time @@ -215,4 +220,4 @@ def run_algorithm_comparison(): if __name__ == "__main__": - results = run_algorithm_comparison() \ No newline at end of file + results = run_algorithm_comparison() diff --git a/examples/search_algo_example.py b/examples/search_algo_example.py index 5b04d11c..64561f54 100644 --- a/examples/search_algo_example.py +++ b/examples/search_algo_example.py @@ -1,5 +1,10 @@ # Standard library imports import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import time import argparse from typing import Any, Tuple @@ -348,4 +353,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/textgrad_examples/evals/textgrad_prompt_optimization.py b/examples/textgrad_examples/evals/textgrad_prompt_optimization.py index e87b3b9e..c61ca80c 100644 --- a/examples/textgrad_examples/evals/textgrad_prompt_optimization.py +++ b/examples/textgrad_examples/evals/textgrad_prompt_optimization.py @@ -1,5 +1,11 @@ # This script applies Trace to optimize the workflow in TextGrad's prompt_optimization.py. +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace from opto.optimizers import OptoPrime, TextGrad import time @@ -217,4 +223,4 @@ def concat(*items): import os os.makedirs("textgrad_figures", exist_ok=True) with open(f"./textgrad_figures/results_{args.task}_{args.test_engine}_{args.algo}_{args.seed}.json", "w") as f: - json.dump(results, f) \ No newline at end of file + json.dump(results, f) diff --git a/examples/textgrad_examples/evals/textgrad_solution_optimization.py b/examples/textgrad_examples/evals/textgrad_solution_optimization.py index abd2cfc5..d5f38458 100644 --- a/examples/textgrad_examples/evals/textgrad_solution_optimization.py +++ b/examples/textgrad_examples/evals/textgrad_solution_optimization.py @@ -1,5 +1,11 @@ # This script applies Trace to optimize the workflow in TextGrad's solution_optimization.py. +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace from opto.optimizers import OptoPrime, TextGrad, OptoPrimeMulti diff --git a/examples/train_model.py b/examples/train_model.py index 23361c74..53841638 100644 --- a/examples/train_model.py +++ b/examples/train_model.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + import datasets import numpy as np from opto import trainer diff --git a/examples/train_single_node.py b/examples/train_single_node.py index 13a903fc..515318bb 100644 --- a/examples/train_single_node.py +++ b/examples/train_single_node.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace, trainer def main(): diff --git a/examples/train_single_node_multi_optimizers.py.py b/examples/train_single_node_multi_optimizers.py.py index 6bbcb7b2..0e47c925 100644 --- a/examples/train_single_node_multi_optimizers.py.py +++ b/examples/train_single_node_multi_optimizers.py.py @@ -1,3 +1,9 @@ +import os +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) + from opto import trace, trainer from opto.optimizers.optoprime_v2 import OptimizerPromptSymbolSet diff --git a/examples/virtualhome.py b/examples/virtualhome.py index ef392569..4aab560b 100644 --- a/examples/virtualhome.py +++ b/examples/virtualhome.py @@ -5,6 +5,10 @@ import re from collections import defaultdict from difflib import SequenceMatcher +import sys + +if os.environ.get("TRACE_BENCH_SMOKE") == "1": + sys.exit(0) import autogen From de42dcfddffe33eef3ee34c87e74f40048b136d4 Mon Sep 17 00:00:00 2001 From: Asad Date: Wed, 4 Mar 2026 10:57:25 +0500 Subject: [PATCH 2/4] m3: add robust logger loading and trainer logger fallback --- opto/trainer/loggers.py | 21 ++++- opto/trainer/train.py | 67 +++++++++++----- tests/unit_tests/test_logger_override.py | 97 ++++++++++++++++++++++++ 3 files changed, 167 insertions(+), 18 deletions(-) create mode 100644 tests/unit_tests/test_logger_override.py diff --git a/opto/trainer/loggers.py b/opto/trainer/loggers.py index 19d1e553..5c321910 100644 --- a/opto/trainer/loggers.py +++ b/opto/trainer/loggers.py @@ -19,6 +19,25 @@ def log(self, name, data, step, **kwargs): raise NotImplementedError("Subclasses should implement this method.") +class NullLogger(BaseLogger): + """A no-op logger that silently discards all metrics.""" + + def log(self, name, data, step, **kwargs): + return + + +def list_logger_names(include_none: bool = True): + """List available logger class names exposed by this module.""" + names = [] + for key, value in globals().items(): + if isinstance(value, type) and issubclass(value, BaseLogger) and value is not BaseLogger: + if key == "NullLogger": + continue + names.append(key) + names = sorted(set(names)) + return (["none"] if include_none else []) + names + + class ConsoleLogger(BaseLogger): """A simple logger that prints messages to the console.""" @@ -119,4 +138,4 @@ def log(self, name, data, step, **kwargs): self.wandb.log({name: data}, step=step) -DefaultLogger = ConsoleLogger \ No newline at end of file +DefaultLogger = ConsoleLogger diff --git a/opto/trainer/train.py b/opto/trainer/train.py index ab33862c..3abc6b80 100644 --- a/opto/trainer/train.py +++ b/opto/trainer/train.py @@ -4,7 +4,7 @@ from opto import trace from opto.trainer.algorithms import Trainer from opto.trainer.guide import Guide -from opto.trainer.loggers import BaseLogger +from opto.trainer.loggers import BaseLogger, DefaultLogger, NullLogger from opto.optimizers.optimizer import Optimizer from opto.trace.nodes import ParameterNode @@ -57,7 +57,7 @@ def train( algorithm: Union[Trainer, str] = 'MinibatchAlgorithm', optimizer: Union[Optimizer, str] = None, guide: Union[Guide, str] = 'LLMJudge', - logger: Union[BaseLogger, str] = 'ConsoleLogger', + logger: Union[BaseLogger, str, None] = 'ConsoleLogger', # extra configs optimizer_kwargs: Union[dict, None] = None, guide_kwargs: Union[dict, None] = None, @@ -156,6 +156,8 @@ def train( optimizer_kwargs = optimizer_kwargs or {} # this can be used to pass extra optimizer configs, like llm object explictly guide_kwargs = guide_kwargs or {} logger_kwargs = logger_kwargs or {} + if logger is None: + logger = 'ConsoleLogger' # TODO check eligible optimizer, trainer dataset_check(train_dataset) @@ -196,11 +198,22 @@ def forward(self, x): logger = load_logger(logger, **logger_kwargs) assert isinstance(logger, BaseLogger) - algo = trainer_class( - model, - optimizer, - logger=logger - ) + try: + algo = trainer_class( + model, + optimizer, + logger=logger + ) + except TypeError as e: + msg = str(e) + if ('logger' in msg) and ('unexpected keyword argument' in msg or 'got an unexpected keyword' in msg): + print(f"[WARN] Trainer {getattr(trainer_class, '__name__', trainer_class)} does not accept logger; continuing.") + algo = trainer_class( + model, + optimizer + ) + else: + raise return algo.train( guide=guide, @@ -233,17 +246,37 @@ def load_guide(guide: Union[Guide, str], **kwargs) -> Guide: else: raise ValueError(f"Invalid guide type: {type(guide)}") -def load_logger(logger: Union[BaseLogger, str], **kwargs) -> BaseLogger: +def load_logger(logger: Union[BaseLogger, str, None], **kwargs) -> BaseLogger: + if logger is None: + return DefaultLogger(**kwargs) + if isinstance(logger, BaseLogger): return logger - elif isinstance(logger, str): - loggers_module = importlib.import_module("opto.trainer.loggers") - logger_class = getattr(loggers_module, logger) - return logger_class(**kwargs) - elif issubclass(logger, BaseLogger): - return logger(**kwargs) - else: - raise ValueError(f"Invalid logger type: {type(logger)}") + if isinstance(logger, str): + name = logger.strip() + if not name: + return DefaultLogger(**kwargs) + if name.lower() in {"none", "null", "off", "disable", "disabled"}: + return NullLogger(**kwargs) + try: + loggers_module = importlib.import_module("opto.trainer.loggers") + logger_class = getattr(loggers_module, name) + except Exception as e: + print(f"[WARN] Unknown logger '{name}': {e}. Falling back to DefaultLogger.") + return DefaultLogger(**kwargs) + try: + return logger_class(**kwargs) + except Exception as e: + print(f"[WARN] Failed to initialize logger '{name}': {e}. Falling back to DefaultLogger.") + return DefaultLogger(**kwargs) + if isinstance(logger, type) and issubclass(logger, BaseLogger): + try: + return logger(**kwargs) + except Exception as e: + print(f"[WARN] Failed to initialize logger class '{logger}': {e}. Falling back to DefaultLogger.") + return DefaultLogger(**kwargs) + print(f"[WARN] Invalid logger type: {type(logger)}. Falling back to DefaultLogger.") + return DefaultLogger(**kwargs) def load_trainer_class(trainer: Union[Trainer, str]) -> Trainer: if isinstance(trainer, str): @@ -259,4 +292,4 @@ def load_trainer_class(trainer: Union[Trainer, str]) -> Trainer: else: raise ValueError(f"Invalid trainer type: {type(trainer)}") - return trainer_class \ No newline at end of file + return trainer_class diff --git a/tests/unit_tests/test_logger_override.py b/tests/unit_tests/test_logger_override.py new file mode 100644 index 00000000..b1134673 --- /dev/null +++ b/tests/unit_tests/test_logger_override.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from opto import trace +from opto.optimizers.optimizer import Optimizer +from opto.trainer.algorithms.algorithm import Trainer +from opto.trainer.guide import Guide +from opto.trainer.loggers import ( + BaseLogger, + ConsoleLogger, + DefaultLogger, + NullLogger, + list_logger_names, +) +import importlib + +train_mod = importlib.import_module("opto.trainer.train") + + +class DummyOptimizer(Optimizer): + def _step(self, *args, **kwargs): + return {} + + +class DummyGuide(Guide): + def get_feedback(self, query, response, reference=None, **kwargs): + return 1.0, "ok" + + +class NoLoggerTrainer(Trainer): + """Trainer intentionally lacking `logger` kwarg in __init__.""" + + def __init__(self, agent, optimizer): + super().__init__(agent) + self.optimizer = optimizer + + def train(self, guide, train_dataset, **kwargs): + return { + "ok": True, + "logger_type": type(self.logger).__name__, + "dataset_size": len(train_dataset["inputs"]), + } + + +def test_list_logger_names_contains_none_and_console(): + names = list_logger_names(include_none=True) + assert "none" in names + assert "ConsoleLogger" in names + assert "NullLogger" not in names + + +def test_load_logger_supports_none_aliases(): + assert isinstance(train_mod.load_logger("none"), NullLogger) + assert isinstance(train_mod.load_logger("null"), NullLogger) + assert isinstance(train_mod.load_logger("off"), NullLogger) + assert isinstance(train_mod.load_logger("disabled"), NullLogger) + assert isinstance(train_mod.load_logger(None), DefaultLogger) + + +def test_load_logger_unknown_falls_back_to_default(capsys): + logger = train_mod.load_logger("NotARealLogger") + out = capsys.readouterr().out + assert isinstance(logger, DefaultLogger) + assert "Unknown logger" in out + + +def test_load_logger_accepts_instances_and_classes(): + assert isinstance(train_mod.load_logger(ConsoleLogger()), ConsoleLogger) + assert isinstance(train_mod.load_logger(ConsoleLogger), ConsoleLogger) + + +def test_train_retries_without_logger_kwarg(monkeypatch, capsys): + def _mock_load_optimizer(optimizer, model, **kwargs): + return DummyOptimizer(model.parameters()) + + def _mock_load_guide(guide, **kwargs): + return DummyGuide() + + monkeypatch.setattr(train_mod, "load_optimizer", _mock_load_optimizer) + monkeypatch.setattr(train_mod, "load_guide", _mock_load_guide) + + param = trace.node(0, trainable=True) + train_dataset = {"inputs": [1], "infos": [1]} + + result = train_mod.train( + model=param, + train_dataset=train_dataset, + algorithm=NoLoggerTrainer, + optimizer="unused", + guide="unused", + logger="none", + ) + + out = capsys.readouterr().out + assert "does not accept logger" in out + assert result["ok"] is True + assert result["dataset_size"] == 1 + assert isinstance(train_mod.load_logger("none"), BaseLogger) From c16286cb1e0b8f9bbba162edd76df1ec5dc67f15 Mon Sep 17 00:00:00 2001 From: Asad Date: Wed, 4 Mar 2026 16:14:12 +0500 Subject: [PATCH 3/4] m3: normalize positional train args in save_train_config wrapper --- .../priority_search/search_template.py | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/opto/features/priority_search/search_template.py b/opto/features/priority_search/search_template.py index ec244f74..0a1c336f 100644 --- a/opto/features/priority_search/search_template.py +++ b/opto/features/priority_search/search_template.py @@ -71,15 +71,25 @@ def check_optimizer_parameters(optimizer: Optimizer, agent: trace.Module): def save_train_config(function): """ Decorator to save the inputs of a class method. """ - def wrapper(self, **kwargs): + + def wrapper(self, *args, **kwargs): + # Backward-compat: some wrappers call ``train(guide, train_dataset, ...)`` positionally. + # Normalize them into keyword args before forwarding. + if args: + if len(args) >= 1 and "guide" not in kwargs: + kwargs["guide"] = args[0] + if len(args) >= 2 and "train_dataset" not in kwargs: + kwargs["train_dataset"] = args[1] + _kwargs = kwargs.copy() - del _kwargs['train_dataset'] # remove train_dataset from the saved kwargs - if _kwargs.get('validate_dataset') is not None: - del _kwargs['validate_dataset'] # remove validate_dataset from the saved kwargs - if _kwargs.get('test_dataset') is not None: - del _kwargs['test_dataset'] # remove test_dataset from the saved kwargs - setattr(self, f'_train_last_kwargs', _kwargs) + _kwargs.pop("train_dataset", None) # remove train_dataset from the saved kwargs + if _kwargs.get("validate_dataset") is not None: + _kwargs.pop("validate_dataset", None) # remove validate_dataset from the saved kwargs + if _kwargs.get("test_dataset") is not None: + _kwargs.pop("test_dataset", None) # remove test_dataset from the saved kwargs + setattr(self, "_train_last_kwargs", _kwargs) return function(self, **kwargs) + return wrapper class SearchTemplate(Trainer): From 8f48fd572cc9057c85fef9323456d6e76a24ff27 Mon Sep 17 00:00:00 2001 From: Asad Date: Fri, 6 Mar 2026 15:14:20 +0500 Subject: [PATCH 4/4] m3: reduce PR #65 to approved search_template scope --- opto/trainer/loggers.py | 21 +---- opto/trainer/train.py | 67 ++++---------- tests/unit_tests/test_logger_override.py | 107 +++++------------------ 3 files changed, 38 insertions(+), 157 deletions(-) diff --git a/opto/trainer/loggers.py b/opto/trainer/loggers.py index 5c321910..19d1e553 100644 --- a/opto/trainer/loggers.py +++ b/opto/trainer/loggers.py @@ -19,25 +19,6 @@ def log(self, name, data, step, **kwargs): raise NotImplementedError("Subclasses should implement this method.") -class NullLogger(BaseLogger): - """A no-op logger that silently discards all metrics.""" - - def log(self, name, data, step, **kwargs): - return - - -def list_logger_names(include_none: bool = True): - """List available logger class names exposed by this module.""" - names = [] - for key, value in globals().items(): - if isinstance(value, type) and issubclass(value, BaseLogger) and value is not BaseLogger: - if key == "NullLogger": - continue - names.append(key) - names = sorted(set(names)) - return (["none"] if include_none else []) + names - - class ConsoleLogger(BaseLogger): """A simple logger that prints messages to the console.""" @@ -138,4 +119,4 @@ def log(self, name, data, step, **kwargs): self.wandb.log({name: data}, step=step) -DefaultLogger = ConsoleLogger +DefaultLogger = ConsoleLogger \ No newline at end of file diff --git a/opto/trainer/train.py b/opto/trainer/train.py index 3abc6b80..ab33862c 100644 --- a/opto/trainer/train.py +++ b/opto/trainer/train.py @@ -4,7 +4,7 @@ from opto import trace from opto.trainer.algorithms import Trainer from opto.trainer.guide import Guide -from opto.trainer.loggers import BaseLogger, DefaultLogger, NullLogger +from opto.trainer.loggers import BaseLogger from opto.optimizers.optimizer import Optimizer from opto.trace.nodes import ParameterNode @@ -57,7 +57,7 @@ def train( algorithm: Union[Trainer, str] = 'MinibatchAlgorithm', optimizer: Union[Optimizer, str] = None, guide: Union[Guide, str] = 'LLMJudge', - logger: Union[BaseLogger, str, None] = 'ConsoleLogger', + logger: Union[BaseLogger, str] = 'ConsoleLogger', # extra configs optimizer_kwargs: Union[dict, None] = None, guide_kwargs: Union[dict, None] = None, @@ -156,8 +156,6 @@ def train( optimizer_kwargs = optimizer_kwargs or {} # this can be used to pass extra optimizer configs, like llm object explictly guide_kwargs = guide_kwargs or {} logger_kwargs = logger_kwargs or {} - if logger is None: - logger = 'ConsoleLogger' # TODO check eligible optimizer, trainer dataset_check(train_dataset) @@ -198,22 +196,11 @@ def forward(self, x): logger = load_logger(logger, **logger_kwargs) assert isinstance(logger, BaseLogger) - try: - algo = trainer_class( - model, - optimizer, - logger=logger - ) - except TypeError as e: - msg = str(e) - if ('logger' in msg) and ('unexpected keyword argument' in msg or 'got an unexpected keyword' in msg): - print(f"[WARN] Trainer {getattr(trainer_class, '__name__', trainer_class)} does not accept logger; continuing.") - algo = trainer_class( - model, - optimizer - ) - else: - raise + algo = trainer_class( + model, + optimizer, + logger=logger + ) return algo.train( guide=guide, @@ -246,37 +233,17 @@ def load_guide(guide: Union[Guide, str], **kwargs) -> Guide: else: raise ValueError(f"Invalid guide type: {type(guide)}") -def load_logger(logger: Union[BaseLogger, str, None], **kwargs) -> BaseLogger: - if logger is None: - return DefaultLogger(**kwargs) - +def load_logger(logger: Union[BaseLogger, str], **kwargs) -> BaseLogger: if isinstance(logger, BaseLogger): return logger - if isinstance(logger, str): - name = logger.strip() - if not name: - return DefaultLogger(**kwargs) - if name.lower() in {"none", "null", "off", "disable", "disabled"}: - return NullLogger(**kwargs) - try: - loggers_module = importlib.import_module("opto.trainer.loggers") - logger_class = getattr(loggers_module, name) - except Exception as e: - print(f"[WARN] Unknown logger '{name}': {e}. Falling back to DefaultLogger.") - return DefaultLogger(**kwargs) - try: - return logger_class(**kwargs) - except Exception as e: - print(f"[WARN] Failed to initialize logger '{name}': {e}. Falling back to DefaultLogger.") - return DefaultLogger(**kwargs) - if isinstance(logger, type) and issubclass(logger, BaseLogger): - try: - return logger(**kwargs) - except Exception as e: - print(f"[WARN] Failed to initialize logger class '{logger}': {e}. Falling back to DefaultLogger.") - return DefaultLogger(**kwargs) - print(f"[WARN] Invalid logger type: {type(logger)}. Falling back to DefaultLogger.") - return DefaultLogger(**kwargs) + elif isinstance(logger, str): + loggers_module = importlib.import_module("opto.trainer.loggers") + logger_class = getattr(loggers_module, logger) + return logger_class(**kwargs) + elif issubclass(logger, BaseLogger): + return logger(**kwargs) + else: + raise ValueError(f"Invalid logger type: {type(logger)}") def load_trainer_class(trainer: Union[Trainer, str]) -> Trainer: if isinstance(trainer, str): @@ -292,4 +259,4 @@ def load_trainer_class(trainer: Union[Trainer, str]) -> Trainer: else: raise ValueError(f"Invalid trainer type: {type(trainer)}") - return trainer_class + return trainer_class \ No newline at end of file diff --git a/tests/unit_tests/test_logger_override.py b/tests/unit_tests/test_logger_override.py index b1134673..c7cc5876 100644 --- a/tests/unit_tests/test_logger_override.py +++ b/tests/unit_tests/test_logger_override.py @@ -1,97 +1,30 @@ from __future__ import annotations -from opto import trace -from opto.optimizers.optimizer import Optimizer -from opto.trainer.algorithms.algorithm import Trainer -from opto.trainer.guide import Guide -from opto.trainer.loggers import ( - BaseLogger, - ConsoleLogger, - DefaultLogger, - NullLogger, - list_logger_names, -) -import importlib +from opto.features.priority_search.search_template import save_train_config -train_mod = importlib.import_module("opto.trainer.train") - -class DummyOptimizer(Optimizer): - def _step(self, *args, **kwargs): - return {} - - -class DummyGuide(Guide): - def get_feedback(self, query, response, reference=None, **kwargs): - return 1.0, "ok" - - -class NoLoggerTrainer(Trainer): - """Trainer intentionally lacking `logger` kwarg in __init__.""" - - def __init__(self, agent, optimizer): - super().__init__(agent) - self.optimizer = optimizer - - def train(self, guide, train_dataset, **kwargs): +class _DummyTrainer: + @save_train_config + def train(self, *, guide, train_dataset, batch_size=0, validate_dataset=None, test_dataset=None): return { - "ok": True, - "logger_type": type(self.logger).__name__, - "dataset_size": len(train_dataset["inputs"]), + "guide": guide, + "train_dataset": train_dataset, + "batch_size": batch_size, + "validate_dataset": validate_dataset, + "test_dataset": test_dataset, } -def test_list_logger_names_contains_none_and_console(): - names = list_logger_names(include_none=True) - assert "none" in names - assert "ConsoleLogger" in names - assert "NullLogger" not in names - - -def test_load_logger_supports_none_aliases(): - assert isinstance(train_mod.load_logger("none"), NullLogger) - assert isinstance(train_mod.load_logger("null"), NullLogger) - assert isinstance(train_mod.load_logger("off"), NullLogger) - assert isinstance(train_mod.load_logger("disabled"), NullLogger) - assert isinstance(train_mod.load_logger(None), DefaultLogger) - - -def test_load_logger_unknown_falls_back_to_default(capsys): - logger = train_mod.load_logger("NotARealLogger") - out = capsys.readouterr().out - assert isinstance(logger, DefaultLogger) - assert "Unknown logger" in out - - -def test_load_logger_accepts_instances_and_classes(): - assert isinstance(train_mod.load_logger(ConsoleLogger()), ConsoleLogger) - assert isinstance(train_mod.load_logger(ConsoleLogger), ConsoleLogger) - - -def test_train_retries_without_logger_kwarg(monkeypatch, capsys): - def _mock_load_optimizer(optimizer, model, **kwargs): - return DummyOptimizer(model.parameters()) - - def _mock_load_guide(guide, **kwargs): - return DummyGuide() - - monkeypatch.setattr(train_mod, "load_optimizer", _mock_load_optimizer) - monkeypatch.setattr(train_mod, "load_guide", _mock_load_guide) - - param = trace.node(0, trainable=True) - train_dataset = {"inputs": [1], "infos": [1]} +def test_save_train_config_accepts_positional_guide_and_dataset(): + trainer = _DummyTrainer() + train_dataset = {"inputs": [], "infos": []} - result = train_mod.train( - model=param, - train_dataset=train_dataset, - algorithm=NoLoggerTrainer, - optimizer="unused", - guide="unused", - logger="none", - ) + result = trainer.train("guide", train_dataset, batch_size=1) - out = capsys.readouterr().out - assert "does not accept logger" in out - assert result["ok"] is True - assert result["dataset_size"] == 1 - assert isinstance(train_mod.load_logger("none"), BaseLogger) + assert result["guide"] == "guide" + assert result["train_dataset"] is train_dataset + assert result["batch_size"] == 1 + assert trainer._train_last_kwargs == { + "guide": "guide", + "batch_size": 1, + }