-
Notifications
You must be signed in to change notification settings - Fork 87
Adding LiteLLM support #78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
437c3ab
adding litellm support for all except sambanova
pythonomar22 099b95b
adding reasoning config support
pythonomar22 38f3c32
addressing some comments
pythonomar22 bbe7e76
fixing modal litellm
pythonomar22 9dac84e
setting defaults
0aaa85d
litellm lgtm
simonguozirui File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # API Keys for LLM Providers | ||
| # Copy this file to .env and fill in your actual API keys | ||
| # DO NOT commit your .env file with real keys! | ||
|
|
||
| # OpenAI (for GPT models and o1/o3 reasoning models) | ||
| OPENAI_API_KEY=sk-... | ||
|
|
||
| # Anthropic (for Claude models) | ||
| ANTHROPIC_API_KEY=sk-ant-api03-... | ||
|
|
||
| # Google Gemini | ||
| GEMINI_API_KEY=... | ||
|
|
||
| # DeepSeek | ||
| DEEPSEEK_API_KEY=sk-... | ||
|
|
||
| # Together AI | ||
| TOGETHER_API_KEY=... | ||
|
|
||
| # Fireworks AI | ||
| FIREWORKS_AI_API_KEY=... | ||
|
|
||
| # Local Server Deployment (SGLang, vLLM, Tokasaurus) | ||
| SGLANG_API_KEY=... |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,9 +20,6 @@ einops | |
| dotenv | ||
| numpy | ||
|
|
||
| # to deprecate with litellm | ||
| google-generativeai | ||
| together | ||
| openai | ||
| anthropic | ||
| litellm[proxy] | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,10 +55,15 @@ def __init__(self): | |
| self.api_query_interval = 0.0 | ||
|
|
||
| # Inference config | ||
| self.server_type = "deepseek" | ||
| self.model_name = "deepseek-coder" | ||
| self.max_tokens = 4096 | ||
| self.server_type = None | ||
| self.model_name = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so are we not specifying a default, this could also works, and just expect them to use a preset here?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated to reflect slack thread - but yeah |
||
| self.max_tokens = None | ||
| self.temperature = 0.0 | ||
|
|
||
| # Reasoning model specific parameters | ||
| self.is_reasoning_model = False # set to True for o1, o3, Gemini 2.5 thinking, etc. | ||
| self.reasoning_effort = "low" # for o1/o3: "low", "medium", "high" | ||
| self.budget_tokens = 0 # for Claude extended thinking mode | ||
|
|
||
| # Logging | ||
| # Top Directory to Store Runs | ||
|
|
@@ -192,6 +197,21 @@ def main(config: GenerationConfig): | |
| Batch Generate Samples for Particular Level | ||
| Store generated kernels in the specified run directory | ||
| """ | ||
| from src.utils import SERVER_PRESETS | ||
|
|
||
| if config.server_type and config.server_type in SERVER_PRESETS: | ||
| preset = SERVER_PRESETS[config.server_type] | ||
| if config.model_name is None or config.model_name == "None": | ||
| config.model_name = preset.get("model_name", "None") | ||
| if config.max_tokens is None or config.max_tokens == "None": | ||
| config.max_tokens = preset.get("max_tokens", "None") | ||
| if config.temperature is None or config.temperature == "None": | ||
| config.temperature = preset.get("temperature", "None") | ||
|
|
||
| # Convert string boolean to actual boolean for reasoning model flag | ||
| if isinstance(config.is_reasoning_model, str): | ||
| config.is_reasoning_model = config.is_reasoning_model.lower() in ['true', '1', 'yes'] | ||
|
|
||
| print(f"Starting Batch Generation with config: {config}") | ||
|
|
||
| # Dataset Configurations | ||
|
|
@@ -217,6 +237,10 @@ def main(config: GenerationConfig): | |
|
|
||
| # set up run directory | ||
| run_dir = os.path.join(config.runs_dir, config.run_name) | ||
| run_exists = os.path.exists(run_dir) | ||
| if run_exists: | ||
| print(f"\n⚠️ WARNING: Run directory already exists: {run_dir}") | ||
| print(f" Existing kernels will be skipped. Use a different run_name for a fresh run.\n") | ||
| os.makedirs(run_dir, exist_ok=True) | ||
| pydra.save_yaml(config.to_dict(), os.path.join(run_dir, "generation_config.yaml")) | ||
|
|
||
|
|
@@ -225,14 +249,22 @@ def main(config: GenerationConfig): | |
| ), "supporting local file-system based storage for now" # database integreation coming soon, need to migrate from CUDA Monkeys code | ||
|
|
||
| problems_to_run = [] | ||
| total_problems = 0 | ||
| already_completed = 0 | ||
| for problem_id in range( | ||
| problem_id_range.start, problem_id_range.stop + 1 | ||
| ): # end index is inclusive | ||
| for sample_id in range(config.num_samples): | ||
| total_problems += 1 | ||
| if not check_kernel_exists(run_dir, config.level, problem_id, sample_id): | ||
| problems_to_run.append( | ||
| WorkArgs(problem_id=int(problem_id), sample_id=sample_id) | ||
| ) | ||
| else: | ||
| already_completed += 1 | ||
|
|
||
| if already_completed > 0: | ||
| print(f"📁 Found {already_completed}/{total_problems} kernels already generated. Generating remaining {len(problems_to_run)} kernels.") | ||
|
|
||
| # Create inference function with config parameters | ||
| # We provide some presets in utils but you can also pass in your own, see query_server for more details | ||
|
|
@@ -242,6 +274,9 @@ def main(config: GenerationConfig): | |
| temperature=config.temperature, | ||
| max_tokens=config.max_tokens, | ||
| verbose=config.verbose, | ||
| is_reasoning_model=config.is_reasoning_model, | ||
| reasoning_effort=config.reasoning_effort, | ||
| budget_tokens=config.budget_tokens, | ||
| ) | ||
|
|
||
| # Launch workers | ||
|
|
@@ -258,11 +293,16 @@ def main(config: GenerationConfig): | |
| ) | ||
|
|
||
| num_generated_samples = len(generation_results) | ||
| total_problems = len(problems_to_run) | ||
| num_failed_problems = total_problems - num_generated_samples | ||
| print( | ||
| f"Generated {num_generated_samples} samples for total {total_problems} problems, Please retry for the {num_failed_problems} failed problems." | ||
| ) | ||
| num_attempted = len(problems_to_run) | ||
| num_failed_problems = num_attempted - num_generated_samples | ||
|
|
||
| if num_attempted == 0: | ||
| print(f"\n✅ All {total_problems} kernels already exist in {run_dir}") | ||
| print(f" Use a different run_name if you want to generate fresh samples.\n") | ||
| else: | ||
| print( | ||
| f"\nGenerated {num_generated_samples} samples for total {num_attempted} problems, Please retry for the {num_failed_problems} failed problems." | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.