Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 150 additions & 52 deletions optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from openai import AzureOpenAI, OpenAI
from flask import Response
import json
import importlib
import glob
import asyncio
import re
from concurrent.futures import ThreadPoolExecutor

# Import the LiteLLM wrapper
from litellm_wrapper import LiteLLMWrapper
Expand Down Expand Up @@ -79,6 +84,111 @@
known_approaches = ["mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar",
"cot_reflection", "plansearch", "leap", "re2"]

plugin_approaches = {}

def load_plugins():
plugin_dir = os.path.join(os.path.dirname(__file__), 'optillm/plugins')
plugin_files = glob.glob(os.path.join(plugin_dir, '*.py'))

for plugin_file in plugin_files:
module_name = os.path.basename(plugin_file)[:-3] # Remove .py extension
spec = importlib.util.spec_from_file_location(module_name, plugin_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

if hasattr(module, 'SLUG') and hasattr(module, 'run'):
plugin_approaches[module.SLUG] = module.run
logger.info(f"Loaded plugin: {module.SLUG}")

def parse_combined_approach(model: str, known_approaches: list, plugin_approaches: dict):
if model == 'auto':
return 'SINGLE', ['bon'], model

parts = model.split('-')
approaches = []
operation = 'SINGLE'
model_parts = []
parsing_approaches = True

for part in parts:
if parsing_approaches:
if part in known_approaches or part in plugin_approaches:
approaches.append(part)
elif '&' in part:
operation = 'AND'
approaches.extend(part.split('&'))
elif '|' in part:
operation = 'OR'
approaches.extend(part.split('|'))
else:
parsing_approaches = False
model_parts.append(part)
else:
model_parts.append(part)

if not approaches:
approaches = ['bon']
operation = 'SINGLE'

actual_model = '-'.join(model_parts)

return operation, approaches, actual_model

def execute_single_approach(approach, system_prompt, initial_query, client, model):
if approach in known_approaches:
# Execute known approaches
if approach == 'mcts':
return chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
server_config['mcts_exploration'], server_config['mcts_depth'])
elif approach == 'bon':
return best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
elif approach == 'moa':
return mixture_of_agents(system_prompt, initial_query, client, model)
elif approach == 'rto':
return round_trip_optimization(system_prompt, initial_query, client, model)
elif approach == 'z3':
z3_solver = Z3SolverSystem(system_prompt, client, model)
return z3_solver.process_query(initial_query)
elif approach == "self_consistency":
return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
elif approach == "pvg":
return inference_time_pv_game(system_prompt, initial_query, client, model)
elif approach == "rstar":
rstar = RStar(system_prompt, client, model,
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
c=server_config['rstar_c'])
return rstar.solve(initial_query)
elif approach == "cot_reflection":
return cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
elif approach == 'plansearch':
return plansearch(system_prompt, initial_query, client, model, n=server_config['n'])
elif approach == 'leap':
return leap(system_prompt, initial_query, client, model)
elif approach == 're2':
return re2_approach(system_prompt, initial_query, client, model, n=server_config['n'])
elif approach in plugin_approaches:
return plugin_approaches[approach](system_prompt, initial_query, client, model)
else:
raise ValueError(f"Unknown approach: {approach}")

def execute_combined_approaches(approaches, system_prompt, initial_query, client, model):
final_response = initial_query
total_tokens = 0
for approach in approaches:
response, tokens = execute_single_approach(approach, system_prompt, final_response, client, model)
final_response = response
total_tokens += tokens
return final_response, total_tokens

async def execute_parallel_approaches(approaches, system_prompt, initial_query, client, model):
async def run_approach(approach):
return await asyncio.to_thread(execute_single_approach, approach, system_prompt, initial_query, client, model)

tasks = [run_approach(approach) for approach in approaches]
results = await asyncio.gather(*tasks)
responses, tokens = zip(*results)
return list(responses), sum(tokens)

