Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,16 +416,23 @@ def parse_conversation(messages):

def tagged_conversation_to_messages(response_text):
"""Convert a tagged conversation string or list of strings into a list of messages.
If the input doesn't contain User:/Assistant: tags, return it as is.

Args:
response_text: Either a string containing "User:" and "Assistant:" tags,
or a list of such strings.

Returns:
If input is a string: A list of message dictionaries.
If input is a list: A list of lists of message dictionaries.
If input has tags: A list of message dictionaries.
If input has no tags: The original input.
"""
def has_conversation_tags(text):
return "User:" in text or "Assistant:" in text

def process_single_response(text):
if not has_conversation_tags(text):
return text

messages = []
# Split on "User:" or "Assistant:" while keeping the delimiter
parts = re.split(r'(?=(User:|Assistant:))', text.strip())
Expand All @@ -447,7 +454,11 @@ def process_single_response(text):
return messages

if isinstance(response_text, list):
return [process_single_response(text) for text in response_text]
processed = [process_single_response(text) for text in response_text]
# If none of the responses had tags, return original list
if all(isinstance(p, str) for p in processed):
return response_text
return processed
else:
return process_single_response(response_text)

Expand Down Expand Up @@ -555,14 +566,18 @@ def proxy():
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return jsonify({"error": str(e)}), 500

# Convert tagged conversation to messages format if needed
if isinstance(response, list):
response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg
for msg in tagged_conversation_to_messages(response)]
processed_response = tagged_conversation_to_messages(response)
# If processed_response is a list of message lists, extract last message content
if processed_response != response: # Only process if format changed
response = [msg[-1]['content'] if isinstance(msg, list) and msg else msg
for msg in processed_response]
# Otherwise keep original response
else:
messages = tagged_conversation_to_messages(response)
if messages: # Only take the last message if we have any
if isinstance(messages, list) and messages: # Only process if format changed
response = messages[-1]['content']

if stream:
Expand Down
53 changes: 44 additions & 9 deletions scripts/eval_aime_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import re
import time
from typing import List, Dict, Tuple, Optional
from typing import List, Dict, Tuple, Optional, Union
from datetime import datetime
from openai import OpenAI
from datasets import load_dataset
Expand Down Expand Up @@ -89,9 +89,17 @@ def extract_answer(response: str) -> Optional[int]:

return None

def get_llm_response(problem: str, model: str) -> str:
def get_llm_response(problem: str, model: str) -> Union[str, List[Dict]]:
"""
Get response from the LLM for a given problem.
If multiple choices are returned, formats them as attempt dictionaries.

Args:
problem (str): The problem text
model (str): The model identifier

Returns:
Union[str, List[Dict]]: Either a string response or list of attempt dictionaries
"""
try:
response = client.with_options(timeout=1000.0).chat.completions.create(
Expand All @@ -101,7 +109,23 @@ def get_llm_response(problem: str, model: str) -> str:
],
max_tokens=8192,
)

# If there's more than one choice, format as attempts
if len(response.choices) > 1:
attempts = []
for i, choice in enumerate(response.choices):
response_text = choice.message.content.strip()
predicted_answer = extract_answer(response_text)
attempts.append({
"attempt_number": i + 1,
"response": response_text,
"predicted_answer": predicted_answer
})
return attempts

# If single choice, return as before
return response.choices[0].message.content.strip()

except Exception as e:
logger.error(f"Error getting LLM response: {e}")
return ""
Expand All @@ -119,14 +143,25 @@ def make_n_attempts(problem: str, model: str, n: int) -> List[Dict]:
List[Dict]: List of dictionaries containing response and predicted answer for each attempt
"""
attempts = []
for i in range(n):
remaining_attempts = n

while remaining_attempts > 0:
response = get_llm_response(problem, model)
predicted_answer = extract_answer(response)
attempts.append({
"attempt_number": i + 1,
"response": response,
"predicted_answer": predicted_answer
})

# If response is already formatted as attempts
if isinstance(response, list):
attempts.extend(response)
remaining_attempts = n - len(attempts)
else:
# Process single response as before
predicted_answer = extract_answer(response)
attempts.append({
"attempt_number": len(attempts) + 1,
"response": response,
"predicted_answer": predicted_answer
})
remaining_attempts -= 1

return attempts

def evaluate_pass_at_n(attempts: List[Dict], correct_answer: int) -> Tuple[bool, Optional[int]]:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="optillm",
version="0.0.23",
version="0.0.24",
packages=find_packages(),
py_modules=['optillm'],
package_data={
Expand Down