Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion examples/async_optimization_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
- Coordinate multiple async trace operations
"""

import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import asyncio
import time
import random
Expand Down Expand Up @@ -367,4 +373,4 @@ async def main():

if __name__ == "__main__":
# Run the async main function
asyncio.run(main())
asyncio.run(main())
5 changes: 5 additions & 0 deletions examples/battleship.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import random
import copy
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)


def create_battleship_board(width, height):
Expand Down
4 changes: 4 additions & 0 deletions examples/bbh/run_prompt_bigbench_dspy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import dspy
from datasets import load_dataset
from dspy.evaluate import Evaluate
Expand Down
6 changes: 6 additions & 0 deletions examples/bbh/run_prompt_bigbench_trace.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

from opto.trace.nodes import node, GRAPH, ParameterNode
from textwrap import dedent
from opto.optimizers import OptoPrime
Expand Down
6 changes: 6 additions & 0 deletions examples/greeting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

from opto import trace
from opto.trace import node, bundle, model, ExecutionError
from opto.optimizers import OptoPrime
Expand Down
6 changes: 6 additions & 0 deletions examples/gsm8k_trainer_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import datasets
import numpy as np
from opto import trainer
Expand Down
6 changes: 6 additions & 0 deletions examples/minibatch_bbh_aynsc/run_bigbench_trace_async.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import re
from opto.trace.nodes import node, GRAPH, ParameterNode
from textwrap import dedent
Expand Down
6 changes: 6 additions & 0 deletions examples/priority_search_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import datasets
import numpy as np
from opto import trace
Expand Down
9 changes: 7 additions & 2 deletions examples/priority_search_on_convex_fn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import re
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import re
import string
import numpy as np
from opto.trace.utils import dedent
Expand Down Expand Up @@ -258,4 +263,4 @@ def get_feedback(self, query: str, response: str, reference=None, **kwargs) -> T
num_proposals=4,
memory_update_frequency=2,
optimizer_kwargs={'objective':"You have a task of guessing two numbers. You should make sure your guess minimizes y.", 'memory_size': 10}
)
)
9 changes: 7 additions & 2 deletions examples/priority_search_on_convex_fn_BENCH.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import re
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import re
import string
import numpy as np
import time
Expand Down Expand Up @@ -215,4 +220,4 @@ def run_algorithm_comparison():


if __name__ == "__main__":
results = run_algorithm_comparison()
results = run_algorithm_comparison()
7 changes: 6 additions & 1 deletion examples/search_algo_example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Standard library imports
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import time
import argparse
from typing import Any, Tuple
Expand Down Expand Up @@ -348,4 +353,4 @@ def main():


if __name__ == "__main__":
main()
main()
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# This script applies Trace to optimize the workflow in TextGrad's prompt_optimization.py.

import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

from opto import trace
from opto.optimizers import OptoPrime, TextGrad
import time
Expand Down Expand Up @@ -217,4 +223,4 @@ def concat(*items):
import os
os.makedirs("textgrad_figures", exist_ok=True)
with open(f"./textgrad_figures/results_{args.task}_{args.test_engine}_{args.algo}_{args.seed}.json", "w") as f:
json.dump(results, f)
json.dump(results, f)
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# This script applies Trace to optimize the workflow in TextGrad's solution_optimization.py.

import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

from opto import trace
from opto.optimizers import OptoPrime, TextGrad, OptoPrimeMulti

Expand Down
6 changes: 6 additions & 0 deletions examples/train_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import datasets
import numpy as np
from opto import trainer
Expand Down
6 changes: 6 additions & 0 deletions examples/train_single_node.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

from opto import trace, trainer

def main():
Expand Down
6 changes: 6 additions & 0 deletions examples/train_single_node_multi_optimizers.py.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

from opto import trace, trainer
from opto.optimizers.optoprime_v2 import OptimizerPromptSymbolSet

Expand Down
4 changes: 4 additions & 0 deletions examples/virtualhome.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import re
from collections import defaultdict
from difflib import SequenceMatcher
import sys

if os.environ.get("TRACE_BENCH_SMOKE") == "1":
sys.exit(0)

import autogen

Expand Down
24 changes: 17 additions & 7 deletions opto/features/priority_search/search_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,25 @@ def check_optimizer_parameters(optimizer: Optimizer, agent: trace.Module):

def save_train_config(function):
""" Decorator to save the inputs of a class method. """
def wrapper(self, **kwargs):

def wrapper(self, *args, **kwargs):
# Backward-compat: some wrappers call ``train(guide, train_dataset, ...)`` positionally.
# Normalize them into keyword args before forwarding.
if args:
if len(args) >= 1 and "guide" not in kwargs:
kwargs["guide"] = args[0]
if len(args) >= 2 and "train_dataset" not in kwargs:
kwargs["train_dataset"] = args[1]

_kwargs = kwargs.copy()
del _kwargs['train_dataset'] # remove train_dataset from the saved kwargs
if _kwargs.get('validate_dataset') is not None:
del _kwargs['validate_dataset'] # remove validate_dataset from the saved kwargs
if _kwargs.get('test_dataset') is not None:
del _kwargs['test_dataset'] # remove test_dataset from the saved kwargs
setattr(self, f'_train_last_kwargs', _kwargs)
_kwargs.pop("train_dataset", None) # remove train_dataset from the saved kwargs
if _kwargs.get("validate_dataset") is not None:
_kwargs.pop("validate_dataset", None) # remove validate_dataset from the saved kwargs
if _kwargs.get("test_dataset") is not None:
_kwargs.pop("test_dataset", None) # remove test_dataset from the saved kwargs
setattr(self, "_train_last_kwargs", _kwargs)
return function(self, **kwargs)

return wrapper

class SearchTemplate(Trainer):
Expand Down
21 changes: 20 additions & 1 deletion opto/trainer/loggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,25 @@ def log(self, name, data, step, **kwargs):
raise NotImplementedError("Subclasses should implement this method.")


class NullLogger(BaseLogger):
"""A no-op logger that silently discards all metrics."""

def log(self, name, data, step, **kwargs):
return


def list_logger_names(include_none: bool = True):
"""List available logger class names exposed by this module."""
names = []
for key, value in globals().items():
if isinstance(value, type) and issubclass(value, BaseLogger) and value is not BaseLogger:
if key == "NullLogger":
continue
names.append(key)
names = sorted(set(names))
return (["none"] if include_none else []) + names


class ConsoleLogger(BaseLogger):
"""A simple logger that prints messages to the console."""

Expand Down Expand Up @@ -119,4 +138,4 @@ def log(self, name, data, step, **kwargs):
self.wandb.log({name: data}, step=step)


DefaultLogger = ConsoleLogger
DefaultLogger = ConsoleLogger
67 changes: 50 additions & 17 deletions opto/trainer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from opto import trace
from opto.trainer.algorithms import Trainer
from opto.trainer.guide import Guide
from opto.trainer.loggers import BaseLogger
from opto.trainer.loggers import BaseLogger, DefaultLogger, NullLogger
from opto.optimizers.optimizer import Optimizer
from opto.trace.nodes import ParameterNode

Expand Down Expand Up @@ -57,7 +57,7 @@ def train(
algorithm: Union[Trainer, str] = 'MinibatchAlgorithm',
optimizer: Union[Optimizer, str] = None,
guide: Union[Guide, str] = 'LLMJudge',
logger: Union[BaseLogger, str] = 'ConsoleLogger',
logger: Union[BaseLogger, str, None] = 'ConsoleLogger',
# extra configs
optimizer_kwargs: Union[dict, None] = None,
guide_kwargs: Union[dict, None] = None,
Expand Down Expand Up @@ -156,6 +156,8 @@ def train(
optimizer_kwargs = optimizer_kwargs or {} # this can be used to pass extra optimizer configs, like llm object explictly
guide_kwargs = guide_kwargs or {}
logger_kwargs = logger_kwargs or {}
if logger is None:
logger = 'ConsoleLogger'

# TODO check eligible optimizer, trainer
dataset_check(train_dataset)
Expand Down Expand Up @@ -196,11 +198,22 @@ def forward(self, x):
logger = load_logger(logger, **logger_kwargs)
assert isinstance(logger, BaseLogger)

algo = trainer_class(
model,
optimizer,
logger=logger
)
try:
algo = trainer_class(
model,
optimizer,
logger=logger
)
except TypeError as e:
msg = str(e)
if ('logger' in msg) and ('unexpected keyword argument' in msg or 'got an unexpected keyword' in msg):
print(f"[WARN] Trainer {getattr(trainer_class, '__name__', trainer_class)} does not accept logger; continuing.")
algo = trainer_class(
model,
optimizer
)
else:
raise

return algo.train(
guide=guide,
Expand Down Expand Up @@ -233,17 +246,37 @@ def load_guide(guide: Union[Guide, str], **kwargs) -> Guide:
else:
raise ValueError(f"Invalid guide type: {type(guide)}")

def load_logger(logger: Union[BaseLogger, str], **kwargs) -> BaseLogger:
def load_logger(logger: Union[BaseLogger, str, None], **kwargs) -> BaseLogger:
if logger is None:
return DefaultLogger(**kwargs)

if isinstance(logger, BaseLogger):
return logger
elif isinstance(logger, str):
loggers_module = importlib.import_module("opto.trainer.loggers")
logger_class = getattr(loggers_module, logger)
return logger_class(**kwargs)
elif issubclass(logger, BaseLogger):
return logger(**kwargs)
else:
raise ValueError(f"Invalid logger type: {type(logger)}")
if isinstance(logger, str):
name = logger.strip()
if not name:
return DefaultLogger(**kwargs)
if name.lower() in {"none", "null", "off", "disable", "disabled"}:
return NullLogger(**kwargs)
try:
loggers_module = importlib.import_module("opto.trainer.loggers")
logger_class = getattr(loggers_module, name)
except Exception as e:
print(f"[WARN] Unknown logger '{name}': {e}. Falling back to DefaultLogger.")
return DefaultLogger(**kwargs)
try:
return logger_class(**kwargs)
except Exception as e:
print(f"[WARN] Failed to initialize logger '{name}': {e}. Falling back to DefaultLogger.")
return DefaultLogger(**kwargs)
if isinstance(logger, type) and issubclass(logger, BaseLogger):
try:
return logger(**kwargs)
except Exception as e:
print(f"[WARN] Failed to initialize logger class '{logger}': {e}. Falling back to DefaultLogger.")
return DefaultLogger(**kwargs)
print(f"[WARN] Invalid logger type: {type(logger)}. Falling back to DefaultLogger.")
return DefaultLogger(**kwargs)

def load_trainer_class(trainer: Union[Trainer, str]) -> Trainer:
if isinstance(trainer, str):
Expand All @@ -259,4 +292,4 @@ def load_trainer_class(trainer: Union[Trainer, str]) -> Trainer:
else:
raise ValueError(f"Invalid trainer type: {type(trainer)}")

return trainer_class
return trainer_class
Loading