diff --git a/optillm/__init__.py b/optillm/__init__.py index c17bd9b..6aa0cea 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -1,5 +1,5 @@ # Version information -__version__ = "0.3.8" +__version__ = "0.3.9" # Import from server module from .server import ( diff --git a/optillm/bon.py b/optillm/bon.py index c23b643..e22ee18 100644 --- a/optillm/bon.py +++ b/optillm/bon.py @@ -4,9 +4,14 @@ logger = logging.getLogger(__name__) -def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_id: str = None) -> str: +def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_config: dict = None, request_id: str = None) -> str: bon_completion_tokens = 0 + # Extract max_tokens from request_config with default + max_tokens = 4096 + if request_config: + max_tokens = request_config.get('max_tokens', max_tokens) + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": initial_query}] @@ -17,7 +22,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st provider_request = { "model": model, "messages": messages, - "max_tokens": 4096, + "max_tokens": max_tokens, "n": n, "temperature": 1 } @@ -50,7 +55,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st provider_request = { "model": model, "messages": messages, - "max_tokens": 4096, + "max_tokens": max_tokens, "temperature": 1 } response = client.chat.completions.create(**provider_request) diff --git a/optillm/leap.py b/optillm/leap.py index 5f212d2..095d8f8 100644 --- a/optillm/leap.py +++ b/optillm/leap.py @@ -10,7 +10,7 @@ logger = logging.getLogger(__name__) class LEAP: - def __init__(self, system_prompt: str, client, model: str, request_id: str = None): + def __init__(self, system_prompt: str, client, model: str, request_config: dict = None, request_id: str = None): self.system_prompt = system_prompt self.client = client self.model = model @@ -19,6 +19,11 @@ def __init__(self, system_prompt: str, client, model: str, request_id: str = Non self.high_level_principles = [] self.leap_completion_tokens = 0 + # Extract max_tokens from request_config with default + self.max_tokens = 4096 + if request_config: + self.max_tokens = request_config.get('max_tokens', self.max_tokens) + def extract_output(self, text: str) -> str: match = re.search(r'(.*?)(?:|$)', text, re.DOTALL) return match.group(1).strip() if match else "" @@ -29,7 +34,7 @@ def extract_examples_from_query(self, initial_query: str) -> List[Tuple[str, str # Prepare request for logging provider_request = { "model": self.model, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": f""" @@ -83,7 +88,7 @@ def generate_mistakes(self, examples: List[Tuple[str, str]]) -> List[Tuple[str, # Prepare request for logging provider_request = { "model": self.model, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": f""" @@ -116,7 +121,7 @@ def generate_low_level_principles(self, mistakes: List[Tuple[str, str, str, str] # Prepare request for logging provider_request = { "model": self.model, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": f""" @@ -152,7 +157,7 @@ def generate_high_level_principles(self) -> List[str]: # Prepare request for logging provider_request = { "model": self.model, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": f""" @@ -185,7 +190,7 @@ def apply_principles(self, query: str) -> str: # Prepare request for logging provider_request = { "model": self.model, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": f""" @@ -220,6 +225,6 @@ def solve(self, initial_query: str) -> str: return self.apply_principles(initial_query) -def leap(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str: - leap_solver = LEAP(system_prompt, client, model, request_id) +def leap(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str: + leap_solver = LEAP(system_prompt, client, model, request_config=request_config, request_id=request_id) return leap_solver.solve(initial_query), leap_solver.leap_completion_tokens \ No newline at end of file diff --git a/optillm/mcts.py b/optillm/mcts.py index 8517749..2baf35b 100644 --- a/optillm/mcts.py +++ b/optillm/mcts.py @@ -26,7 +26,7 @@ def __init__(self, state: DialogueState, parent=None): self.value = 0 class MCTS: - def __init__(self, simulation_depth, exploration_weight, client, model, request_id=None): + def __init__(self, simulation_depth, exploration_weight, client, model, request_config=None, request_id=None): self.simulation_depth = simulation_depth self.exploration_weight = exploration_weight self.root = None @@ -37,6 +37,11 @@ def __init__(self, simulation_depth, exploration_weight, client, model, request_ self.completion_tokens = 0 self.request_id = request_id + # Extract max_tokens from request_config with default + self.max_tokens = 4096 + if request_config: + self.max_tokens = request_config.get('max_tokens', self.max_tokens) + def select(self, node: MCTSNode) -> MCTSNode: logger.debug(f"Selecting node. Current node visits: {node.visits}, value: {node.value}") if not node.children: @@ -117,7 +122,7 @@ def generate_actions(self, state: DialogueState) -> List[str]: provider_request = { "model": self.model, "messages": messages, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "n": n, "temperature": 1 } @@ -151,7 +156,7 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState: provider_request = { "model": self.model, "messages": messages, - "max_tokens": 1024, + "max_tokens": min(self.max_tokens, 1024), "n": 1, "temperature": 1 } @@ -220,11 +225,11 @@ def evaluate_state(self, state: DialogueState) -> float: logger.warning("Failed to parse evaluation score. Using default value 0.5") return 0.5 # Default to a neutral score if parsing fails -def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, num_simulations: int = 2, exploration_weight: float = 0.2, - simulation_depth: int = 1, request_id: str = None) -> str: +def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, num_simulations: int = 2, exploration_weight: float = 0.2, + simulation_depth: int = 1, request_config: dict = None, request_id: str = None) -> str: logger.info("Starting chat with MCTS") logger.info(f"Parameters: num_simulations={num_simulations}, exploration_weight={exploration_weight}, simulation_depth={simulation_depth}") - mcts = MCTS(simulation_depth=simulation_depth, exploration_weight=exploration_weight, client=client, model=model, request_id=request_id) + mcts = MCTS(simulation_depth=simulation_depth, exploration_weight=exploration_weight, client=client, model=model, request_config=request_config, request_id=request_id) initial_state = DialogueState(system_prompt, [], initial_query) logger.info(f"Initial query: {initial_query}") final_state = mcts.search(initial_state, num_simulations) diff --git a/optillm/moa.py b/optillm/moa.py index 86371c1..6f5f9ad 100644 --- a/optillm/moa.py +++ b/optillm/moa.py @@ -4,9 +4,15 @@ logger = logging.getLogger(__name__) -def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str: +def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str: logger.info(f"Starting mixture_of_agents function with model: {model}") moa_completion_tokens = 0 + + # Extract max_tokens from request_config with default + max_tokens = 4096 + if request_config: + max_tokens = request_config.get('max_tokens', max_tokens) + completions = [] logger.debug(f"Generating initial completions for query: {initial_query}") @@ -19,7 +25,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str {"role": "system", "content": system_prompt}, {"role": "user", "content": initial_query} ], - "max_tokens": 4096, + "max_tokens": max_tokens, "n": 3, "temperature": 1 } @@ -59,7 +65,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str {"role": "system", "content": system_prompt}, {"role": "user", "content": initial_query} ], - "max_tokens": 4096, + "max_tokens": max_tokens, "temperature": 1 } @@ -182,14 +188,14 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str """ logger.debug("Generating final response") - + provider_request = { "model": model, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": final_prompt} ], - "max_tokens": 8192, + "max_tokens": max_tokens, "n": 1, "temperature": 0.1 } diff --git a/optillm/plansearch.py b/optillm/plansearch.py index f91c9a8..85e91a7 100644 --- a/optillm/plansearch.py +++ b/optillm/plansearch.py @@ -6,13 +6,18 @@ logger = logging.getLogger(__name__) class PlanSearch: - def __init__(self, system_prompt: str, client, model: str, request_id: str = None): + def __init__(self, system_prompt: str, client, model: str, request_config: dict = None, request_id: str = None): self.system_prompt = system_prompt self.client = client self.model = model self.request_id = request_id self.plansearch_completion_tokens = 0 + # Extract max_tokens from request_config with default + self.max_tokens = 4096 + if request_config: + self.max_tokens = request_config.get('max_tokens', self.max_tokens) + def generate_observations(self, problem: str, num_observations: int = 3) -> List[str]: prompt = f"""You are an expert Python programmer. You will be given a competitive programming question (problem specification). You will return several useful, non-obvious, and correct observations @@ -27,7 +32,7 @@ def generate_observations(self, problem: str, num_observations: int = 3) -> List # Prepare request for logging provider_request = { "model": self.model, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt} @@ -71,7 +76,7 @@ def generate_derived_observations(self, problem: str, observations: List[str], n # Prepare request for logging provider_request = { "model": self.model, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt} @@ -113,7 +118,7 @@ def generate_solution(self, problem: str, observations: List[str]) -> str: # Prepare request for logging provider_request = { "model": self.model, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt} @@ -155,7 +160,7 @@ def implement_solution(self, problem: str, solution: str) -> str: # Prepare request for logging provider_request = { "model": self.model, - "max_tokens": 4096, + "max_tokens": self.max_tokens, "messages": [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": prompt} @@ -204,6 +209,6 @@ def solve_multiple(self, problem: str, n: int, num_initial_observations: int = 3 solutions.append(python_implementation) return solutions -def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1, request_id: str = None) -> List[str]: - planner = PlanSearch(system_prompt, client, model, request_id) +def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1, request_config: dict = None, request_id: str = None) -> List[str]: + planner = PlanSearch(system_prompt, client, model, request_config=request_config, request_id=request_id) return planner.solve_multiple(initial_query, n), planner.plansearch_completion_tokens diff --git a/optillm/pvg.py b/optillm/pvg.py index 44c2b27..bb327d3 100644 --- a/optillm/pvg.py +++ b/optillm/pvg.py @@ -8,7 +8,7 @@ pvg_completion_tokens = 0 -def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7, request_id: str = None) -> List[str]: +def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7, max_tokens: int = 4096, request_id: str = None) -> List[str]: global pvg_completion_tokens role = "sneaky" if is_sneaky else "helpful" logger.info(f"Generating {num_solutions} {role} solutions") @@ -36,7 +36,7 @@ def generate_solutions(client, system_prompt: str, query: str, model: str, num_s "model": model, "messages": messages, "n": num_solutions, - "max_tokens": 4096, + "max_tokens": max_tokens, "temperature": temperature, } response = client.chat.completions.create(**provider_request) @@ -151,10 +151,15 @@ def extract_answer(final_state: str) -> Tuple[str, float]: logger.warning("No answer found in the state.") return "", 0.0 -def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3, request_id: str = None) -> str: +def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3, request_config: dict = None, request_id: str = None) -> str: global pvg_completion_tokens logger.info(f"Starting inference-time PV game with {num_rounds} rounds and {num_solutions} solutions per round") - + + # Extract max_tokens from request_config with default + max_tokens = 4096 + if request_config: + max_tokens = request_config.get('max_tokens', max_tokens) + best_solution = "" best_score = -1 @@ -163,8 +168,8 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model temperature = max(0.2, 0.7 - (round * 0.1)) - helpful_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, temperature=temperature, request_id=request_id) - sneaky_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, is_sneaky=True, temperature=temperature, request_id=request_id) + helpful_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, temperature=temperature, max_tokens=max_tokens, request_id=request_id) + sneaky_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, is_sneaky=True, temperature=temperature, max_tokens=max_tokens, request_id=request_id) all_solutions = helpful_solutions + sneaky_solutions scores = verify_solutions(client, system_prompt, initial_query, all_solutions, model, request_id=request_id) @@ -198,7 +203,7 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model provider_request = { "model": model, "messages": messages, - "max_tokens": 1024, + "max_tokens": min(max_tokens, 1024), "temperature": 0.5, } response = client.chat.completions.create(**provider_request) diff --git a/optillm/reread.py b/optillm/reread.py index 32706a3..1a03a01 100644 --- a/optillm/reread.py +++ b/optillm/reread.py @@ -4,22 +4,28 @@ logger = logging.getLogger(__name__) -def re2_approach(system_prompt, initial_query, client, model, n=1, request_id: str = None): +def re2_approach(system_prompt, initial_query, client, model, n=1, request_config: dict = None, request_id: str = None): """ Implement the RE2 (Re-Reading) approach for improved reasoning in LLMs. - + Args: system_prompt (str): The system prompt to be used. initial_query (str): The initial user query. client: The OpenAI client object. model (str): The name of the model to use. n (int): Number of completions to generate. - + request_config (dict): Optional configuration including max_tokens. + Returns: str or list: The generated response(s) from the model. """ logger.info("Using RE2 approach for query processing") re2_completion_tokens = 0 + + # Extract max_tokens from request_config if provided + max_tokens = None + if request_config: + max_tokens = request_config.get('max_tokens') # Construct the RE2 prompt re2_prompt = f"{initial_query}\nRead the question again: {initial_query}" @@ -35,6 +41,8 @@ def re2_approach(system_prompt, initial_query, client, model, n=1, request_id: s "messages": messages, "n": n } + if max_tokens is not None: + provider_request["max_tokens"] = max_tokens response = client.chat.completions.create(**provider_request) # Log provider call diff --git a/optillm/rstar.py b/optillm/rstar.py index 5cbeda6..520641b 100644 --- a/optillm/rstar.py +++ b/optillm/rstar.py @@ -23,17 +23,23 @@ def __init__(self, state: str, action: str, parent: 'Node' = None): self.value = 0.0 class RStar: - def __init__(self, system: str, client, model: str, max_depth: int = 3, num_rollouts: int = 5, c: float = 1.4, request_id: str = None): + def __init__(self, system: str, client, model: str, max_depth: int = 3, num_rollouts: int = 5, c: float = 1.4, request_config: dict = None, request_id: str = None): self.client = client self.model_name = model self.max_depth = max_depth self.num_rollouts = num_rollouts self.c = c self.actions = ["A1", "A2", "A3", "A4", "A5"] - self.original_question = None + self.original_question = None self.system = system self.rstar_completion_tokens = 0 self.request_id = request_id + + # Extract max_tokens from request_config with default + self.max_tokens = 4096 + if request_config: + self.max_tokens = request_config.get('max_tokens', self.max_tokens) + logger.debug(f"Initialized RStar with model: {model}, max_depth: {max_depth}, num_rollouts: {num_rollouts}") async def generate_response_async(self, prompt: str) -> str: @@ -102,7 +108,7 @@ def generate_response(self, prompt: str) -> str: {"role": "system", "content": "You are a helpful assistant focused on solving mathematical problems. Stick to the given question and avoid introducing new scenarios."}, {"role": "user", "content": prompt} ], - "max_tokens": 4096, + "max_tokens": self.max_tokens, "temperature": 0.2 } response = self.client.chat.completions.create(**provider_request) diff --git a/optillm/rto.py b/optillm/rto.py index 59ca88d..aa517dc 100644 --- a/optillm/rto.py +++ b/optillm/rto.py @@ -15,8 +15,14 @@ def extract_code_from_prompt(text): logger.warning("Could not extract code from prompt. Returning original text.") return text -def round_trip_optimization(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str: +def round_trip_optimization(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str: rto_completion_tokens = 0 + + # Extract max_tokens from request_config with default + max_tokens = 4096 + if request_config: + max_tokens = request_config.get('max_tokens', max_tokens) + messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": initial_query}] @@ -24,7 +30,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode provider_request = { "model": model, "messages": messages, - "max_tokens": 4096, + "max_tokens": max_tokens, "n": 1, "temperature": 0.1 } @@ -64,7 +70,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode provider_request = { "model": model, "messages": messages, - "max_tokens": 4096, + "max_tokens": max_tokens, "n": 1, "temperature": 0.1 } @@ -89,7 +95,7 @@ def round_trip_optimization(system_prompt: str, initial_query: str, client, mode provider_request = { "model": model, "messages": messages, - "max_tokens": 4096, + "max_tokens": max_tokens, "n": 1, "temperature": 0.1 } diff --git a/optillm/self_consistency.py b/optillm/self_consistency.py index 441599b..3b63213 100644 --- a/optillm/self_consistency.py +++ b/optillm/self_consistency.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) class AdvancedSelfConsistency: - def __init__(self, client, model: str, num_samples: int = 5, similarity_threshold: float = 0.8, request_id: str = None): + def __init__(self, client, model: str, num_samples: int = 5, similarity_threshold: float = 0.8, request_config: dict = None, request_id: str = None): self.client = client self.model = model self.num_samples = num_samples @@ -15,6 +15,11 @@ def __init__(self, client, model: str, num_samples: int = 5, similarity_thresho self.self_consistency_completion_tokens = 0 self.request_id = request_id + # Extract max_tokens from request_config with default + self.max_tokens = 4096 + if request_config: + self.max_tokens = request_config.get('max_tokens', self.max_tokens) + def generate_responses(self, system_prompt: str, user_prompt: str) -> List[str]: responses = [] for _ in range(self.num_samples): @@ -25,7 +30,7 @@ def generate_responses(self, system_prompt: str, user_prompt: str) -> List[str]: {"role": "user", "content": user_prompt} ], "temperature": 1, - "max_tokens": 4096 + "max_tokens": self.max_tokens } response = self.client.chat.completions.create(**provider_request) @@ -83,8 +88,8 @@ def evaluate(self, system_prompt: str, user_prompt: str) -> Dict[str, any]: "aggregated_result": aggregated_result } -def advanced_self_consistency_approach(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str: - self_consistency = AdvancedSelfConsistency(client, model, request_id=request_id) +def advanced_self_consistency_approach(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str: + self_consistency = AdvancedSelfConsistency(client, model, request_config=request_config, request_id=request_id) result = self_consistency.evaluate(system_prompt, initial_query) logger.info("Advanced Self-Consistency Results:") diff --git a/optillm/server.py b/optillm/server.py index e5975df..08f59f3 100644 --- a/optillm/server.py +++ b/optillm/server.py @@ -414,33 +414,33 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode return response, 0 elif approach == 'mcts': return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'], - server_config['mcts_exploration'], server_config['mcts_depth'], request_id) + server_config['mcts_exploration'], server_config['mcts_depth'], request_config, request_id) elif approach == 'bon': - return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'], request_id) + return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'], request_config, request_id) elif approach == 'moa': - return mixture_of_agents(system_prompt, initial_query, client, model, request_id) + return mixture_of_agents(system_prompt, initial_query, client, model, request_config, request_id) elif approach == 'rto': - return round_trip_optimization(system_prompt, initial_query, client, model, request_id) + return round_trip_optimization(system_prompt, initial_query, client, model, request_config, request_id) elif approach == 'z3': - z3_solver = Z3SymPySolverSystem(system_prompt, client, model, request_id=request_id) + z3_solver = Z3SymPySolverSystem(system_prompt, client, model, request_config=request_config, request_id=request_id) return z3_solver.process_query(initial_query) elif approach == "self_consistency": - return advanced_self_consistency_approach(system_prompt, initial_query, client, model, request_id) + return advanced_self_consistency_approach(system_prompt, initial_query, client, model, request_config, request_id) elif approach == "pvg": - return inference_time_pv_game(system_prompt, initial_query, client, model, request_id) + return inference_time_pv_game(system_prompt, initial_query, client, model, request_config=request_config, request_id=request_id) elif approach == "rstar": rstar = RStar(system_prompt, client, model, max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'], - c=server_config['rstar_c'], request_id=request_id) + c=server_config['rstar_c'], request_config=request_config, request_id=request_id) return rstar.solve(initial_query) elif approach == "cot_reflection": return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'], request_config=request_config, request_id=request_id) elif approach == 'plansearch': - return plansearch(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id) + return plansearch(system_prompt, initial_query, client, model, n=server_config['n'], request_config=request_config, request_id=request_id) elif approach == 'leap': - return leap(system_prompt, initial_query, client, model, request_id) + return leap(system_prompt, initial_query, client, model, request_config, request_id) elif approach == 're2': - return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'], request_id=request_id) + return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'], request_config=request_config, request_id=request_id) elif approach == 'cepo': return cepo(system_prompt, initial_query, client, model, cepo_config, request_id) elif approach == 'mars': diff --git a/optillm/z3_solver.py b/optillm/z3_solver.py index dcc83d1..7e70876 100644 --- a/optillm/z3_solver.py +++ b/optillm/z3_solver.py @@ -133,13 +133,19 @@ def Rational(numerator, denominator=1): return ("success", output_buffer.getvalue()) class Z3SymPySolverSystem: - def __init__(self, system_prompt: str, client, model: str, timeout: int = 30, request_id: str = None): + def __init__(self, system_prompt: str, client, model: str, timeout: int = 30, request_config: dict = None, request_id: str = None): self.system_prompt = system_prompt self.model = model self.client = client self.timeout = timeout self.solver_completion_tokens = 0 self.request_id = request_id + + # Extract max_tokens from request_config with default + self.max_tokens = 4096 + if request_config: + self.max_tokens = request_config.get('max_tokens', self.max_tokens) + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def process_query(self, query: str) -> str: @@ -221,7 +227,7 @@ def generate_response(self, query: str, analysis: str, solver_result: Dict[str, {"role": "system", "content": self.system_prompt}, {"role": "user", "content": response_prompt} ], - "max_tokens": 4096, + "max_tokens": self.max_tokens, "n": 1, "temperature": 0.1 } @@ -242,7 +248,7 @@ def standard_llm_inference(self, query: str) -> str: {"role": "system", "content": self.system_prompt}, {"role": "user", "content": query} ], - "max_tokens": 4096, + "max_tokens": self.max_tokens, "n": 1, "temperature": 0.1 } diff --git a/pyproject.toml b/pyproject.toml index ece589c..f676fce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "optillm" -version = "0.3.8" +version = "0.3.9" description = "An optimizing inference proxy for LLMs." readme = "README.md" license = "Apache-2.0"