Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optillm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Version information
__version__ = "0.3.8"
__version__ = "0.3.9"

# Import from server module
from .server import (
Expand Down
11 changes: 8 additions & 3 deletions optillm/bon.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@

logger = logging.getLogger(__name__)

def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_id: str = None) -> str:
def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: str, n: int = 3, request_config: dict = None, request_id: str = None) -> str:
bon_completion_tokens = 0

# Extract max_tokens from request_config with default
max_tokens = 4096
if request_config:
max_tokens = request_config.get('max_tokens', max_tokens)

messages = [{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}]

Expand All @@ -17,7 +22,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
provider_request = {
"model": model,
"messages": messages,
"max_tokens": 4096,
"max_tokens": max_tokens,
"n": n,
"temperature": 1
}
Expand Down Expand Up @@ -50,7 +55,7 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
provider_request = {
"model": model,
"messages": messages,
"max_tokens": 4096,
"max_tokens": max_tokens,
"temperature": 1
}
response = client.chat.completions.create(**provider_request)
Expand Down
21 changes: 13 additions & 8 deletions optillm/leap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = logging.getLogger(__name__)

class LEAP:
def __init__(self, system_prompt: str, client, model: str, request_id: str = None):
def __init__(self, system_prompt: str, client, model: str, request_config: dict = None, request_id: str = None):
self.system_prompt = system_prompt
self.client = client
self.model = model
Expand All @@ -19,6 +19,11 @@ def __init__(self, system_prompt: str, client, model: str, request_id: str = Non
self.high_level_principles = []
self.leap_completion_tokens = 0

# Extract max_tokens from request_config with default
self.max_tokens = 4096
if request_config:
self.max_tokens = request_config.get('max_tokens', self.max_tokens)

def extract_output(self, text: str) -> str:
match = re.search(r'<output>(.*?)(?:</output>|$)', text, re.DOTALL)
return match.group(1).strip() if match else ""
Expand All @@ -29,7 +34,7 @@ def extract_examples_from_query(self, initial_query: str) -> List[Tuple[str, str
# Prepare request for logging
provider_request = {
"model": self.model,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Expand Down Expand Up @@ -83,7 +88,7 @@ def generate_mistakes(self, examples: List[Tuple[str, str]]) -> List[Tuple[str,
# Prepare request for logging
provider_request = {
"model": self.model,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Expand Down Expand Up @@ -116,7 +121,7 @@ def generate_low_level_principles(self, mistakes: List[Tuple[str, str, str, str]
# Prepare request for logging
provider_request = {
"model": self.model,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Expand Down Expand Up @@ -152,7 +157,7 @@ def generate_high_level_principles(self) -> List[str]:
# Prepare request for logging
provider_request = {
"model": self.model,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Expand Down Expand Up @@ -185,7 +190,7 @@ def apply_principles(self, query: str) -> str:
# Prepare request for logging
provider_request = {
"model": self.model,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": f"""
Expand Down Expand Up @@ -220,6 +225,6 @@ def solve(self, initial_query: str) -> str:

return self.apply_principles(initial_query)

def leap(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str:
leap_solver = LEAP(system_prompt, client, model, request_id)
def leap(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str:
leap_solver = LEAP(system_prompt, client, model, request_config=request_config, request_id=request_id)
return leap_solver.solve(initial_query), leap_solver.leap_completion_tokens
17 changes: 11 additions & 6 deletions optillm/mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, state: DialogueState, parent=None):
self.value = 0

class MCTS:
def __init__(self, simulation_depth, exploration_weight, client, model, request_id=None):
def __init__(self, simulation_depth, exploration_weight, client, model, request_config=None, request_id=None):
self.simulation_depth = simulation_depth
self.exploration_weight = exploration_weight
self.root = None
Expand All @@ -37,6 +37,11 @@ def __init__(self, simulation_depth, exploration_weight, client, model, request_
self.completion_tokens = 0
self.request_id = request_id

# Extract max_tokens from request_config with default
self.max_tokens = 4096
if request_config:
self.max_tokens = request_config.get('max_tokens', self.max_tokens)

def select(self, node: MCTSNode) -> MCTSNode:
logger.debug(f"Selecting node. Current node visits: {node.visits}, value: {node.value}")
if not node.children:
Expand Down Expand Up @@ -117,7 +122,7 @@ def generate_actions(self, state: DialogueState) -> List[str]:
provider_request = {
"model": self.model,
"messages": messages,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"n": n,
"temperature": 1
}
Expand Down Expand Up @@ -151,7 +156,7 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
provider_request = {
"model": self.model,
"messages": messages,
"max_tokens": 1024,
"max_tokens": min(self.max_tokens, 1024),
"n": 1,
"temperature": 1
}
Expand Down Expand Up @@ -220,11 +225,11 @@ def evaluate_state(self, state: DialogueState) -> float:
logger.warning("Failed to parse evaluation score. Using default value 0.5")
return 0.5 # Default to a neutral score if parsing fails

def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, num_simulations: int = 2, exploration_weight: float = 0.2,
simulation_depth: int = 1, request_id: str = None) -> str:
def chat_with_mcts(system_prompt: str, initial_query: str, client, model: str, num_simulations: int = 2, exploration_weight: float = 0.2,
simulation_depth: int = 1, request_config: dict = None, request_id: str = None) -> str:
logger.info("Starting chat with MCTS")
logger.info(f"Parameters: num_simulations={num_simulations}, exploration_weight={exploration_weight}, simulation_depth={simulation_depth}")
mcts = MCTS(simulation_depth=simulation_depth, exploration_weight=exploration_weight, client=client, model=model, request_id=request_id)
mcts = MCTS(simulation_depth=simulation_depth, exploration_weight=exploration_weight, client=client, model=model, request_config=request_config, request_id=request_id)
initial_state = DialogueState(system_prompt, [], initial_query)
logger.info(f"Initial query: {initial_query}")
final_state = mcts.search(initial_state, num_simulations)
Expand Down
16 changes: 11 additions & 5 deletions optillm/moa.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@

logger = logging.getLogger(__name__)

def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str, request_id: str = None) -> str:
def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str, request_config: dict = None, request_id: str = None) -> str:
logger.info(f"Starting mixture_of_agents function with model: {model}")
moa_completion_tokens = 0

# Extract max_tokens from request_config with default
max_tokens = 4096
if request_config:
max_tokens = request_config.get('max_tokens', max_tokens)

completions = []

logger.debug(f"Generating initial completions for query: {initial_query}")
Expand All @@ -19,7 +25,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}
],
"max_tokens": 4096,
"max_tokens": max_tokens,
"n": 3,
"temperature": 1
}
Expand Down Expand Up @@ -59,7 +65,7 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}
],
"max_tokens": 4096,
"max_tokens": max_tokens,
"temperature": 1
}

Expand Down Expand Up @@ -182,14 +188,14 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
"""

logger.debug("Generating final response")

provider_request = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": final_prompt}
],
"max_tokens": 8192,
"max_tokens": max_tokens,
"n": 1,
"temperature": 0.1
}
Expand Down
19 changes: 12 additions & 7 deletions optillm/plansearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,18 @@
logger = logging.getLogger(__name__)

class PlanSearch:
def __init__(self, system_prompt: str, client, model: str, request_id: str = None):
def __init__(self, system_prompt: str, client, model: str, request_config: dict = None, request_id: str = None):
self.system_prompt = system_prompt
self.client = client
self.model = model
self.request_id = request_id
self.plansearch_completion_tokens = 0

# Extract max_tokens from request_config with default
self.max_tokens = 4096
if request_config:
self.max_tokens = request_config.get('max_tokens', self.max_tokens)

def generate_observations(self, problem: str, num_observations: int = 3) -> List[str]:
prompt = f"""You are an expert Python programmer. You will be given a competitive programming question
(problem specification). You will return several useful, non-obvious, and correct observations
Expand All @@ -27,7 +32,7 @@ def generate_observations(self, problem: str, num_observations: int = 3) -> List
# Prepare request for logging
provider_request = {
"model": self.model,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
Expand Down Expand Up @@ -71,7 +76,7 @@ def generate_derived_observations(self, problem: str, observations: List[str], n
# Prepare request for logging
provider_request = {
"model": self.model,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
Expand Down Expand Up @@ -113,7 +118,7 @@ def generate_solution(self, problem: str, observations: List[str]) -> str:
# Prepare request for logging
provider_request = {
"model": self.model,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
Expand Down Expand Up @@ -155,7 +160,7 @@ def implement_solution(self, problem: str, solution: str) -> str:
# Prepare request for logging
provider_request = {
"model": self.model,
"max_tokens": 4096,
"max_tokens": self.max_tokens,
"messages": [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": prompt}
Expand Down Expand Up @@ -204,6 +209,6 @@ def solve_multiple(self, problem: str, n: int, num_initial_observations: int = 3
solutions.append(python_implementation)
return solutions

def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1, request_id: str = None) -> List[str]:
planner = PlanSearch(system_prompt, client, model, request_id)
def plansearch(system_prompt: str, initial_query: str, client, model: str, n: int = 1, request_config: dict = None, request_id: str = None) -> List[str]:
planner = PlanSearch(system_prompt, client, model, request_config=request_config, request_id=request_id)
return planner.solve_multiple(initial_query, n), planner.plansearch_completion_tokens
19 changes: 12 additions & 7 deletions optillm/pvg.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

pvg_completion_tokens = 0

def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7, request_id: str = None) -> List[str]:
def generate_solutions(client, system_prompt: str, query: str, model: str, num_solutions: int, is_sneaky: bool = False, temperature: float = 0.7, max_tokens: int = 4096, request_id: str = None) -> List[str]:
global pvg_completion_tokens
role = "sneaky" if is_sneaky else "helpful"
logger.info(f"Generating {num_solutions} {role} solutions")
Expand Down Expand Up @@ -36,7 +36,7 @@ def generate_solutions(client, system_prompt: str, query: str, model: str, num_s
"model": model,
"messages": messages,
"n": num_solutions,
"max_tokens": 4096,
"max_tokens": max_tokens,
"temperature": temperature,
}
response = client.chat.completions.create(**provider_request)
Expand Down Expand Up @@ -151,10 +151,15 @@ def extract_answer(final_state: str) -> Tuple[str, float]:
logger.warning("No answer found in the state.")
return "", 0.0

def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3, request_id: str = None) -> str:
def inference_time_pv_game(system_prompt: str, initial_query: str, client, model: str, num_rounds: int = 2, num_solutions: int = 3, request_config: dict = None, request_id: str = None) -> str:
global pvg_completion_tokens
logger.info(f"Starting inference-time PV game with {num_rounds} rounds and {num_solutions} solutions per round")


# Extract max_tokens from request_config with default
max_tokens = 4096
if request_config:
max_tokens = request_config.get('max_tokens', max_tokens)

best_solution = ""
best_score = -1

Expand All @@ -163,8 +168,8 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model

temperature = max(0.2, 0.7 - (round * 0.1))

helpful_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, temperature=temperature, request_id=request_id)
sneaky_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, is_sneaky=True, temperature=temperature, request_id=request_id)
helpful_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, temperature=temperature, max_tokens=max_tokens, request_id=request_id)
sneaky_solutions = generate_solutions(client, system_prompt, initial_query, model, num_solutions, is_sneaky=True, temperature=temperature, max_tokens=max_tokens, request_id=request_id)
all_solutions = helpful_solutions + sneaky_solutions

scores = verify_solutions(client, system_prompt, initial_query, all_solutions, model, request_id=request_id)
Expand Down Expand Up @@ -198,7 +203,7 @@ def inference_time_pv_game(system_prompt: str, initial_query: str, client, model
provider_request = {
"model": model,
"messages": messages,
"max_tokens": 1024,
"max_tokens": min(max_tokens, 1024),
"temperature": 0.5,
}
response = client.chat.completions.create(**provider_request)
Expand Down
14 changes: 11 additions & 3 deletions optillm/reread.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,28 @@

logger = logging.getLogger(__name__)

def re2_approach(system_prompt, initial_query, client, model, n=1, request_id: str = None):
def re2_approach(system_prompt, initial_query, client, model, n=1, request_config: dict = None, request_id: str = None):
"""
Implement the RE2 (Re-Reading) approach for improved reasoning in LLMs.

Args:
system_prompt (str): The system prompt to be used.
initial_query (str): The initial user query.
client: The OpenAI client object.
model (str): The name of the model to use.
n (int): Number of completions to generate.

request_config (dict): Optional configuration including max_tokens.

Returns:
str or list: The generated response(s) from the model.
"""
logger.info("Using RE2 approach for query processing")
re2_completion_tokens = 0

# Extract max_tokens from request_config if provided
max_tokens = None
if request_config:
max_tokens = request_config.get('max_tokens')

# Construct the RE2 prompt
re2_prompt = f"{initial_query}\nRead the question again: {initial_query}"
Expand All @@ -35,6 +41,8 @@ def re2_approach(system_prompt, initial_query, client, model, n=1, request_id: s
"messages": messages,
"n": n
}
if max_tokens is not None:
provider_request["max_tokens"] = max_tokens
response = client.chat.completions.create(**provider_request)

# Log provider call
Expand Down
Loading