Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion optillm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@ def get_config():
# OpenAI, Azure, or LiteLLM API configuration
if os.environ.get("OPENAI_API_KEY"):
API_KEY = os.environ.get("OPENAI_API_KEY")
default_client = OpenAI(api_key=API_KEY)
base_url = server_config['base_url']
if base_url != "":
default_client = OpenAI(api_key=API_KEY, base_url=base_url)
else:
default_client = OpenAI(api_key=API_KEY)
elif os.environ.get("AZURE_OPENAI_API_KEY"):
API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
API_VERSION = os.environ.get("AZURE_API_VERSION")
Expand Down
154 changes: 154 additions & 0 deletions optillm/plugins/router_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from safetensors.torch import load_model
from transformers import AutoTokenizer, AutoModel
from optillm.mcts import chat_with_mcts
from optillm.bon import best_of_n_sampling
from optillm.moa import mixture_of_agents
from optillm.rto import round_trip_optimization
from optillm.self_consistency import advanced_self_consistency_approach
from optillm.pvg import inference_time_pv_game
from optillm.z3_solver import Z3SymPySolverSystem
from optillm.rstar import RStar
from optillm.cot_reflection import cot_reflection
from optillm.plansearch import plansearch
from optillm.leap import leap
from optillm.reread import re2_approach

SLUG = "router"

# Constants
MAX_LENGTH = 512
APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"]
MODEL_NAME = "codelion/optillm-bert-uncased"

class OptILMClassifier(nn.Module):
def __init__(self, base_model, num_labels):
super().__init__()
self.base_model = base_model
self.effort_encoder = nn.Sequential(
nn.Linear(1, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU()
)
self.classifier = nn.Linear(base_model.config.hidden_size + 64, num_labels)

def forward(self, input_ids, attention_mask, effort):
outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.last_hidden_state[:, 0] # Shape: (batch_size, hidden_size)
effort_encoded = self.effort_encoder(effort.unsqueeze(1)) # Shape: (batch_size, 64)
combined_input = torch.cat((pooled_output, effort_encoded), dim=1)
logits = self.classifier(combined_input)
return logits

def load_optillm_model():
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
# Load the base model
base_model = AutoModel.from_pretrained("google-bert/bert-large-uncased")
# Create the OptILMClassifier
model = OptILMClassifier(base_model, num_labels=len(APPROACHES))
model.to(device)
# Download the safetensors file
safetensors_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors")
# Load the state dict from the safetensors file
load_model(model, safetensors_path)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
return model, tokenizer, device

def preprocess_input(tokenizer, system_prompt, initial_query):
combined_input = f"{system_prompt}\n\nUser: {initial_query}"
encoding = tokenizer.encode_plus(
combined_input,
add_special_tokens=True,
max_length=MAX_LENGTH,
padding='max_length',
truncation=True,
return_attention_mask=True,
return_tensors='pt'
)
return encoding['input_ids'], encoding['attention_mask']

def predict_approach(model, input_ids, attention_mask, device, effort=0.7):
model.eval()
with torch.no_grad():
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
effort_tensor = torch.tensor([effort], dtype=torch.float).to(device)

logits = model(input_ids, attention_mask=attention_mask, effort=effort_tensor)
probabilities = F.softmax(logits, dim=1)
predicted_approach_index = torch.argmax(probabilities, dim=1).item()
confidence = probabilities[0][predicted_approach_index].item()

return APPROACHES[predicted_approach_index], confidence

def run(system_prompt, initial_query, client, model, **kwargs):
try:
# Load the trained model
router_model, tokenizer, device = load_optillm_model()

# Preprocess the input
input_ids, attention_mask = preprocess_input(tokenizer, system_prompt, initial_query)

# Predict the best approach
predicted_approach, _ = predict_approach(router_model, input_ids, attention_mask, device)

print(f"Router predicted approach: {predicted_approach}")

# Route to the appropriate approach or use the model directly
if predicted_approach == "none":
# Use the model directly without routing
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}
]
)
return response.choices[0].message.content, response.usage.completion_tokens
elif predicted_approach == "mcts":
return chat_with_mcts(system_prompt, initial_query, client, model, **kwargs)
elif predicted_approach == "bon":
return best_of_n_sampling(system_prompt, initial_query, client, model, **kwargs)
elif predicted_approach == "moa":
return mixture_of_agents(system_prompt, initial_query, client, model)
elif predicted_approach == "rto":
return round_trip_optimization(system_prompt, initial_query, client, model)
elif predicted_approach == "z3":
z3_solver = Z3SymPySolverSystem(system_prompt, client, model)
return z3_solver.process_query(initial_query)
elif predicted_approach == "self_consistency":
return advanced_self_consistency_approach(system_prompt, initial_query, client, model)
elif predicted_approach == "pvg":
return inference_time_pv_game(system_prompt, initial_query, client, model)
elif predicted_approach == "rstar":
rstar = RStar(system_prompt, client, model, **kwargs)
return rstar.solve(initial_query)
elif predicted_approach == "cot_reflection":
return cot_reflection(system_prompt, initial_query, client, model, **kwargs)
elif predicted_approach == "plansearch":
return plansearch(system_prompt, initial_query, client, model, **kwargs)
elif predicted_approach == "leap":
return leap(system_prompt, initial_query, client, model)
elif predicted_approach == "re2":
return re2_approach(system_prompt, initial_query, client, model, **kwargs)
else:
raise ValueError(f"Unknown approach: {predicted_approach}")

except Exception as e:
# Log the error and fall back to using the model directly
print(f"Error in router plugin: {str(e)}. Falling back to direct model usage.")
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": initial_query}
]
)
return response.choices[0].message.content, response.usage.completion_tokens
Loading