def generate_streaming_response(final_response, model):
# Yield the final response
if isinstance(final_response, list):
Expand All @@ -99,18 +209,31 @@ def generate_streaming_response(final_response, model):
def parse_conversation(messages):
system_prompt = ""
conversation = []
optillm_approach = None

for message in messages:
role = message['role']
content = message['content']

if role == 'system':
system_prompt = content
elif role in ['user', 'assistant']:
conversation.append(f"{role.capitalize()}: {content}")
system_prompt, optillm_approach = extract_optillm_approach(content)
elif role == 'user':
if not optillm_approach:
content, optillm_approach = extract_optillm_approach(content)
conversation.append(f"User: {content}")
elif role == 'assistant':
conversation.append(f"Assistant: {content}")

initial_query = "\n".join(conversation)
return system_prompt, initial_query
return system_prompt, initial_query, optillm_approach

def extract_optillm_approach(content):
match = re.search(r'<optillm_approach>(.*?)</optillm_approach>', content)
if match:
approach = match.group(1)
content = re.sub(r'<optillm_approach>.*?</optillm_approach>', '', content).strip()
return content, approach
return content, None

# Optional API key configuration to secure the proxy
@app.before_request
Expand All @@ -136,65 +259,39 @@ def proxy():
stream = data.get('stream', False)
messages = data.get('messages', [])
model = data.get('model', server_config['model'])
n = data.get('n', server_config['n'])

system_prompt, initial_query = parse_conversation(messages)
optillm_approach = data.get('optillm_approach', {})

system_prompt, initial_query, message_optillm_approach = parse_conversation(messages)

# Use optillm_approach from extra_body if present, otherwise use from messages
if not optillm_approach and message_optillm_approach:
optillm_approach = message_optillm_approach

if optillm_approach:
model = f"{optillm_approach}-{model}"

approach = server_config['approach']
base_url = server_config['base_url']

if base_url != "":
client = OpenAI(api_key=API_KEY, base_url=base_url)
else:
client = default_client

# Handle 'auto' approach
if approach == 'auto':
for known_approach in known_approaches:
if model.startswith(f"{known_approach}-"):
approach = known_approach
model = model[len(known_approach)+1:]
break
else:
# If no known approach is found in the model name, default to 'bon'
approach = 'bon'


logger.info(f'Using approach {approach}, with {model}')
completion_tokens = 0
operation, approaches, model = parse_combined_approach(model, known_approaches, plugin_approaches)
logger.info(f'Using approach(es) {approaches}, operation {operation}, with model {model}')

try:
if approach == 'mcts':
final_response, completion_tokens = chat_with_mcts(system_prompt, initial_query, client, model, server_config['mcts_simulations'],
server_config['mcts_exploration'], server_config['mcts_depth'])
elif approach == 'bon':
final_response, completion_tokens = best_of_n_sampling(system_prompt, initial_query, client, model, server_config['best_of_n'])
elif approach == 'moa':
final_response, completion_tokens = mixture_of_agents(system_prompt, initial_query, client, model)
elif approach == 'rto':
final_response, completion_tokens = round_trip_optimization(system_prompt, initial_query, client, model)
elif approach == 'z3':
z3_solver = Z3SolverSystem(system_prompt, client, model)
final_response, completion_tokens = z3_solver.process_query(initial_query)
elif approach == "self_consistency":
final_response, completion_tokens = advanced_self_consistency_approach(system_prompt, initial_query, client, model)
elif approach == "pvg":
final_response, completion_tokens = inference_time_pv_game(system_prompt, initial_query, client, model)
elif approach == "rstar":
rstar = RStar(system_prompt, client, model,
max_depth=server_config['rstar_max_depth'], num_rollouts=server_config['rstar_num_rollouts'],
c=server_config['rstar_c'])
final_response, completion_tokens = rstar.solve(initial_query)
elif approach == "cot_reflection":
final_response, completion_tokens = cot_reflection(system_prompt, initial_query, client, model, return_full_response=server_config['return_full_response'])
elif approach == 'plansearch':
final_response, completion_tokens = plansearch(system_prompt, initial_query, client, model, n=n)
elif approach == 'leap':
final_response, completion_tokens = leap(system_prompt, initial_query, client, model)
elif approach == 're2':
final_response, completion_tokens = re2_approach(system_prompt, initial_query, client, model, n=n)
if operation == 'SINGLE':
final_response, completion_tokens = execute_single_approach(approaches[0], system_prompt, initial_query, client, model)
elif operation == 'AND':
final_response, completion_tokens = execute_combined_approaches(approaches, system_prompt, initial_query, client, model)
elif operation == 'OR':
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
final_response, completion_tokens = loop.run_until_complete(execute_parallel_approaches(approaches, system_prompt, initial_query, client, model))
else:
raise ValueError(f"Unknown approach: {approach}")
raise ValueError(f"Unknown operation: {operation}")
except Exception as e:
logger.error(f"Error processing request: {str(e)}")
return jsonify({"error": str(e)}), 500
Expand Down Expand Up @@ -233,7 +330,6 @@ def proxy():
logger.debug(f'API response: {response_data}')
return jsonify(response_data), 200


