diff --git a/optillm.py b/optillm.py index 3874c95b..de767acd 100644 --- a/optillm.py +++ b/optillm.py @@ -6,6 +6,11 @@ from openai import AzureOpenAI, OpenAI from flask import Response import json +import importlib +import glob +import asyncio +import re +from concurrent.futures import ThreadPoolExecutor # Import the LiteLLM wrapper from litellm_wrapper import LiteLLMWrapper @@ -79,6 +84,111 @@ known_approaches = ["mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"] +plugin_approaches = {} + +def load_plugins(): + plugin_dir = os.path.join(os.path.dirname(__file__), 'optillm/plugins') + plugin_files = glob.glob(os.path.join(plugin_dir, '*.py')) + + for plugin_file in plugin_files: + module_name = os.path.basename(plugin_file)[:-3] # Remove .py extension + spec = importlib.util.spec_from_file_location(module_name, plugin_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + if hasattr(module, 'SLUG') and hasattr(module, 'run'): + plugin_approaches[module.SLUG] = module.run + logger.info(f"Loaded plugin: {module.SLUG}") + +def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict): + if model == 'auto': + return 'SINGLE', ['bon'], model + + parts = model.split('-') + approaches = [] + operation = 'SINGLE' + model_parts = [] + parsing_approaches = True + + for part in parts: + if parsing_approaches: + if part in known_approaches or part in plugin_approaches: + approaches.append(part) + elif '&' in part: + operation = 'AND' + approaches.extend(part.split('&')) + elif '|' in part: + operation = 'OR' + approaches.extend(part.split('|')) + else: + parsing_approaches = False + model_parts.append(part) + else: + model_parts.append(part) + + if not approaches: + approaches = ['bon'] + operation = 'SINGLE' + + actual_model = '-'.join(model_parts) + + return operation, approaches, actual_model + +def execute_single_approach(approach, system_prompt, initial_query, client, model): + if approach in known_approaches: + # Execute known approaches + if approach == 'mcts': + return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'], + server_config['mcts_exploration'], server_config['mcts_depth']) + elif approach == 'bon': + return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n']) + elif approach == 'moa': + return mixture_of_agents(system_prompt, initial_query, client, model) + elif approach == 'rto': + return round_trip_optimization(system_prompt, initial_query, client, model) + elif approach == 'z3': + z3_solver = Z3SolverSystem(system_prompt, client, model) + return z3_solver.process_query(initial_query) + elif approach == "self_consistency": + return advanced_self_consistency_approach(system_prompt, initial_query, client, model) + elif approach == "pvg": + return inference_time_pv_game(system_prompt, initial_query, client, model) + elif approach == "rstar": + rstar = RStar(system_prompt, client, model, + max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'], + c=server_config['rstar_c']) + return rstar.solve(initial_query) + elif approach == "cot_reflection": + return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response']) + elif approach == 'plansearch': + return plansearch(system_prompt, initial_query, client, model, n=server_config['n']) + elif approach == 'leap': + return leap(system_prompt, initial_query, client, model) + elif approach == 're2': + return re2_approach(system_prompt, initial_query, client, model, n=server_config['n']) + elif approach in plugin_approaches: + return plugin_approaches[approach](system_prompt, initial_query, client, model) + else: + raise ValueError(f"Unknown approach: {approach}") + +def execute_combined_approaches(approaches, system_prompt, initial_query, client, model): + final_response = initial_query + total_tokens = 0 + for approach in approaches: + response, tokens = execute_single_approach(approach, system_prompt, final_response, client, model) + final_response = response + total_tokens += tokens + return final_response, total_tokens + +async def execute_parallel_approaches(approaches, system_prompt, initial_query, client, model): + async def run_approach(approach): + return await asyncio.to_thread(execute_single_approach, approach, system_prompt, initial_query, client, model) + + tasks = [run_approach(approach) for approach in approaches] + results = await asyncio.gather(*tasks) + responses, tokens = zip(*results) + return list(responses), sum(tokens) + def generate_streaming_response(final_response, model): # Yield the final response if isinstance(final_response, list): @@ -99,18 +209,31 @@ def generate_streaming_response(final_response, model): def parse_conversation(messages): system_prompt = "" conversation = [] + optillm_approach = None for message in messages: role = message['role'] content = message['content'] if role == 'system': - system_prompt = content - elif role in ['user', 'assistant']: - conversation.append(f"{role.capitalize()}: {content}") + system_prompt, optillm_approach = extract_optillm_approach(content) + elif role == 'user': + if not optillm_approach: + content, optillm_approach = extract_optillm_approach(content) + conversation.append(f"User: {content}") + elif role == 'assistant': + conversation.append(f"Assistant: {content}") initial_query = "\n".join(conversation) - return system_prompt, initial_query + return system_prompt, initial_query, optillm_approach + +def extract_optillm_approach(content): + match = re.search(r'(.*?)', content) + if match: + approach = match.group(1) + content = re.sub(r'.*?', '', content).strip() + return content, approach + return content, None # Optional API key configuration to secure the proxy @app.before_request @@ -136,11 +259,18 @@ def proxy(): stream = data.get('stream', False) messages = data.get('messages', []) model = data.get('model', server_config['model']) - n = data.get('n', server_config['n']) - system_prompt, initial_query = parse_conversation(messages) + optillm_approach = data.get('optillm_approach', {}) + + system_prompt, initial_query, message_optillm_approach = parse_conversation(messages) + + # Use optillm_approach from extra_body if present, otherwise use from messages + if not optillm_approach and message_optillm_approach: + optillm_approach = message_optillm_approach + + if optillm_approach: + model = f"{optillm_approach}-{model}" - approach = server_config['approach'] base_url = server_config['base_url'] if base_url != "": @@ -148,53 +278,20 @@ def proxy(): else: client = default_client - # Handle 'auto' approach - if approach == 'auto': - for known_approach in known_approaches: - if model.startswith(f"{known_approach}-"): - approach = known_approach - model = model[len(known_approach)+1:] - break - else: - # If no known approach is found in the model name, default to 'bon' - approach = 'bon' - - - logger.info(f'Using approach {approach}, with {model}') - completion_tokens = 0 + operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches) + logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}') try: - if approach == 'mcts': - final_response, completion_tokens = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'], - server_config['mcts_exploration'], server_config['mcts_depth']) - elif approach == 'bon': - final_response, completion_tokens = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n']) - elif approach == 'moa': - final_response, completion_tokens = mixture_of_agents(system_prompt, initial_query, client, model) - elif approach == 'rto': - final_response, completion_tokens = round_trip_optimization(system_prompt, initial_query, client, model) - elif approach == 'z3': - z3_solver = Z3SolverSystem(system_prompt, client, model) - final_response, completion_tokens = z3_solver.process_query(initial_query) - elif approach == "self_consistency": - final_response, completion_tokens = advanced_self_consistency_approach(system_prompt, initial_query, client, model) - elif approach == "pvg": - final_response, completion_tokens = inference_time_pv_game(system_prompt, initial_query, client, model) - elif approach == "rstar": - rstar = RStar(system_prompt, client, model, - max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'], - c=server_config['rstar_c']) - final_response, completion_tokens = rstar.solve(initial_query) - elif approach == "cot_reflection": - final_response, completion_tokens = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response']) - elif approach == 'plansearch': - final_response, completion_tokens = plansearch(system_prompt, initial_query, client, model, n=n) - elif approach == 'leap': - final_response, completion_tokens = leap(system_prompt, initial_query, client, model) - elif approach == 're2': - final_response, completion_tokens = re2_approach(system_prompt, initial_query, client, model, n=n) + if operation == 'SINGLE': + final_response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model) + elif operation == 'AND': + final_response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model) + elif operation == 'OR': + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + final_response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model)) else: - raise ValueError(f"Unknown approach: {approach}") + raise ValueError(f"Unknown operation: {operation}") except Exception as e: logger.error(f"Error processing request: {str(e)}") return jsonify({"error": str(e)}), 500 @@ -233,7 +330,6 @@ def proxy(): logger.debug(f'API response: {response_data}') return jsonify(response_data), 200 - @app.route('/v1/models', methods=['GET']) def proxy_models(): logger.info('Received request to /v1/models') @@ -313,6 +409,8 @@ def main(): global server_config args = parse_args() + # Call this function at the start of main() + load_plugins() # Update server_config with all argument values server_config.update(vars(args)) diff --git a/optillm/plugins/memory_plugin.py b/optillm/plugins/memory_plugin.py new file mode 100644 index 00000000..00d2bade --- /dev/null +++ b/optillm/plugins/memory_plugin.py @@ -0,0 +1,103 @@ +import re +from typing import Tuple, List +import numpy as np +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity + +SLUG = "memory" + +class Memory: + def __init__(self, max_size: int = 100): + self.max_size = max_size + self.items: List[str] = [] + self.vectorizer = TfidfVectorizer() + self.vectors = None + self.completion_tokens = 0 + + def add(self, item: str): + if len(self.items) >= self.max_size: + self.items.pop(0) + self.items.append(item) + self.vectors = None # Reset vectors to force recalculation + + def get_relevant(self, query: str, n: int = 5) -> List[str]: + if not self.items: + return [] + + if self.vectors is None: + self.vectors = self.vectorizer.fit_transform(self.items) + + query_vector = self.vectorizer.transform([query]) + similarities = cosine_similarity(query_vector, self.vectors).flatten() + top_indices = similarities.argsort()[-n:][::-1] + + return [self.items[i] for i in top_indices] + +def extract_query(text: str) -> Tuple[str, str]: + query_index = text.rfind("Query:") + + if query_index != -1: + context = text[:query_index].strip() + query = text[query_index + 6:].strip() + else: + sentences = re.split(r'(?<=[.!?])\s+', text.strip()) + if len(sentences) > 1: + context = ' '.join(sentences[:-1]) + query = sentences[-1] + else: + context = text + query = "What is the main point of this text?" + return query, context + +def extract_key_information(text: str, client, model: str) -> List[str]: + prompt = f"""Extract key information from the following text. Provide a list of important facts or concepts, each on a new line: + +{text} + +Key information:""" + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + max_tokens=1000 + ) + + key_info = response.choices[0].message.content.strip().split('\n') + + return [info.strip('- ') for info in key_info if info.strip()], response.usage.completion_tokens + +def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]: + memory = Memory() + query, context = extract_query(initial_query) + completion_tokens = 0 + + # Process context and add to memory + chunk_size = 10000 + for i in range(0, len(context), chunk_size): + chunk = context[i:i+chunk_size] + key_info, tokens = extract_key_information(chunk, client, model) + completion_tokens += tokens + for info in key_info: + memory.add(info) + + # Retrieve relevant information from memory + relevant_info = memory.get_relevant(query) + + # Generate response using relevant information + prompt = f"""System: {system_prompt} + +Context: {' '.join(relevant_info)} + +{query} +""" + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + max_tokens=1000 + ) + + final_response = response.choices[0].message.content.strip() + completion_tokens += response.usage.completion_tokens + + return final_response, completion_tokens \ No newline at end of file diff --git a/optillm/plugins/readurls_plugin.py b/optillm/plugins/readurls_plugin.py new file mode 100644 index 00000000..4392f2c4 --- /dev/null +++ b/optillm/plugins/readurls_plugin.py @@ -0,0 +1,82 @@ +import re +from typing import Tuple, List +import requests +from bs4 import BeautifulSoup +from urllib.parse import urlparse + +SLUG = "readurls" + +def extract_urls(text: str) -> List[str]: + # Updated regex pattern to be more precise + url_pattern = re.compile(r'https?://[^\s\'"]+') + + # Find all matches + urls = url_pattern.findall(text) + + # Clean up the URLs + cleaned_urls = [] + for url in urls: + # Remove trailing punctuation and quotes + url = re.sub(r'[,\'\"\)\]]+$', '', url) + cleaned_urls.append(url) + + return cleaned_urls + +def fetch_webpage_content(url: str, max_length: int = 40000) -> str: + try: + headers = { + 'User-Agent': 'optillm/0.0.1 (hhttps://github.com/codelion/optillm)' + } + + response = requests.get(url, headers=headers, timeout=10) + response.raise_for_status() + + # Make a soup + soup = BeautifulSoup(response.content, 'lxml') + + # Remove script and style elements + for script in soup(["script", "style"]): + script.decompose() + + # Get text from various elements + text_elements = [] + + # Prioritize content from main content tags + for tag in ['article', 'main', 'div[role="main"]', '.main-content']: + content = soup.select_one(tag) + if content: + text_elements.extend(content.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'p'])) + break + + # If no main content found, fall back to all headers and paragraphs + if not text_elements: + text_elements = soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'p']) + + # Extract text from elements + text = ' '.join(element.get_text(strip=True) for element in text_elements) + + # Remove extra whitespace + text = re.sub(r'\s+', ' ', text).strip() + + # Remove footnote superscripts in brackets + text = re.sub(r"\[.*?\]+", '', text) + + # Truncate to max_length + if len(text) > max_length: + text = text[:max_length] + '...' + + return text + except Exception as e: + return f"Error fetching content: {str(e)}" + +def run(system_prompt, initial_query: str, client=None, model=None) -> Tuple[str, int]: + urls = extract_urls(initial_query) + # print(urls) + modified_query = initial_query + + for url in urls: + content = fetch_webpage_content(url) + domain = urlparse(url).netloc + modified_query = modified_query.replace(url, f"{url} [Content from {domain}: {content}]") + # print(modified_query) + return modified_query, 0 \ No newline at end of file diff --git a/optillm/wim.py b/optillm/wim.py new file mode 100644 index 00000000..56596a08 --- /dev/null +++ b/optillm/wim.py @@ -0,0 +1,122 @@ +from collections import deque +import tiktoken +import re + +class WiMInfiniteContextAPI: + def __init__(self, system_prompt, client, model, max_context_tokens=64000, max_margins=10, chunk_size=16000): + self.model = model + self.max_context_tokens = max_context_tokens + self.max_margins = max_margins + self.chunk_size = chunk_size + self.context_buffer = deque() + self.margins = deque(maxlen=max_margins) + try: + self.tokenizer = tiktoken.encoding_for_model(model) + except: + self.tokenizer = tiktoken.get_encoding("o200k_base") + self.system_message = system_prompt + self.client = client + self.win_completion_tokens = 0 + + def count_tokens(self, text): + return len(self.tokenizer.encode(text)) + + def trim_context_buffer(self): + while self.count_tokens("".join(self.context_buffer)) > self.max_context_tokens: + self.context_buffer.popleft() + + def generate_margin(self, chunk, query): + messages = [ + {"role": "system", "content": self.system_message}, + {"role": "user", "content": f""" +'''text +{chunk} +''' +Copy over all context relevant to the query: {query} +Provide the answer in the format: #. +Here are rules: +- If you don't know how to answer the query - start your answer with NO# +- If the text is not related to the query - start your answer with NO# +- If you can extract relevant information - start your answer with YES# +- If the text does not mention the person by name - start your answer with NO# +Example answers: +- YES#Western philosophy originated in Ancient Greece in the 6th century BCE with the pre-Socratics. +- NO#No relevant context. +"""} + ] + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens = 512 + ) + self.win_completion_tokens += response.usage.completion_tokens + return response.choices[0].message.content + + def classify_margin(self, margin): + return margin.startswith("YES#") + + def extract_query(self, text): + # Split the text into sentences + sentences = re.split(r'(?<=[.!?])\s+', text) + + # Check if the last sentence starts with "Query:" + if sentences[-1].startswith("Query:"): + return sentences[-1][6:].strip(), "".join(sentences[:-1]) + + # If not, assume the last sentence is the query + return sentences[-1].strip(), "".join(sentences[:-1]) + + def process_chunk(self, chunk, query): + self.context_buffer.append(chunk) + self.trim_context_buffer() + margin = self.generate_margin(chunk, query) + if self.classify_margin(margin): + self.margins.append(margin.split("#", 1)[1]) + + def process_stream(self, text_stream, query): + for chunk in text_stream: + self.process_chunk(chunk, query) + + def generate_final_answer(self, query): + context = "".join(self.context_buffer) + margins = "\n".join(self.margins) + messages = [ + {"role": "system", "content": self.system_message}, + {"role": "user", "content": f""" +'''text +{context} +''' +I asked my assistant to read and analyse the above content page by page to help you complete this task. These are margin notes left on each page: +'''text +{margins} +''' +Read again the note(s) and the provided content, take a deep breath and answer the query. +{self.instruction} +{query} +"""} + ] + response = self.client.chat.completions.create( + model=self.model, + messages=messages + ) + self.win_completion_tokens += response.usage.completion_tokens + return response.choices[0].message.content + + def run(self, text_stream, query): + self.process_stream(text_stream, query) + return self.generate_final_answer(query) + + @property + def instruction(self): + return "Answer the following question based on the provided context and margin notes:" + + # Usage + def text_stream_generator(self, text): + for i in range(0, len(text), self.chunk_size): + yield text[i:i+self.chunk_size] + + def process_query(self, initial_query): + query, context = self.extract_query(initial_query) + text_stream = self.text_stream_generator(context) + final_answer = self.run(text_stream, query) + return final_answer, self.win_completion_tokens \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index fd16e529..47c69eb5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,9 @@ flask torch transformers azure.identity -litellm \ No newline at end of file +tiktoken +scikit-learn +litellm +requests +beautifulsoup4 +lxml \ No newline at end of file diff --git a/scripts/eval_frames_benchmark.py b/scripts/eval_frames_benchmark.py new file mode 100644 index 00000000..16c44c09 --- /dev/null +++ b/scripts/eval_frames_benchmark.py @@ -0,0 +1,152 @@ +import argparse +import json +import os +import time +from typing import List, Dict + +from openai import OpenAI +from datasets import load_dataset +from tqdm import tqdm + +client = OpenAI(api_key="none", base_url="http://localhost:8000/v1") +SLEEP_INTERVAL = 60 + +def load_existing_results(filename: str) -> List[Dict]: + try: + with open(filename, 'r') as f: + return json.load(f) + except FileNotFoundError: + return [] + +def save_result(filename: str, result: Dict): + results = load_existing_results(filename) + results.append(result) + with open(filename, 'w') as f: + json.dump(results, f, indent=2) + +def get_last_processed_index(results: List[Dict]) -> int: + if not results: + return -1 + return max(int(r.get('index', -1)) for r in results) + + +def generate_llm_prompt(prompt: str, wiki_links: List[str]) -> str: + return f"Here are the relevant Wikipedia articles:\n{wiki_links}\n\nBased on all the information, answer the query. \n\nQuery: {prompt}\n\n" + +def get_llm_response(prompt: str, model: str) -> str: + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt} + ], + max_tokens=1000, + n=1, + stop=None, + temperature=0.7, + ) + return response.choices[0].message.content.strip() + +def evaluate_response(question: str, llm_response: str, ground_truth: str, model: str) -> Dict[str, str]: + evaluation_prompt = f"""===Task=== +I need your help in evaluating an answer provided by an LLM against a ground +truth answer. Your task is to determine if the ground truth answer is present in the LLM's +response. Please analyze the provided data and make a decision. +===Instructions=== +1. Carefully compare the "Predicted Answer" with the "Ground Truth Answer". +2. Consider the substance of the answers – look for equivalent information or correct answers. +Do not focus on exact wording unless the exact wording is crucial to the meaning. +3. Your final decision should be based on whether the meaning and the vital facts of the +"Ground Truth Answer" are present in the "Predicted Answer:" +===Input Data=== +- Question: {question} +- Predicted Answer: {llm_response} +- Ground Truth Answer: {ground_truth} +===Output Format=== +Provide your final evaluation in the following format: +"Explanation:" (How you made the decision?) +"Decision:" ("TRUE" or "FALSE" ) +Please proceed with the evaluation.""" + + evaluation_response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": evaluation_prompt} + ], + max_tokens=300, + n=1, + stop=None, + temperature=0.3, + ) + + evaluation_text = evaluation_response.choices[0].message.content.strip() + + # Extract the decision and explanation + lines = evaluation_text.split('\n') + decision = "FALSE" + explanation = "" + for line in lines: + if line.startswith("Decision:"): + decision = line.split(":")[1].strip().upper() + elif line.startswith("Explanation:"): + explanation = line.split(":", 1)[1].strip() + + return {"decision": decision, "explanation": explanation} + +def main(model: str): + # Load the dataset + dataset = load_dataset("google/frames-benchmark", split="test") + + filename = f"evaluation_results_{model.replace('/', '_')}.json" + existing_results = load_existing_results(filename) + last_processed_index = get_last_processed_index(existing_results) + + for item in tqdm(dataset, desc="Processing samples"): + index = int(item['Unnamed: 0']) + if index <= last_processed_index: + continue + + prompt = generate_llm_prompt(item['Prompt'], item['wiki_links']) + llm_response = get_llm_response(prompt, model) + evaluation = evaluate_response(item['Prompt'], llm_response, item['Answer'], model) + + result = { + "index": index, + "prompt": item['Prompt'], + "ground_truth": item['Answer'], + "llm_response": llm_response, + "evaluation_decision": evaluation['decision'], + "evaluation_explanation": evaluation['explanation'], + "reasoning_type": item['reasoning_types'] + } + + save_result(filename, result) + print(f"Index: {index}, Decision: {result['evaluation_decision']}") + time.sleep(SLEEP_INTERVAL) + + # Calculate and print summary statistics + results = load_existing_results(filename) + total_samples = len(results) + correct_answers = sum(1 for r in results if r['evaluation_decision'] == 'TRUE') + accuracy = correct_answers / total_samples + + print(f"Model: {model}") + print(f"Total samples: {total_samples}") + print(f"Correct answers: {correct_answers}") + print(f"Accuracy: {accuracy:.2%}") + + # Print accuracy by reasoning type + reasoning_types = set(r['reasoning_types'] for r in results) + for rt in reasoning_types: + rt_samples = [r for r in results if r['reasoning_types'] == rt] + rt_correct = sum(1 for r in rt_samples if r['evaluation_decision'] == 'TRUE') + rt_accuracy = rt_correct / len(rt_samples) + print(f"Accuracy for {rt}: {rt_accuracy:.2%}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate LLM performance on google/frames-benchmark") + parser.add_argument("--model", type=str, required=True, help="OpenAI model to use (e.g., gpt-4o, gpt-4o-mini)") + args = parser.parse_args() + + main(args.model) \ No newline at end of file diff --git a/scripts/gen_optillm_dataset.py b/scripts/gen_optillm_dataset.py new file mode 100644 index 00000000..9a3b3cd4 --- /dev/null +++ b/scripts/gen_optillm_dataset.py @@ -0,0 +1,96 @@ +import os +import json +import argparse +import asyncio +from tqdm import tqdm +from datasets import load_dataset +from openai import AsyncOpenAI +from typing import List, Dict, Any +import random + +# OptILM approaches +APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"] + +async def generate_response(prompt: str, approach: str) -> Dict[str, Any]: + """Generate a response using the specified approach.""" + if approach == "none": + # Use the base model without any optimization technique + client = AsyncOpenAI() + response = await client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + ) + return { + "content": response.choices[0].message.content, + "tokens": response.usage.completion_tokens, + } + else: + # Use OptILM with the specified approach + client = AsyncOpenAI(api_key="none", base_url="http://localhost:8000/v1") + response = await client.chat.completions.create( + model=f"{approach}-gpt-4o-mini", # Assuming OptILM uses this naming convention + messages=[{"role": "user", "content": prompt}], + ) + return { + "content": response.choices[0].message.content, + "tokens": response.usage.completion_tokens, + } + +async def rank_responses(prompt: str, responses: List[Dict[str, Any]]) -> List[int]: + """Rank the responses using the LLM.""" + ranking_prompt = f"Given the following prompt:\n\n{prompt}\n\nRank the following responses from best to worst, considering accuracy, completeness, and relevance. Provide the ranking as a comma-separated list of indices (0-indexed). Do not add any explanations or any other text other than the comma-separated list.\n\n" + for i, response in enumerate(responses): + ranking_prompt += f"Response {i}:\n{response['content']}\n\n" + client = AsyncOpenAI() + ranking_response = await client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": ranking_prompt}], + ) + + ranking_str = ranking_response.choices[0].message.content.strip() + print(ranking_str) + return [int(idx) for idx in ranking_str.split(",")] + +async def process_sample(sample: Dict[str, Any]) -> Dict[str, Any]: + """Process a single sample from the dataset.""" + prompt = sample["turns"][0]["content"] + results = [] + + # Generate responses for each approach + for approach in APPROACHES: + response = await generate_response(prompt, approach) + results.append({"approach": approach, **response}) + + random.shuffle(results) + # Rank the responses + rankings = await rank_responses(prompt, results) + + # Add rankings to results + for rank, idx in enumerate(rankings): + results[idx]["rank"] = rank + + return { + "prompt": prompt, + "results": results, + } + +async def generate_dataset(num_samples: int, output_file: str): + """Generate the dataset and save it to a JSONL file.""" + dataset = load_dataset("lmsys/arena-hard-auto-v0.1", split="train") + + with open(output_file, "w") as f: + for sample in tqdm(dataset.select(range(num_samples)), total=num_samples): + result = await process_sample(sample) + f.write(json.dumps(result) + "\n") + +def main(): + parser = argparse.ArgumentParser(description="Generate OptILM dataset") + parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to process") + parser.add_argument("--output_file", type=str, default="optillm_dataset.jsonl", help="Output file path") + args = parser.parse_args() + + asyncio.run(generate_dataset(args.num_samples, args.output_file)) + print(f"Dataset generated and saved to {args.output_file}") + +if __name__ == "__main__": + main() diff --git a/scripts/requirements.txt b/scripts/requirements.txt new file mode 100644 index 00000000..aee11b28 --- /dev/null +++ b/scripts/requirements.txt @@ -0,0 +1 @@ +datasets diff --git a/setup.py b/setup.py index 61579c36..8097c206 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,12 @@ "torch", "transformers", "azure-identity", + "tiktoken", + "scikit-learn", + "litellm", + "requests", + "beautifulsoup4", + "lxml", ], author="codelion", author_email="codelion@okyasoft.com", diff --git a/test_cases.json b/test_cases.json index a9c8df7a..fadf3e08 100644 --- a/test_cases.json +++ b/test_cases.json @@ -28,5 +28,10 @@ "name" : "reddit", "system_prompt": "", "query" : "There are 24 volunteers. Over the next 3 weeks, each volunteer is assigned to a different task. There are 8 tasks. Each week, the volunteers switch tasks. Each task has 3 volunteers assigned to it. Volunteers cannot be assigned to the same task more than once, and volunteers cannot share the same task more than once." + }, + { + "name" : "GH", + "system_prompt" : "", + "query" : "Find the largest possible real part of[(75+117i)z+\frac{96+144i}{z}]where z is a complex number with |z|=4" } ]