From 0016daaddd890b5832efd797ceba5224a8eddcc5 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Tue, 8 Jul 2025 20:27:24 +0800 Subject: [PATCH 1/6] Add majority voting plugin for candidate selection Introduces a plugin that generates multiple candidate solutions using the OpenAI API and selects the most frequent answer via majority voting. Includes answer extraction, normalization, and a summary of the voting process. Useful for tasks with discrete answers such as math, coding, and multiple choice problems. --- optillm/plugins/majority_voting_plugin.py | 271 ++++++++++++++++++++++ 1 file changed, 271 insertions(+) create mode 100644 optillm/plugins/majority_voting_plugin.py diff --git a/optillm/plugins/majority_voting_plugin.py b/optillm/plugins/majority_voting_plugin.py new file mode 100644 index 00000000..45c7ab68 --- /dev/null +++ b/optillm/plugins/majority_voting_plugin.py @@ -0,0 +1,271 @@ +""" +Majority Voting Plugin for OptILLM + +This plugin implements a majority voting approach where k candidate solutions +are generated and the most frequent answer is selected. This is particularly +effective for problems with discrete answers (math, coding, multiple choice). + +The plugin uses the OpenAI API's n parameter to generate multiple responses +efficiently in a single API call. +""" + +import re +import logging +from typing import Tuple, Dict, Any, List, Optional +from collections import Counter +import json + +logger = logging.getLogger(__name__) + +# Plugin identifier +SLUG = "majority_voting" + +# Default number of candidates to generate +DEFAULT_K = 6 + +# Default temperature for candidate generation +DEFAULT_TEMPERATURE = 0.6 + +def extract_answer(text: str) -> Optional[str]: + """ + Extract the answer from a response text. + + This function looks for common answer patterns in the response: + 1. Text after "Answer:" or "Final Answer:" + 2. Text within \\boxed{} (LaTeX format) + 3. Numbers at the end of the response + 4. The last line if it's short (likely the answer) + + Args: + text: The response text to extract answer from + + Returns: + The extracted answer or None if no clear answer found + """ + # Remove any trailing whitespace + text = text.strip() + + # Pattern 1: Look for "Answer:" or "Final Answer:" patterns + answer_patterns = [ + r'(?:final\s+)?answer\s*[:=]\s*(.+?)(?:\n|$)', + r'(?:the\s+)?(?:final\s+)?answer\s+is\s*[:=]?\s*(.+?)(?:\n|$)', + r'(?:therefore|thus|so)\s*,?\s*(.+?)(?:\n|$)' + ] + + for pattern in answer_patterns: + match = re.search(pattern, text, re.IGNORECASE) + if match: + answer = match.group(1).strip() + # Clean up the answer + answer = answer.rstrip('.,;') + if answer: + logger.debug(f"Extracted answer using pattern: {answer}") + return answer + + # Pattern 2: Look for LaTeX boxed format + boxed_match = re.search(r'\\boxed\{([^}]+)\}', text) + if boxed_match: + answer = boxed_match.group(1).strip() + logger.debug(f"Extracted boxed answer: {answer}") + return answer + + # Pattern 3: Look for standalone numbers (useful for math problems) + # Check the last few lines for a number + lines = text.split('\n') + for line in reversed(lines[-3:]): # Check last 3 lines + line = line.strip() + # Match numbers (including decimals, fractions, negative numbers) + number_match = re.match(r'^-?\d+\.?\d*$|^-?\d+/\d+$', line) + if number_match: + logger.debug(f"Extracted number answer: {line}") + return line + + # Pattern 4: If the last line is short (< 50 chars), it might be the answer + if lines: + last_line = lines[-1].strip() + if last_line and len(last_line) < 50 and not last_line.endswith(':'): + logger.debug(f"Using last line as answer: {last_line}") + return last_line + + # Pattern 5: For multiple choice, look for single letter answers + mc_match = re.search(r'\b([A-E])\b(?:\s*\))?$', text) + if mc_match: + answer = mc_match.group(1) + logger.debug(f"Extracted multiple choice answer: {answer}") + return answer + + logger.warning("Could not extract a clear answer from the response") + return None + +def normalize_answer(answer: str) -> str: + """ + Normalize an answer for comparison. + + This helps ensure that equivalent answers are treated as the same: + - Converts to lowercase + - Removes extra whitespace + - Removes quotes + - Normalizes number formats + + Args: + answer: The answer to normalize + + Returns: + The normalized answer + """ + # Convert to lowercase + answer = answer.lower().strip() + + # Remove quotes + answer = answer.strip('"\'') + + # Normalize whitespace + answer = ' '.join(answer.split()) + + # Try to normalize numbers + try: + # Check if it's a float + if '.' in answer: + num = float(answer) + # Format to remove trailing zeros + answer = f"{num:g}" + else: + # Try integer + num = int(answer) + answer = str(num) + except ValueError: + # Not a number, keep as is + pass + + # Handle yes/no variations + if answer in ['yes', 'yeah', 'yep', 'true', 'correct']: + answer = 'yes' + elif answer in ['no', 'nope', 'false', 'incorrect']: + answer = 'no' + + return answer + +def run( + system_prompt: str, + initial_query: str, + client, + model: str, + request_config: Dict[str, Any] = None +) -> Tuple[str, int]: + """ + Main entry point for the majority voting plugin. + + Generates k candidate solutions and returns the most frequent answer. + + Args: + system_prompt: System prompt for the model + initial_query: User's query + client: OpenAI-compatible client instance + model: Model identifier + request_config: Additional configuration parameters + + Returns: + Tuple of (response_text, completion_tokens_used) + """ + logger.info("Starting majority voting process") + + # Extract parameters from request_config + k = DEFAULT_K + temperature = DEFAULT_TEMPERATURE + + if request_config: + k = request_config.get('k', DEFAULT_K) + # Allow overriding temperature if needed + temperature = request_config.get('temperature', DEFAULT_TEMPERATURE) + # Respect max_tokens if provided + max_tokens = request_config.get('max_tokens', 4096) + else: + max_tokens = 4096 + + logger.info(f"Generating {k} candidates with temperature={temperature}") + + # Prepare messages + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ] + + try: + # Generate k candidates in a single API call using n parameter + response = client.chat.completions.create( + model=model, + messages=messages, + n=k, + temperature=temperature, + max_tokens=max_tokens + ) + + # Extract all candidate responses + candidates = [choice.message.content for choice in response.choices] + total_tokens = response.usage.completion_tokens + + logger.info(f"Generated {len(candidates)} candidates. Tokens used: {total_tokens}") + + # Extract answers from each candidate + answers = [] + answer_to_response = {} # Map normalized answers to full responses + + for i, candidate in enumerate(candidates): + answer = extract_answer(candidate) + if answer: + normalized = normalize_answer(answer) + answers.append(normalized) + # Keep the first full response for each unique answer + if normalized not in answer_to_response: + answer_to_response[normalized] = candidate + logger.debug(f"Candidate {i+1} answer: {answer} (normalized: {normalized})") + else: + logger.warning(f"Could not extract answer from candidate {i+1}") + + if not answers: + logger.warning("No answers could be extracted from any candidate") + # Return the first candidate as fallback + return candidates[0] if candidates else "Error: No candidates generated", total_tokens + + # Count answer frequencies + answer_counts = Counter(answers) + logger.info(f"Answer distribution: {dict(answer_counts)}") + + # Get the most common answer + most_common_answer, count = answer_counts.most_common(1)[0] + confidence = count / len(answers) + + logger.info(f"Most common answer: '{most_common_answer}' with {count}/{len(answers)} votes ({confidence:.1%} confidence)") + + # Get the full response corresponding to the most common answer + winning_response = answer_to_response.get(most_common_answer, candidates[0]) + + # Add voting summary to the response + voting_summary = f"\n\n**Majority Voting Result**:\n" + voting_summary += f"- Generated {k} candidates\n" + voting_summary += f"- Most common answer: {most_common_answer}\n" + voting_summary += f"- Votes: {count}/{len(answers)} ({confidence:.1%} confidence)\n" + + if len(answer_counts) > 1: + voting_summary += f"- Other answers: " + other_answers = [f"{ans} ({cnt} votes)" for ans, cnt in answer_counts.items() if ans != most_common_answer] + voting_summary += ", ".join(other_answers) + + # Return the full response from the winning answer with voting summary + final_response = winning_response + voting_summary + + return final_response, total_tokens + + except Exception as e: + logger.error(f"Error in majority voting: {str(e)}") + # Fall back to single response + logger.info("Falling back to single response generation") + + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens + ) + + return response.choices[0].message.content, response.usage.completion_tokens \ No newline at end of file From e4d09250d61390f0766a73943c11fcfea5a242c7 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Tue, 8 Jul 2025 21:29:16 +0800 Subject: [PATCH 2/6] Improve answer extraction and logging in majority voting Enhanced the extract_answer function to better handle LaTeX boxed answers and multiple choice patterns. Moved majority voting summary from being appended to the response to logging via the logger, ensuring cleaner output. --- optillm/plugins/majority_voting_plugin.py | 60 ++++++++++--------- scripts/eval_optillmbench.py | 73 ++++++++++++++++++----- 2 files changed, 91 insertions(+), 42 deletions(-) diff --git a/optillm/plugins/majority_voting_plugin.py b/optillm/plugins/majority_voting_plugin.py index 45c7ab68..3031cd10 100644 --- a/optillm/plugins/majority_voting_plugin.py +++ b/optillm/plugins/majority_voting_plugin.py @@ -45,7 +45,14 @@ def extract_answer(text: str) -> Optional[str]: # Remove any trailing whitespace text = text.strip() - # Pattern 1: Look for "Answer:" or "Final Answer:" patterns + # Pattern 1: Look for LaTeX boxed format first (handle both \boxed and \\boxed) + boxed_match = re.search(r'\\{1,2}boxed\{([^}]+)\}', text) + if boxed_match: + answer = boxed_match.group(1).strip() + logger.debug(f"Extracted boxed answer: {answer}") + return answer + + # Pattern 2: Look for "Answer:" or "Final Answer:" patterns answer_patterns = [ r'(?:final\s+)?answer\s*[:=]\s*(.+?)(?:\n|$)', r'(?:the\s+)?(?:final\s+)?answer\s+is\s*[:=]?\s*(.+?)(?:\n|$)', @@ -62,13 +69,6 @@ def extract_answer(text: str) -> Optional[str]: logger.debug(f"Extracted answer using pattern: {answer}") return answer - # Pattern 2: Look for LaTeX boxed format - boxed_match = re.search(r'\\boxed\{([^}]+)\}', text) - if boxed_match: - answer = boxed_match.group(1).strip() - logger.debug(f"Extracted boxed answer: {answer}") - return answer - # Pattern 3: Look for standalone numbers (useful for math problems) # Check the last few lines for a number lines = text.split('\n') @@ -80,20 +80,29 @@ def extract_answer(text: str) -> Optional[str]: logger.debug(f"Extracted number answer: {line}") return line - # Pattern 4: If the last line is short (< 50 chars), it might be the answer + # Pattern 4: For multiple choice, look for single letter answers + # Check this before the generic last line check + mc_patterns = [ + r'(?:the\s+)?(?:correct\s+)?(?:answer|option)\s+is\s+([A-E])(?:\b|$)', + r'(?:choose|select|pick)\s+(?:option\s+)?([A-E])(?:\b|$)', + r'\b([A-E])\s*\)\s*[A-Za-z]+.*is\s+(?:the\s+)?(?:correct|right)', + r'^([A-E])$', # Just a letter on its own line + ] + + for pattern in mc_patterns: + mc_match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE) + if mc_match: + answer = mc_match.group(1).upper() + logger.debug(f"Extracted multiple choice answer: {answer}") + return answer + + # Pattern 5: If the last line is short (< 50 chars), it might be the answer if lines: last_line = lines[-1].strip() if last_line and len(last_line) < 50 and not last_line.endswith(':'): logger.debug(f"Using last line as answer: {last_line}") return last_line - # Pattern 5: For multiple choice, look for single letter answers - mc_match = re.search(r'\b([A-E])\b(?:\s*\))?$', text) - if mc_match: - answer = mc_match.group(1) - logger.debug(f"Extracted multiple choice answer: {answer}") - return answer - logger.warning("Could not extract a clear answer from the response") return None @@ -240,21 +249,18 @@ def run( # Get the full response corresponding to the most common answer winning_response = answer_to_response.get(most_common_answer, candidates[0]) - # Add voting summary to the response - voting_summary = f"\n\n**Majority Voting Result**:\n" - voting_summary += f"- Generated {k} candidates\n" - voting_summary += f"- Most common answer: {most_common_answer}\n" - voting_summary += f"- Votes: {count}/{len(answers)} ({confidence:.1%} confidence)\n" + # Log voting summary to console instead of adding to response + logger.info("Majority Voting Summary:") + logger.info(f" - Generated {k} candidates") + logger.info(f" - Most common answer: {most_common_answer}") + logger.info(f" - Votes: {count}/{len(answers)} ({confidence:.1%} confidence)") if len(answer_counts) > 1: - voting_summary += f"- Other answers: " other_answers = [f"{ans} ({cnt} votes)" for ans, cnt in answer_counts.items() if ans != most_common_answer] - voting_summary += ", ".join(other_answers) - - # Return the full response from the winning answer with voting summary - final_response = winning_response + voting_summary + logger.info(f" - Other answers: {', '.join(other_answers)}") - return final_response, total_tokens + # Return only the full response from the winning answer + return winning_response, total_tokens except Exception as e: logger.error(f"Error in majority voting: {str(e)}") diff --git a/scripts/eval_optillmbench.py b/scripts/eval_optillmbench.py index ed59d356..0756855b 100644 --- a/scripts/eval_optillmbench.py +++ b/scripts/eval_optillmbench.py @@ -21,19 +21,44 @@ logger = logging.getLogger(__name__) # Define the approaches to test -# Each approach is (name, description) +# Each approach is (name, description, extra_body_params) APPROACHES = [ - ("none", "Baseline without any optimization"), - ("leap", "LEAP Approach"), - ("rto", "Round Trip Optimization"), - ("cot_reflection", "Chain of Thought with Reflection"), - ("self_consistency", "Self Consistency Check"), - ("plansearch", "Planning with Search"), - ("re2", "ReRead Approach"), - ("z3", "Z3 Solver for Mathematical Problems"), - ("coc", "Chain of Code"), - ("executecode" , "Execute Code"), - ("spl", "System Prompt Learning") + ("none", "Baseline without any optimization", {}), + ("leap", "LEAP Approach", {}), + ("rto", "Round Trip Optimization", {}), + ("cot_reflection", "Chain of Thought with Reflection", {}), + ("self_consistency", "Self Consistency Check", {}), + ("plansearch", "Planning with Search", {}), + ("re2", "ReRead Approach", {}), + ("z3", "Z3 Solver for Mathematical Problems", {}), + ("coc", "Chain of Code", {}), + ("executecode" , "Execute Code", {}), + ("spl", "System Prompt Learning", {}) +] + +# Define test-time compute approaches for sequential and parallel scaling +TEST_TIME_COMPUTE_APPROACHES = [ + # Baseline + ("none", "Baseline without any optimization", {}), + + # Sequential test-time compute using thinkdeeper with different thinking budgets + ("thinkdeeper_8k", "ThinkDeeper with 8K thinking tokens", { + "decoding": "thinkdeeper", + "max_thinking_tokens": 8000 + }), + ("thinkdeeper_16k", "ThinkDeeper with 16K thinking tokens", { + "decoding": "thinkdeeper", + "max_thinking_tokens": 16000 + }), + ("thinkdeeper_32k", "ThinkDeeper with 32K thinking tokens", { + "decoding": "thinkdeeper", + "max_thinking_tokens": 32000 + }), + + # Parallel test-time compute using majority voting with different k values + ("majority_voting_6", "Majority Voting with k=6", {"k": 6}), + ("majority_voting_36", "Majority Voting with k=36", {"k": 36}), + ("majority_voting_60", "Majority Voting with k=60", {"k": 60}), ] def load_optillm_bench() -> datasets.Dataset: @@ -265,6 +290,7 @@ def evaluate_model( model: str, dataset: datasets.Dataset, approach: str, + approach_extra_body: Dict[str, Any] = None, max_samples: int = None ) -> Tuple[Dict[str, float], List[Dict[str, Any]]]: """ @@ -286,8 +312,18 @@ def evaluate_model( # Prepare the dataset examples = dataset if max_samples is None else dataset.select(range(max_samples)) - # Create model name with approach - full_model_name = f"{approach}-{model}" if approach != "none" else model + # Create model name with approach - handle special cases + if approach == "none": + full_model_name = model + elif approach.startswith("thinkdeeper_"): + # For thinkdeeper, use base model name (decoding is passed in extra_body) + full_model_name = model + elif approach.startswith("majority_voting_"): + # For majority voting, use majority_voting prefix + full_model_name = f"majority_voting-{model}" + else: + # Standard approach prefix + full_model_name = f"{approach}-{model}" for example in tqdm(examples, desc=f"Evaluating {approach}"): try: @@ -297,6 +333,11 @@ def evaluate_model( # Record start time start_time = time.time() + # Prepare extra_body parameters + extra_body = {"spl_learning": False} + if approach_extra_body: + extra_body.update(approach_extra_body) + # Make API call response = client.chat.completions.create( model=full_model_name, @@ -306,7 +347,7 @@ def evaluate_model( ], temperature=0.2, max_tokens=4096, - extra_body= {"spl_learning": False}, + extra_body=extra_body, ) # Calculate time taken @@ -469,6 +510,8 @@ def main(): help="Directory to save results") parser.add_argument("--approaches", nargs="+", help="Specific approaches to evaluate (default: all)") + parser.add_argument("--test-time-compute", action="store_true", + help="Evaluate test-time compute approaches (sequential and parallel scaling)") parser.add_argument("--debug", action="store_true", help="Enable debug logging") args = parser.parse_args() From d2ed0632130c2aa28bb415cd965291d1cef3d287 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Wed, 9 Jul 2025 07:46:11 +0800 Subject: [PATCH 3/6] Add MLX ThinkDeeper support and update eval configs Introduces a new MLX-compatible ThinkDeeper implementation (`thinkdeeper_mlx.py`) and integrates it into the inference pipeline for MLX models. Updates the inference logic to select the appropriate ThinkDeeper version based on the backend, and refines fallback and error handling for MLX generation. The evaluation script is updated with new test-time compute scaling approaches, including revised ThinkDeeper and majority voting configurations, and improved reporting for test-time compute experiments. --- optillm/inference.py | 262 ++++++++++++++-------------- optillm/thinkdeeper_mlx.py | 327 +++++++++++++++++++++++++++++++++++ scripts/eval_optillmbench.py | 70 +++++--- 3 files changed, 502 insertions(+), 157 deletions(-) create mode 100644 optillm/thinkdeeper_mlx.py diff --git a/optillm/inference.py b/optillm/inference.py index 206b357e..94a003db 100644 --- a/optillm/inference.py +++ b/optillm/inference.py @@ -22,6 +22,7 @@ from optillm.cot_decoding import cot_decode from optillm.entropy_decoding import entropy_decode from optillm.thinkdeeper import thinkdeeper_decode +from optillm.thinkdeeper_mlx import thinkdeeper_decode_mlx from optillm.autothink import autothink_decode # Configure logging @@ -33,6 +34,7 @@ import mlx.core as mx from mlx_lm import load as mlx_load, generate as mlx_generate from mlx_lm.tokenizer_utils import TokenizerWrapper + from mlx_lm.sample_utils import make_sampler MLX_AVAILABLE = True logger.info("MLX framework available") except ImportError: @@ -349,85 +351,46 @@ def generate( return responses, token_counts, logprobs_results def _robust_mlx_generate(self, prompt: str, max_tokens: int, temperature: float, top_p: float, repetition_penalty: float) -> str: - """Robust MLX generation with multiple parameter combinations""" - - # Try different parameter combinations based on MLX-LM version - parameter_combinations = [ - # Version 1: Current style with positional args and temp - { - "style": "positional_temp", - "args": (self.model, self.tokenizer, prompt), - "kwargs": { - "max_tokens": max_tokens, - "temp": temperature, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "verbose": False - } - }, - # Version 2: All keyword arguments with temp - { - "style": "keyword_temp", - "args": (), - "kwargs": { - "model": self.model, - "tokenizer": self.tokenizer, - "prompt": prompt, - "max_tokens": max_tokens, - "temp": temperature, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "verbose": False - } - }, - # Version 3: Using temperature instead of temp - { - "style": "positional_temperature", - "args": (self.model, self.tokenizer, prompt), - "kwargs": { - "max_tokens": max_tokens, - "temperature": temperature, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "verbose": False - } - }, - # Version 4: Minimal parameters only - { - "style": "minimal", - "args": (self.model, self.tokenizer, prompt), - "kwargs": { - "max_tokens": max_tokens, - "temp": temperature, - "verbose": False - } - }, - # Version 5: Just essential parameters - { - "style": "essential", - "args": (self.model, self.tokenizer, prompt), - "kwargs": { - "max_tokens": max_tokens - } - } - ] + """Robust MLX generation using sampler approach""" - last_error = None - - for combo in parameter_combinations: + try: + # Create sampler with generation parameters + sampler = make_sampler( + temp=temperature, + top_p=top_p, + min_p=0.0, # Default min_p + min_tokens_to_keep=1 # Default min_tokens_to_keep + ) + + # Generate using the sampler + response = mlx_generate( + self.model, + self.tokenizer, + prompt, + max_tokens=max_tokens, + sampler=sampler, + verbose=False + ) + + return response + + except Exception as e: + logger.error(f"MLX generation with sampler failed: {str(e)}") + + # Fallback: Try minimal parameters without sampler try: - logger.debug(f"Trying MLX generation with style: {combo['style']}") - response = mlx_generate(*combo["args"], **combo["kwargs"]) - logger.debug(f"Successfully generated with style: {combo['style']}") + logger.debug("Attempting MLX generation without sampler") + response = mlx_generate( + self.model, + self.tokenizer, + prompt, + max_tokens=max_tokens, + verbose=False + ) return response - - except Exception as e: - last_error = e - logger.debug(f"Failed with style {combo['style']}: {str(e)}") - continue - - # If all combinations failed, raise the last error - raise RuntimeError(f"All MLX generation methods failed. Last error: {str(last_error)}") + except Exception as fallback_e: + logger.error(f"MLX fallback generation also failed: {str(fallback_e)}") + raise def format_chat_prompt(self, system_prompt: str, user_prompt: str) -> str: """Format the prompt according to model's chat template""" @@ -1691,37 +1654,47 @@ def create( if decoding: logger.info(f"Using specialized decoding approach: {decoding}") - # Ensure model is in eval mode and on correct device - pipeline.current_model.eval() - device = pipeline.current_model.device + # Check if this decoding approach is supported for MLX + mlx_unsupported_decodings = ["cot_decoding", "entropy_decoding", "autothink"] + if isinstance(pipeline, MLXInferencePipeline) and decoding in mlx_unsupported_decodings: + logger.warning(f"{decoding} is not supported for MLX models. Falling back to standard generation.") + decoding = None + + if decoding: + # For PyTorch pipelines, ensure model is in eval mode and get device + # MLX pipelines handle this differently + if not isinstance(pipeline, MLXInferencePipeline): + pipeline.current_model.eval() + device = pipeline.current_model.device + else: + device = None # MLX doesn't use torch devices if decoding == "cot_decoding": # Use directly available parameters for CoT - cot_params = { - "k": k, - "num_beams": num_beams, - "max_new_tokens": max_tokens if max_tokens is not None else 512, - "temperature": temperature, - "top_p": top_p, - "repetition_penalty": 1.0, - "length_penalty": length_penalty, - "no_repeat_ngram_size": no_repeat_ngram_size, - "early_stopping": early_stopping, - "aggregate_paths": aggregate_paths, - } - - result, confidence = cot_decode( - pipeline.current_model, - pipeline.tokenizer, - messages, - **cot_params - ) - responses = [result] - logprobs_results = [{"confidence_score": confidence} if confidence is not None else None] - completion_tokens = len(pipeline.tokenizer.encode(result)) + cot_params = { + "k": k, + "num_beams": num_beams, + "max_new_tokens": max_tokens if max_tokens is not None else 512, + "temperature": temperature, + "top_p": top_p, + "repetition_penalty": 1.0, + "length_penalty": length_penalty, + "no_repeat_ngram_size": no_repeat_ngram_size, + "early_stopping": early_stopping, + "aggregate_paths": aggregate_paths, + } + + result, confidence = cot_decode( + pipeline.current_model, + pipeline.tokenizer, + messages, + **cot_params + ) + responses = [result] + logprobs_results = [{"confidence_score": confidence} if confidence is not None else None] + completion_tokens = len(pipeline.tokenizer.encode(result)) elif decoding == "entropy_decoding": - # Ensure model is using full precision original_dtype = pipeline.current_model.dtype pipeline.current_model = pipeline.current_model.to(torch.float32) @@ -1778,43 +1751,66 @@ def create( } thinkdeeper_config.update(custom_config) - result = thinkdeeper_decode( - pipeline.current_model, - pipeline.tokenizer, - messages, - thinkdeeper_config + # Check if we're using MLX pipeline + if isinstance(pipeline, MLXInferencePipeline): + logger.info("Using MLX ThinkDeeper implementation") + + # Ensure we have enough tokens for thinking + response + user_max_tokens = max_tokens if max_tokens is not None else 512 + total_tokens_needed = max_thinking_tokens + 512 # thinking + response buffer + adjusted_max_tokens = max(user_max_tokens, total_tokens_needed) + + # Add max_tokens to thinkdeeper config + thinkdeeper_config_with_tokens = thinkdeeper_config.copy() + thinkdeeper_config_with_tokens["max_tokens"] = adjusted_max_tokens + + logger.debug(f"ThinkDeeper tokens: user={user_max_tokens}, thinking={max_thinking_tokens}, adjusted={adjusted_max_tokens}") + + result = thinkdeeper_decode_mlx( + pipeline.model, + pipeline.tokenizer, + messages, + thinkdeeper_config_with_tokens + ) + else: + logger.info("Using PyTorch ThinkDeeper implementation") + result = thinkdeeper_decode( + pipeline.current_model, + pipeline.tokenizer, + messages, + thinkdeeper_config ) responses = [result] logprobs_results = [None] completion_tokens = len(pipeline.tokenizer.encode(result)) elif decoding == "autothink": # Get steering dataset configuration - steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors") - target_layer = kwargs.get("target_layer", 19) - - # Prepare AutoThink configuration - autothink_config = { - "steering_dataset": steering_dataset, - "target_layer": target_layer, - "pattern_strengths": kwargs.get("pattern_strengths", { - "depth_and_thoroughness": 2.5, - "numerical_accuracy": 2.0, - "self_correction": 3.0, - "exploration": 2.0, - "organization": 1.5 - }) - } - - # Process with AutoThink - result = autothink_decode( - pipeline.current_model, - pipeline.tokenizer, - messages, - autothink_config - ) - responses = [result] - logprobs_results = [None] - completion_tokens = len(pipeline.tokenizer.encode(result)) + steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors") + target_layer = kwargs.get("target_layer", 19) + + # Prepare AutoThink configuration + autothink_config = { + "steering_dataset": steering_dataset, + "target_layer": target_layer, + "pattern_strengths": kwargs.get("pattern_strengths", { + "depth_and_thoroughness": 2.5, + "numerical_accuracy": 2.0, + "self_correction": 3.0, + "exploration": 2.0, + "organization": 1.5 + }) + } + + # Process with AutoThink + result = autothink_decode( + pipeline.current_model, + pipeline.tokenizer, + messages, + autothink_config + ) + responses = [result] + logprobs_results = [None] + completion_tokens = len(pipeline.tokenizer.encode(result)) else: raise ValueError(f"Unknown specialized decoding approach: {decoding}") diff --git a/optillm/thinkdeeper_mlx.py b/optillm/thinkdeeper_mlx.py new file mode 100644 index 00000000..043e2876 --- /dev/null +++ b/optillm/thinkdeeper_mlx.py @@ -0,0 +1,327 @@ +""" +MLX-compatible implementation of ThinkDeeper +Provides the same functionality as the PyTorch version but adapted for MLX framework +""" + +import random +from typing import Tuple, Dict, Any, List +import logging + +logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) + +try: + import mlx.core as mx + from mlx_lm import generate as mlx_generate + from mlx_lm.sample_utils import make_sampler + MLX_AVAILABLE = True +except ImportError: + MLX_AVAILABLE = False + +DEFAULT_CONFIG = { + "min_thinking_tokens": 1024, + "max_thinking_tokens": 4196, + "max_thoughts": 64, + "prefill": "", + "start_think_token": "", + "end_think_token": "", + "thought_switch_tokens": [ + "Wait,", + "Alternatively,", + ], +} + +class MLXThinkDeeperProcessor: + def __init__(self, config: Dict[str, Any], tokenizer, model): + self.config = {**DEFAULT_CONFIG, **config} + self.tokenizer = tokenizer + self.model = model + + # Get token IDs for think markers + start_tokens = self.tokenizer.encode(self.config['start_think_token']) + end_tokens = self.tokenizer.encode(self.config['end_think_token']) + self._start_think_token = start_tokens[0] if len(start_tokens) == 1 else start_tokens[1] + self.end_think_token = end_tokens[0] if len(end_tokens) == 1 else end_tokens[1] + + # Store thought switch markers as token sequences + self.thought_switch_sequences = [] + for phrase in self.config["thought_switch_tokens"]: + # Encode without adding special tokens to get exact sequence + token_ids = self.tokenizer.encode(phrase, add_special_tokens=False) + self.thought_switch_sequences.append(token_ids) + + # Track thought switches + self.thought_count = 0 + self.current_sequence = [] # Track recent tokens for sequence matching + self.max_sequence_length = max(len(seq) for seq in self.thought_switch_sequences) if self.thought_switch_sequences else 5 + + # Track total tokens for budget management + self.total_tokens_generated = 0 + self.max_total_tokens = config.get('max_tokens', 8192) # Default to 8192 if not specified + + def is_thought_switch(self, token: int) -> bool: + """Check if adding this token creates a thought switch sequence.""" + # Add new token to current sequence + self.current_sequence.append(token) + + # Keep only the most recent tokens that could match our sequences + if len(self.current_sequence) > self.max_sequence_length: + self.current_sequence = self.current_sequence[-self.max_sequence_length:] + + # Check if current sequence ends with any thought switch sequence + for sequence in self.thought_switch_sequences: + if len(sequence) <= len(self.current_sequence) and \ + self.current_sequence[-len(sequence):] == sequence: + return True + + return False + + def reasoning_effort(self, messages) -> str: + """Generate response with ThinkDeeper's controlled thinking process using MLX""" + + # Prepare the messages with thinking token + thinking_messages = messages.copy() + thinking_messages.append({ + "role": "assistant", + "content": f"{self.config['start_think_token']}\n{self.config['prefill']}" + }) + + # Convert messages to prompt using tokenizer + if hasattr(self.tokenizer, 'apply_chat_template'): + prompt = self.tokenizer.apply_chat_template( + thinking_messages, + continue_final_message=False, # This was causing MLX failures! + tokenize=False, + add_generation_prompt=True # Standard generation prompt + ) + else: + # Fallback: simple concatenation + prompt = "" + for msg in thinking_messages: + prompt += f"{msg['role']}: {msg['content']}\n" + + + # Initialize tracking variables + n_thinking_tokens = 0 + seen_end_think = False + response_chunks = [] + + # Use MLX generation with custom token-by-token control + # Since MLX doesn't support token-by-token generation like PyTorch, + # we'll use a different approach: generate in chunks and check for markers + + current_prompt = prompt + max_chunk_size = 150 # Increase chunk size - MLX may work better with larger chunks + consecutive_empty_chunks = 0 + max_empty_chunks = 3 # Allow up to 3 consecutive empty chunks before stopping + + while (n_thinking_tokens < self.config["max_thinking_tokens"] and + self.thought_count < self.config["max_thoughts"] and + self.total_tokens_generated < self.max_total_tokens - 512): # Reserve 512 tokens for final response + try: + # Generate a small chunk of tokens + chunk_response = self._generate_chunk( + current_prompt, + max_tokens=min(max_chunk_size, self.config["max_thinking_tokens"] - n_thinking_tokens), + temperature=0.6 + ) + + if not chunk_response or chunk_response.strip() == "": + consecutive_empty_chunks += 1 + + if consecutive_empty_chunks >= max_empty_chunks: + break + + # Try with different parameters for next attempt + max_chunk_size = min(max_chunk_size + 50, 300) # Increase chunk size more aggressively + continue + else: + # Reset empty chunk counter on successful generation + consecutive_empty_chunks = 0 + max_chunk_size = 150 # Reset chunk size + + # Update token counts + chunk_tokens = len(self.tokenizer.encode(chunk_response)) + self.total_tokens_generated += chunk_tokens + + # Check for end think token in the chunk + if self.config['end_think_token'] in chunk_response: + # Split at the end think token + parts = chunk_response.split(self.config['end_think_token'], 1) + before_end = parts[0] + after_end = parts[1] if len(parts) > 1 else "" + + response_chunks.append(before_end) + n_thinking_tokens += len(self.tokenizer.encode(before_end)) + + # Check if we've reached minimum thinking tokens + if n_thinking_tokens < self.config["min_thinking_tokens"]: + # Insert thought transition instead of ending + transition = random.choice(self.config["thought_switch_tokens"]) + response_chunks.append(transition) + current_prompt += before_end + transition + n_thinking_tokens += len(self.tokenizer.encode(transition)) + self.thought_count += 1 + continue + else: + # Natural end - add the end token and continue for conclusion + response_chunks.append(self.config['end_think_token']) + current_prompt += before_end + self.config['end_think_token'] + seen_end_think = True + + # Generate conclusion after thinking + if after_end.strip(): + response_chunks.append(after_end) + else: + conclusion = self._generate_chunk(current_prompt, max_tokens=200, temperature=0.3) + if conclusion: + response_chunks.append(conclusion) + break + else: + # No end think token found, add the chunk and continue + response_chunks.append(chunk_response) + current_prompt += chunk_response + n_thinking_tokens += len(self.tokenizer.encode(chunk_response)) + + # Check for thought switch patterns in the chunk + for phrase in self.config["thought_switch_tokens"]: + if phrase in chunk_response: + self.thought_count += 1 + break + + # Safety check to avoid infinite loops + if len(response_chunks) > 100: + logger.warning("Too many chunks generated, stopping to avoid infinite loop") + break + + except Exception as e: + logger.error(f"Error during MLX chunk generation: {str(e)}") + break + + # Enforce minimum thinking tokens if not reached + if not seen_end_think and n_thinking_tokens < self.config["min_thinking_tokens"]: + while n_thinking_tokens < self.config["min_thinking_tokens"] and self.thought_count < self.config["max_thoughts"]: + # Add transition and continue thinking + transition = random.choice(self.config["thought_switch_tokens"]) + response_chunks.append(f" {transition} ") + current_prompt += f" {transition} " + + # Generate more thinking content + additional_thinking = self._generate_chunk( + current_prompt, + max_tokens=min(200, self.config["min_thinking_tokens"] - n_thinking_tokens + 100), + temperature=0.6 + ) + + if additional_thinking and additional_thinking.strip(): + response_chunks.append(additional_thinking) + current_prompt += additional_thinking + additional_tokens = len(self.tokenizer.encode(additional_thinking)) + n_thinking_tokens += additional_tokens + self.thought_count += 1 + else: + # If generation fails, break to avoid infinite loop + break + + # If we haven't seen end think token, force it + if not seen_end_think: + response_chunks.append(self.config['end_think_token']) + + # Add a brief conclusion + try: + conclusion = self._generate_chunk( + current_prompt + self.config['end_think_token'], + max_tokens=100, + temperature=0.3 + ) + if conclusion: + response_chunks.append(conclusion) + except Exception as e: + logger.error(f"Error generating conclusion: {str(e)}") + + # Join all chunks and create final response + response_content = "".join(response_chunks) + full_response = f"{self.config['start_think_token']}\n{self.config['prefill']}{response_content}" + + return full_response + + def _generate_chunk(self, prompt: str, max_tokens: int, temperature: float) -> str: + """Generate a small chunk of text using MLX with proper sampler""" + try: + # Let MLX fail naturally to identify the real issue + + # Create sampler with specified thinkdeeper parameters + sampler = make_sampler( + temp=temperature, + top_p=0.95, + top_k=20, + min_p=0.0, + min_tokens_to_keep=3 + ) + + # Use mlx_generate with the sampler + # Ensure we have minimum tokens to generate - larger minimum for better MLX performance + actual_max_tokens = max(max_tokens, 30) # At least 30 tokens for better generation + + response = mlx_generate( + self.model, + self.tokenizer, + prompt, + max_tokens=actual_max_tokens, + sampler=sampler, + verbose=False + ) + + # MLX generate might return just the generated tokens or the full text + # Check if response starts with the prompt + if response: + if response.startswith(prompt): + # Response includes the prompt, extract new content + new_content = response[len(prompt):] + else: + # Response is just the generated tokens + new_content = response + + if new_content.strip(): # Only return non-empty content + return new_content + + return "" + + except Exception as e: + logger.error(f"Error in MLX chunk generation: {str(e)}") + return "" + +def thinkdeeper_decode_mlx( + model, + tokenizer, + messages: List[Dict[str, str]], + request_config: Dict[str, Any] = None +) -> str: + """MLX-compatible ThinkDeeper processing function""" + logger.info("Starting MLX ThinkDeeper processing") + + if not MLX_AVAILABLE: + raise RuntimeError("MLX framework not available for ThinkDeeper processing") + + # Extract config from request_config if provided + config = DEFAULT_CONFIG.copy() + if request_config: + # Update only valid keys from DEFAULT_CONFIG + for key in DEFAULT_CONFIG: + if key in request_config: + config[key] = request_config[key] + + # Also handle max_tokens which is not in DEFAULT_CONFIG + if 'max_tokens' in request_config: + config['max_tokens'] = request_config['max_tokens'] + + logger.info(f"MLX ThinkDeeper using config: {config}") + + try: + processor = MLXThinkDeeperProcessor(config, tokenizer, model) + response = processor.reasoning_effort(messages) + return response + + except Exception as e: + logger.error(f"Error in MLX ThinkDeeper processing: {str(e)}") + raise \ No newline at end of file diff --git a/scripts/eval_optillmbench.py b/scripts/eval_optillmbench.py index 0756855b..594b0ecf 100644 --- a/scripts/eval_optillmbench.py +++ b/scripts/eval_optillmbench.py @@ -41,24 +41,30 @@ # Baseline ("none", "Baseline without any optimization", {}), - # Sequential test-time compute using thinkdeeper with different thinking budgets - ("thinkdeeper_8k", "ThinkDeeper with 8K thinking tokens", { + # Sequential test-time compute using thinkdeeper with different minimum thinking budgets + ("thinkdeeper_4k", "ThinkDeeper with 4K min thinking tokens", { "decoding": "thinkdeeper", - "max_thinking_tokens": 8000 + "min_thinking_tokens": 4000, + "max_thinking_tokens": 20000, # Allow up to 20K for completion + "max_tokens": 24000 # Total budget: 20K thinking + 4K response }), - ("thinkdeeper_16k", "ThinkDeeper with 16K thinking tokens", { + ("thinkdeeper_8k", "ThinkDeeper with 8K min thinking tokens", { "decoding": "thinkdeeper", - "max_thinking_tokens": 16000 + "min_thinking_tokens": 8000, + "max_thinking_tokens": 32000, # Allow up to 32K for completion + "max_tokens": 36000 # Total budget: 32K thinking + 4K response }), - ("thinkdeeper_32k", "ThinkDeeper with 32K thinking tokens", { + ("thinkdeeper_16k", "ThinkDeeper with 16K min thinking tokens", { "decoding": "thinkdeeper", - "max_thinking_tokens": 32000 + "min_thinking_tokens": 16000, + "max_thinking_tokens": 48000, # Allow up to 48K for completion + "max_tokens": 52000 # Total budget: 48K thinking + 4K response }), # Parallel test-time compute using majority voting with different k values ("majority_voting_6", "Majority Voting with k=6", {"k": 6}), - ("majority_voting_36", "Majority Voting with k=36", {"k": 36}), - ("majority_voting_60", "Majority Voting with k=60", {"k": 60}), + ("majority_voting_12", "Majority Voting with k=12", {"k": 12}), + ("majority_voting_18", "Majority Voting with k=18", {"k": 18}), ] def load_optillm_bench() -> datasets.Dataset: @@ -448,14 +454,20 @@ def save_results(metrics: Dict[str, float], detailed_results: List[Dict[str, Any logger.info(f"Results saved to {base_filename}_*") -def generate_report(all_metrics: Dict[str, Dict[str, float]], output_dir: str): +def generate_report(all_metrics: Dict[str, Dict[str, float]], output_dir: str, is_test_time_compute: bool = False): """Generate a comprehensive report comparing all approaches.""" report = [] # Header - report.append("# OptiLLM Bench Evaluation Report") + report_title = "OptiLLM Bench Test-Time Compute Evaluation Report" if is_test_time_compute else "OptiLLM Bench Evaluation Report" + report.append(f"# {report_title}") report.append(f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + if is_test_time_compute: + report.append("This report evaluates test-time compute scaling approaches:") + report.append("- **Sequential scaling**: ThinkDeeper with varying thinking token budgets") + report.append("- **Parallel scaling**: Majority voting with varying k values\n") + # Overall Results Table report.append("## Overall Results") headers = ["Approach", "Accuracy", "Avg Time (s)", "Total Time (s)"] @@ -537,44 +549,54 @@ def main(): dataset = load_optillm_bench() # Determine which approaches to evaluate - approaches_to_test = ( - [a[0] for a in APPROACHES if a[0] in args.approaches] - if args.approaches - else [a[0] for a in APPROACHES] - ) + if args.test_time_compute: + # Use test-time compute approaches + approaches_config = TEST_TIME_COMPUTE_APPROACHES + if args.approaches: + # Filter test-time compute approaches if specific ones are requested + approaches_config = [a for a in TEST_TIME_COMPUTE_APPROACHES if a[0] in args.approaches] + else: + # Use standard approaches + if args.approaches: + approaches_config = [a for a in APPROACHES if a[0] in args.approaches] + else: + approaches_config = APPROACHES # Store all metrics for final report all_metrics = {} # Evaluate each approach - for approach in approaches_to_test: - logger.info(f"Evaluating approach: {approach}") + for approach_name, description, extra_body_params in approaches_config: + logger.info(f"Evaluating approach: {approach_name} - {description}") + if extra_body_params: + logger.info(f"Extra parameters: {extra_body_params}") try: metrics, detailed_results = evaluate_model( client, args.model, dataset, - approach, + approach_name, + extra_body_params, args.max_samples ) - all_metrics[approach] = metrics + all_metrics[approach_name] = metrics # Save results for this approach - save_results(metrics, detailed_results, args.model, approach, + save_results(metrics, detailed_results, args.model, approach_name, args.output_dir) - logger.info(f"Completed evaluation for {approach}") + logger.info(f"Completed evaluation for {approach_name}") logger.info(f"Accuracy: {metrics['accuracy']*100:.2f}%") logger.info(f"Average time per sample: {metrics['average_time']:.2f}s") except Exception as e: - logger.error(f"Error evaluating approach {approach}: {e}") + logger.error(f"Error evaluating approach {approach_name}: {e}") continue # Generate final report - generate_report(all_metrics, args.output_dir) + generate_report(all_metrics, args.output_dir, args.test_time_compute) if __name__ == "__main__": main() \ No newline at end of file From 580946362230652a4abc68628c7b8a6e3bed4edd Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 10 Jul 2025 22:02:55 +0800 Subject: [PATCH 4/6] support ttc argument --- scripts/eval_aime_benchmark.py | 69 ++++++++++++++++++++++++++++------ scripts/eval_optillmbench.py | 30 +++++++-------- 2 files changed, 72 insertions(+), 27 deletions(-) diff --git a/scripts/eval_aime_benchmark.py b/scripts/eval_aime_benchmark.py index a854f833..ac61c35d 100644 --- a/scripts/eval_aime_benchmark.py +++ b/scripts/eval_aime_benchmark.py @@ -256,7 +256,7 @@ def analyze_logits_probs(logprobs_data: List[Dict]) -> Dict: "token_count": len(token_entropies) } -def get_llm_response(problem: str, model: str, analyze_logits: bool = False) -> Union[str, List[Dict]]: +def get_llm_response(problem: str, model: str, analyze_logits: bool = False, extra_body: dict = None) -> Union[str, List[Dict]]: """ Get response from the LLM for a given problem. If multiple choices are returned, formats them as attempt dictionaries. @@ -276,18 +276,16 @@ def get_llm_response(problem: str, model: str, analyze_logits: bool = False) -> kwargs["logprobs"] = True kwargs["top_logprobs"] = 3 + # Add extra_body if provided + if extra_body: + kwargs["extra_body"] = extra_body + response = client.with_options(timeout=1000.0).chat.completions.create( model=model, messages=[ {"role": "user", "content": SYSTEM_PROMPT + problem} ], max_tokens=8192, - # extra_body={ - # "decoding": "thinkdeeper", - # "min_thinking_tokens" : 0, - # "max_thinking_tokens" : 8000, - # "max_thoughts": 100, - # }, **kwargs ) @@ -333,7 +331,7 @@ def get_llm_response(problem: str, model: str, analyze_logits: bool = False) -> logger.error(f"Error getting LLM response: {e}") return "" -def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = False, analyze_logits: bool = False) -> List[Dict]: +def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = False, analyze_logits: bool = False, extra_body: dict = None) -> List[Dict]: """ Make n attempts to solve a problem and return all responses and predictions. @@ -351,7 +349,7 @@ def make_n_attempts(problem: str, model: str, n: int, analyze_thoughts: bool = F remaining_attempts = n while remaining_attempts > 0: - response = get_llm_response(problem, model, analyze_logits) + response = get_llm_response(problem, model, analyze_logits, extra_body) # If response is already formatted as attempts if isinstance(response, list): @@ -774,7 +772,7 @@ def save_raw_response(filename: str, problem_id: int, response_data: Dict): return response_id -def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_logits: bool = False): +def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_logits: bool = False, test_time_compute: bool = False, approach_name: str = None, extra_body: dict = None): """Main evaluation function that handles gaps in processed indexes.""" os.makedirs("results", exist_ok=True) @@ -784,6 +782,8 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo suffix_parts.append("thought_analysis") if analyze_logits: suffix_parts.append("logit_analysis") + if approach_name: + suffix_parts.append(approach_name) suffix = "_" + "_".join(suffix_parts) if suffix_parts else "" results_file = f"results/evaluation_results_{model.replace('/', '_')}_pass_at_{n_attempts}{suffix}.json" @@ -804,7 +804,7 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo correct_answer = int(item['answer']) # Make n attempts for each problem - attempts = make_n_attempts(problem_text, model, n_attempts, analyze_thoughts, analyze_logits) + attempts = make_n_attempts(problem_text, model, n_attempts, analyze_thoughts, analyze_logits, extra_body) is_correct, first_correct = evaluate_pass_at_n(attempts, correct_answer) result = { @@ -826,6 +826,51 @@ def main(model: str, n_attempts: int, analyze_thoughts: bool = False, analyze_lo parser.add_argument("--n", type=int, default=1, help="Number of attempts per problem (for pass@n evaluation)") parser.add_argument("--analyze-thoughts", action="store_true", help="Analyze thinking patterns in responses") parser.add_argument("--analyze-logits", action="store_true", help="Analyze token probability distributions") + parser.add_argument("--test-time-compute", action="store_true", help="Evaluate test-time compute scaling approaches") args = parser.parse_args() - main(args.model, args.n, args.analyze_thoughts, args.analyze_logits) \ No newline at end of file + if args.test_time_compute: + # Define test-time compute approaches with same config as eval_optillmbench.py + TEST_TIME_COMPUTE_APPROACHES = [ + # Baseline + ("none", "Baseline without any optimization", {}), + + # Sequential test-time compute using thinkdeeper with controlled thinking budgets + ("thinkdeeper_2k", "ThinkDeeper with 2K thinking tokens", { + "decoding": "thinkdeeper", + "min_thinking_tokens": 2048, + "max_thinking_tokens": 2560, # min + 512 for flexibility + "max_tokens": 3072 # Total budget: max_thinking_tokens + 512 + }), + ("thinkdeeper_4k", "ThinkDeeper with 4K thinking tokens", { + "decoding": "thinkdeeper", + "min_thinking_tokens": 4096, + "max_thinking_tokens": 4608, # min + 512 for flexibility + "max_tokens": 5120 # Total budget: max_thinking_tokens + 512 + }), + ("thinkdeeper_8k", "ThinkDeeper with 8K thinking tokens", { + "decoding": "thinkdeeper", + "min_thinking_tokens": 8192, + "max_thinking_tokens": 8704, # min + 512 for flexibility + "max_tokens": 9216 # Total budget: max_thinking_tokens + 512 + }), + + # Parallel test-time compute using majority voting with different k values + ("majority_voting_3", "Majority Voting with k=3", {"k": 3}), + ("majority_voting_6", "Majority Voting with k=6", {"k": 6}), + ("majority_voting_9", "Majority Voting with k=9", {"k": 9}), + ] + + # Run evaluation for each approach + for approach_slug, approach_name, extra_body in TEST_TIME_COMPUTE_APPROACHES: + print(f"\n{'=' * 80}") + print(f"Evaluating: {approach_name}") + print(f"Model: {args.model}") + print(f"Approach: {approach_slug}") + print(f"Extra body: {extra_body}") + print(f"{'=' * 80}\n") + + main(args.model, args.n, args.analyze_thoughts, args.analyze_logits, + test_time_compute=True, approach_name=approach_slug, extra_body=extra_body) + else: + main(args.model, args.n, args.analyze_thoughts, args.analyze_logits) \ No newline at end of file diff --git a/scripts/eval_optillmbench.py b/scripts/eval_optillmbench.py index 594b0ecf..58eac413 100644 --- a/scripts/eval_optillmbench.py +++ b/scripts/eval_optillmbench.py @@ -41,30 +41,30 @@ # Baseline ("none", "Baseline without any optimization", {}), - # Sequential test-time compute using thinkdeeper with different minimum thinking budgets - ("thinkdeeper_4k", "ThinkDeeper with 4K min thinking tokens", { + # Sequential test-time compute using thinkdeeper with controlled thinking budgets + ("thinkdeeper_2k", "ThinkDeeper with 2K thinking tokens", { "decoding": "thinkdeeper", - "min_thinking_tokens": 4000, - "max_thinking_tokens": 20000, # Allow up to 20K for completion - "max_tokens": 24000 # Total budget: 20K thinking + 4K response + "min_thinking_tokens": 2048, + "max_thinking_tokens": 2560, # min + 512 for flexibility + "max_tokens": 3072 # Total budget: max_thinking_tokens + 512 }), - ("thinkdeeper_8k", "ThinkDeeper with 8K min thinking tokens", { + ("thinkdeeper_4k", "ThinkDeeper with 4K thinking tokens", { "decoding": "thinkdeeper", - "min_thinking_tokens": 8000, - "max_thinking_tokens": 32000, # Allow up to 32K for completion - "max_tokens": 36000 # Total budget: 32K thinking + 4K response + "min_thinking_tokens": 4096, + "max_thinking_tokens": 4608, # min + 512 for flexibility + "max_tokens": 5120 # Total budget: max_thinking_tokens + 512 }), - ("thinkdeeper_16k", "ThinkDeeper with 16K min thinking tokens", { + ("thinkdeeper_8k", "ThinkDeeper with 8K thinking tokens", { "decoding": "thinkdeeper", - "min_thinking_tokens": 16000, - "max_thinking_tokens": 48000, # Allow up to 48K for completion - "max_tokens": 52000 # Total budget: 48K thinking + 4K response + "min_thinking_tokens": 8192, + "max_thinking_tokens": 8704, # min + 512 for flexibility + "max_tokens": 9216 # Total budget: max_thinking_tokens + 512 }), # Parallel test-time compute using majority voting with different k values + ("majority_voting_3", "Majority Voting with k=3", {"k": 3}), ("majority_voting_6", "Majority Voting with k=6", {"k": 6}), - ("majority_voting_12", "Majority Voting with k=12", {"k": 12}), - ("majority_voting_18", "Majority Voting with k=18", {"k": 18}), + ("majority_voting_9", "Majority Voting with k=9", {"k": 9}), ] def load_optillm_bench() -> datasets.Dataset: From f656189b39da621a1e0d72e850b071fb66769725 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Fri, 11 Jul 2025 06:01:18 +0800 Subject: [PATCH 5/6] Add fallback for providers without n parameter support Updated best_of_n_sampling, mixture_of_agents, and majority_voting_plugin to handle providers that do not support the 'n' parameter by generating completions/candidates one by one in a loop. This improves compatibility with a wider range of API providers and ensures robust completion generation even when batch generation is not available. --- optillm/bon.py | 49 ++++++-- optillm/moa.py | 68 ++++++++--- optillm/plugins/majority_voting_plugin.py | 130 ++++++++++++---------- 3 files changed, 167 insertions(+), 80 deletions(-) diff --git a/optillm/bon.py b/optillm/bon.py index 8ee752a3..3da7d140 100644 --- a/optillm/bon.py +++ b/optillm/bon.py @@ -10,16 +10,45 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st completions = [] - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=4096, - n=n, - temperature=1 - ) - completions = [choice.message.content for choice in response.choices] - logger.info(f"Generated {len(completions)} initial completions. Tokens used: {response.usage.completion_tokens}") - bon_completion_tokens += response.usage.completion_tokens + try: + # Try to generate n completions in a single API call using n parameter + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=4096, + n=n, + temperature=1 + ) + completions = [choice.message.content for choice in response.choices] + logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}") + bon_completion_tokens += response.usage.completion_tokens + + except Exception as e: + logger.warning(f"n parameter not supported by provider: {str(e)}") + logger.info(f"Falling back to generating {n} completions one by one") + + # Fallback: Generate completions one by one in a loop + for i in range(n): + try: + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=4096, + temperature=1 + ) + completions.append(response.choices[0].message.content) + bon_completion_tokens += response.usage.completion_tokens + logger.debug(f"Generated completion {i+1}/{n}") + + except Exception as fallback_error: + logger.error(f"Error generating completion {i+1}: {str(fallback_error)}") + continue + + if not completions: + logger.error("Failed to generate any completions") + return "Error: Could not generate any completions", 0 + + logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {bon_completion_tokens}") # Rate the completions rating_messages = messages.copy() diff --git a/optillm/moa.py b/optillm/moa.py index 5306e25b..21d5e105 100644 --- a/optillm/moa.py +++ b/optillm/moa.py @@ -8,19 +8,61 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str completions = [] logger.debug(f"Generating initial completions for query: {initial_query}") - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": initial_query} - ], - max_tokens=4096, - n=3, - temperature=1 - ) - completions = [choice.message.content for choice in response.choices] - moa_completion_tokens += response.usage.completion_tokens - logger.info(f"Generated {len(completions)} initial completions. Tokens used: {response.usage.completion_tokens}") + + try: + # Try to generate 3 completions in a single API call using n parameter + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ], + max_tokens=4096, + n=3, + temperature=1 + ) + completions = [choice.message.content for choice in response.choices] + moa_completion_tokens += response.usage.completion_tokens + logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}") + + except Exception as e: + logger.warning(f"n parameter not supported by provider: {str(e)}") + logger.info("Falling back to generating 3 completions one by one") + + # Fallback: Generate 3 completions one by one in a loop + completions = [] + for i in range(3): + try: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ], + max_tokens=4096, + temperature=1 + ) + completions.append(response.choices[0].message.content) + moa_completion_tokens += response.usage.completion_tokens + logger.debug(f"Generated completion {i+1}/3") + + except Exception as fallback_error: + logger.error(f"Error generating completion {i+1}: {str(fallback_error)}") + continue + + if not completions: + logger.error("Failed to generate any completions") + return "Error: Could not generate any completions", 0 + + logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {moa_completion_tokens}") + + # Handle case where fewer than 3 completions were generated + if len(completions) < 3: + original_count = len(completions) + # Pad with the first completion to ensure we have 3 + while len(completions) < 3: + completions.append(completions[0]) + logger.warning(f"Only generated {original_count} unique completions, padded to 3 for critique") logger.debug("Preparing critique prompt") critique_prompt = f""" diff --git a/optillm/plugins/majority_voting_plugin.py b/optillm/plugins/majority_voting_plugin.py index 3031cd10..311072b7 100644 --- a/optillm/plugins/majority_voting_plugin.py +++ b/optillm/plugins/majority_voting_plugin.py @@ -213,65 +213,81 @@ def run( candidates = [choice.message.content for choice in response.choices] total_tokens = response.usage.completion_tokens - logger.info(f"Generated {len(candidates)} candidates. Tokens used: {total_tokens}") + logger.info(f"Generated {len(candidates)} candidates using n parameter. Tokens used: {total_tokens}") - # Extract answers from each candidate - answers = [] - answer_to_response = {} # Map normalized answers to full responses - - for i, candidate in enumerate(candidates): - answer = extract_answer(candidate) - if answer: - normalized = normalize_answer(answer) - answers.append(normalized) - # Keep the first full response for each unique answer - if normalized not in answer_to_response: - answer_to_response[normalized] = candidate - logger.debug(f"Candidate {i+1} answer: {answer} (normalized: {normalized})") - else: - logger.warning(f"Could not extract answer from candidate {i+1}") - - if not answers: - logger.warning("No answers could be extracted from any candidate") - # Return the first candidate as fallback - return candidates[0] if candidates else "Error: No candidates generated", total_tokens - - # Count answer frequencies - answer_counts = Counter(answers) - logger.info(f"Answer distribution: {dict(answer_counts)}") - - # Get the most common answer - most_common_answer, count = answer_counts.most_common(1)[0] - confidence = count / len(answers) - - logger.info(f"Most common answer: '{most_common_answer}' with {count}/{len(answers)} votes ({confidence:.1%} confidence)") - - # Get the full response corresponding to the most common answer - winning_response = answer_to_response.get(most_common_answer, candidates[0]) - - # Log voting summary to console instead of adding to response - logger.info("Majority Voting Summary:") - logger.info(f" - Generated {k} candidates") - logger.info(f" - Most common answer: {most_common_answer}") - logger.info(f" - Votes: {count}/{len(answers)} ({confidence:.1%} confidence)") - - if len(answer_counts) > 1: - other_answers = [f"{ans} ({cnt} votes)" for ans, cnt in answer_counts.items() if ans != most_common_answer] - logger.info(f" - Other answers: {', '.join(other_answers)}") + except Exception as e: + logger.warning(f"n parameter not supported by provider: {str(e)}") + logger.info(f"Falling back to generating {k} candidates one by one") - # Return only the full response from the winning answer - return winning_response, total_tokens + # Fallback: Generate candidates one by one in a loop + candidates = [] + total_tokens = 0 - except Exception as e: - logger.error(f"Error in majority voting: {str(e)}") - # Fall back to single response - logger.info("Falling back to single response generation") + for i in range(k): + try: + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens + ) + candidates.append(response.choices[0].message.content) + total_tokens += response.usage.completion_tokens + logger.debug(f"Generated candidate {i+1}/{k}") + + except Exception as fallback_error: + logger.error(f"Error generating candidate {i+1}: {str(fallback_error)}") + continue - response = client.chat.completions.create( - model=model, - messages=messages, - temperature=temperature, - max_tokens=max_tokens - ) + if not candidates: + logger.error("Failed to generate any candidates") + return "Error: Could not generate any candidates", 0 - return response.choices[0].message.content, response.usage.completion_tokens \ No newline at end of file + logger.info(f"Generated {len(candidates)} candidates using fallback method. Total tokens used: {total_tokens}") + + # Extract answers from each candidate + answers = [] + answer_to_response = {} # Map normalized answers to full responses + + for i, candidate in enumerate(candidates): + answer = extract_answer(candidate) + if answer: + normalized = normalize_answer(answer) + answers.append(normalized) + # Keep the first full response for each unique answer + if normalized not in answer_to_response: + answer_to_response[normalized] = candidate + logger.debug(f"Candidate {i+1} answer: {answer} (normalized: {normalized})") + else: + logger.warning(f"Could not extract answer from candidate {i+1}") + + if not answers: + logger.warning("No answers could be extracted from any candidate") + # Return the first candidate as fallback + return candidates[0] if candidates else "Error: No candidates generated", total_tokens + + # Count answer frequencies + answer_counts = Counter(answers) + logger.info(f"Answer distribution: {dict(answer_counts)}") + + # Get the most common answer + most_common_answer, count = answer_counts.most_common(1)[0] + confidence = count / len(answers) + + logger.info(f"Most common answer: '{most_common_answer}' with {count}/{len(answers)} votes ({confidence:.1%} confidence)") + + # Get the full response corresponding to the most common answer + winning_response = answer_to_response.get(most_common_answer, candidates[0]) + + # Log voting summary to console instead of adding to response + logger.info("Majority Voting Summary:") + logger.info(f" - Generated {len(candidates)} candidates") + logger.info(f" - Most common answer: {most_common_answer}") + logger.info(f" - Votes: {count}/{len(answers)} ({confidence:.1%} confidence)") + + if len(answer_counts) > 1: + other_answers = [f"{ans} ({cnt} votes)" for ans, cnt in answer_counts.items() if ans != most_common_answer] + logger.info(f" - Other answers: {', '.join(other_answers)}") + + # Return only the full response from the winning answer + return winning_response, total_tokens \ No newline at end of file From cf2578e8683ce71c5613affd2aa0a6ed1c6a9ca5 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Fri, 11 Jul 2025 06:04:34 +0800 Subject: [PATCH 6/6] Bump version to 0.1.20 Update version number in __init__.py and setup.py to 0.1.20 for new release. --- optillm/__init__.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optillm/__init__.py b/optillm/__init__.py index c7e63bcc..24870a61 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -2,7 +2,7 @@ import os # Version information -__version__ = "0.1.19" +__version__ = "0.1.20" # Get the path to the root optillm.py spec = util.spec_from_file_location( diff --git a/setup.py b/setup.py index d164df56..73a48e32 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="optillm", - version="0.1.19", + version="0.1.20", packages=find_packages(include=['optillm', 'optillm.*']), # This ensures all subpackages are included py_modules=['optillm'], package_data={