@app.route('/v1/models', methods=['GET'])
def proxy_models():
logger.info('Received request to /v1/models')
Expand Down Expand Up @@ -313,6 +409,8 @@ def main():
global server_config
args = parse_args()

# Call this function at the start of main()
load_plugins()
# Update server_config with all argument values
server_config.update(vars(args))

Expand Down
103 changes: 103 additions & 0 deletions optillm/plugins/memory_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import re
from typing import Tuple, List
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

SLUG = "memory"

class Memory:
def __init__(self, max_size: int = 100):
self.max_size = max_size
self.items: List[str] = []
self.vectorizer = TfidfVectorizer()
self.vectors = None
self.completion_tokens = 0

def add(self, item: str):
if len(self.items) >= self.max_size:
self.items.pop(0)
self.items.append(item)
self.vectors = None # Reset vectors to force recalculation

def get_relevant(self, query: str, n: int = 5) -> List[str]:
if not self.items:
return []

if self.vectors is None:
self.vectors = self.vectorizer.fit_transform(self.items)

query_vector = self.vectorizer.transform([query])
similarities = cosine_similarity(query_vector, self.vectors).flatten()
top_indices = similarities.argsort()[-n:][::-1]

return [self.items[i] for i in top_indices]

def extract_query(text: str) -> Tuple[str, str]:
query_index = text.rfind("Query:")

if query_index != -1:
context = text[:query_index].strip()
query = text[query_index + 6:].strip()
else:
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
if len(sentences) > 1:
context = ' '.join(sentences[:-1])
query = sentences[-1]
else:
context = text
query = "What is the main point of this text?"
return query, context

def extract_key_information(text: str, client, model: str) -> List[str]:
prompt = f"""Extract key information from the following text. Provide a list of important facts or concepts, each on a new line:

{text}

Key information:"""

response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=1000
)

key_info = response.choices[0].message.content.strip().split('\n')

return [info.strip('- ') for info in key_info if info.strip()], response.usage.completion_tokens

def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]:
memory = Memory()
query, context = extract_query(initial_query)
completion_tokens = 0

# Process context and add to memory
chunk_size = 10000
for i in range(0, len(context), chunk_size):
chunk = context[i:i+chunk_size]
key_info, tokens = extract_key_information(chunk, client, model)
completion_tokens += tokens
for info in key_info:
memory.add(info)

# Retrieve relevant information from memory
relevant_info = memory.get_relevant(query)

# Generate response using relevant information
prompt = f"""System: {system_prompt}

Context: {' '.join(relevant_info)}

{query}
"""

response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}],
max_tokens=1000
)

final_response = response.choices[0].message.content.strip()
completion_tokens += response.usage.completion_tokens

return final_response, completion_tokens
Loading