diff --git a/optillm.py b/optillm.py
index 3874c95b..de767acd 100644
--- a/optillm.py
+++ b/optillm.py
@@ -6,6 +6,11 @@
from openai import AzureOpenAI, OpenAI
from flask import Response
import json
+import importlib
+import glob
+import asyncio
+import re
+from concurrent.futures import ThreadPoolExecutor
# Import the LiteLLM wrapper
from litellm_wrapper import LiteLLMWrapper
@@ -79,6 +84,111 @@
known_approaches = ["mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar",
"cot_reflection", "plansearch", "leap", "re2"]
+plugin_approaches = {}
+
+def load_plugins():
+ plugin_dir = os.path.join(os.path.dirname(__file__), 'optillm/plugins')
+ plugin_files = glob.glob(os.path.join(plugin_dir, '*.py'))
+
+ for plugin_file in plugin_files:
+ module_name = os.path.basename(plugin_file)[:-3] # Remove .py extension
+ spec = importlib.util.spec_from_file_location(module_name, plugin_file)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+
+ if hasattr(module, 'SLUG') and hasattr(module, 'run'):
+ plugin_approaches[module.SLUG] = module.run
+ logger.info(f"Loaded plugin: {module.SLUG}")
+
+def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict):
+ if model == 'auto':
+ return 'SINGLE', ['bon'], model
+
+ parts = model.split('-')
+ approaches = []
+ operation = 'SINGLE'
+ model_parts = []
+ parsing_approaches = True
+
+ for part in parts:
+ if parsing_approaches:
+ if part in known_approaches or part in plugin_approaches:
+ approaches.append(part)
+ elif '&' in part:
+ operation = 'AND'
+ approaches.extend(part.split('&'))
+ elif '|' in part:
+ operation = 'OR'
+ approaches.extend(part.split('|'))
+ else:
+ parsing_approaches = False
+ model_parts.append(part)
+ else:
+ model_parts.append(part)
+
+ if not approaches:
+ approaches = ['bon']
+ operation = 'SINGLE'
+
+ actual_model = '-'.join(model_parts)
+
+ return operation, approaches, actual_model
+
+def execute_single_approach(approach, system_prompt, initial_query, client, model):
+ if approach in known_approaches:
+ # Execute known approaches
+ if approach == 'mcts':
+ return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
+ server_config['mcts_exploration'], server_config['mcts_depth'])
+ elif approach == 'bon':
+ return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
+ elif approach == 'moa':
+ return mixture_of_agents(system_prompt, initial_query, client, model)
+ elif approach == 'rto':
+ return round_trip_optimization(system_prompt, initial_query, client, model)
+ elif approach == 'z3':
+ z3_solver = Z3SolverSystem(system_prompt, client, model)
+ return z3_solver.process_query(initial_query)
+ elif approach == "self_consistency":
+ return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
+ elif approach == "pvg":
+ return inference_time_pv_game(system_prompt, initial_query, client, model)
+ elif approach == "rstar":
+ rstar = RStar(system_prompt, client, model,
+ max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
+ c=server_config['rstar_c'])
+ return rstar.solve(initial_query)
+ elif approach == "cot_reflection":
+ return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
+ elif approach == 'plansearch':
+ return plansearch(system_prompt, initial_query, client, model, n=server_config['n'])
+ elif approach == 'leap':
+ return leap(system_prompt, initial_query, client, model)
+ elif approach == 're2':
+ return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
+ elif approach in plugin_approaches:
+ return plugin_approaches[approach](system_prompt, initial_query, client, model)
+ else:
+ raise ValueError(f"Unknown approach: {approach}")
+
+def execute_combined_approaches(approaches, system_prompt, initial_query, client, model):
+ final_response = initial_query
+ total_tokens = 0
+ for approach in approaches:
+ response, tokens = execute_single_approach(approach, system_prompt, final_response, client, model)
+ final_response = response
+ total_tokens += tokens
+ return final_response, total_tokens
+
+async def execute_parallel_approaches(approaches, system_prompt, initial_query, client, model):
+ async def run_approach(approach):
+ return await asyncio.to_thread(execute_single_approach, approach, system_prompt, initial_query, client, model)
+
+ tasks = [run_approach(approach) for approach in approaches]
+ results = await asyncio.gather(*tasks)
+ responses, tokens = zip(*results)
+ return list(responses), sum(tokens)
+
def generate_streaming_response(final_response, model):
# Yield the final response
if isinstance(final_response, list):
@@ -99,18 +209,31 @@ def generate_streaming_response(final_response, model):
def parse_conversation(messages):
system_prompt = ""
conversation = []
+ optillm_approach = None
for message in messages:
role = message['role']
content = message['content']
if role == 'system':
- system_prompt = content
- elif role in ['user', 'assistant']:
- conversation.append(f"{role.capitalize()}: {content}")
+ system_prompt, optillm_approach = extract_optillm_approach(content)
+ elif role == 'user':
+ if not optillm_approach:
+ content, optillm_approach = extract_optillm_approach(content)
+ conversation.append(f"User: {content}")
+ elif role == 'assistant':
+ conversation.append(f"Assistant: {content}")
initial_query = "\n".join(conversation)
- return system_prompt, initial_query
+ return system_prompt, initial_query, optillm_approach
+
+def extract_optillm_approach(content):
+ match = re.search(r'(.*?)', content)
+ if match:
+ approach = match.group(1)
+ content = re.sub(r'.*?', '', content).strip()
+ return content, approach
+ return content, None
# Optional API key configuration to secure the proxy
@app.before_request
@@ -136,11 +259,18 @@ def proxy():
stream = data.get('stream', False)
messages = data.get('messages', [])
model = data.get('model', server_config['model'])
- n = data.get('n', server_config['n'])
- system_prompt, initial_query = parse_conversation(messages)
+ optillm_approach = data.get('optillm_approach', {})
+
+ system_prompt, initial_query, message_optillm_approach = parse_conversation(messages)
+
+ # Use optillm_approach from extra_body if present, otherwise use from messages
+ if not optillm_approach and message_optillm_approach:
+ optillm_approach = message_optillm_approach
+
+ if optillm_approach:
+ model = f"{optillm_approach}-{model}"
- approach = server_config['approach']
base_url = server_config['base_url']
if base_url != "":
@@ -148,53 +278,20 @@ def proxy():
else:
client = default_client
- # Handle 'auto' approach
- if approach == 'auto':
- for known_approach in known_approaches:
- if model.startswith(f"{known_approach}-"):
- approach = known_approach
- model = model[len(known_approach)+1:]
- break
- else:
- # If no known approach is found in the model name, default to 'bon'
- approach = 'bon'
-
-
- logger.info(f'Using approach {approach}, with {model}')
- completion_tokens = 0
+ operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches)
+ logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}')
try:
- if approach == 'mcts':
- final_response, completion_tokens = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
- server_config['mcts_exploration'], server_config['mcts_depth'])
- elif approach == 'bon':
- final_response, completion_tokens = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
- elif approach == 'moa':
- final_response, completion_tokens = mixture_of_agents(system_prompt, initial_query, client, model)
- elif approach == 'rto':
- final_response, completion_tokens = round_trip_optimization(system_prompt, initial_query, client, model)
- elif approach == 'z3':
- z3_solver = Z3SolverSystem(system_prompt, client, model)
- final_response, completion_tokens = z3_solver.process_query(initial_query)
- elif approach == "self_consistency":
- final_response, completion_tokens = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
- elif approach == "pvg":
- final_response, completion_tokens = inference_time_pv_game(system_prompt, initial_query, client, model)
- elif approach == "rstar":
- rstar = RStar(system_prompt, client, model,
- max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
- c=server_config['rstar_c'])
- final_response, completion_tokens = rstar.solve(initial_query)
- elif approach == "cot_reflection":
- final_response, completion_tokens = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
- elif approach == 'plansearch':
- final_response, completion_tokens = plansearch(system_prompt, initial_query, client, model, n=n)
- elif approach == 'leap':
- final_response, completion_tokens = leap(system_prompt, initial_query, client, model)
- elif approach == 're2':
- final_response, completion_tokens = re2_approach(system_prompt, initial_query, client, model, n=n)
+ if operation == 'SINGLE':
+ final_response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
+ elif operation == 'AND':
+ final_response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
+ elif operation == 'OR':
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ final_response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
else:
- raise ValueError(f"Unknown approach: {approach}")
+ raise ValueError(f"Unknown operation: {operation}")
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return jsonify({"error": str(e)}), 500
@@ -233,7 +330,6 @@ def proxy():
logger.debug(f'API response: {response_data}')
return jsonify(response_data), 200
-
@app.route('/v1/models', methods=['GET'])
def proxy_models():
logger.info('Received request to /v1/models')
@@ -313,6 +409,8 @@ def main():
global server_config
args = parse_args()
+ # Call this function at the start of main()
+ load_plugins()
# Update server_config with all argument values
server_config.update(vars(args))
diff --git a/optillm/plugins/memory_plugin.py b/optillm/plugins/memory_plugin.py
new file mode 100644
index 00000000..00d2bade
--- /dev/null
+++ b/optillm/plugins/memory_plugin.py
@@ -0,0 +1,103 @@
+import re
+from typing import Tuple, List
+import numpy as np
+from sklearn.feature_extraction.text import TfidfVectorizer
+from sklearn.metrics.pairwise import cosine_similarity
+
+SLUG = "memory"
+
+class Memory:
+ def __init__(self, max_size: int = 100):
+ self.max_size = max_size
+ self.items: List[str] = []
+ self.vectorizer = TfidfVectorizer()
+ self.vectors = None
+ self.completion_tokens = 0
+
+ def add(self, item: str):
+ if len(self.items) >= self.max_size:
+ self.items.pop(0)
+ self.items.append(item)
+ self.vectors = None # Reset vectors to force recalculation
+
+ def get_relevant(self, query: str, n: int = 5) -> List[str]:
+ if not self.items:
+ return []
+
+ if self.vectors is None:
+ self.vectors = self.vectorizer.fit_transform(self.items)
+
+ query_vector = self.vectorizer.transform([query])
+ similarities = cosine_similarity(query_vector, self.vectors).flatten()
+ top_indices = similarities.argsort()[-n:][::-1]
+
+ return [self.items[i] for i in top_indices]
+
+def extract_query(text: str) -> Tuple[str, str]:
+ query_index = text.rfind("Query:")
+
+ if query_index != -1:
+ context = text[:query_index].strip()
+ query = text[query_index + 6:].strip()
+ else:
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
+ if len(sentences) > 1:
+ context = ' '.join(sentences[:-1])
+ query = sentences[-1]
+ else:
+ context = text
+ query = "What is the main point of this text?"
+ return query, context
+
+def extract_key_information(text: str, client, model: str) -> List[str]:
+ prompt = f"""Extract key information from the following text. Provide a list of important facts or concepts, each on a new line:
+
+{text}
+
+Key information:"""
+
+ response = client.chat.completions.create(
+ model=model,
+ messages=[{"role": "user", "content": prompt}],
+ max_tokens=1000
+ )
+
+ key_info = response.choices[0].message.content.strip().split('\n')
+
+ return [info.strip('- ') for info in key_info if info.strip()], response.usage.completion_tokens
+
+def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]:
+ memory = Memory()
+ query, context = extract_query(initial_query)
+ completion_tokens = 0
+
+ # Process context and add to memory
+ chunk_size = 10000
+ for i in range(0, len(context), chunk_size):
+ chunk = context[i:i+chunk_size]
+ key_info, tokens = extract_key_information(chunk, client, model)
+ completion_tokens += tokens
+ for info in key_info:
+ memory.add(info)
+
+ # Retrieve relevant information from memory
+ relevant_info = memory.get_relevant(query)
+
+ # Generate response using relevant information
+ prompt = f"""System: {system_prompt}
+
+Context: {' '.join(relevant_info)}
+
+{query}
+"""
+
+ response = client.chat.completions.create(
+ model=model,
+ messages=[{"role": "user", "content": prompt}],
+ max_tokens=1000
+ )
+
+ final_response = response.choices[0].message.content.strip()
+ completion_tokens += response.usage.completion_tokens
+
+ return final_response, completion_tokens
\ No newline at end of file
diff --git a/optillm/plugins/readurls_plugin.py b/optillm/plugins/readurls_plugin.py
new file mode 100644
index 00000000..4392f2c4
--- /dev/null
+++ b/optillm/plugins/readurls_plugin.py
@@ -0,0 +1,82 @@
+import re
+from typing import Tuple, List
+import requests
+from bs4 import BeautifulSoup
+from urllib.parse import urlparse
+
+SLUG = "readurls"
+
+def extract_urls(text: str) -> List[str]:
+ # Updated regex pattern to be more precise
+ url_pattern = re.compile(r'https?://[^\s\'"]+')
+
+ # Find all matches
+ urls = url_pattern.findall(text)
+
+ # Clean up the URLs
+ cleaned_urls = []
+ for url in urls:
+ # Remove trailing punctuation and quotes
+ url = re.sub(r'[,\'\"\)\]]+$', '', url)
+ cleaned_urls.append(url)
+
+ return cleaned_urls
+
+def fetch_webpage_content(url: str, max_length: int = 40000) -> str:
+ try:
+ headers = {
+ 'User-Agent': 'optillm/0.0.1 (hhttps://github.com/codelion/optillm)'
+ }
+
+ response = requests.get(url, headers=headers, timeout=10)
+ response.raise_for_status()
+
+ # Make a soup
+ soup = BeautifulSoup(response.content, 'lxml')
+
+ # Remove script and style elements
+ for script in soup(["script", "style"]):
+ script.decompose()
+
+ # Get text from various elements
+ text_elements = []
+
+ # Prioritize content from main content tags
+ for tag in ['article', 'main', 'div[role="main"]', '.main-content']:
+ content = soup.select_one(tag)
+ if content:
+ text_elements.extend(content.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'p']))
+ break
+
+ # If no main content found, fall back to all headers and paragraphs
+ if not text_elements:
+ text_elements = soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'p'])
+
+ # Extract text from elements
+ text = ' '.join(element.get_text(strip=True) for element in text_elements)
+
+ # Remove extra whitespace
+ text = re.sub(r'\s+', ' ', text).strip()
+
+ # Remove footnote superscripts in brackets
+ text = re.sub(r"\[.*?\]+", '', text)
+
+ # Truncate to max_length
+ if len(text) > max_length:
+ text = text[:max_length] + '...'
+
+ return text
+ except Exception as e:
+ return f"Error fetching content: {str(e)}"
+
+def run(system_prompt, initial_query: str, client=None, model=None) -> Tuple[str, int]:
+ urls = extract_urls(initial_query)
+ # print(urls)
+ modified_query = initial_query
+
+ for url in urls:
+ content = fetch_webpage_content(url)
+ domain = urlparse(url).netloc
+ modified_query = modified_query.replace(url, f"{url} [Content from {domain}: {content}]")
+ # print(modified_query)
+ return modified_query, 0
\ No newline at end of file
diff --git a/optillm/wim.py b/optillm/wim.py
new file mode 100644
index 00000000..56596a08
--- /dev/null
+++ b/optillm/wim.py
@@ -0,0 +1,122 @@
+from collections import deque
+import tiktoken
+import re
+
+class WiMInfiniteContextAPI:
+ def __init__(self, system_prompt, client, model, max_context_tokens=64000, max_margins=10, chunk_size=16000):
+ self.model = model
+ self.max_context_tokens = max_context_tokens
+ self.max_margins = max_margins
+ self.chunk_size = chunk_size
+ self.context_buffer = deque()
+ self.margins = deque(maxlen=max_margins)
+ try:
+ self.tokenizer = tiktoken.encoding_for_model(model)
+ except:
+ self.tokenizer = tiktoken.get_encoding("o200k_base")
+ self.system_message = system_prompt
+ self.client = client
+ self.win_completion_tokens = 0
+
+ def count_tokens(self, text):
+ return len(self.tokenizer.encode(text))
+
+ def trim_context_buffer(self):
+ while self.count_tokens("".join(self.context_buffer)) > self.max_context_tokens:
+ self.context_buffer.popleft()
+
+ def generate_margin(self, chunk, query):
+ messages = [
+ {"role": "system", "content": self.system_message},
+ {"role": "user", "content": f"""
+'''text
+{chunk}
+'''
+Copy over all context relevant to the query: {query}
+Provide the answer in the format: #.
+Here are rules:
+- If you don't know how to answer the query - start your answer with NO#
+- If the text is not related to the query - start your answer with NO#
+- If you can extract relevant information - start your answer with YES#
+- If the text does not mention the person by name - start your answer with NO#
+Example answers:
+- YES#Western philosophy originated in Ancient Greece in the 6th century BCE with the pre-Socratics.
+- NO#No relevant context.
+"""}
+ ]
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=messages,
+ max_tokens = 512
+ )
+ self.win_completion_tokens += response.usage.completion_tokens
+ return response.choices[0].message.content
+
+ def classify_margin(self, margin):
+ return margin.startswith("YES#")
+
+ def extract_query(self, text):
+ # Split the text into sentences
+ sentences = re.split(r'(?<=[.!?])\s+', text)
+
+ # Check if the last sentence starts with "Query:"
+ if sentences[-1].startswith("Query:"):
+ return sentences[-1][6:].strip(), "".join(sentences[:-1])
+
+ # If not, assume the last sentence is the query
+ return sentences[-1].strip(), "".join(sentences[:-1])
+
+ def process_chunk(self, chunk, query):
+ self.context_buffer.append(chunk)
+ self.trim_context_buffer()
+ margin = self.generate_margin(chunk, query)
+ if self.classify_margin(margin):
+ self.margins.append(margin.split("#", 1)[1])
+
+ def process_stream(self, text_stream, query):
+ for chunk in text_stream:
+ self.process_chunk(chunk, query)
+
+ def generate_final_answer(self, query):
+ context = "".join(self.context_buffer)
+ margins = "\n".join(self.margins)
+ messages = [
+ {"role": "system", "content": self.system_message},
+ {"role": "user", "content": f"""
+'''text
+{context}
+'''
+I asked my assistant to read and analyse the above content page by page to help you complete this task. These are margin notes left on each page:
+'''text
+{margins}
+'''
+Read again the note(s) and the provided content, take a deep breath and answer the query.
+{self.instruction}
+{query}
+"""}
+ ]
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=messages
+ )
+ self.win_completion_tokens += response.usage.completion_tokens
+ return response.choices[0].message.content
+
+ def run(self, text_stream, query):
+ self.process_stream(text_stream, query)
+ return self.generate_final_answer(query)
+
+ @property
+ def instruction(self):
+ return "Answer the following question based on the provided context and margin notes:"
+
+ # Usage
+ def text_stream_generator(self, text):
+ for i in range(0, len(text), self.chunk_size):
+ yield text[i:i+self.chunk_size]
+
+ def process_query(self, initial_query):
+ query, context = self.extract_query(initial_query)
+ text_stream = self.text_stream_generator(context)
+ final_answer = self.run(text_stream, query)
+ return final_answer, self.win_completion_tokens
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index fd16e529..47c69eb5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,4 +7,9 @@ flask
torch
transformers
azure.identity
-litellm
\ No newline at end of file
+tiktoken
+scikit-learn
+litellm
+requests
+beautifulsoup4
+lxml
\ No newline at end of file
diff --git a/scripts/eval_frames_benchmark.py b/scripts/eval_frames_benchmark.py
new file mode 100644
index 00000000..16c44c09
--- /dev/null
+++ b/scripts/eval_frames_benchmark.py
@@ -0,0 +1,152 @@
+import argparse
+import json
+import os
+import time
+from typing import List, Dict
+
+from openai import OpenAI
+from datasets import load_dataset
+from tqdm import tqdm
+
+client = OpenAI(api_key="none", base_url="http://localhost:8000/v1")
+SLEEP_INTERVAL = 60
+
+def load_existing_results(filename: str) -> List[Dict]:
+ try:
+ with open(filename, 'r') as f:
+ return json.load(f)
+ except FileNotFoundError:
+ return []
+
+def save_result(filename: str, result: Dict):
+ results = load_existing_results(filename)
+ results.append(result)
+ with open(filename, 'w') as f:
+ json.dump(results, f, indent=2)
+
+def get_last_processed_index(results: List[Dict]) -> int:
+ if not results:
+ return -1
+ return max(int(r.get('index', -1)) for r in results)
+
+
+def generate_llm_prompt(prompt: str, wiki_links: List[str]) -> str:
+ return f"Here are the relevant Wikipedia articles:\n{wiki_links}\n\nBased on all the information, answer the query. \n\nQuery: {prompt}\n\n"
+
+def get_llm_response(prompt: str, model: str) -> str:
+ response = client.chat.completions.create(
+ model=model,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": prompt}
+ ],
+ max_tokens=1000,
+ n=1,
+ stop=None,
+ temperature=0.7,
+ )
+ return response.choices[0].message.content.strip()
+
+def evaluate_response(question: str, llm_response: str, ground_truth: str, model: str) -> Dict[str, str]:
+ evaluation_prompt = f"""===Task===
+I need your help in evaluating an answer provided by an LLM against a ground
+truth answer. Your task is to determine if the ground truth answer is present in the LLM's
+response. Please analyze the provided data and make a decision.
+===Instructions===
+1. Carefully compare the "Predicted Answer" with the "Ground Truth Answer".
+2. Consider the substance of the answers – look for equivalent information or correct answers.
+Do not focus on exact wording unless the exact wording is crucial to the meaning.
+3. Your final decision should be based on whether the meaning and the vital facts of the
+"Ground Truth Answer" are present in the "Predicted Answer:"
+===Input Data===
+- Question: {question}
+- Predicted Answer: {llm_response}
+- Ground Truth Answer: {ground_truth}
+===Output Format===
+Provide your final evaluation in the following format:
+"Explanation:" (How you made the decision?)
+"Decision:" ("TRUE" or "FALSE" )
+Please proceed with the evaluation."""
+
+ evaluation_response = client.chat.completions.create(
+ model=model,
+ messages=[
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": evaluation_prompt}
+ ],
+ max_tokens=300,
+ n=1,
+ stop=None,
+ temperature=0.3,
+ )
+
+ evaluation_text = evaluation_response.choices[0].message.content.strip()
+
+ # Extract the decision and explanation
+ lines = evaluation_text.split('\n')
+ decision = "FALSE"
+ explanation = ""
+ for line in lines:
+ if line.startswith("Decision:"):
+ decision = line.split(":")[1].strip().upper()
+ elif line.startswith("Explanation:"):
+ explanation = line.split(":", 1)[1].strip()
+
+ return {"decision": decision, "explanation": explanation}
+
+def main(model: str):
+ # Load the dataset
+ dataset = load_dataset("google/frames-benchmark", split="test")
+
+ filename = f"evaluation_results_{model.replace('/', '_')}.json"
+ existing_results = load_existing_results(filename)
+ last_processed_index = get_last_processed_index(existing_results)
+
+ for item in tqdm(dataset, desc="Processing samples"):
+ index = int(item['Unnamed: 0'])
+ if index <= last_processed_index:
+ continue
+
+ prompt = generate_llm_prompt(item['Prompt'], item['wiki_links'])
+ llm_response = get_llm_response(prompt, model)
+ evaluation = evaluate_response(item['Prompt'], llm_response, item['Answer'], model)
+
+ result = {
+ "index": index,
+ "prompt": item['Prompt'],
+ "ground_truth": item['Answer'],
+ "llm_response": llm_response,
+ "evaluation_decision": evaluation['decision'],
+ "evaluation_explanation": evaluation['explanation'],
+ "reasoning_type": item['reasoning_types']
+ }
+
+ save_result(filename, result)
+ print(f"Index: {index}, Decision: {result['evaluation_decision']}")
+ time.sleep(SLEEP_INTERVAL)
+
+ # Calculate and print summary statistics
+ results = load_existing_results(filename)
+ total_samples = len(results)
+ correct_answers = sum(1 for r in results if r['evaluation_decision'] == 'TRUE')
+ accuracy = correct_answers / total_samples
+
+ print(f"Model: {model}")
+ print(f"Total samples: {total_samples}")
+ print(f"Correct answers: {correct_answers}")
+ print(f"Accuracy: {accuracy:.2%}")
+
+ # Print accuracy by reasoning type
+ reasoning_types = set(r['reasoning_types'] for r in results)
+ for rt in reasoning_types:
+ rt_samples = [r for r in results if r['reasoning_types'] == rt]
+ rt_correct = sum(1 for r in rt_samples if r['evaluation_decision'] == 'TRUE')
+ rt_accuracy = rt_correct / len(rt_samples)
+ print(f"Accuracy for {rt}: {rt_accuracy:.2%}")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Evaluate LLM performance on google/frames-benchmark")
+ parser.add_argument("--model", type=str, required=True, help="OpenAI model to use (e.g., gpt-4o, gpt-4o-mini)")
+ args = parser.parse_args()
+
+ main(args.model)
\ No newline at end of file
diff --git a/scripts/gen_optillm_dataset.py b/scripts/gen_optillm_dataset.py
new file mode 100644
index 00000000..9a3b3cd4
--- /dev/null
+++ b/scripts/gen_optillm_dataset.py
@@ -0,0 +1,96 @@
+import os
+import json
+import argparse
+import asyncio
+from tqdm import tqdm
+from datasets import load_dataset
+from openai import AsyncOpenAI
+from typing import List, Dict, Any
+import random
+
+# OptILM approaches
+APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
+
+async def generate_response(prompt: str, approach: str) -> Dict[str, Any]:
+ """Generate a response using the specified approach."""
+ if approach == "none":
+ # Use the base model without any optimization technique
+ client = AsyncOpenAI()
+ response = await client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=[{"role": "user", "content": prompt}],
+ )
+ return {
+ "content": response.choices[0].message.content,
+ "tokens": response.usage.completion_tokens,
+ }
+ else:
+ # Use OptILM with the specified approach
+ client = AsyncOpenAI(api_key="none", base_url="http://localhost:8000/v1")
+ response = await client.chat.completions.create(
+ model=f"{approach}-gpt-4o-mini", # Assuming OptILM uses this naming convention
+ messages=[{"role": "user", "content": prompt}],
+ )
+ return {
+ "content": response.choices[0].message.content,
+ "tokens": response.usage.completion_tokens,
+ }
+
+async def rank_responses(prompt: str, responses: List[Dict[str, Any]]) -> List[int]:
+ """Rank the responses using the LLM."""
+ ranking_prompt = f"Given the following prompt:\n\n{prompt}\n\nRank the following responses from best to worst, considering accuracy, completeness, and relevance. Provide the ranking as a comma-separated list of indices (0-indexed). Do not add any explanations or any other text other than the comma-separated list.\n\n"
+ for i, response in enumerate(responses):
+ ranking_prompt += f"Response {i}:\n{response['content']}\n\n"
+ client = AsyncOpenAI()
+ ranking_response = await client.chat.completions.create(
+ model="gpt-4o-mini",
+ messages=[{"role": "user", "content": ranking_prompt}],
+ )
+
+ ranking_str = ranking_response.choices[0].message.content.strip()
+ print(ranking_str)
+ return [int(idx) for idx in ranking_str.split(",")]
+
+async def process_sample(sample: Dict[str, Any]) -> Dict[str, Any]:
+ """Process a single sample from the dataset."""
+ prompt = sample["turns"][0]["content"]
+ results = []
+
+ # Generate responses for each approach
+ for approach in APPROACHES:
+ response = await generate_response(prompt, approach)
+ results.append({"approach": approach, **response})
+
+ random.shuffle(results)
+ # Rank the responses
+ rankings = await rank_responses(prompt, results)
+
+ # Add rankings to results
+ for rank, idx in enumerate(rankings):
+ results[idx]["rank"] = rank
+
+ return {
+ "prompt": prompt,
+ "results": results,
+ }
+
+async def generate_dataset(num_samples: int, output_file: str):
+ """Generate the dataset and save it to a JSONL file."""
+ dataset = load_dataset("lmsys/arena-hard-auto-v0.1", split="train")
+
+ with open(output_file, "w") as f:
+ for sample in tqdm(dataset.select(range(num_samples)), total=num_samples):
+ result = await process_sample(sample)
+ f.write(json.dumps(result) + "\n")
+
+def main():
+ parser = argparse.ArgumentParser(description="Generate OptILM dataset")
+ parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to process")
+ parser.add_argument("--output_file", type=str, default="optillm_dataset.jsonl", help="Output file path")
+ args = parser.parse_args()
+
+ asyncio.run(generate_dataset(args.num_samples, args.output_file))
+ print(f"Dataset generated and saved to {args.output_file}")
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/requirements.txt b/scripts/requirements.txt
new file mode 100644
index 00000000..aee11b28
--- /dev/null
+++ b/scripts/requirements.txt
@@ -0,0 +1 @@
+datasets
diff --git a/setup.py b/setup.py
index 61579c36..8097c206 100644
--- a/setup.py
+++ b/setup.py
@@ -14,6 +14,12 @@
"torch",
"transformers",
"azure-identity",
+ "tiktoken",
+ "scikit-learn",
+ "litellm",
+ "requests",
+ "beautifulsoup4",
+ "lxml",
],
author="codelion",
author_email="codelion@okyasoft.com",
diff --git a/test_cases.json b/test_cases.json
index a9c8df7a..fadf3e08 100644
--- a/test_cases.json
+++ b/test_cases.json
@@ -28,5 +28,10 @@
"name" : "reddit",
"system_prompt": "",
"query" : "There are 24 volunteers. Over the next 3 weeks, each volunteer is assigned to a different task. There are 8 tasks. Each week, the volunteers switch tasks. Each task has 3 volunteers assigned to it. Volunteers cannot be assigned to the same task more than once, and volunteers cannot share the same task more than once."
+ },
+ {
+ "name" : "GH",
+ "system_prompt" : "",
+ "query" : "Find the largest possible real part of[(75+117i)z+\frac{96+144i}{z}]where z is a complex number with |z|=4"
}
]