diff --git a/optillm.py b/optillm.py index 4cdf97f0..afd63e31 100644 --- a/optillm.py +++ b/optillm.py @@ -49,7 +49,11 @@ def get_config(): # OpenAI, Azure, or LiteLLM API configuration if os.environ.get("OPENAI_API_KEY"): API_KEY = os.environ.get("OPENAI_API_KEY") - default_client = OpenAI(api_key=API_KEY) + base_url = server_config['base_url'] + if base_url != "": + default_client = OpenAI(api_key=API_KEY, base_url=base_url) + else: + default_client = OpenAI(api_key=API_KEY) elif os.environ.get("AZURE_OPENAI_API_KEY"): API_KEY = os.environ.get("AZURE_OPENAI_API_KEY") API_VERSION = os.environ.get("AZURE_API_VERSION") diff --git a/optillm/plugins/router_plugin.py b/optillm/plugins/router_plugin.py new file mode 100644 index 00000000..de2b0805 --- /dev/null +++ b/optillm/plugins/router_plugin.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoModel, AutoTokenizer, AutoConfig +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from safetensors.torch import load_model +from transformers import AutoTokenizer, AutoModel +from optillm.mcts import chat_with_mcts +from optillm.bon import best_of_n_sampling +from optillm.moa import mixture_of_agents +from optillm.rto import round_trip_optimization +from optillm.self_consistency import advanced_self_consistency_approach +from optillm.pvg import inference_time_pv_game +from optillm.z3_solver import Z3SymPySolverSystem +from optillm.rstar import RStar +from optillm.cot_reflection import cot_reflection +from optillm.plansearch import plansearch +from optillm.leap import leap +from optillm.reread import re2_approach + +SLUG = "router" + +# Constants +MAX_LENGTH = 512 +APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"] +MODEL_NAME = "codelion/optillm-bert-uncased" + +class OptILMClassifier(nn.Module): + def __init__(self, base_model, num_labels): + super().__init__() + self.base_model = base_model + self.effort_encoder = nn.Sequential( + nn.Linear(1, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU() + ) + self.classifier = nn.Linear(base_model.config.hidden_size + 64, num_labels) + + def forward(self, input_ids, attention_mask, effort): + outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) + pooled_output = outputs.last_hidden_state[:, 0] # Shape: (batch_size, hidden_size) + effort_encoded = self.effort_encoder(effort.unsqueeze(1)) # Shape: (batch_size, 64) + combined_input = torch.cat((pooled_output, effort_encoded), dim=1) + logits = self.classifier(combined_input) + return logits + +def load_optillm_model(): + device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") + # Load the base model + base_model = AutoModel.from_pretrained("google-bert/bert-large-uncased") + # Create the OptILMClassifier + model = OptILMClassifier(base_model, num_labels=len(APPROACHES)) + model.to(device) + # Download the safetensors file + safetensors_path = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors") + # Load the state dict from the safetensors file + load_model(model, safetensors_path) + + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + return model, tokenizer, device + +def preprocess_input(tokenizer, system_prompt, initial_query): + combined_input = f"{system_prompt}\n\nUser: {initial_query}" + encoding = tokenizer.encode_plus( + combined_input, + add_special_tokens=True, + max_length=MAX_LENGTH, + padding='max_length', + truncation=True, + return_attention_mask=True, + return_tensors='pt' + ) + return encoding['input_ids'], encoding['attention_mask'] + +def predict_approach(model, input_ids, attention_mask, device, effort=0.7): + model.eval() + with torch.no_grad(): + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + effort_tensor = torch.tensor([effort], dtype=torch.float).to(device) + + logits = model(input_ids, attention_mask=attention_mask, effort=effort_tensor) + probabilities = F.softmax(logits, dim=1) + predicted_approach_index = torch.argmax(probabilities, dim=1).item() + confidence = probabilities[0][predicted_approach_index].item() + + return APPROACHES[predicted_approach_index], confidence + +def run(system_prompt, initial_query, client, model, **kwargs): + try: + # Load the trained model + router_model, tokenizer, device = load_optillm_model() + + # Preprocess the input + input_ids, attention_mask = preprocess_input(tokenizer, system_prompt, initial_query) + + # Predict the best approach + predicted_approach, _ = predict_approach(router_model, input_ids, attention_mask, device) + + print(f"Router predicted approach: {predicted_approach}") + + # Route to the appropriate approach or use the model directly + if predicted_approach == "none": + # Use the model directly without routing + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ] + ) + return response.choices[0].message.content, response.usage.completion_tokens + elif predicted_approach == "mcts": + return chat_with_mcts(system_prompt, initial_query, client, model, **kwargs) + elif predicted_approach == "bon": + return best_of_n_sampling(system_prompt, initial_query, client, model, **kwargs) + elif predicted_approach == "moa": + return mixture_of_agents(system_prompt, initial_query, client, model) + elif predicted_approach == "rto": + return round_trip_optimization(system_prompt, initial_query, client, model) + elif predicted_approach == "z3": + z3_solver = Z3SymPySolverSystem(system_prompt, client, model) + return z3_solver.process_query(initial_query) + elif predicted_approach == "self_consistency": + return advanced_self_consistency_approach(system_prompt, initial_query, client, model) + elif predicted_approach == "pvg": + return inference_time_pv_game(system_prompt, initial_query, client, model) + elif predicted_approach == "rstar": + rstar = RStar(system_prompt, client, model, **kwargs) + return rstar.solve(initial_query) + elif predicted_approach == "cot_reflection": + return cot_reflection(system_prompt, initial_query, client, model, **kwargs) + elif predicted_approach == "plansearch": + return plansearch(system_prompt, initial_query, client, model, **kwargs) + elif predicted_approach == "leap": + return leap(system_prompt, initial_query, client, model) + elif predicted_approach == "re2": + return re2_approach(system_prompt, initial_query, client, model, **kwargs) + else: + raise ValueError(f"Unknown approach: {predicted_approach}") + + except Exception as e: + # Log the error and fall back to using the model directly + print(f"Error in router plugin: {str(e)}. Falling back to direct model usage.") + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ] + ) + return response.choices[0].message.content, response.usage.completion_tokens diff --git a/scripts/train_optillm_classifier.py b/scripts/train_optillm_classifier.py index d47654d6..b6bd96f0 100644 --- a/scripts/train_optillm_classifier.py +++ b/scripts/train_optillm_classifier.py @@ -1,32 +1,32 @@ import argparse import torch -from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler -from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler +from transformers import AutoTokenizer, AutoModel +from transformers import PreTrainedModel, PretrainedConfig, AutoConfig from datasets import load_dataset -from sklearn.model_selection import train_test_split +from sklearn.model_selection import KFold from tqdm import tqdm import torch.nn.functional as F -from safetensors.torch import save_model +import torch.nn as nn +from safetensors.torch import save_model, load_model from collections import Counter - -# Check for MPS (Apple Silicon) support -if torch.backends.mps.is_available(): - device = torch.device("mps") -elif torch.cuda.is_available(): - device = torch.device("cuda") -else: - device = torch.device("cpu") - -print(f"Using device: {device}") +from torch.optim.lr_scheduler import ReduceLROnPlateau +import numpy as np # Constants APPROACHES = ["none", "mcts", "bon", "moa", "rto", "z3", "self_consistency", "pvg", "rstar", "cot_reflection", "plansearch", "leap", "re2"] MAX_LENGTH = 512 +# Device selection +device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu") +print(f"Using device: {device}") + class OptILMDataset(Dataset): - def __init__(self, prompts, best_approaches, tokenizer): + def __init__(self, prompts, approaches, ranks, tokens, tokenizer): self.prompts = prompts - self.best_approaches = best_approaches + self.approaches = approaches + self.ranks = ranks + self.tokens = tokens self.tokenizer = tokenizer def __len__(self): @@ -34,7 +34,9 @@ def __len__(self): def __getitem__(self, idx): prompt = self.prompts[idx] - best_approach = self.best_approaches[idx] + approaches = self.approaches[idx] + ranks = self.ranks[idx] + tokens = self.tokens[idx] encoding = self.tokenizer.encode_plus( prompt, @@ -49,7 +51,9 @@ def __getitem__(self, idx): return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), - 'labels': torch.tensor(APPROACHES.index(best_approach), dtype=torch.long) + 'approaches': torch.tensor([APPROACHES.index(approach) for approach in approaches], dtype=torch.long), + 'ranks': torch.tensor(ranks, dtype=torch.float), + 'tokens': torch.tensor(tokens, dtype=torch.float), } def load_and_preprocess_data(tokenizer): @@ -63,46 +67,69 @@ def load_and_preprocess_data(tokenizer): if not results: continue - # Filter the list to exclude items where rank is None - filtered_data = [item for item in results if item['rank'] is not None] - # Find the best approach (lowest rank) - best_result = min(filtered_data, key=lambda x: x['rank']) - best_approach = best_result['approach'] + + # Filter results to only include approaches with valid ranks and tokens + valid_results = [result for result in results if result['rank'] is not None and 'tokens' in result] + + # Check if we have all 13 approaches + if len(valid_results) != 13: + continue + + # Sort the results by approach to ensure consistent ordering + valid_results.sort(key=lambda x: APPROACHES.index(x['approach'])) + + approaches = [result['approach'] for result in valid_results] + ranks = [result['rank'] for result in valid_results] + tokens = [result['tokens'] for result in valid_results] data_items.append({ 'prompt': prompt, - 'best_approach': best_approach + 'approaches': approaches, + 'ranks': ranks, + 'tokens': tokens }) - # Print some statistics print(f"Total data points: {len(data_items)}") print(f"Unique prompts: {len(set(item['prompt'] for item in data_items))}") - approach_counts = Counter(item['best_approach'] for item in data_items) - print("Best Approach distribution:") + approach_counts = Counter(approach for item in data_items for approach in item['approaches']) + print("Approach distribution:") for approach, count in approach_counts.items(): print(f" {approach}: {count}") - # Split the data - train_data, val_data = train_test_split(data_items, test_size=0.2, random_state=42) - - train_dataset = OptILMDataset( - [item['prompt'] for item in train_data], - [item['best_approach'] for item in train_data], + return OptILMDataset( + [item['prompt'] for item in data_items], + [item['approaches'] for item in data_items], + [item['ranks'] for item in data_items], + [item['tokens'] for item in data_items], tokenizer ) - val_dataset = OptILMDataset( - [item['prompt'] for item in val_data], - [item['best_approach'] for item in val_data], - tokenizer - ) - - return train_dataset, val_dataset def calculate_accuracy(predictions, labels): return (predictions == labels).float().mean() -def train(model, train_dataloader, val_dataloader, optimizer, scheduler, num_epochs): +class OptILMClassifier(nn.Module): + def __init__(self, base_model, num_labels): + super().__init__() + self.base_model = base_model + self.effort_encoder = nn.Sequential( + nn.Linear(1, 64), + nn.ReLU(), + nn.Linear(64, 64), + nn.ReLU() + ) + self.classifier = nn.Linear(base_model.config.hidden_size + 64, num_labels) + + def forward(self, input_ids, attention_mask, effort): + outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask) + pooled_output = outputs.last_hidden_state[:, 0] # Shape: (batch_size, hidden_size) + effort_encoded = self.effort_encoder(effort.unsqueeze(1)) # Shape: (batch_size, 64) + combined_input = torch.cat((pooled_output, effort_encoded), dim=1) + logits = self.classifier(combined_input) + return logits + +def train(model, train_dataloader, val_dataloader, optimizer, scheduler, num_epochs, patience, clip_value): best_val_accuracy = 0.0 + epochs_without_improvement = 0 for epoch in range(num_epochs): model.train() @@ -112,120 +139,184 @@ def train(model, train_dataloader, val_dataloader, optimizer, scheduler, num_epo for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) - labels = batch['labels'].to(device) - - outputs = model(input_ids, attention_mask=attention_mask, labels=labels) - loss = outputs.loss - logits = outputs.logits + approaches = batch['approaches'].to(device) + ranks = batch['ranks'].to(device) + tokens = batch['tokens'].to(device) + + # Normalize tokens to [0, 1] range as a proxy for effort + effort = (tokens - tokens.min()) / (tokens.max() - tokens.min()) + + # Use the minimum rank (best approach) for each prompt + best_approach_indices = ranks.argmin(dim=1) + + logits = model(input_ids, attention_mask, effort[:, 0]) # Use effort for the best approach + + # Calculate standard cross-entropy loss + ce_loss = F.cross_entropy(logits, best_approach_indices) + + # Calculate effort-sensitive loss + effort_loss = F.mse_loss(logits.softmax(dim=1).gather(1, best_approach_indices.unsqueeze(1)).squeeze(), effort[:, 0]) + + # Combine losses + loss = ce_loss + 0.1 * effort_loss # Adjust the weight of effort_loss as needed loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value) optimizer.step() - scheduler.step() optimizer.zero_grad() total_loss += loss.item() predictions = torch.argmax(logits, dim=-1) - total_accuracy += calculate_accuracy(predictions, labels) + total_accuracy += calculate_accuracy(predictions, best_approach_indices) avg_train_loss = total_loss / len(train_dataloader) avg_train_accuracy = total_accuracy / len(train_dataloader) # Validation - model.eval() - total_val_accuracy = 0 - - with torch.no_grad(): - for batch in val_dataloader: - input_ids = batch['input_ids'].to(device) - attention_mask = batch['attention_mask'].to(device) - labels = batch['labels'].to(device) - - outputs = model(input_ids, attention_mask=attention_mask) - logits = outputs.logits - predictions = torch.argmax(logits, dim=-1) - total_val_accuracy += calculate_accuracy(predictions, labels) - - avg_val_accuracy = total_val_accuracy / len(val_dataloader) + avg_val_accuracy = validate(model, val_dataloader) print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy:.4f}, Val Accuracy: {avg_val_accuracy:.4f}") + # Learning rate scheduling + if isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(avg_val_accuracy) + else: + scheduler.step() + if avg_val_accuracy > best_val_accuracy: best_val_accuracy = avg_val_accuracy - # Save the best model + epochs_without_improvement = 0 save_model(model, "best_model.safetensors") + else: + epochs_without_improvement += 1 + + if epochs_without_improvement >= patience: + print(f"Early stopping triggered after {epoch+1} epochs") + break + +def validate(model, val_dataloader): + model.eval() + total_val_accuracy = 0 + + with torch.no_grad(): + for batch in val_dataloader: + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + approaches = batch['approaches'].to(device) + ranks = batch['ranks'].to(device) + tokens = batch['tokens'].to(device) + + effort = (tokens - tokens.min()) / (tokens.max() - tokens.min()) + best_approach_indices = ranks.argmin(dim=1) + + logits = model(input_ids, attention_mask, effort[:, 0]) + predictions = torch.argmax(logits, dim=-1) + total_val_accuracy += calculate_accuracy(predictions, best_approach_indices) + + return total_val_accuracy / len(val_dataloader) -def inference(model, tokenizer, prompt): +def inference(model, tokenizer, prompt, effort_levels): model.eval() with torch.no_grad(): inputs = tokenizer(prompt, return_tensors="pt", max_length=MAX_LENGTH, truncation=True, padding="max_length") input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) - outputs = model(input_ids, attention_mask=attention_mask) - logits = outputs.logits - probabilities = F.softmax(logits, dim=1) - predicted_approach_index = torch.argmax(probabilities, dim=1).item() + results = [] + for effort in effort_levels: + effort_tensor = torch.tensor([effort], dtype=torch.float).to(device) + logits = model(input_ids, attention_mask, effort_tensor) + probabilities = F.softmax(logits, dim=1) + predicted_approach_index = torch.argmax(probabilities, dim=1).item() + results.append((APPROACHES[predicted_approach_index], probabilities[0][predicted_approach_index].item())) - return APPROACHES[predicted_approach_index], probabilities[0][predicted_approach_index].item() + return results def main(args): - # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(args.model_name) - model = AutoModelForSequenceClassification.from_pretrained(args.model_name, num_labels=len(APPROACHES)) - model.to(device) + dataset = load_and_preprocess_data(tokenizer) - # Load and preprocess data - train_dataset, val_dataset = load_and_preprocess_data(tokenizer) + kf = KFold(n_splits=args.k_folds, shuffle=True, random_state=42) + + best_val_accuracy = 0 + best_fold = 0 + + for fold, (train_indices, val_indices) in enumerate(kf.split(dataset), 1): + print(f"\nTraining Fold {fold}") + + train_sampler = SubsetRandomSampler(train_indices) + val_sampler = SubsetRandomSampler(val_indices) + + train_dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler) + val_dataloader = DataLoader(dataset, batch_size=args.batch_size, sampler=val_sampler) + + base_model = AutoModel.from_pretrained(args.model_name) + model = OptILMClassifier(base_model, num_labels=len(APPROACHES)).to(device) - # Create data loaders - train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size) - val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.01) + scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=2, verbose=True) - # Optimizer and scheduler - optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.01) - total_steps = len(train_dataloader) * args.num_epochs - scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=total_steps//10, num_training_steps=total_steps) + train(model, train_dataloader, val_dataloader, optimizer, scheduler, args.num_epochs, args.patience, args.clip_value) - # Train the model - train(model, train_dataloader, val_dataloader, optimizer, scheduler, args.num_epochs) + # Evaluate the model on the validation set + fold_val_accuracy = validate(model, val_dataloader) + print(f"Fold {fold} Validation Accuracy: {fold_val_accuracy:.4f}") - # Save the final model - save_model(model, "final_model.safetensors") + # Save the model for this fold + save_model(model, f"model_fold_{fold}.safetensors") + + # Update best model if this fold performed better + if fold_val_accuracy > best_val_accuracy: + best_val_accuracy = fold_val_accuracy + best_fold = fold + save_model(model, "best_model.safetensors") + + print(f"\nBest performing model was from fold {best_fold} with validation accuracy {best_val_accuracy:.4f}") if args.push_to_hub: - model.push_to_hub(args.hub_model_id) + base_model = AutoModel.from_pretrained(args.model_name) + # best_model = OptILMClassifier(base_model, num_labels=len(APPROACHES)) + # best_model.to(device) + # load_model(best_model, "best_model.safetensors") + # we just push the base model and then upload the safetensors file manually as OptILMClassifier class doesn't have a push_to_hub method. + base_model.push_to_hub(args.hub_model_id) tokenizer.push_to_hub(args.hub_model_id) - # Example inferences + # Load the best model for inference + base_model = AutoModel.from_pretrained(args.model_name) + best_model = OptILMClassifier(base_model, num_labels=len(APPROACHES)) + best_model.to(device) + load_model(best_model, "best_model.safetensors") + best_model.eval() + test_prompts = [ "Maximize x + y subject to: x + 2y <= 10, x >= 0, y >= 0", "Find the shortest path between nodes A and B in the given graph", "Solve the Tower of Hanoi problem with 4 disks", "Determine if the given number is prime", - "Find all possible combinations of coins that sum up to $1", - "Implement a binary search algorithm", - "Design an algorithm to find the longest palindromic substring", - "Solve the 8-queens problem", - "Implement a depth-first search algorithm for a graph", - "Find the maximum subarray sum in a given array of integers" + "Find all possible combinations of coins that sum up to $1" ] + effort_levels = [0.2, 0.5, 0.8, 1.0] + print("\nInference Examples:") for prompt in test_prompts: - predicted_approach, confidence = inference(model, tokenizer, prompt) print(f"\nTest Prompt: {prompt}") - print(f"Predicted Approach: {predicted_approach}") - print(f"Confidence: {confidence:.4f}") + results = inference(best_model, tokenizer, prompt, effort_levels) + for effort, (approach, confidence) in zip(effort_levels, results): + print(f"Effort: {effort:.1f}, Predicted Approach: {approach}, Confidence: {confidence:.4f}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train OptILM classifier") parser.add_argument("--model_name", type=str, default="google-bert/bert-large-uncased", help="Pretrained model name") parser.add_argument("--batch_size", type=int, default=4, help="Batch size for training") - parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate") - parser.add_argument("--num_epochs", type=int, default=10, help="Number of training epochs") + parser.add_argument("--learning_rate", type=float, default=1e-6, help="Learning rate") + parser.add_argument("--num_epochs", type=int, default=10, help="Maximum number of training epochs") parser.add_argument("--push_to_hub", action="store_true", help="Push model to Hugging Face Hub") parser.add_argument("--hub_model_id", type=str, help="Model ID for Hugging Face Hub") + parser.add_argument("--k_folds", type=int, default=5, help="Number of folds for cross-validation") + parser.add_argument("--patience", type=int, default=3, help="Number of epochs to wait for improvement before early stopping") + parser.add_argument("--clip_value", type=float, default=1.0, help="Gradient clipping value") args = parser.parse_args() main(args) \ No newline at end of file