From 23de9188de6b20d00c29204d03c7ff49f9d743d3 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Wed, 7 May 2025 10:38:40 +0800 Subject: [PATCH 01/11] init implementation --- README.md | 3 +- optillm/autothink/README.md | 95 +++++ optillm/autothink/__init__.py | 7 + optillm/autothink/autothink.py | 88 +++++ optillm/autothink/classifier.py | 152 ++++++++ optillm/autothink/example.py | 98 ++++++ optillm/autothink/processor.py | 376 ++++++++++++++++++++ optillm/autothink/steering.py | 603 ++++++++++++++++++++++++++++++++ optillm/inference.py | 29 ++ requirements.txt | 1 + setup.py | 1 + 11 files changed, 1452 insertions(+), 1 deletion(-) create mode 100644 optillm/autothink/README.md create mode 100644 optillm/autothink/__init__.py create mode 100644 optillm/autothink/autothink.py create mode 100644 optillm/autothink/classifier.py create mode 100644 optillm/autothink/example.py create mode 100644 optillm/autothink/processor.py create mode 100644 optillm/autothink/steering.py diff --git a/README.md b/README.md index 8fe90d53..aa32de68 100644 --- a/README.md +++ b/README.md @@ -343,7 +343,7 @@ Check this log file for connection issues, tool execution errors, and other diag | Approach | Slug | Description | | ------------------------------------ | ------------------ | ---------------------------------------------------------------------------------------------- | -| Cerebras Planning and Optimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques | +| Cerebras Planning and Optimization | `cepo` | Combines Best of N, Chain-of-Thought, Self-Reflection, Self-Improvement, and various prompting techniques | | CoT with Reflection | `cot_reflection` | Implements chain-of-thought reasoning with \, \ and \ sections | | PlanSearch | `plansearch` | Implements a search algorithm over candidate plans for solving a problem in natural language | | ReRead | `re2` | Implements rereading to improve reasoning by processing queries twice | @@ -359,6 +359,7 @@ Check this log file for connection issues, tool execution errors, and other diag | CoT Decoding | N/A for proxy | Implements chain-of-thought decoding to elicit reasoning without explicit prompting | | Entropy Decoding | N/A for proxy | Implements adaptive sampling based on the uncertainty of tokens during generation | | Thinkdeeper | N/A for proxy | Implements the `reasoning_effort` param from OpenAI for reasoning models like DeepSeek R1 | +| AutoThink | N/A for proxy | Combines query complexity classification with steering vectors to enhance reasoning | ## Implemented plugins diff --git a/optillm/autothink/README.md b/optillm/autothink/README.md new file mode 100644 index 00000000..74199d1d --- /dev/null +++ b/optillm/autothink/README.md @@ -0,0 +1,95 @@ +# AutoThink + +AutoThink is an adaptive thinking approach for Large Language Models that combines query complexity classification with steering vector guidance to enhance model reasoning capabilities. + +## Overview + +AutoThink combines several advanced techniques to optimize the thinking process of LLMs: + +1. **Query Complexity Classification**: Uses an adaptive classifier to determine if a query requires HIGH or LOW complexity reasoning +2. **Token Budget Allocation**: Dynamically allocates thinking tokens based on query complexity +3. **Steering Vector Guidance**: Applies activation-based steering vectors to guide the model's reasoning process +4. **Controlled Thinking Process**: Manages explicit thinking phases with start and end tokens + +## How It Works + +### 1. Query Classification + +AutoThink uses the `adaptive-classifier/llm-router` model to classify incoming queries: + +- **HIGH**: Complex queries requiring deep reasoning, multi-step calculations, or thorough exploration +- **LOW**: Simpler queries requiring less extensive reasoning + +### 2. Token Budget + +Based on the classification, AutoThink allocates different token budgets for the thinking phase: + +- **HIGH**: 70-90% of max tokens allocated for thinking +- **LOW**: 20-40% of max tokens allocated for thinking + +### 3. Steering Vectors + +AutoThink uses pre-extracted steering vectors from datasets like `codelion/Qwen3-0.6B-pts-steering-vectors`. These vectors represent different reasoning patterns: + +- **Depth and thoroughness**: Encourages detailed, step-by-step reasoning +- **Numerical accuracy**: Promotes precise calculations and verification +- **Self-correction**: Facilitates error detection and correction +- **Exploration**: Supports considering multiple approaches +- **Organization**: Improves logical structure in responses + +During inference, the model's internal activations are modified based on these vectors to enhance specific reasoning capabilities. + +### 4. Controlled Thinking Process + +The generation process includes: +1. A thinking phase marked by `` and `` tokens +2. Automatic adjustment of thinking time based on query complexity +3. Dynamic application of steering vectors +4. Graceful transition to the final response + +## Configuration + +AutoThink can be configured with: + +```python +{ + "model_name": "your-model-name", + "classifier_model": "adaptive-classifier/llm-router", + "steering_dataset": "codelion/Qwen3-0.6B-pts-steering-vectors", + "target_layer": 19, # Layer to apply steering vectors + "high_complexity_min_tokens": 1024, + "high_complexity_max_tokens": 4096, + "low_complexity_min_tokens": 256, + "low_complexity_max_tokens": 1024, + "pattern_strengths": { + "depth_and_thoroughness": 2.5, # Steering strength for different patterns + "numerical_accuracy": 2.0, + "self_correction": 3.0, + "exploration": 2.0, + "organization": 1.5 + } +} +``` + +## Usage + +```python +from optillm.autothink import autothink_decode + +response = autothink_decode( + model, + tokenizer, + messages, + { + "steering_dataset": "codelion/Qwen3-0.6B-pts-steering-vectors", + "target_layer": 19 + } +) +``` + +## Benefits + +- **Adaptive Resource Usage**: Models think more on complex problems and less on simple ones +- **Enhanced Reasoning**: Steering vectors guide the model toward better reasoning patterns +- **Efficiency**: Better performance without increasing model size +- **Customizability**: Can be tailored for different domains using domain-specific steering vector datasets diff --git a/optillm/autothink/__init__.py b/optillm/autothink/__init__.py new file mode 100644 index 00000000..7b269554 --- /dev/null +++ b/optillm/autothink/__init__.py @@ -0,0 +1,7 @@ +""" +AutoThink - Adaptive thinking approach for LLMs with query complexity classification and steering vectors. +""" + +from .autothink import autothink_decode, AutoThinkProcessor + +__all__ = ["autothink_decode", "AutoThinkProcessor"] diff --git a/optillm/autothink/autothink.py b/optillm/autothink/autothink.py new file mode 100644 index 00000000..bbf24968 --- /dev/null +++ b/optillm/autothink/autothink.py @@ -0,0 +1,88 @@ +""" +AutoThink main implementation. + +This module provides the main implementation of AutoThink, combining +query complexity classification with steering vectors to enhance reasoning. +""" + +import logging +from typing import Dict, List, Any, Optional +from transformers import PreTrainedModel, PreTrainedTokenizer + +from .processor import AutoThinkProcessor + +logger = logging.getLogger(__name__) + +class AutoThinkProcessor: + """ + Main AutoThink processor class for external use. + Wraps the internal processor implementation. + """ + + def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, config: Dict[str, Any] = None): + """ + Initialize the AutoThink processor. + + Args: + model: Language model + tokenizer: Model tokenizer + config: Configuration dictionary + """ + self.config = config or {} + self.processor = None + self.model = model + self.tokenizer = tokenizer + + def __call__(self, messages: List[Dict[str, str]]) -> str: + """ + Process messages with AutoThink's controlled thinking. + + Args: + messages: List of message dictionaries + + Returns: + Generated response + """ + # Create processor on first use to allow for model loading + if self.processor is None: + self.processor = self._create_processor() + + return self.processor.process(messages) + + def _create_processor(self): + """Create the internal processor instance.""" + return AutoThinkProcessor(self.config, self.tokenizer, self.model) + +def autothink_decode( + model: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + messages: List[Dict[str, str]], + request_config: Optional[Dict[str, Any]] = None +) -> str: + """ + Main plugin execution function with AutoThink's controlled thinking process. + + Args: + model: Language model + tokenizer: Model tokenizer + messages: List of message dictionaries + request_config: Optional configuration dictionary + + Returns: + Generated response with thinking process + """ + logger.info("Starting AutoThink processing") + + # Create config dictionary + config = {} + if request_config: + config.update(request_config) + + try: + processor = AutoThinkProcessor(model, tokenizer, config) + response = processor(messages) + return response + + except Exception as e: + logger.error(f"Error in AutoThink processing: {str(e)}") + raise diff --git a/optillm/autothink/classifier.py b/optillm/autothink/classifier.py new file mode 100644 index 00000000..261faaad --- /dev/null +++ b/optillm/autothink/classifier.py @@ -0,0 +1,152 @@ +""" +Query complexity classifier for AutoThink. + +This module provides functionality to classify queries as HIGH or LOW complexity +using the adaptive-classifier model. +""" + +import logging +from typing import Dict, Any, Tuple, Optional, List, Union +import os +import sys + +logger = logging.getLogger(__name__) + +class ComplexityClassifier: + """ + Classifies queries as HIGH or LOW complexity for token budget allocation. + Uses the adaptive-classifier model for classification. + """ + + def __init__(self, model_name: str = "adaptive-classifier/llm-router"): + """ + Initialize the complexity classifier. + + Args: + model_name: HuggingFace model name or path for the classifier + """ + self.model_name = model_name + self.classifier = None + + # Load model + self._load_model() + + def _load_model(self): + """Load the classification model using adaptive-classifier library.""" + try: + # Check if adaptive-classifier is installed + try: + import adaptive_classifier + except ImportError: + logger.info("Installing adaptive-classifier library...") + os.system(f"{sys.executable} -m pip install adaptive-classifier") + import adaptive_classifier + + # Import the AdaptiveClassifier class + from adaptive_classifier import AdaptiveClassifier + + logger.info(f"Loading complexity classifier model: {self.model_name}") + self.classifier = AdaptiveClassifier.from_pretrained(self.model_name) + logger.info("Classifier loaded successfully") + + except Exception as e: + logger.error(f"Error loading complexity classifier: {e}") + # Fallback to basic classification if model fails to load + self.classifier = None + + def predict(self, text: str) -> List[Tuple[str, float]]: + """ + Predict the complexity label for a given text. + + Args: + text: The query text to classify + + Returns: + List of (label, score) tuples sorted by confidence + """ + if self.classifier is None: + logger.warning("Classifier not loaded. Using fallback classification.") + return self._fallback_classification(text) + + try: + # Make prediction using the AdaptiveClassifier + predictions = self.classifier.predict(text) + logger.debug(f"Classifier predictions: {predictions}") + + # Make sure predictions are in the expected format + if isinstance(predictions, list) and all(isinstance(p, tuple) and len(p) == 2 for p in predictions): + # Sort by confidence (assuming higher score = higher confidence) + predictions.sort(key=lambda x: x[1], reverse=True) + return predictions + else: + logger.warning(f"Unexpected prediction format: {predictions}") + return self._fallback_classification(text) + + except Exception as e: + logger.error(f"Error during classification: {e}") + return self._fallback_classification(text) + + def _fallback_classification(self, text: str) -> List[Tuple[str, float]]: + """ + Simple heuristic classification when model isn't available. + + Args: + text: The query text + + Returns: + List of (label, score) tuples + """ + # Count key indicators of complexity + complexity_indicators = [ + "explain", "analyze", "compare", "evaluate", "synthesize", + "how", "why", "complex", "detail", "thorough", "comprehensive", + "step by step", "calculate", "prove", "justify", "multiple", + "consequences", "implications", "differentiate", "frameworks" + ] + + # Count mentions of complexity indicators + count = sum(1 for indicator in complexity_indicators if indicator.lower() in text.lower()) + + # Calculate complexity probability based on count and text length + text_length_factor = min(len(text) / 100, 2.0) # Cap at 2.0 + indicator_factor = min(count / 3, 1.5) # Cap at 1.5 + + # Combined factor determines HIGH vs LOW + complexity_score = text_length_factor * indicator_factor + + if complexity_score > 1.0: + return [("HIGH", 0.7), ("LOW", 0.3)] + else: + return [("LOW", 0.8), ("HIGH", 0.2)] + + def is_high_complexity(self, text: str, threshold: float = 0.5) -> bool: + """ + Determine if a query is high complexity. + + Args: + text: The query text + threshold: Confidence threshold for HIGH classification + + Returns: + Boolean indicating if the query is high complexity + """ + predictions = self.predict(text) + + for label, score in predictions: + if label == "HIGH" and score >= threshold: + return True + + return False + + def get_complexity_with_confidence(self, text: str) -> Tuple[str, float]: + """ + Get the complexity label and confidence score. + + Args: + text: The query text + + Returns: + Tuple of (complexity_label, confidence_score) + """ + predictions = self.predict(text) + return predictions[0] # Return highest confidence prediction diff --git a/optillm/autothink/example.py b/optillm/autothink/example.py new file mode 100644 index 00000000..952e2c36 --- /dev/null +++ b/optillm/autothink/example.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +""" +Example usage of AutoThink. + +This script demonstrates how to use AutoThink with a language model. +""" + +import torch +import argparse +import logging +from transformers import AutoModelForCausalLM, AutoTokenizer + +from optillm.autothink import autothink_decode + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +def main(): + parser = argparse.ArgumentParser(description="Run AutoThink demo") + parser.add_argument("--model", type=str, default="deepseek-ai/deepseek-r1-llama-8b", + help="Model name or path") + parser.add_argument("--steering-dataset", type=str, + default="codelion/Qwen3-0.6B-pts-steering-vectors", + help="Steering vectors dataset") + parser.add_argument("--target-layer", type=int, default=19, + help="Target layer for steering") + parser.add_argument("--query", type=str, + default="Explain quantum computing to me in detail", + help="Query to process") + + args = parser.parse_args() + + # Load model and tokenizer + try: + logger.info(f"Loading model: {args.model}") + + # Determine device + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {device}") + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(args.model) + + # Load model with appropriate configuration based on device + model_kwargs = {"trust_remote_code": True} + + if device == "cuda": + model_kwargs["torch_dtype"] = torch.float16 + model_kwargs["device_map"] = "auto" + + model = AutoModelForCausalLM.from_pretrained(args.model, **model_kwargs) + + # Ensure proper PAD token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + logger.info("Model and tokenizer loaded successfully") + + # Create AutoThink configuration + config = { + "steering_dataset": args.steering_dataset, + "target_layer": args.target_layer, + "pattern_strengths": { + "depth_and_thoroughness": 2.5, + "numerical_accuracy": 2.0, + "self_correction": 3.0, + "exploration": 2.0, + "organization": 1.5 + } + } + + # Create messages + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": args.query} + ] + + # Process with AutoThink + logger.info("Running AutoThink processing...") + response = autothink_decode(model, tokenizer, messages, config) + + # Print response + print("\n" + "=" * 80) + print("QUERY:", args.query) + print("-" * 80) + print(response) + print("=" * 80 + "\n") + + except Exception as e: + logger.error(f"Error in AutoThink demo: {str(e)}") + raise + +if __name__ == "__main__": + main() diff --git a/optillm/autothink/processor.py b/optillm/autothink/processor.py new file mode 100644 index 00000000..39a4d67c --- /dev/null +++ b/optillm/autothink/processor.py @@ -0,0 +1,376 @@ +""" +AutoThink processor implementation. + +This module implements the AutoThink processor for controlled thinking +with query complexity classification and steering vectors. +""" + +import torch +import random +import logging +from transformers import PreTrainedModel, PreTrainedTokenizer, DynamicCache +from typing import Dict, List, Any, Optional, Union, Tuple + +from .classifier import ComplexityClassifier +from .steering import SteeringVectorManager, install_steering_hooks, remove_steering_hooks + +logger = logging.getLogger(__name__) + +# Default configurations +DEFAULT_CONFIG = { + # General configuration + "min_thinking_tokens": 256, + "max_thinking_tokens": 2048, + "max_thoughts": 64, + "prefill": "", + "start_think_token": "", + "end_think_token": "", + + # Complexity-specific configurations + "high_complexity_min_tokens": 1024, + "high_complexity_max_tokens": 4096, + "low_complexity_min_tokens": 256, + "low_complexity_max_tokens": 1024, + + # Thought switch tokens + "thought_switch_tokens": [ + "Wait,", + "Alternatively,", + "However,", + "Additionally,", + "Let's consider,", + "On second thought,", + "Actually,", + "Furthermore,", + "Looking at it differently,", + "To be thorough," + ], + + # Classifier configuration + "classifier_model": "adaptive-classifier/llm-router", + "complexity_threshold": 0.6, + + # Steering configuration + "steering_dataset": "", + "target_layer": 19, + "pattern_strengths": { + "depth_and_thoroughness": 2.5, + "numerical_accuracy": 2.0, + "self_correction": 3.0, + "exploration": 2.0, + "organization": 1.5 + } +} + +class AutoThinkProcessor: + """ + AutoThink processor for controlled thinking with + complexity classification and steering vectors. + """ + + def __init__(self, config: Dict[str, Any], tokenizer: PreTrainedTokenizer, model: PreTrainedModel): + """ + Initialize the AutoThink processor. + + Args: + config: Configuration dictionary + tokenizer: Model tokenizer + model: Language model + """ + # Merge default config with provided config + self.config = {**DEFAULT_CONFIG, **config} + self.tokenizer = tokenizer + self.model = model + + # Initialize classifier + self.classifier = ComplexityClassifier(self.config["classifier_model"]) + + # Get token IDs for think markers + start_tokens = self.tokenizer.encode(self.config['start_think_token']) + end_tokens = self.tokenizer.encode(self.config['end_think_token']) + self._start_think_token = start_tokens[0] if len(start_tokens) == 1 else start_tokens[1] + self.end_think_token = end_tokens[0] if len(end_tokens) == 1 else end_tokens[1] + + # Store thought switch markers as token sequences + self.thought_switch_sequences = [] + for phrase in self.config["thought_switch_tokens"]: + token_ids = self.tokenizer.encode(phrase, add_special_tokens=False) + self.thought_switch_sequences.append(token_ids) + logger.debug(f"Encoded '{phrase}' to token sequence: {token_ids}") + logger.debug(f"Decoded back: {self.tokenizer.decode(token_ids)}") + + # Track thought switches + self.thought_count = 0 + self.current_sequence = [] # Track recent tokens for sequence matching + self.max_sequence_length = max(len(seq) for seq in self.thought_switch_sequences) + + # Initialize steering vector manager and hooks if dataset is provided + self.steering_manager = None + self.steering_hooks = [] + + if self.config["steering_dataset"]: + self._setup_steering() + + def _setup_steering(self): + """Set up steering vector management.""" + try: + # Initialize steering vector manager + self.steering_manager = SteeringVectorManager( + dataset_name=self.config["steering_dataset"], + target_layer=self.config["target_layer"] + ) + + # Set pattern strengths + if "pattern_strengths" in self.config: + for pattern, strength in self.config["pattern_strengths"].items(): + self.steering_manager.set_steering_strength(pattern, strength) + + # Create tokenized contexts for efficient matching + self.steering_manager.create_tokenized_contexts(self.tokenizer) + + # Install hooks on the model + self.steering_hooks = install_steering_hooks( + self.model, + self.steering_manager, + self.tokenizer + ) + + logger.info(f"Set up steering with {len(self.steering_hooks)} hooks") + + except Exception as e: + logger.error(f"Error setting up steering: {e}") + self.steering_manager = None + self.steering_hooks = [] + + def _cleanup_steering(self): + """Clean up steering hooks.""" + if self.steering_hooks: + remove_steering_hooks(self.steering_hooks) + self.steering_hooks = [] + + def classify_complexity(self, query: str) -> Tuple[str, float]: + """ + Classify query complexity. + + Args: + query: The query to classify + + Returns: + Tuple of (complexity_label, confidence_score) + """ + complexity, confidence = self.classifier.get_complexity_with_confidence(query) + logger.info(f"Query classified as {complexity} with confidence {confidence:.2f}") + return complexity, confidence + + def get_token_budget(self, complexity: str) -> Tuple[int, int]: + """ + Get token budget based on complexity. + + Args: + complexity: Complexity label (HIGH or LOW) + + Returns: + Tuple of (min_tokens, max_tokens) + """ + if complexity == "HIGH": + return ( + self.config["high_complexity_min_tokens"], + self.config["high_complexity_max_tokens"] + ) + else: + return ( + self.config["low_complexity_min_tokens"], + self.config["low_complexity_max_tokens"] + ) + + def is_thought_switch(self, token: int) -> bool: + """ + Check if adding this token creates a thought switch sequence. + + Args: + token: Token ID to check + + Returns: + Boolean indicating if this completes a thought switch + """ + # Add new token to current sequence + self.current_sequence.append(token) + + # Keep only the most recent tokens that could match our sequences + if len(self.current_sequence) > self.max_sequence_length: + self.current_sequence = self.current_sequence[-self.max_sequence_length:] + + # Check if current sequence ends with any thought switch sequence + for sequence in self.thought_switch_sequences: + if len(sequence) <= len(self.current_sequence) and \ + self.current_sequence[-len(sequence):] == sequence: + return True + + return False + + @torch.inference_mode() + def process(self, messages: List[Dict[str, str]]) -> str: + """ + Process messages with AutoThink's controlled thinking. + + Args: + messages: List of message dictionaries + + Returns: + Generated response + """ + try: + # Extract the query from the messages + query = self._extract_query(messages) + + # Classify query complexity + complexity, confidence = self.classify_complexity(query) + + # Get token budget based on complexity + min_tokens, max_tokens = self.get_token_budget(complexity) + logger.info(f"Using token budget: {min_tokens}-{max_tokens} for {complexity} complexity") + + # Prepare messages with thinking start token + thinking_messages = messages.copy() + thinking_messages.append({ + "role": "assistant", + "content": f"{self.config['start_think_token']}\n{self.config['prefill']}" + }) + + # Tokenize the messages + tokens = self.tokenizer.apply_chat_template( + thinking_messages, + continue_final_message=True, + return_tensors="pt" + ).to(self.model.device) + + # Update token history in steering hooks + if self.steering_hooks: + token_ids = tokens[0].tolist() + for hook, _ in self.steering_hooks: + hook.update_token_history(token_ids) + # Try to match with a steering vector + hook.try_match() + + # Generate with controlled thinking + kv = DynamicCache() + n_thinking_tokens = 0 + seen_end_think = False + response_chunks = [] + + while True: + out = self.model(input_ids=tokens, past_key_values=kv, use_cache=True) + logits = out.logits[0, -1, :] + + # Check if we need to force end thinking + force_end = (n_thinking_tokens >= max_tokens or + self.thought_count >= self.config["max_thoughts"]) + + if force_end and not seen_end_think: + logger.debug(f"Forcing end think token. Tokens: {n_thinking_tokens}, Thoughts: {self.thought_count}") + next_token = self.end_think_token + response_chunks.append(self.tokenizer.decode([next_token])) + seen_end_think = True + tokens = torch.tensor([[next_token]]).to(tokens.device) + continue + else: + next_token = torch.multinomial( + torch.softmax(logits, dim=-1), 1 + ).item() + + kv = out.past_key_values + next_str = self.tokenizer.decode([next_token]) + + # Update steering hooks with new token + if self.steering_hooks: + for hook, _ in self.steering_hooks: + hook.update_token_history([next_token]) + + # Check if this is a thought-switching token (only if not in conclusion phase) + if not seen_end_think and self.is_thought_switch(next_token): + self.thought_count += 1 + logger.debug(f"Detected thought switch marker. Total thoughts: {self.thought_count}") + self.current_sequence = [] + + # Handle natural end think token + if next_token == self.end_think_token: + seen_end_think = True + logger.debug("Found end think token") + + # If we haven't reached minimum tokens, continue with thought transition + if n_thinking_tokens < min_tokens: + replacement = random.choice(self.config["thought_switch_tokens"]) + logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})") + response_chunks.append(replacement) + replacement_tokens = self.tokenizer.encode(replacement) + n_thinking_tokens += len(replacement_tokens) + tokens = torch.tensor([replacement_tokens]).to(tokens.device) + self.thought_count += 1 + seen_end_think = False + continue + + # Handle EOS token + if next_token == self.model.config.eos_token_id: + logger.debug("Found EOS token") + if seen_end_think: + logger.debug("Reached EOS after end think token - stopping generation") + response_chunks.append(next_str) + break + elif n_thinking_tokens < min_tokens: + # Continue with thought transition if under minimum tokens + replacement = random.choice(self.config["thought_switch_tokens"]) + logger.debug(f"Inserting thought transition: '{replacement}' (tokens: {n_thinking_tokens})") + response_chunks.append(replacement) + replacement_tokens = self.tokenizer.encode(replacement) + n_thinking_tokens += len(replacement_tokens) + tokens = torch.tensor([replacement_tokens]).to(tokens.device) + self.thought_count += 1 + continue + else: + # Force end think token and continue generating for natural conclusion + logger.debug("Reached EOS without end think token - adding end token and continuing generation") + response_chunks.append(self.tokenizer.decode([self.end_think_token])) + tokens = torch.tensor([[self.end_think_token]]).to(tokens.device) + seen_end_think = True + continue + + # Normal token processing + response_chunks.append(next_str) + if not seen_end_think: + n_thinking_tokens += 1 + tokens = torch.tensor([[next_token]]).to(tokens.device) + + # Clean up steering hooks + self._cleanup_steering() + + # Join all chunks and add framing tokens + response = "".join(response_chunks) + full_response = f"{self.config['start_think_token']}\n{self.config['prefill']}{response}" + + logger.debug(f"Final response length: {len(full_response)} chars, Total thoughts: {self.thought_count}") + return full_response + + except Exception as e: + # Clean up steering hooks in case of error + self._cleanup_steering() + logger.error(f"Error in AutoThink processing: {str(e)}") + raise + + def _extract_query(self, messages: List[Dict[str, str]]) -> str: + """ + Extract the query from messages for classification. + + Args: + messages: List of message dictionaries + + Returns: + Extracted query string + """ + # Get the last user message + user_messages = [m["content"] for m in messages if m["role"] == "user"] + + if user_messages: + return user_messages[-1] + + # Fallback to concatenated messages + return " ".join(m["content"] for m in messages) diff --git a/optillm/autothink/steering.py b/optillm/autothink/steering.py new file mode 100644 index 00000000..56c68297 --- /dev/null +++ b/optillm/autothink/steering.py @@ -0,0 +1,603 @@ +""" +Steering vector manager for AutoThink. + +This module provides functionality to load and apply steering vectors +from Hugging Face datasets during inference. +""" + +import torch +import logging +import random +import json +import datasets +from typing import Dict, List, Any, Tuple, Optional, Union +from collections import defaultdict + +logger = logging.getLogger(__name__) + +class SteeringVectorManager: + """ + Manager for loading and applying steering vectors from a dataset. + """ + + def __init__( + self, + dataset_name: str, + target_layer: int = 19, + cache_dir: Optional[str] = None, + device: Optional[str] = None + ): + """ + Initialize the steering vector manager. + + Args: + dataset_name: Name of the HuggingFace dataset containing steering vectors + target_layer: Target layer for applying steering vectors + cache_dir: Directory for caching the dataset + device: Device to use for tensors + """ + self.dataset_name = dataset_name + self.target_layer = target_layer + self.cache_dir = cache_dir + self.device = device or ( + "cuda" if torch.cuda.is_available() else + "mps" if torch.backends.mps.is_available() else + "cpu" + ) + + # Storage for steering vectors + self.steering_vectors = [] + self.pattern_to_vectors = {} + self.tokenized_contexts = {} + + # Default steering strengths + self.default_strength = 2.0 + self.pattern_strengths = { + "depth_and_thoroughness": 2.5, + "numerical_accuracy": 2.0, + "self_correction": 3.0, + "exploration": 2.0, + "organization": 1.5, + "unknown": 1.0 + } + + # If dataset is provided, load it + if dataset_name: + self.load_dataset() + + def load_dataset(self): + """Load steering vectors from the HuggingFace dataset.""" + try: + logger.info(f"Loading steering vectors from dataset: {self.dataset_name}") + + # Load the dataset + dataset = datasets.load_dataset(self.dataset_name, cache_dir=self.cache_dir) + + # Get the main split (usually 'train') + main_split = list(dataset.keys())[0] + vector_data = dataset[main_split] + + # Load each item as a steering vector + for item in vector_data: + # Convert dataset item to proper format + vector = self._process_dataset_item(item) + if vector: + self.steering_vectors.append(vector) + + # Group by reasoning pattern + pattern = vector.get("reasoning_pattern", "unknown") + if pattern not in self.pattern_to_vectors: + self.pattern_to_vectors[pattern] = [] + self.pattern_to_vectors[pattern].append(vector) + + logger.info(f"Loaded {len(self.steering_vectors)} steering vectors") + logger.info(f"Found {len(self.pattern_to_vectors)} reasoning patterns: {list(self.pattern_to_vectors.keys())}") + + # Log the first vector for debugging + if self.steering_vectors: + first_vector = self.steering_vectors[0] + logger.info(f"First vector sample - pattern: {first_vector.get('reasoning_pattern', 'missing')}") + if 'pivot_context' in first_vector: + context_len = len(first_vector['pivot_context']) + logger.info(f"First vector pivot_context length: {context_len}") + + except Exception as e: + logger.error(f"Error loading steering vectors: {e}") + self.steering_vectors = [] + self.pattern_to_vectors = {} + + def _process_dataset_item(self, item: Dict[str, Any]) -> Dict[str, Any]: + """ + Process a dataset item into a steering vector. + + Args: + item: Dataset item + + Returns: + Processed steering vector or None if invalid + """ + try: + # Check if item has the required fields + required_fields = ["pivot_context", "steering_vector", "reasoning_pattern"] + if not all(field in item for field in required_fields): + return None + + # Convert steering_vector to a proper format if it's a string or list + steering_vector = item["steering_vector"] + if isinstance(steering_vector, str): + # Try to parse JSON string + try: + steering_vector = json.loads(steering_vector) + except json.JSONDecodeError: + # Try comma-separated format + steering_vector = [float(x) for x in steering_vector.strip("[]").split(",")] + + # Ensure we have a proper list + if not isinstance(steering_vector, list): + logger.warning(f"Invalid steering vector format: {type(steering_vector)}") + return None + + # Create the steering vector dictionary + vector = { + "pivot_context": item["pivot_context"], + "pivot_token": item.get("pivot_token", ""), + "pivot_token_id": item.get("pivot_token_id", -1), + "prob_before": item.get("prob_before", 0.0), + "prob_after": item.get("prob_after", 0.0), + "prob_delta": item.get("prob_delta", 0.0), + "model_id": item.get("model_id", ""), + "task_type": item.get("task_type", "unknown"), + "steering_vector": steering_vector, + "cluster_id": item.get("cluster_id", -1), + "reasoning_pattern": item.get("reasoning_pattern", "unknown"), + "cluster_vector": item.get("cluster_vector", steering_vector), + "steering_layer": item.get("steering_layer", self.target_layer), + } + + return vector + + except Exception as e: + logger.error(f"Error processing dataset item: {e}") + return None + + def create_tokenized_contexts(self, tokenizer): + """ + Pre-tokenize context patterns for efficient matching during generation. + + Args: + tokenizer: Tokenizer for encoding contexts + """ + # Get configurations + max_pts_tokens = 256 # Maximum tokens to store for matching + + count = 0 + for vector in self.steering_vectors: + # Get the context + context = vector.get("pivot_context", "") + if not context: + continue + + # Pre-tokenize the context for faster matching + tokenized_context = tokenizer.encode(context, add_special_tokens=False) + + # Keep only up to max_pts_tokens + if len(tokenized_context) > max_pts_tokens: + tokenized_context = tokenized_context[-max_pts_tokens:] + + # Store the tokenized context with its vector + tuple_key = tuple(tokenized_context) + self.tokenized_contexts[tuple_key] = vector + + # Store additional shorter versions for partial matching + for suffix_len in [4, 8, 12]: + if len(tokenized_context) > suffix_len: + suffix = tokenized_context[-suffix_len:] + suffix_tuple = tuple(suffix) + if suffix_tuple not in self.tokenized_contexts: + self.tokenized_contexts[suffix_tuple] = vector + + count += 1 + + # Log statistics + logger.info(f"Pre-tokenized {count} contexts into {len(self.tokenized_contexts)} token patterns") + + # Count patterns by length for debugging + length_counts = {} + for key in self.tokenized_contexts.keys(): + length = len(key) + if length not in length_counts: + length_counts[length] = 0 + length_counts[length] += 1 + + logger.info(f"Token pattern length distribution: {sorted(length_counts.items())}") + + def get_steering_strength(self, pattern: str) -> float: + """ + Get the steering strength for a specific pattern. + + Args: + pattern: The reasoning pattern + + Returns: + The steering strength + """ + return self.pattern_strengths.get(pattern, self.default_strength) + + def set_steering_strength(self, pattern: str, strength: float): + """ + Set the steering strength for a specific pattern. + + Args: + pattern: The reasoning pattern + strength: The steering strength + """ + self.pattern_strengths[pattern] = strength + logger.info(f"Set strength for {pattern} to {strength}") + + def get_pattern_vectors(self, pattern: str) -> List[Dict[str, Any]]: + """ + Get all steering vectors for a specific reasoning pattern. + + Args: + pattern: The reasoning pattern + + Returns: + List of steering vectors + """ + return self.pattern_to_vectors.get(pattern, []) + +class SteeringHook: + """Hook for applying steering vectors during generation.""" + + def __init__(self, manager: SteeringVectorManager, layer_num: int, tokenizer=None): + """ + Initialize the steering hook. + + Args: + manager: The steering vector manager + layer_num: The layer number to apply steering to + tokenizer: Tokenizer for token-based matching + """ + self.manager = manager + self.layer_num = layer_num + self.tokenizer = tokenizer + + # For token-based matching + self.token_history = [] # Store token IDs for matching + self.max_history = 256 # Maximum tokens to keep in history + + # State tracking + self.match_found = False + self.current_vector = None + self.last_pattern = None + + # Single pattern for entire request + self.active_pattern = None # Currently active pattern + self.generation_started = False + + logger.info(f"Initialized hook for layer {layer_num}") + + def __call__(self, module, input_tensors, output): + """ + Apply steering to the output of a layer. + + Args: + module: The module being hooked + input_tensors: The input tensors + output: The output tensor + + Returns: + Modified output tensor + """ + try: + # Skip if no active pattern is set + if not self.active_pattern: + return output + + # Apply steering vector if available + if self.current_vector is not None: + # Get the appropriate steering strength + pattern = self.current_vector.get("reasoning_pattern", "unknown") + strength = self.manager.get_steering_strength(pattern) + + # Keep strength within safe bounds + safe_strength = min(max(strength, 0.1), 2.0) + + # Log when pattern changes + if pattern != self.last_pattern: + logger.info(f"Switching to {pattern} reasoning pattern with strength {safe_strength}") + self.last_pattern = pattern + + # Apply the steering vector + try: + if isinstance(output, tuple): + # Some models return a tuple + hidden_states = output[0] + modified_hidden_states = self._apply_steering_vector(hidden_states, self.current_vector, safe_strength) + + # Validate the result + if modified_hidden_states.shape == hidden_states.shape: + return (modified_hidden_states,) + output[1:] + else: + logger.error(f"Modified hidden states have wrong shape. Expected {hidden_states.shape}, got {modified_hidden_states.shape}") + return output + else: + # Direct tensor output + return self._apply_steering_vector(output, self.current_vector, safe_strength) + + except Exception as e: + logger.error(f"Error applying steering: {e}") + return output + + return output + except Exception as e: + logger.error(f"Critical error in hook: {e}") + return output + + def _apply_steering_vector(self, hidden_states: torch.Tensor, + steering_vector: Dict[str, Any], + scaling_factor: float = 2.0) -> torch.Tensor: + """ + Apply a steering vector to hidden states. + + Args: + hidden_states: The hidden states tensor + steering_vector: Dictionary with steering vector data + scaling_factor: Factor to scale the steering vector by + + Returns: + Modified hidden states tensor + """ + try: + # Make a deep clone + hidden_states_clone = hidden_states.clone().detach() + + # Check what kind of vector we're using + vector_data = None + if "steering_vector" in steering_vector: + vector_data = steering_vector["steering_vector"] + vector_type = "steering_vector" + elif "cluster_vector" in steering_vector: + vector_data = steering_vector["cluster_vector"] + vector_type = "cluster_vector" + else: + logger.warning("No valid vector found in steering data") + return hidden_states + + # Convert vector to tensor + vector = torch.tensor(vector_data, + dtype=hidden_states.dtype, + device=hidden_states.device) + + # Log vector info + pattern = steering_vector.get("reasoning_pattern", "unknown") + logger.debug(f"Applying {vector_type} for pattern '{pattern}' with scaling {scaling_factor}") + + # Apply scaling based on prob_delta if available + if "prob_delta" in steering_vector: + prob_delta = abs(steering_vector["prob_delta"]) + prob_delta_capped = min(max(prob_delta, 0.1), 2.0) + scaling_factor *= prob_delta_capped + + # Check if the token is positive or negative + is_positive = steering_vector.get("is_positive", True) + + # Verify shapes are compatible + hs_shape = hidden_states.shape + vector_shape = vector.shape + + if len(vector_shape) != 1 or vector_shape[0] != hs_shape[-1]: + logger.error(f"Shape mismatch - hidden_states: {hs_shape}, vector: {vector_shape}") + return hidden_states + + # Bound scaling factor for safety + safe_scaling = min(max(scaling_factor, 0.0), 3.0) + + # Apply steering + if len(hs_shape) >= 3 and hs_shape[0] > 0 and hs_shape[1] > 0: + # Apply to the last token's representation + if is_positive: + # Normalize vector to prevent numerical instability + vector_norm = torch.nn.functional.normalize(vector, dim=0) + hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] + safe_scaling * vector_norm + else: + vector_norm = torch.nn.functional.normalize(vector, dim=0) + hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] - safe_scaling * vector_norm + + # Check for NaN or inf values + if torch.isnan(hidden_states_clone).any() or torch.isinf(hidden_states_clone).any(): + logger.error("NaN or inf values detected after applying vector, reverting to original") + return hidden_states + else: + logger.error(f"Hidden states shape not suitable for steering: {hs_shape}") + return hidden_states + + return hidden_states_clone + except Exception as e: + logger.error(f"Unexpected error applying steering vector: {e}") + return hidden_states + + def update_token_history(self, new_tokens: List[int]): + """ + Update the token history with new tokens. + + Args: + new_tokens: New token IDs to add + """ + # Add to token history + self.token_history.extend(new_tokens) + + # Trim history if needed + if len(self.token_history) > self.max_history: + self.token_history = self.token_history[-self.max_history:] + + # Log token updates periodically + if random.random() < 0.01: + logger.debug(f"Token history updated, now has {len(self.token_history)} tokens") + + def try_match(self) -> bool: + """ + Try to match the current context with a steering vector. + + Returns: + Boolean indicating if a match was found + """ + # If we already have an active pattern, don't try to match again + if self.generation_started and self.active_pattern: + return False + + # Only attempt pattern matching at the beginning of generation + self.generation_started = True + + # Try token-based matching + match_result = self._try_token_match() + + # If a match is found, set this as the permanent pattern for this generation + if match_result and self.current_vector: + new_pattern = self.current_vector.get("reasoning_pattern", "unknown") + self.active_pattern = new_pattern + logger.info(f"Selected '{new_pattern}' pattern for this request") + + return match_result + + def _try_token_match(self) -> bool: + """ + Try to match using token-based context. + + Returns: + Boolean indicating if a match was found + """ + # Ensure we have enough tokens + if len(self.token_history) < 4: + return False + + # Track best match + best_match = { + 'length': 0, + 'vector': None, + 'is_partial': True + } + + # Check for matches in tokenized contexts + for tokenized_context, vector in self.manager.tokenized_contexts.items(): + token_list = list(tokenized_context) + token_len = len(token_list) + + # Try partial matching for shorter contexts + if len(self.token_history) < token_len: + # Only try partial matching if we have enough context tokens + if len(self.token_history) >= 4: + # Calculate how many tokens to match + match_len = min(len(self.token_history), max(4, token_len // 2)) + # Try to match the end of the token sequence + if self.token_history[-match_len:] == token_list[-match_len:]: + # Track this match - prefer longer matches + if match_len > best_match['length']: + best_match = { + 'length': match_len, + 'vector': vector, + 'is_partial': True, + 'match_len': match_len, + 'token_len': token_len + } + else: + # Full matching when we have enough tokens + if self.token_history[-token_len:] == token_list: + # Track this match - full matches are preferred + if token_len >= best_match['length']: + best_match = { + 'length': token_len, + 'vector': vector, + 'is_partial': False, + 'match_len': token_len, + 'token_len': token_len + } + + # Apply best match if found + if best_match['vector'] is not None: + match_type = "PARTIAL" if best_match['is_partial'] else "FULL" + self.match_found = True + self.current_vector = best_match['vector'] + pattern = best_match['vector'].get("reasoning_pattern", "unknown") + logger.info(f"Found {match_type} token match ({best_match['match_len']}/{best_match['token_len']} tokens) for {pattern} pattern") + return True + + return False + + def reset(self): + """Reset the hook state.""" + self.match_found = False + self.current_vector = None + self.token_history = [] + self.last_pattern = None + self.active_pattern = None + self.generation_started = False + +def install_steering_hooks(model, manager: SteeringVectorManager, tokenizer=None) -> List[Tuple]: + """ + Install steering hooks on a model. + + Args: + model: The model to install hooks on + manager: The steering vector manager + tokenizer: Tokenizer for token-based matching + + Returns: + List of installed hooks + """ + hooks = [] + + # Target layer is specified in the manager + layer_num = manager.target_layer + logger.info(f"Attempting to install hook on layer {layer_num}") + + # First, log model structure to help with debugging + model_type = type(model).__name__ + logger.info(f"Model type is {model_type}") + + # Find the appropriate module - depends on model architecture + module = None + if hasattr(model, 'transformer'): + logger.info("Model has 'transformer' attribute") + if hasattr(model.transformer, 'h') and layer_num < len(model.transformer.h): + module = model.transformer.h[layer_num] + logger.info(f"Using transformer.h[{layer_num}]") + elif hasattr(model, 'model'): + logger.info("Model has 'model' attribute") + if hasattr(model.model, 'layers') and layer_num < len(model.model.layers): + module = model.model.layers[layer_num] + logger.info(f"Using model.layers[{layer_num}]") + elif hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers') and layer_num < len(model.model.decoder.layers): + module = model.model.decoder.layers[layer_num] + logger.info(f"Using model.decoder.layers[{layer_num}]") + elif hasattr(model, 'layers') and layer_num < len(model.layers): + module = model.layers[layer_num] + logger.info(f"Using layers[{layer_num}]") + + if module is None: + logger.error(f"Could not find appropriate module for layer {layer_num}") + logger.error("Model structure not compatible with current hook installation logic") + return [] + + # Create and register hook + hook = SteeringHook(manager, layer_num, tokenizer) + handle = module.register_forward_hook(hook) + + # Return both hook object and handle for later removal + hooks.append((hook, handle)) + + logger.info(f"Installed hook on layer {layer_num} successfully") + + return hooks + +def remove_steering_hooks(hooks): + """ + Remove steering hooks from a model. + + Args: + hooks: List of (hook, handle) tuples + """ + for _, handle in hooks: + handle.remove() + + logger.info(f"Removed {len(hooks)} hooks") diff --git a/optillm/inference.py b/optillm/inference.py index f7f265c0..624b5a5d 100644 --- a/optillm/inference.py +++ b/optillm/inference.py @@ -20,6 +20,7 @@ from optillm.cot_decoding import cot_decode from optillm.entropy_decoding import entropy_decode from optillm.thinkdeeper import thinkdeeper_decode +from optillm.autothink import autothink_decode # Configure logging logging.basicConfig(level=logging.INFO) @@ -1467,6 +1468,34 @@ def create( responses = [result] logprobs_results = [None] completion_tokens = len(pipeline.tokenizer.encode(result)) + elif decoding == "autothink": + # Get steering dataset configuration + steering_dataset = kwargs.get("steering_dataset", "codelion/Qwen3-0.6B-pts-steering-vectors") + target_layer = kwargs.get("target_layer", 19) + + # Prepare AutoThink configuration + autothink_config = { + "steering_dataset": steering_dataset, + "target_layer": target_layer, + "pattern_strengths": kwargs.get("pattern_strengths", { + "depth_and_thoroughness": 2.5, + "numerical_accuracy": 2.0, + "self_correction": 3.0, + "exploration": 2.0, + "organization": 1.5 + }) + } + + # Process with AutoThink + result = autothink_decode( + pipeline.current_model, + pipeline.tokenizer, + messages, + autothink_config + ) + responses = [result] + logprobs_results = [None] + completion_tokens = len(pipeline.tokenizer.encode(result)) else: raise ValueError(f"Unknown specialized decoding approach: {decoding}") diff --git a/requirements.txt b/requirements.txt index 651f8d3e..1996b88d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -27,4 +27,5 @@ spacy<3.8.0 cerebras_cloud_sdk outlines[transformers] sentencepiece +adaptive-classifier mcp \ No newline at end of file diff --git a/setup.py b/setup.py index 9c521910..ea4d4102 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,7 @@ "outlines[transformers]", "sentencepiece", "mcp", + "adaptive-classifier", ], entry_points={ 'console_scripts': [ From 335b900e0f7ee6adaffdaeb5ef430e200ea32491 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Wed, 7 May 2025 10:47:33 +0800 Subject: [PATCH 02/11] Update autothink.py --- optillm/autothink/autothink.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/optillm/autothink/autothink.py b/optillm/autothink/autothink.py index bbf24968..2169558c 100644 --- a/optillm/autothink/autothink.py +++ b/optillm/autothink/autothink.py @@ -9,7 +9,7 @@ from typing import Dict, List, Any, Optional from transformers import PreTrainedModel, PreTrainedTokenizer -from .processor import AutoThinkProcessor +from .processor import AutoThinkProcessor as InternalProcessor logger = logging.getLogger(__name__) @@ -34,8 +34,11 @@ def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, confi self.tokenizer = tokenizer def __call__(self, messages: List[Dict[str, str]]) -> str: - """ - Process messages with AutoThink's controlled thinking. + """Process messages with AutoThink's controlled thinking.""" + return self.process(messages) + + def process(self, messages: List[Dict[str, str]]) -> str: + """Process messages with AutoThink's controlled thinking. Args: messages: List of message dictionaries @@ -51,7 +54,7 @@ def __call__(self, messages: List[Dict[str, str]]) -> str: def _create_processor(self): """Create the internal processor instance.""" - return AutoThinkProcessor(self.config, self.tokenizer, self.model) + return InternalProcessor(self.config, self.tokenizer, self.model) def autothink_decode( model: PreTrainedModel, @@ -80,7 +83,7 @@ def autothink_decode( try: processor = AutoThinkProcessor(model, tokenizer, config) - response = processor(messages) + response = processor.process(messages) return response except Exception as e: From c304749c37ea36669f7b59b6638a1ffd6327b767 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Wed, 7 May 2025 16:55:35 +0800 Subject: [PATCH 03/11] Update steering.py --- optillm/autothink/steering.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/optillm/autothink/steering.py b/optillm/autothink/steering.py index 56c68297..931a3a70 100644 --- a/optillm/autothink/steering.py +++ b/optillm/autothink/steering.py @@ -303,10 +303,14 @@ def __call__(self, module, input_tensors, output): # Keep strength within safe bounds safe_strength = min(max(strength, 0.1), 2.0) - # Log when pattern changes + # Log when pattern changes or is applied if pattern != self.last_pattern: - logger.info(f"Switching to {pattern} reasoning pattern with strength {safe_strength}") + logger.info(f"[STEERING ACTIVATED] Switching to {pattern} reasoning pattern with strength {safe_strength}") self.last_pattern = pattern + else: + # Log periodically that steering is still active (every ~20 tokens) + if random.random() < 0.05: + logger.info(f"[STEERING ACTIVE] Applying {pattern} pattern with strength {safe_strength}") # Apply the steering vector try: @@ -478,6 +482,11 @@ def _try_token_match(self) -> bool: 'is_partial': True } + # Log token history periodically + if random.random() < 0.01: + history_sample = self.token_history[-5:] if len(self.token_history) >= 5 else self.token_history + logger.debug(f"Token matching with history (last {len(history_sample)} of {len(self.token_history)} tokens): {history_sample}") + # Check for matches in tokenized contexts for tokenized_context, vector in self.manager.tokenized_contexts.items(): token_list = list(tokenized_context) @@ -519,7 +528,12 @@ def _try_token_match(self) -> bool: self.match_found = True self.current_vector = best_match['vector'] pattern = best_match['vector'].get("reasoning_pattern", "unknown") - logger.info(f"Found {match_type} token match ({best_match['match_len']}/{best_match['token_len']} tokens) for {pattern} pattern") + pivot_token = best_match['vector'].get("pivot_token", "") + + logger.info(f"[STEERING MATCH FOUND] {match_type} token match for '{pattern}' pattern") + logger.info(f"[STEERING DETAILS] Match quality: {best_match['match_len']}/{best_match['token_len']} tokens") + logger.info(f"[STEERING DETAILS] Pivot token: '{pivot_token}'") + return True return False From 24fa4434b7f73447006b28236fde50f985f6de21 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Wed, 7 May 2025 18:03:55 +0800 Subject: [PATCH 04/11] test steering --- optillm/autothink/processor.py | 15 +- optillm/autothink/steering.py | 371 ++++++++++++++++++++++++--------- 2 files changed, 288 insertions(+), 98 deletions(-) diff --git a/optillm/autothink/processor.py b/optillm/autothink/processor.py index 39a4d67c..277647c8 100644 --- a/optillm/autothink/processor.py +++ b/optillm/autothink/processor.py @@ -135,10 +135,10 @@ def _setup_steering(self): self.tokenizer ) - logger.info(f"Set up steering with {len(self.steering_hooks)} hooks") + logger.info(f"STEERING: Set up steering with {len(self.steering_hooks)} hooks") except Exception as e: - logger.error(f"Error setting up steering: {e}") + logger.error(f"STEERING: Error setting up steering: {e}") self.steering_manager = None self.steering_hooks = [] @@ -147,6 +147,7 @@ def _cleanup_steering(self): if self.steering_hooks: remove_steering_hooks(self.steering_hooks) self.steering_hooks = [] + logger.info("STEERING: Hooks removed successfully") def classify_complexity(self, query: str) -> Tuple[str, float]: """ @@ -338,6 +339,16 @@ def process(self, messages: List[Dict[str, str]]) -> str: response_chunks.append(next_str) if not seen_end_think: n_thinking_tokens += 1 + + # Update steering hooks with new token + if self.steering_hooks: + for hook, _ in self.steering_hooks: + # Update token history with the new token + hook.update_token_history([next_token]) + # Check for matches occasionally during generation + if random.random() < 0.1: # 10% chance per token + hook.try_match() + tokens = torch.tensor([[next_token]]).to(tokens.device) # Clean up steering hooks diff --git a/optillm/autothink/steering.py b/optillm/autothink/steering.py index 931a3a70..dc4d81fb 100644 --- a/optillm/autothink/steering.py +++ b/optillm/autothink/steering.py @@ -163,11 +163,12 @@ def _process_dataset_item(self, item: Dict[str, Any]) -> Dict[str, Any]: def create_tokenized_contexts(self, tokenizer): """ Pre-tokenize context patterns for efficient matching during generation. + Similar to how guided mode does token-based matching. Args: tokenizer: Tokenizer for encoding contexts """ - # Get configurations + # Get configurations - use similar defaults as ThinkDeeperProcessor max_pts_tokens = 256 # Maximum tokens to store for matching count = 0 @@ -177,29 +178,32 @@ def create_tokenized_contexts(self, tokenizer): if not context: continue - # Pre-tokenize the context for faster matching + # Pre-tokenize the context for faster matching during generation tokenized_context = tokenizer.encode(context, add_special_tokens=False) - # Keep only up to max_pts_tokens + # Keep only up to max_pts_tokens - no point storing more than our history capacity if len(tokenized_context) > max_pts_tokens: + # Get only the last max_pts_tokens (most important for matching) tokenized_context = tokenized_context[-max_pts_tokens:] - # Store the tokenized context with its vector + # Store the tokenized context with its corresponding vector tuple_key = tuple(tokenized_context) self.tokenized_contexts[tuple_key] = vector - # Store additional shorter versions for partial matching + # Store additional shorter versions for partial matching during early generation + # Create shorter suffixes for early matching when context is still building up for suffix_len in [4, 8, 12]: if len(tokenized_context) > suffix_len: suffix = tokenized_context[-suffix_len:] suffix_tuple = tuple(suffix) + # Only store if not already present (avoid overwriting longer matches) if suffix_tuple not in self.tokenized_contexts: self.tokenized_contexts[suffix_tuple] = vector count += 1 - # Log statistics - logger.info(f"Pre-tokenized {count} contexts into {len(self.tokenized_contexts)} token patterns") + # Log statistics about the tokenized contexts + logger.info(f"STEERING: Pre-tokenized {count} contexts into {len(self.tokenized_contexts)} token patterns") # Count patterns by length for debugging length_counts = {} @@ -209,7 +213,7 @@ def create_tokenized_contexts(self, tokenizer): length_counts[length] = 0 length_counts[length] += 1 - logger.info(f"Token pattern length distribution: {sorted(length_counts.items())}") + logger.info(f"STEERING: Token pattern length distribution: {sorted(length_counts.items())}") def get_steering_strength(self, pattern: str) -> float: """ @@ -232,7 +236,7 @@ def set_steering_strength(self, pattern: str, strength: float): strength: The steering strength """ self.pattern_strengths[pattern] = strength - logger.info(f"Set strength for {pattern} to {strength}") + logger.info(f"STEERING: Set strength for {pattern} to {strength}") def get_pattern_vectors(self, pattern: str) -> List[Dict[str, Any]]: """ @@ -245,6 +249,38 @@ def get_pattern_vectors(self, pattern: str) -> List[Dict[str, Any]]: List of steering vectors """ return self.pattern_to_vectors.get(pattern, []) + + def get_steering_vector(self, context: str, match_key: Optional[str] = None) -> Optional[Dict[str, Any]]: + """ + Get the most appropriate steering vector for a context. + + Args: + context: The current generation context. + match_key: Optional key for matching. + + Returns: + Dictionary with steering data or None if no match. + """ + if match_key is not None: + # Try exact matching by key + for vector in self.steering_vectors: + # Get the last 100 chars of the pivot_context for comparison + vector_context = vector.get("pivot_context", "") + vector_key = vector_context[-100:] if len(vector_context) >= 100 else vector_context + + # Perform exact match comparison and log for debugging + if vector_key == match_key: + logger.debug(f"STEERING: Context match found for '{vector.get('pivot_token', '')}' with pattern {vector.get('reasoning_pattern', 'unknown')}") + return vector + + # For first 5 attempts, log debugging info when match fails + if random.random() < 0.001: # Log a small random sample for debugging + logger.debug(f"STEERING: Match failed - key length: {len(match_key)}, vector key length: {len(vector_key)}") + logger.debug(f"STEERING: Match key sample: '{match_key[:20]}...'") + logger.debug(f"STEERING: Vector key sample: '{vector_key[:20]}...'") + + # If no match found, return None + return None class SteeringHook: """Hook for applying steering vectors during generation.""" @@ -262,7 +298,10 @@ def __init__(self, manager: SteeringVectorManager, layer_num: int, tokenizer=Non self.layer_num = layer_num self.tokenizer = tokenizer - # For token-based matching + # For text-based matching (original approach) + self.context_buffer = "" + + # For token-based matching (guided-style approach) self.token_history = [] # Store token IDs for matching self.max_history = 256 # Maximum tokens to keep in history @@ -275,7 +314,7 @@ def __init__(self, manager: SteeringVectorManager, layer_num: int, tokenizer=Non self.active_pattern = None # Currently active pattern self.generation_started = False - logger.info(f"Initialized hook for layer {layer_num}") + logger.info(f"STEERING: Initialized hook for layer {layer_num}") def __call__(self, module, input_tensors, output): """ @@ -289,53 +328,66 @@ def __call__(self, module, input_tensors, output): Returns: Modified output tensor """ + # Use a try-except block around the entire function to prevent crashing try: # Skip if no active pattern is set if not self.active_pattern: return output - # Apply steering vector if available + # Apply steering vector (only if we have an active pattern) if self.current_vector is not None: # Get the appropriate steering strength pattern = self.current_vector.get("reasoning_pattern", "unknown") strength = self.manager.get_steering_strength(pattern) - # Keep strength within safe bounds - safe_strength = min(max(strength, 0.1), 2.0) + # Keep strength within safe bounds - use lower values for better stability + safe_strength = min(max(strength, 0.1), 2.0) # Limit between 0.1 and 2.0 - # Log when pattern changes or is applied + # Log when pattern changes if pattern != self.last_pattern: - logger.info(f"[STEERING ACTIVATED] Switching to {pattern} reasoning pattern with strength {safe_strength}") + logger.info(f"STEERING: Switching to {pattern} reasoning pattern with strength {safe_strength}") self.last_pattern = pattern else: - # Log periodically that steering is still active (every ~20 tokens) + # Log periodically that steering is still active if random.random() < 0.05: - logger.info(f"[STEERING ACTIVE] Applying {pattern} pattern with strength {safe_strength}") + logger.info(f"STEERING: Still applying {pattern} pattern with strength {safe_strength}") - # Apply the steering vector + # Apply the steering vector using our safer function try: if isinstance(output, tuple): - # Some models return a tuple + # Some models return a tuple where the first element is the hidden states hidden_states = output[0] - modified_hidden_states = self._apply_steering_vector(hidden_states, self.current_vector, safe_strength) - # Validate the result - if modified_hidden_states.shape == hidden_states.shape: - return (modified_hidden_states,) + output[1:] - else: - logger.error(f"Modified hidden states have wrong shape. Expected {hidden_states.shape}, got {modified_hidden_states.shape}") + # Apply steering - if it fails, return original + try: + # Create a new reference for the modified hidden states + modified_hidden_states = self._apply_steering_vector(hidden_states, self.current_vector, safe_strength) + # Validate the result has the right shape + if modified_hidden_states.shape == hidden_states.shape: + # Create a new tuple with the modified hidden states + return (modified_hidden_states,) + output[1:] + else: + logger.error(f"STEERING: Modified hidden states have wrong shape. Expected {hidden_states.shape}, got {modified_hidden_states.shape}") + return output + except Exception as e: + logger.error(f"STEERING: Error applying steering to tuple output: {e}") return output else: # Direct tensor output - return self._apply_steering_vector(output, self.current_vector, safe_strength) + try: + # Apply steering directly + return self._apply_steering_vector(output, self.current_vector, safe_strength) + except Exception as e: + logger.error(f"STEERING: Error applying steering to direct output: {e}") + return output except Exception as e: - logger.error(f"Error applying steering: {e}") + logger.error(f"STEERING: Unexpected error in steering application: {e}") return output return output except Exception as e: - logger.error(f"Critical error in hook: {e}") + logger.error(f"STEERING: Critical error in hook: {e}") return output def _apply_steering_vector(self, hidden_states: torch.Tensor, @@ -353,11 +405,11 @@ def _apply_steering_vector(self, hidden_states: torch.Tensor, Modified hidden states tensor """ try: - # Make a deep clone + # Make a DEEP clone to avoid in-place modification issues hidden_states_clone = hidden_states.clone().detach() # Check what kind of vector we're using - vector_data = None + vector_type = None if "steering_vector" in steering_vector: vector_data = steering_vector["steering_vector"] vector_type = "steering_vector" @@ -365,62 +417,111 @@ def _apply_steering_vector(self, hidden_states: torch.Tensor, vector_data = steering_vector["cluster_vector"] vector_type = "cluster_vector" else: - logger.warning("No valid vector found in steering data") + logger.warning("STEERING: No valid vector found in steering data") + return hidden_states # No steering vector found + + # Safely convert vector to tensor + try: + vector = torch.tensor(vector_data, + dtype=hidden_states.dtype, + device=hidden_states.device) + except Exception as e: + logger.error(f"STEERING: Error converting vector to tensor: {e}") return hidden_states - # Convert vector to tensor - vector = torch.tensor(vector_data, - dtype=hidden_states.dtype, - device=hidden_states.device) - # Log vector info pattern = steering_vector.get("reasoning_pattern", "unknown") - logger.debug(f"Applying {vector_type} for pattern '{pattern}' with scaling {scaling_factor}") + logger.debug(f"STEERING: Applying {vector_type} for pattern '{pattern}' with base scaling {scaling_factor}") # Apply scaling based on prob_delta if available if "prob_delta" in steering_vector: prob_delta = abs(steering_vector["prob_delta"]) + # Limit the impact of prob_delta to prevent extreme scaling prob_delta_capped = min(max(prob_delta, 0.1), 2.0) scaling_factor *= prob_delta_capped + logger.debug(f"STEERING: Adjusted scaling by prob_delta {prob_delta_capped} to {scaling_factor}") # Check if the token is positive or negative is_positive = steering_vector.get("is_positive", True) - # Verify shapes are compatible + # Log tensor shapes and verify compatibility hs_shape = hidden_states.shape vector_shape = vector.shape + logger.debug(f"STEERING: hidden_states shape: {hs_shape}, vector shape: {vector_shape}") + # Verify shapes are compatible if len(vector_shape) != 1 or vector_shape[0] != hs_shape[-1]: - logger.error(f"Shape mismatch - hidden_states: {hs_shape}, vector: {vector_shape}") + logger.error(f"STEERING: Shape mismatch - hidden_states: {hs_shape}, vector: {vector_shape}") return hidden_states - # Bound scaling factor for safety - safe_scaling = min(max(scaling_factor, 0.0), 3.0) - - # Apply steering - if len(hs_shape) >= 3 and hs_shape[0] > 0 and hs_shape[1] > 0: - # Apply to the last token's representation - if is_positive: - # Normalize vector to prevent numerical instability - vector_norm = torch.nn.functional.normalize(vector, dim=0) - hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] + safe_scaling * vector_norm + # Bound scaling factor for safety - using a tighter range to prevent instability + safe_scaling = min(max(scaling_factor, 0.0), 3.0) # Limit between 0 and 3 + + # Apply steering with safe indexing - with additional safeguards + try: + if len(hs_shape) >= 3 and hs_shape[0] > 0 and hs_shape[1] > 0: + # Apply to the last token's representation (safe indexing) + if is_positive: + # For positive tokens, add the vector + # Normalize vector first to prevent numerical instability + vector_norm = torch.nn.functional.normalize(vector, dim=0) + hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] + safe_scaling * vector_norm + else: + # For negative tokens, subtract the vector + vector_norm = torch.nn.functional.normalize(vector, dim=0) + hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] - safe_scaling * vector_norm + + # Check for NaN or inf values after modification + if torch.isnan(hidden_states_clone).any() or torch.isinf(hidden_states_clone).any(): + logger.error("STEERING: NaN or inf values detected after applying vector, reverting to original") + return hidden_states else: - vector_norm = torch.nn.functional.normalize(vector, dim=0) - hidden_states_clone[-1, -1, :] = hidden_states_clone[-1, -1, :] - safe_scaling * vector_norm - - # Check for NaN or inf values - if torch.isnan(hidden_states_clone).any() or torch.isinf(hidden_states_clone).any(): - logger.error("NaN or inf values detected after applying vector, reverting to original") + logger.error(f"STEERING: Hidden states shape not suitable for steering: {hs_shape}") return hidden_states - else: - logger.error(f"Hidden states shape not suitable for steering: {hs_shape}") + except IndexError as e: + logger.error(f"STEERING: IndexError when applying vector: {e}") + logger.error(f"STEERING: Indices: [-1, -1, :], tensor shape: {hidden_states.shape}") return hidden_states return hidden_states_clone except Exception as e: - logger.error(f"Unexpected error applying steering vector: {e}") + logger.error(f"STEERING: Unexpected error applying steering vector: {e}") return hidden_states + def update_context(self, new_tokens: str): + """ + Update the context buffer with new tokens. + + Args: + new_tokens: New tokens to add to the context. + """ + # Both methods - text-based and token-based + if self.tokenizer is not None: + # Token-based approach (similar to guided mode) + # Tokenize the new text + token_ids = self.tokenizer.encode(new_tokens, add_special_tokens=False) + + if token_ids: # Only proceed if we got tokens + # Add to token history + self.token_history.extend(token_ids) + + # Trim history if needed + if len(self.token_history) > self.max_history: + self.token_history = self.token_history[-self.max_history:] + + # Log token updates periodically + if random.random() < 0.01: + logger.debug(f"STEERING: Token history updated, now has {len(self.token_history)} tokens") + else: + # Original text-based approach as fallback + # Update context buffer + self.context_buffer += new_tokens + + # Keep only the last 500 characters + if len(self.context_buffer) > 500: + self.context_buffer = self.context_buffer[-500:] + logger.debug(f"STEERING: Context buffer trimmed to {len(self.context_buffer)} chars") + def update_token_history(self, new_tokens: List[int]): """ Update the token history with new tokens. @@ -437,44 +538,45 @@ def update_token_history(self, new_tokens: List[int]): # Log token updates periodically if random.random() < 0.01: - logger.debug(f"Token history updated, now has {len(self.token_history)} tokens") + logger.debug(f"STEERING: Token history updated, now has {len(self.token_history)} tokens") - def try_match(self) -> bool: + def try_match(self): """ Try to match the current context with a steering vector. - - Returns: - Boolean indicating if a match was found + Only allows one pattern to be selected for the entire generation. """ - # If we already have an active pattern, don't try to match again + # If we already have an active pattern for this generation, don't try to match again if self.generation_started and self.active_pattern: return False # Only attempt pattern matching at the beginning of generation self.generation_started = True - - # Try token-based matching - match_result = self._try_token_match() - + + # Use token-based matching or text-based matching as appropriate + if self.tokenizer is not None and hasattr(self.manager, 'tokenized_contexts') and self.manager.tokenized_contexts: + # Token-based matching (similar to guided mode) + match_result = self._try_token_match() + else: + # Text-based matching as fallback + match_result = self._try_text_match() + # If a match is found, set this as the permanent pattern for this generation if match_result and self.current_vector: new_pattern = self.current_vector.get("reasoning_pattern", "unknown") self.active_pattern = new_pattern - logger.info(f"Selected '{new_pattern}' pattern for this request") - + logger.info(f"STEERING: Selected '{new_pattern}' pattern for this request") + return match_result - def _try_token_match(self) -> bool: + def _try_token_match(self): """ - Try to match using token-based context. - - Returns: - Boolean indicating if a match was found + Try to match using token-based context (similar to guided mode). """ # Ensure we have enough tokens if len(self.token_history) < 4: + logger.debug(f"STEERING: Not enough tokens to match ({len(self.token_history)})") return False - + # Track best match best_match = { 'length': 0, @@ -485,7 +587,7 @@ def _try_token_match(self) -> bool: # Log token history periodically if random.random() < 0.01: history_sample = self.token_history[-5:] if len(self.token_history) >= 5 else self.token_history - logger.debug(f"Token matching with history (last {len(history_sample)} of {len(self.token_history)} tokens): {history_sample}") + logger.debug(f"STEERING: Token matching with history (last {len(history_sample)} of {len(self.token_history)} tokens): {history_sample}") # Check for matches in tokenized contexts for tokenized_context, vector in self.manager.tokenized_contexts.items(): @@ -494,11 +596,11 @@ def _try_token_match(self) -> bool: # Try partial matching for shorter contexts if len(self.token_history) < token_len: - # Only try partial matching if we have enough context tokens + # Only try partial matching if we have enough context tokens (at least 4) if len(self.token_history) >= 4: - # Calculate how many tokens to match + # Calculate how many tokens to match - minimum of context length or 1/2 of token sequence match_len = min(len(self.token_history), max(4, token_len // 2)) - # Try to match the end of the token sequence + # Try to match the end of the token sequence with the context tokens if self.token_history[-match_len:] == token_list[-match_len:]: # Track this match - prefer longer matches if match_len > best_match['length']: @@ -530,11 +632,83 @@ def _try_token_match(self) -> bool: pattern = best_match['vector'].get("reasoning_pattern", "unknown") pivot_token = best_match['vector'].get("pivot_token", "") - logger.info(f"[STEERING MATCH FOUND] {match_type} token match for '{pattern}' pattern") - logger.info(f"[STEERING DETAILS] Match quality: {best_match['match_len']}/{best_match['token_len']} tokens") - logger.info(f"[STEERING DETAILS] Pivot token: '{pivot_token}'") + logger.info(f"STEERING: Found {match_type} token match ({best_match['match_len']}/{best_match['token_len']} tokens) for {pattern} pattern") + logger.info(f"STEERING: Pivot token: '{pivot_token}'") return True + + # If no match, try fuzzy matching with 70% similarity threshold + if len(self.token_history) >= 8 and not self.match_found: + logger.debug("STEERING: No exact match found, trying fuzzy matching") + for tokenized_context, vector in self.manager.tokenized_contexts.items(): + token_list = list(tokenized_context) + token_len = len(token_list) + + if token_len >= 8: # Only try fuzzy matching for contexts with enough tokens + match_len = min(len(self.token_history), token_len) + last_tokens = self.token_history[-match_len:] + context_tokens = token_list[-match_len:] + + # Count matching tokens + matches = sum(1 for a, b in zip(last_tokens, context_tokens) if a == b) + similarity = matches / match_len + + if similarity >= 0.7: # 70% similarity threshold + if match_len > best_match['length']: + best_match = { + 'length': match_len, + 'vector': vector, + 'is_partial': True, + 'match_len': match_len, + 'token_len': token_len, + 'similarity': similarity + } + + # Apply fuzzy match if found + if best_match['vector'] is not None: + self.match_found = True + self.current_vector = best_match['vector'] + pattern = best_match['vector'].get("reasoning_pattern", "unknown") + pivot_token = best_match['vector'].get("pivot_token", "") + similarity = best_match.get('similarity', 0.0) + + logger.info(f"STEERING: Found fuzzy match ({similarity:.2f} similarity) for {pattern} pattern") + logger.info(f"STEERING: Pivot token: '{pivot_token}'") + + return True + + # TEMPORARY: Force a match for testing purposes sometimes + if not self.match_found and len(self.manager.steering_vectors) > 0 and random.random() < 0.05: + logger.info("STEERING: Forcing a random steering vector match for testing") + # Pick a random vector + random_vector = random.choice(self.manager.steering_vectors) + self.match_found = True + self.current_vector = random_vector + pattern = random_vector.get("reasoning_pattern", "unknown") + logger.info(f"STEERING: Forced '{pattern}' pattern for testing") + return True + + return False + + def _try_text_match(self): + """Try to match using text-based context (original approach).""" + # Get the last 100 characters as the match key + match_key = self.context_buffer[-100:] if len(self.context_buffer) >= 100 else self.context_buffer + + # Log context buffer periodically to debug + if random.random() < 0.01: # Log occasionally to avoid spam + logger.debug(f"STEERING: Current context buffer (last 50 chars): '{self.context_buffer[-50:]}'") + logger.debug(f"STEERING: Matching with key (length {len(match_key)}): '{match_key[:20]}...'") + + # Try to find a matching steering vector using original matching + vector = self.manager.get_steering_vector(self.context_buffer, match_key) + + if vector is not None: + self.match_found = True + self.current_vector = vector + pattern = vector.get("reasoning_pattern", "unknown") + logger.info(f"STEERING: Found text match for {pattern} reasoning pattern: '{vector.get('pivot_token', '')}'") + return True return False @@ -542,8 +716,11 @@ def reset(self): """Reset the hook state.""" self.match_found = False self.current_vector = None + self.context_buffer = "" self.token_history = [] self.last_pattern = None + + # Reset pattern tracking self.active_pattern = None self.generation_started = False @@ -563,34 +740,36 @@ def install_steering_hooks(model, manager: SteeringVectorManager, tokenizer=None # Target layer is specified in the manager layer_num = manager.target_layer - logger.info(f"Attempting to install hook on layer {layer_num}") + logger.info(f"STEERING: Attempting to install hook on layer {layer_num}") # First, log model structure to help with debugging model_type = type(model).__name__ - logger.info(f"Model type is {model_type}") + logger.info(f"STEERING: Model type is {model_type}") + if hasattr(model, 'config'): + logger.info(f"STEERING: Model architecture is {model.config.architectures[0] if hasattr(model.config, 'architectures') else 'unknown'}") # Find the appropriate module - depends on model architecture module = None if hasattr(model, 'transformer'): - logger.info("Model has 'transformer' attribute") + logger.info("STEERING: Model has 'transformer' attribute") if hasattr(model.transformer, 'h') and layer_num < len(model.transformer.h): module = model.transformer.h[layer_num] - logger.info(f"Using transformer.h[{layer_num}]") + logger.info(f"STEERING: Using transformer.h[{layer_num}]") elif hasattr(model, 'model'): - logger.info("Model has 'model' attribute") + logger.info("STEERING: Model has 'model' attribute") if hasattr(model.model, 'layers') and layer_num < len(model.model.layers): module = model.model.layers[layer_num] - logger.info(f"Using model.layers[{layer_num}]") + logger.info(f"STEERING: Using model.layers[{layer_num}]") elif hasattr(model.model, 'decoder') and hasattr(model.model.decoder, 'layers') and layer_num < len(model.model.decoder.layers): module = model.model.decoder.layers[layer_num] - logger.info(f"Using model.decoder.layers[{layer_num}]") + logger.info(f"STEERING: Using model.decoder.layers[{layer_num}]") elif hasattr(model, 'layers') and layer_num < len(model.layers): module = model.layers[layer_num] - logger.info(f"Using layers[{layer_num}]") + logger.info(f"STEERING: Using layers[{layer_num}]") if module is None: - logger.error(f"Could not find appropriate module for layer {layer_num}") - logger.error("Model structure not compatible with current hook installation logic") + logger.error(f"STEERING: Could not find appropriate module for layer {layer_num}") + logger.error("STEERING: Model structure not compatible with current hook installation logic") return [] # Create and register hook @@ -600,7 +779,7 @@ def install_steering_hooks(model, manager: SteeringVectorManager, tokenizer=None # Return both hook object and handle for later removal hooks.append((hook, handle)) - logger.info(f"Installed hook on layer {layer_num} successfully") + logger.info(f"STEERING: Installed hook on layer {layer_num} successfully") return hooks @@ -614,4 +793,4 @@ def remove_steering_hooks(hooks): for _, handle in hooks: handle.remove() - logger.info(f"Removed {len(hooks)} hooks") + logger.info(f"STEERING: Removed {len(hooks)} hooks") From 5950c7398c3e38faeea0cc9b9c4f19a130218777 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Wed, 7 May 2025 18:09:49 +0800 Subject: [PATCH 05/11] Update steering.py remove random steering test --- optillm/autothink/steering.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/optillm/autothink/steering.py b/optillm/autothink/steering.py index dc4d81fb..3a0c4472 100644 --- a/optillm/autothink/steering.py +++ b/optillm/autothink/steering.py @@ -666,29 +666,18 @@ def _try_token_match(self): # Apply fuzzy match if found if best_match['vector'] is not None: - self.match_found = True - self.current_vector = best_match['vector'] - pattern = best_match['vector'].get("reasoning_pattern", "unknown") - pivot_token = best_match['vector'].get("pivot_token", "") - similarity = best_match.get('similarity', 0.0) - - logger.info(f"STEERING: Found fuzzy match ({similarity:.2f} similarity) for {pattern} pattern") - logger.info(f"STEERING: Pivot token: '{pivot_token}'") - - return True - - # TEMPORARY: Force a match for testing purposes sometimes - if not self.match_found and len(self.manager.steering_vectors) > 0 and random.random() < 0.05: - logger.info("STEERING: Forcing a random steering vector match for testing") - # Pick a random vector - random_vector = random.choice(self.manager.steering_vectors) self.match_found = True - self.current_vector = random_vector - pattern = random_vector.get("reasoning_pattern", "unknown") - logger.info(f"STEERING: Forced '{pattern}' pattern for testing") + self.current_vector = best_match['vector'] + pattern = best_match['vector'].get("reasoning_pattern", "unknown") + pivot_token = best_match['vector'].get("pivot_token", "") + similarity = best_match.get('similarity', 0.0) + + logger.info(f"STEERING: Found fuzzy match ({similarity:.2f} similarity) for {pattern} pattern") + logger.info(f"STEERING: Pivot token: '{pivot_token}'") + return True - return False + return False def _try_text_match(self): """Try to match using text-based context (original approach).""" From c22e40664a821754b85094bf7f34d77c2a1b345e Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Wed, 7 May 2025 18:14:12 +0800 Subject: [PATCH 06/11] Update steering.py --- optillm/autothink/steering.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/optillm/autothink/steering.py b/optillm/autothink/steering.py index 3a0c4472..e2147c21 100644 --- a/optillm/autothink/steering.py +++ b/optillm/autothink/steering.py @@ -666,18 +666,18 @@ def _try_token_match(self): # Apply fuzzy match if found if best_match['vector'] is not None: - self.match_found = True - self.current_vector = best_match['vector'] - pattern = best_match['vector'].get("reasoning_pattern", "unknown") - pivot_token = best_match['vector'].get("pivot_token", "") - similarity = best_match.get('similarity', 0.0) - - logger.info(f"STEERING: Found fuzzy match ({similarity:.2f} similarity) for {pattern} pattern") - logger.info(f"STEERING: Pivot token: '{pivot_token}'") - - return True - - return False + self.match_found = True + self.current_vector = best_match['vector'] + pattern = best_match['vector'].get("reasoning_pattern", "unknown") + pivot_token = best_match['vector'].get("pivot_token", "") + similarity = best_match.get('similarity', 0.0) + + logger.info(f"STEERING: Found fuzzy match ({similarity:.2f} similarity) for {pattern} pattern") + logger.info(f"STEERING: Pivot token: '{pivot_token}'") + + return True + + return False def _try_text_match(self): """Try to match using text-based context (original approach).""" From d51ea08ab4d986128c23b9603d9790d173fb56d8 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Wed, 7 May 2025 18:23:58 +0800 Subject: [PATCH 07/11] fixes --- optillm/autothink/processor.py | 5 ++--- optillm/autothink/steering.py | 11 ++++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/optillm/autothink/processor.py b/optillm/autothink/processor.py index 277647c8..5ab71e2c 100644 --- a/optillm/autothink/processor.py +++ b/optillm/autothink/processor.py @@ -345,9 +345,8 @@ def process(self, messages: List[Dict[str, str]]) -> str: for hook, _ in self.steering_hooks: # Update token history with the new token hook.update_token_history([next_token]) - # Check for matches occasionally during generation - if random.random() < 0.1: # 10% chance per token - hook.try_match() + # Check for matches on EVERY token + hook.try_match() tokens = torch.tensor([[next_token]]).to(tokens.device) diff --git a/optillm/autothink/steering.py b/optillm/autothink/steering.py index e2147c21..2df60962 100644 --- a/optillm/autothink/steering.py +++ b/optillm/autothink/steering.py @@ -545,14 +545,12 @@ def try_match(self): Try to match the current context with a steering vector. Only allows one pattern to be selected for the entire generation. """ - # If we already have an active pattern for this generation, don't try to match again - if self.generation_started and self.active_pattern: + # If we already have an active pattern, don't try to match again + if self.active_pattern: return False - # Only attempt pattern matching at the beginning of generation - self.generation_started = True - # Use token-based matching or text-based matching as appropriate + match_result = False if self.tokenizer is not None and hasattr(self.manager, 'tokenized_contexts') and self.manager.tokenized_contexts: # Token-based matching (similar to guided mode) match_result = self._try_token_match() @@ -560,6 +558,9 @@ def try_match(self): # Text-based matching as fallback match_result = self._try_text_match() + # Set generation started flag AFTER trying to match + self.generation_started = True + # If a match is found, set this as the permanent pattern for this generation if match_result and self.current_vector: new_pattern = self.current_vector.get("reasoning_pattern", "unknown") From 9378c2f6e4a28ad76536d7192497a8322e446cee Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Tue, 13 May 2025 12:21:31 +0800 Subject: [PATCH 08/11] match both tokens and context text --- optillm/autothink/processor.py | 15 +++- optillm/autothink/steering.py | 137 +++++++++++++++++++++++++++++++-- 2 files changed, 142 insertions(+), 10 deletions(-) diff --git a/optillm/autothink/processor.py b/optillm/autothink/processor.py index 5ab71e2c..2fcb4998 100644 --- a/optillm/autothink/processor.py +++ b/optillm/autothink/processor.py @@ -245,11 +245,16 @@ def process(self, messages: List[Dict[str, str]]) -> str: return_tensors="pt" ).to(self.model.device) - # Update token history in steering hooks + # Reset and update token history in steering hooks if self.steering_hooks: token_ids = tokens[0].tolist() + prompt_text = self.tokenizer.decode(token_ids) for hook, _ in self.steering_hooks: + # Reset the hook state for a new generation + hook.reset() + # Update both token history and text context buffer hook.update_token_history(token_ids) + hook.update_context(prompt_text) # Try to match with a steering vector hook.try_match() @@ -343,13 +348,19 @@ def process(self, messages: List[Dict[str, str]]) -> str: # Update steering hooks with new token if self.steering_hooks: for hook, _ in self.steering_hooks: - # Update token history with the new token + # Update both token history and text context hook.update_token_history([next_token]) + hook.update_context(next_str) # Check for matches on EVERY token hook.try_match() tokens = torch.tensor([[next_token]]).to(tokens.device) + # Reset and clean up steering hooks + if self.steering_hooks: + for hook, _ in self.steering_hooks: + hook.reset() + # Clean up steering hooks self._cleanup_steering() diff --git a/optillm/autothink/steering.py b/optillm/autothink/steering.py index 2df60962..11facb8d 100644 --- a/optillm/autothink/steering.py +++ b/optillm/autothink/steering.py @@ -539,23 +539,60 @@ def update_token_history(self, new_tokens: List[int]): # Log token updates periodically if random.random() < 0.01: logger.debug(f"STEERING: Token history updated, now has {len(self.token_history)} tokens") + + def update_context(self, new_tokens: str): + """ + Update the context buffer with new tokens. + + Args: + new_tokens: New tokens to add to the context. + """ + # Both methods - text-based and token-based + if self.tokenizer is not None: + # Token-based approach (similar to guided mode) + # Tokenize the new text + token_ids = self.tokenizer.encode(new_tokens, add_special_tokens=False) + + if token_ids: # Only proceed if we got tokens + # Add to token history + self.token_history.extend(token_ids) + + # Trim history if needed + if len(self.token_history) > self.max_history: + self.token_history = self.token_history[-self.max_history:] + + # Log token updates periodically + if random.random() < 0.01: + logger.debug(f"STEERING: Token history updated, now has {len(self.token_history)} tokens") + + # Text-based approach (always update) + # Update context buffer + self.context_buffer += new_tokens + + # Keep only the last 500 characters + if len(self.context_buffer) > 500: + self.context_buffer = self.context_buffer[-500:] + logger.debug(f"STEERING: Context buffer trimmed to {len(self.context_buffer)} chars") def try_match(self): """ Try to match the current context with a steering vector. Only allows one pattern to be selected for the entire generation. + Tries both token-based and text-based matching approaches. """ # If we already have an active pattern, don't try to match again if self.active_pattern: return False - # Use token-based matching or text-based matching as appropriate + # Try both token-based and text-based matching match_result = False + + # First try token-based matching if available if self.tokenizer is not None and hasattr(self.manager, 'tokenized_contexts') and self.manager.tokenized_contexts: - # Token-based matching (similar to guided mode) match_result = self._try_token_match() - else: - # Text-based matching as fallback + + # If token matching fails, try text-based matching + if not match_result: match_result = self._try_text_match() # Set generation started flag AFTER trying to match @@ -638,6 +675,46 @@ def _try_token_match(self): return True + # If no match, try fuzzy matching with 70% similarity threshold + if len(self.token_history) >= 8 and not self.match_found: + logger.debug("STEERING: No exact match found, trying fuzzy matching") + for tokenized_context, vector in self.manager.tokenized_contexts.items(): + token_list = list(tokenized_context) + token_len = len(token_list) + + if token_len >= 8: # Only try fuzzy matching for contexts with enough tokens + match_len = min(len(self.token_history), token_len) + last_tokens = self.token_history[-match_len:] + context_tokens = token_list[-match_len:] + + # Count matching tokens + matches = sum(1 for a, b in zip(last_tokens, context_tokens) if a == b) + similarity = matches / match_len + + if similarity >= 0.7: # 70% similarity threshold + if match_len > best_match['length']: + best_match = { + 'length': match_len, + 'vector': vector, + 'is_partial': True, + 'match_len': match_len, + 'token_len': token_len, + 'similarity': similarity + } + + # Apply fuzzy match if found + if best_match['vector'] is not None: + self.match_found = True + self.current_vector = best_match['vector'] + pattern = best_match['vector'].get("reasoning_pattern", "unknown") + pivot_token = best_match['vector'].get("pivot_token", "") + similarity = best_match.get('similarity', 0.0) + + logger.info(f"STEERING: Found fuzzy match ({similarity:.2f} similarity) for {pattern} pattern") + logger.info(f"STEERING: Pivot token: '{pivot_token}'") + + return True + # If no match, try fuzzy matching with 70% similarity threshold if len(self.token_history) >= 8 and not self.match_found: logger.debug("STEERING: No exact match found, trying fuzzy matching") @@ -682,6 +759,10 @@ def _try_token_match(self): def _try_text_match(self): """Try to match using text-based context (original approach).""" + # Skip if context buffer is too short + if len(self.context_buffer) < 10: # Require at least 10 chars for matching + return False + # Get the last 100 characters as the match key match_key = self.context_buffer[-100:] if len(self.context_buffer) >= 100 else self.context_buffer @@ -697,22 +778,62 @@ def _try_text_match(self): self.match_found = True self.current_vector = vector pattern = vector.get("reasoning_pattern", "unknown") - logger.info(f"STEERING: Found text match for {pattern} reasoning pattern: '{vector.get('pivot_token', '')}'") + pivot_token = vector.get("pivot_token", "") + logger.info(f"STEERING: Found text match for {pattern} reasoning pattern") + logger.info(f"STEERING: Pivot token: '{pivot_token}'") return True + + # Attempt fuzzy text matching as a fallback + if len(match_key) >= 20: # Only try for reasonably sized contexts + # Try each steering vector for approximate match + best_match = None + best_similarity = 0.0 + + for vector in self.manager.steering_vectors: + vector_context = vector.get("pivot_context", "") + if not vector_context or len(vector_context) < 20: + continue + + # Get the end of the vector context (last 100 chars) + vector_key = vector_context[-100:] if len(vector_context) >= 100 else vector_context + + # Calculate simple character-level similarity + min_length = min(len(match_key), len(vector_key)) + matching_chars = sum(1 for a, b in zip(match_key, vector_key) if a == b) + similarity = matching_chars / min_length if min_length > 0 else 0 + + # Keep track of best match above threshold + if similarity >= 0.7 and similarity > best_similarity: # 70% similarity threshold + best_similarity = similarity + best_match = vector + + # Use the best match if found + if best_match is not None: + self.match_found = True + self.current_vector = best_match + pattern = best_match.get("reasoning_pattern", "unknown") + pivot_token = best_match.get("pivot_token", "") + logger.info(f"STEERING: Found fuzzy text match ({best_similarity:.2f} similarity) for {pattern} pattern") + logger.info(f"STEERING: Pivot token: '{pivot_token}'") + return True return False def reset(self): - """Reset the hook state.""" + """Reset the hook state for a new generation.""" self.match_found = False self.current_vector = None + + # Clear both text and token histories self.context_buffer = "" self.token_history = [] - self.last_pattern = None - # Reset pattern tracking + # Reset pattern and state tracking + self.last_pattern = None self.active_pattern = None self.generation_started = False + + logger.info("STEERING: Hook state reset for new generation") def install_steering_hooks(model, manager: SteeringVectorManager, tokenizer=None) -> List[Tuple]: """ From 32aa02fd38c571aa1880ae813cac555cda0737b0 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 15 May 2025 18:17:16 +0800 Subject: [PATCH 09/11] Update README.md --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index aa32de68..c7622c41 100644 --- a/README.md +++ b/README.md @@ -468,6 +468,16 @@ Authorization: Bearer your_secret_api_key ## SOTA results on benchmarks with optillm +### AutoThink on GPQA-Diamond & MMLU-Pro (May 2025) + +| **Model** | **GPQA-Diamond** | | **MMLU-Pro** | | +|----------------|-----------------------------|--------------------------|----------------------------|--------------------------| +| | Accuracy (%) | Avg. Tokens | Accuracy (%) | Avg. Tokens | +| DeepSeek-R1-Distill-Qwen-1.5B | 21.72 | 7868.26 | 25.58 | 2842.75 | +| with Fixed Budget | 28.47 | 3570.00 | 26.18 | 1815.67 | +| **with AutoThink** | **31.06** | **3520.52** | **26.38** | **1792.50** | + + ### LongCePO on LongBench v2 (Apr 2025) | Model¹ | Context window | Short samples (up to 32K words) | Medium samples (32–128K words) | @@ -552,6 +562,7 @@ called patchflows. We saw huge performance gains across all the supported patchf ![Results showing optillm mixture of agents approach used with patchflows](https://raw.githubusercontent.com/codelion/optillm/main/moa-patchwork-results.png) ## References +- [AutoThink: efficient inference for reasoning LLMs](https://dx.doi.org/10.2139/ssrn.5253327) - [Implementation](optillm/autothink) - [CePO: Empowering Llama with Reasoning using Test-Time Compute](https://cerebras.ai/blog/cepo) - [Implementation](optillm/cepo) - [LongCePO: Empowering LLMs to efficiently leverage infinite context](https://cerebras.ai/blog/longcepo) - [Implementation](optillm/plugins/longcepo) - [Chain of Code: Reasoning with a Language Model-Augmented Code Emulator](https://arxiv.org/abs/2312.04474) - [Inspired the implementation of coc plugin](optillm/plugins/coc_plugin.py) From f3e6592a96e17f00fae3a17a8c38d79d28fc110e Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Thu, 15 May 2025 18:40:14 +0800 Subject: [PATCH 10/11] Update README.md --- optillm/autothink/README.md | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/optillm/autothink/README.md b/optillm/autothink/README.md index 74199d1d..f86027f3 100644 --- a/optillm/autothink/README.md +++ b/optillm/autothink/README.md @@ -15,7 +15,7 @@ AutoThink combines several advanced techniques to optimize the thinking process ### 1. Query Classification -AutoThink uses the `adaptive-classifier/llm-router` model to classify incoming queries: +AutoThink uses the `adaptive-classifier/llm-router` [model](https://huggingface.co/adaptive-classifier/llm-router) to classify incoming queries: - **HIGH**: Complex queries requiring deep reasoning, multi-step calculations, or thorough exploration - **LOW**: Simpler queries requiring less extensive reasoning @@ -29,7 +29,7 @@ Based on the classification, AutoThink allocates different token budgets for the ### 3. Steering Vectors -AutoThink uses pre-extracted steering vectors from datasets like `codelion/Qwen3-0.6B-pts-steering-vectors`. These vectors represent different reasoning patterns: +AutoThink uses pre-extracted steering vectors from [datasets](https://huggingface.co/datasets?other=pts) like `codelion/Qwen3-0.6B-pts-steering-vectors`. These vectors represent different reasoning patterns: - **Depth and thoroughness**: Encourages detailed, step-by-step reasoning - **Numerical accuracy**: Promotes precise calculations and verification @@ -93,3 +93,18 @@ response = autothink_decode( - **Enhanced Reasoning**: Steering vectors guide the model toward better reasoning patterns - **Efficiency**: Better performance without increasing model size - **Customizability**: Can be tailored for different domains using domain-specific steering vector datasets + + +## Citation + +If you use this approach in your research, please cite: + +```bibtex +@article{autothink, + title={AutoThink: efficient inference for reasoning LLMs}, + author={Sharma, Asankhaya}, + journal={SSRN Artificial Intelligence eJournal}, + year={2025}, + url = {https://dx.doi.org/10.2139/ssrn.5253327} +} +``` From ce9277a1039413286d9fa6d6da78baa0dbaf89e9 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Sat, 17 May 2025 21:23:50 +0800 Subject: [PATCH 11/11] update for new release --- optillm/__init__.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optillm/__init__.py b/optillm/__init__.py index 8fdde599..c4312267 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -2,7 +2,7 @@ import os # Version information -__version__ = "0.1.11" +__version__ = "0.1.12" # Get the path to the root optillm.py spec = util.spec_from_file_location( diff --git a/setup.py b/setup.py index ea4d4102..5e4c2fd9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="optillm", - version="0.1.11", + version="0.1.12", packages=find_packages(include=['optillm', 'optillm.*']), # This ensures all subpackages are included py_modules=['optillm'], package_data={