diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..2150e6d9 --- /dev/null +++ b/.env.example @@ -0,0 +1,24 @@ +# API Keys for LLM Providers +# Copy this file to .env and fill in your actual API keys +# DO NOT commit your .env file with real keys! + +# OpenAI (for GPT models and o1/o3 reasoning models) +OPENAI_API_KEY=sk-... + +# Anthropic (for Claude models) +ANTHROPIC_API_KEY=sk-ant-api03-... + +# Google Gemini +GEMINI_API_KEY=... + +# DeepSeek +DEEPSEEK_API_KEY=sk-... + +# Together AI +TOGETHER_API_KEY=... + +# Fireworks AI +FIREWORKS_AI_API_KEY=... + +# Local Server Deployment (SGLang, vLLM, Tokasaurus) +SGLANG_API_KEY=... diff --git a/README.md b/README.md index 8ec45e8a..9c3232c8 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ pip install -r requirements.txt pip install -e . ``` -To call LLM API providers, set your `{INFERENCE_SERVER_PROVIDER}_API_KEY` API key. +We use `litellm` for API calls. Please set your keys by creating a `.env` following our `.env.example`. Running and profiling kernels require a GPU. If you don't have GPU available locally, you can set up [Modal](https://modal.com/). Set up your modal token after creating an account by running `modal token new`. Then, use the `generate_and_eval_single_sample_modal.py` script. @@ -122,7 +122,7 @@ If you are using a different hardware, you can generate the baseline time with ` We provide some reference baseline times a variety of NVIDIA GPUs across generations in `results/timing`, but we recommend you to generate your own baseline time for more accurate results (cluster power, software version, all affects timing result). See `results/timing/README.md` for more details. ### Multi-Turn Framework -We have also releaed the test-time framework [Caesar](https://github.com/simonguozirui/caesar) that are used in the multi-turn / iterative refinement experiments in our paper. You can use or modify this framework for high-throughput test-time scaling (both sequential and parallel) targeting KernelBench problems. +We have also releaed the test-time framework [Caesar](https://github.com/ScalingIntelligence/caesar) that are used in the multi-turn / iterative refinement experiments in our paper. You can use or modify this framework for high-throughput test-time scaling (both sequential and parallel) targeting KernelBench problems. ## 🛣️ Upcoming Roadmap Check out our [roadmap](https://github.com/ScalingIntelligence/KernelBench/issues/74) for what we plan to add as features. We welcome community contirbutions in these directions. diff --git a/requirements.txt b/requirements.txt index c912156c..a253f422 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,9 +20,6 @@ einops dotenv numpy -# to deprecate with litellm -google-generativeai -together openai -anthropic +litellm[proxy] diff --git a/scripts/generate_and_eval_single_sample.py b/scripts/generate_and_eval_single_sample.py index ff71e4bc..d596b557 100644 --- a/scripts/generate_and_eval_single_sample.py +++ b/scripts/generate_and_eval_single_sample.py @@ -50,10 +50,15 @@ def __init__(self): self.gpu_arch = ["Ada"] # Inference config - self.server_type = "deepseek" - self.model_name = "deepseek-coder" - self.max_tokens = 4096 - self.temperature = 0.0 + self.server_type = None + self.model_name = None + self.max_tokens = None + self.temperature = None + + # Reasoning model specific parameters + self.is_reasoning_model = False # set to True for o1, o3, Gemini 2.5 thinking, etc. + self.reasoning_effort = None # for o1/o3: "low", "medium", "high" + self.budget_tokens = 0 # for Claude extended thinking mode # Logging self.logdir = os.path.join(REPO_TOP_DIR, "results/eval_logs") @@ -81,6 +86,21 @@ def main(config: EvalConfig): """ Keep it simple: Generate and evaluate a single sample """ + from src.utils import SERVER_PRESETS + + if config.server_type and config.server_type in SERVER_PRESETS: + preset = SERVER_PRESETS[config.server_type] + if config.model_name is None or config.model_name == "None": + config.model_name = preset.get("model_name", "None") + if config.max_tokens is None or config.max_tokens == "None": + config.max_tokens = preset.get("max_tokens", "None") + if config.temperature is None or config.temperature == "None": + config.temperature = preset.get("temperature", "None") + + # Convert string boolean to actual boolean for reasoning model flag + if isinstance(config.is_reasoning_model, str): + config.is_reasoning_model = config.is_reasoning_model.lower() in ['true', '1', 'yes'] + print(f"Starting Eval with config: {config}") # Configurations @@ -143,6 +163,9 @@ def main(config: EvalConfig): max_tokens=config.max_tokens, verbose=config.verbose, time_generation=True, + is_reasoning_model=config.is_reasoning_model, + reasoning_effort=config.reasoning_effort, + budget_tokens=config.budget_tokens, ) # Use appropriate prompt constructor based on backend diff --git a/scripts/generate_and_eval_single_sample_modal.py b/scripts/generate_and_eval_single_sample_modal.py index e9e0866a..da449de1 100644 --- a/scripts/generate_and_eval_single_sample_modal.py +++ b/scripts/generate_and_eval_single_sample_modal.py @@ -56,10 +56,15 @@ def __init__(self): # Inference config - self.server_type = "deepseek" - self.model_name = "deepseek-coder" - self.max_tokens = 4096 - self.temperature = 0.0 + self.server_type = None + self.model_name = None + self.max_tokens = None + self.temperature = None + + # Reasoning model specific parameters + self.is_reasoning_model = False # set to True for o1, o3, Gemini 2.5 thinking, etc. + self.reasoning_effort = None # for o1/o3: "low", "medium", "high" + self.budget_tokens = 0 # for Claude extended thinking mode # Logging self.logdir = os.path.join(REPO_TOP_DIR, "results/eval_logs") @@ -94,7 +99,6 @@ def __repr__(self): "clang" # note i skip a step ) .pip_install( # required to build flash-attn - "anthropic", "numpy", "openai", "packaging", @@ -103,8 +107,6 @@ def __repr__(self): "tqdm", "datasets", "transformers", - "google-generativeai", - "together", "pytest", "ninja", "utils", @@ -112,6 +114,8 @@ def __repr__(self): #"apache-tvm", "python-dotenv", "nvidia-cutlass-dsl", + "litellm[proxy]", # Unified LLM interface + "einops", # for numerics ) .add_local_python_source("src") @@ -140,6 +144,21 @@ def main(config: EvalConfig): """ Keep it simple: Generate and evaluate a single sample """ + from src.utils import SERVER_PRESETS + + if config.server_type and config.server_type in SERVER_PRESETS: + preset = SERVER_PRESETS[config.server_type] + if config.model_name is None or config.model_name == "None": + config.model_name = preset.get("model_name", "None") + if config.max_tokens is None or config.max_tokens == "None": + config.max_tokens = preset.get("max_tokens", "None") + if config.temperature is None or config.temperature == "None": + config.temperature = preset.get("temperature", "None") + + # Convert string boolean to actual boolean for reasoning model flag + if isinstance(config.is_reasoning_model, str): + config.is_reasoning_model = config.is_reasoning_model.lower() in ['true', '1', 'yes'] + print(f"Starting Eval with config: {config}") # Configurations @@ -187,7 +206,10 @@ def main(config: EvalConfig): temperature=config.temperature, max_tokens=config.max_tokens, verbose=config.verbose, - time_generation=True) + time_generation=True, + is_reasoning_model=config.is_reasoning_model, + reasoning_effort=config.reasoning_effort, + budget_tokens=config.budget_tokens) diff --git a/scripts/generate_samples.py b/scripts/generate_samples.py index 5ee217cf..7915fc30 100644 --- a/scripts/generate_samples.py +++ b/scripts/generate_samples.py @@ -55,10 +55,15 @@ def __init__(self): self.api_query_interval = 0.0 # Inference config - self.server_type = "deepseek" - self.model_name = "deepseek-coder" - self.max_tokens = 4096 + self.server_type = None + self.model_name = None + self.max_tokens = None self.temperature = 0.0 + + # Reasoning model specific parameters + self.is_reasoning_model = False # set to True for o1, o3, Gemini 2.5 thinking, etc. + self.reasoning_effort = "low" # for o1/o3: "low", "medium", "high" + self.budget_tokens = 0 # for Claude extended thinking mode # Logging # Top Directory to Store Runs @@ -192,6 +197,21 @@ def main(config: GenerationConfig): Batch Generate Samples for Particular Level Store generated kernels in the specified run directory """ + from src.utils import SERVER_PRESETS + + if config.server_type and config.server_type in SERVER_PRESETS: + preset = SERVER_PRESETS[config.server_type] + if config.model_name is None or config.model_name == "None": + config.model_name = preset.get("model_name", "None") + if config.max_tokens is None or config.max_tokens == "None": + config.max_tokens = preset.get("max_tokens", "None") + if config.temperature is None or config.temperature == "None": + config.temperature = preset.get("temperature", "None") + + # Convert string boolean to actual boolean for reasoning model flag + if isinstance(config.is_reasoning_model, str): + config.is_reasoning_model = config.is_reasoning_model.lower() in ['true', '1', 'yes'] + print(f"Starting Batch Generation with config: {config}") # Dataset Configurations @@ -217,6 +237,10 @@ def main(config: GenerationConfig): # set up run directory run_dir = os.path.join(config.runs_dir, config.run_name) + run_exists = os.path.exists(run_dir) + if run_exists: + print(f"\n⚠️ WARNING: Run directory already exists: {run_dir}") + print(f" Existing kernels will be skipped. Use a different run_name for a fresh run.\n") os.makedirs(run_dir, exist_ok=True) pydra.save_yaml(config.to_dict(), os.path.join(run_dir, "generation_config.yaml")) @@ -225,14 +249,22 @@ def main(config: GenerationConfig): ), "supporting local file-system based storage for now" # database integreation coming soon, need to migrate from CUDA Monkeys code problems_to_run = [] + total_problems = 0 + already_completed = 0 for problem_id in range( problem_id_range.start, problem_id_range.stop + 1 ): # end index is inclusive for sample_id in range(config.num_samples): + total_problems += 1 if not check_kernel_exists(run_dir, config.level, problem_id, sample_id): problems_to_run.append( WorkArgs(problem_id=int(problem_id), sample_id=sample_id) ) + else: + already_completed += 1 + + if already_completed > 0: + print(f"📁 Found {already_completed}/{total_problems} kernels already generated. Generating remaining {len(problems_to_run)} kernels.") # Create inference function with config parameters # We provide some presets in utils but you can also pass in your own, see query_server for more details @@ -242,6 +274,9 @@ def main(config: GenerationConfig): temperature=config.temperature, max_tokens=config.max_tokens, verbose=config.verbose, + is_reasoning_model=config.is_reasoning_model, + reasoning_effort=config.reasoning_effort, + budget_tokens=config.budget_tokens, ) # Launch workers @@ -258,11 +293,16 @@ def main(config: GenerationConfig): ) num_generated_samples = len(generation_results) - total_problems = len(problems_to_run) - num_failed_problems = total_problems - num_generated_samples - print( - f"Generated {num_generated_samples} samples for total {total_problems} problems, Please retry for the {num_failed_problems} failed problems." - ) + num_attempted = len(problems_to_run) + num_failed_problems = num_attempted - num_generated_samples + + if num_attempted == 0: + print(f"\n✅ All {total_problems} kernels already exist in {run_dir}") + print(f" Use a different run_name if you want to generate fresh samples.\n") + else: + print( + f"\nGenerated {num_generated_samples} samples for total {num_attempted} problems, Please retry for the {num_failed_problems} failed problems." + ) if __name__ == "__main__": diff --git a/src/utils.py b/src/utils.py index 1af1b289..87d722af 100644 --- a/src/utils.py +++ b/src/utils.py @@ -17,10 +17,8 @@ from tqdm import tqdm # API clients -from together import Together from openai import OpenAI -import google.generativeai as genai -import anthropic +from litellm import completion # from datasets import load_dataset import numpy as np @@ -35,43 +33,13 @@ from concurrent.futures import ProcessPoolExecutor, as_completed -# Define API key access -TOGETHER_KEY = os.environ.get("TOGETHER_API_KEY") -DEEPSEEK_KEY = os.environ.get("DEEPSEEK_API_KEY") -OPENAI_KEY = os.environ.get("OPENAI_API_KEY") -GEMINI_KEY = os.environ.get("GEMINI_API_KEY") -SGLANG_KEY = os.environ.get("SGLANG_API_KEY") # for Local Deployment -ANTHROPIC_KEY = os.environ.get("ANTHROPIC_API_KEY") -SAMBANOVA_API_KEY = os.environ.get("SAMBANOVA_API_KEY") -FIREWORKS_API_KEY = os.environ.get("FIREWORKS_API_KEY") +SGLANG_KEY = os.environ.get("SGLANG_API_KEY") ######################################################## # Inference Helpers ######################################################## -@cache -def load_deepseek_tokenizer(): - # TODO: Should we update this for new deepseek? Same tokenizer? - # return AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-Coder-V2-Instruct-0724") - return AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V2", trust_remote_code=True) - -# Buffer because deepseek totally blocks us if we send stuff that's too long :( -TOO_LONG_FOR_DEEPSEEK = 115_000 - - -def is_safe_to_send_to_deepseek(prompt): - tokenizer = load_deepseek_tokenizer() - # print(f"Prompt: {len(prompt)}") - # print(f"Prompt length: {len(tokenizer(prompt, verbose=False)['input_ids'])}") - - if type(prompt) == str: - return ( - len(tokenizer(prompt, verbose=False)["input_ids"]) < TOO_LONG_FOR_DEEPSEEK - ) - else: - return len(tokenizer.apply_chat_template(prompt)) < TOO_LONG_FOR_DEEPSEEK - def set_gpu_arch(arch_list: list[str]): """ Set env variable for torch cuda arch list to build kernels for specified architectures @@ -111,225 +79,17 @@ def query_server( - Anthropic - Gemini / Google AI Studio - Fireworks (OpenAI compatbility) - - SGLang (Local Server) + - Local Server (SGLang, vLLM, Tokasaurus) """ - # Select model and client based on arguments - match server_type: - case "sglang": - url = f"http://{server_address}:{server_port}" - client = OpenAI( - api_key=SGLANG_KEY, base_url=f"{url}/v1", timeout=None, max_retries=0 - ) - model = "default" - case "deepseek": - client = OpenAI( - api_key=DEEPSEEK_KEY, - base_url="https://api.deepseek.com", - timeout=10000000, - max_retries=3, - ) - model = model_name - assert model in ["deepseek-chat", "deepseek-coder", "deepseek-reasoner"], "Only support deepseek-chat or deepseek-coder for now" - if not is_safe_to_send_to_deepseek(prompt): - raise RuntimeError("Prompt is too long for DeepSeek") - case "fireworks": - client = OpenAI( - api_key=FIREWORKS_API_KEY, - base_url="https://api.fireworks.ai/inference/v1", - timeout=10000000, - max_retries=3, - ) - model = model_name - - case "anthropic": - client = anthropic.Anthropic( - api_key=ANTHROPIC_KEY, - ) - model = model_name - case "google": - genai.configure(api_key=GEMINI_KEY) - model = model_name - case "together": - client = Together(api_key=TOGETHER_KEY) - model = model_name - case "sambanova": - client = OpenAI(api_key=SAMBANOVA_API_KEY, base_url="https://api.sambanova.ai/v1") - model = model_name - - case "openai": - client = OpenAI(api_key=OPENAI_KEY) - model = model_name - case _: - raise NotImplementedError - - if server_type != "google": - assert client is not None, "Client is not set, cannot proceed to generations" - else: - print( - f"Querying {server_type} {model} with temp {temperature} max tokens {max_tokens}" - ) - # Logic to query the LLM - if server_type == "anthropic": - assert type(prompt) == str - - if is_reasoning_model: - # Use beta endpoint with thinking enabled for reasoning models - response = client.beta.messages.create( - model=model, - system=system_prompt, - messages=[ - {"role": "user", "content": prompt}, - ], - max_tokens=max_tokens, - # Claude thinking requires budget_tokens for thinking (reasoning) - thinking={"type": "enabled", "budget_tokens": budget_tokens}, - betas=["output-128k-2025-02-19"], - ) - else: - # Use standard endpoint for normal models - response = client.messages.create( - model=model, - system=system_prompt, - messages=[ - {"role": "user", "content": prompt}, - ], - temperature=temperature, - top_p=top_p, - top_k=top_k, - max_tokens=max_tokens, - ) - outputs = [choice.text for choice in response.content if not hasattr(choice, 'thinking') or not choice.thinking] - - elif server_type == "google": - # assert model_name == "gemini-1.5-flash-002", "Only test this for now" - - generation_config = { - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, - "max_output_tokens": max_tokens, - "response_mime_type": "text/plain", - } - - model = genai.GenerativeModel( - model_name=model_name, - system_instruction=system_prompt, - generation_config=generation_config, + # Local Server (SGLang, vLLM, Tokasaurus) - special handling + if server_type == "local": + url = f"http://{server_address}:{server_port}" + client = OpenAI( + api_key=SGLANG_KEY, base_url=f"{url}/v1", timeout=None, max_retries=0 ) - - response = model.generate_content(prompt) - - return response.text - - elif server_type == "deepseek": - - if model in ["deepseek-chat", "deepseek-coder"]: - # regular deepseek model - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ], - stream=False, - temperature=temperature, - n=num_completions, - max_tokens=max_tokens, - top_p=top_p, - ) - - else: # deepseek reasoner - assert is_reasoning_model, "Only support deepseek-reasoner for now" - assert model == "deepseek-reasoner", "Only support deepseek-reasoner for now" - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ], - stream=False, - n=num_completions, - max_tokens=max_tokens, - # do not use temperature or top_p - ) - outputs = [choice.message.content for choice in response.choices] - elif server_type == "openai": - if is_reasoning_model: - assert "o1" in model or "o3" in model, "Only support o1 and o3 for now" - print(f"Using OpenAI reasoning model: {model} with reasoning effort {reasoning_effort}") - print(f"Using OpenAI reasoning model: {model} with reasoning effort {reasoning_effort}") - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "user", "content": prompt}, - ], - reasoning_effort=reasoning_effort, - ) - else: - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ], - stream=False, - temperature=temperature, - n=num_completions, - max_tokens=max_tokens, - top_p=top_p, - ) - outputs = [choice.message.content for choice in response.choices] - elif server_type == "together": - response = client.chat.completions.create( - model=model, - max_tokens=max_tokens, - temperature=temperature, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ], - top_p=top_p, - top_k=top_k, - # repetition_penalty=1, - stop=["<|eot_id|>", "<|eom_id|>"], - # truncate=32256, - stream=False, - ) - outputs = [choice.message.content for choice in response.choices] - elif server_type == "fireworks": - response = client.chat.completions.create( - model=model, - max_tokens=max_tokens, - temperature=temperature, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ], - # top_p=top_p, - # top_k=top_k, - # repetition_penalty=1, - stop=["<|eot_id|>", "<|eom_id|>"], - # truncate=32256, - stream=False, - ) - outputs = [choice.message.content for choice in response.choices] - elif server_type == "sambanova": - response = client.chat.completions.create( - model=model, - max_tokens=max_tokens, - temperature=temperature, - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": prompt}, - ], - top_p=top_p, - ) - outputs = [choice.message.content for choice in response.choices] - # for all other kinds of servers, use standard API - else: - if type(prompt) == str: + if isinstance(prompt, str): response = client.completions.create( - model=model, + model="default", prompt=prompt, temperature=temperature, n=num_completions, @@ -339,7 +99,7 @@ def query_server( outputs = [choice.text for choice in response.choices] else: response = client.chat.completions.create( - model=model, + model="default", messages=prompt, temperature=temperature, n=num_completions, @@ -347,42 +107,105 @@ def query_server( top_p=top_p, ) outputs = [choice.message.content for choice in response.choices] - - # output processing - if len(outputs) == 1: - return outputs[0] + + # output processing + if len(outputs) == 1: + return outputs[0] + else: + return outputs + + # All other providers - use LiteLLM unified interface + # Build messages list with system prompt first (if not already present) + messages = [] + + # Check if prompt is already a list with a system message + if isinstance(prompt, list) and prompt and prompt[0].get("role") == "system": + # Prompt already has system message, use it directly + messages = prompt else: - return outputs + # Add system prompt first if provided + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + + # Then add the actual prompt + if isinstance(prompt, str): + messages.append({"role": "user", "content": prompt}) + else: + messages.extend(prompt) + + try: + completion_kwargs = { + "model": model_name, + "messages": messages, + "max_tokens": max_tokens, + "n": num_completions, + } + + # Reasoning models (o1, o3, etc.) don't support standard sampling params + if is_reasoning_model: + # Note: o1/o3 models don't support temperature, top_p, top_k + # LiteLLM will pass through reasoning_effort for OpenAI o1/o3 models + if reasoning_effort: + completion_kwargs["reasoning_effort"] = reasoning_effort + # Claude extended thinking uses "thinking" parameter with dict structure + # Format: {"type": "enabled", "budget_tokens": } + if budget_tokens > 0 and "anthropic" in model_name.lower(): + completion_kwargs["thinking"] = {"type": "enabled", "budget_tokens": budget_tokens} + else: + # Standard models support temperature and top_p + completion_kwargs["temperature"] = temperature + completion_kwargs["top_p"] = top_p + + # top_k is not supported by OpenAI models + if "openai/" not in model_name.lower() and "gpt" not in model_name.lower(): + completion_kwargs["top_k"] = top_k + + response = completion(**completion_kwargs) + + # output processing + if num_completions == 1: + content = response.choices[0].message.content + if content is None: + raise ValueError(f"LLM returned None content for model {model_name}. finish_reason: {response.choices[0].finish_reason}") + return content + else: + contents = [choice.message.content for choice in response.choices] + if any(c is None for c in contents): + raise ValueError(f"LLM returned None content in one or more completions for model {model_name}") + return contents + except Exception as e: + print(f"Error in query_server for model {model_name}: {e}") + raise # a list of presets for API server configs SERVER_PRESETS = { "deepseek": { "temperature": 1.6, - "model_name": "deepseek", + "model_name": "deepseek/deepseek-coder", "max_tokens": 4096 }, "google": { - "model_name": "gemini-1.5-flash-002", + "model_name": "gemini/gemini-2.5-flash", "temperature": 0.7, # need to experiment with temperature - "max_tokens": 8192, + "max_tokens": 16384, }, "together": { # mostly for Llama 3.1 - "model_name": "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + "model_name": "together_ai/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", # "model_name": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo", "temperature": 0.7, "max_tokens": 4096, }, - "sglang": { # this is for running locally, mostly for Llama + "local": { # this is for running locally (SGLang, vLLM, Tokasaurus), mostly for Llama "temperature": 0.8, # human eval pass@N temperature "server_port": 10210, "server_address": "matx2.stanford.edu", "max_tokens": 8192, }, - "anthropic": { # for Claude 3.5 Sonnet - "model_name": "claude-3-5-sonnet-20241022", + "anthropic": { # for Claude 3.7 Sonnet + "model_name": "anthropic/claude-3-7-sonnet-20250219", "temperature": 0.8, - "max_tokens": 4096, + "max_tokens": 8192, }, "openai": { "model_name": "gpt-4o-2024-08-06", @@ -390,10 +213,10 @@ def query_server( "temperature": 0.0, "max_tokens": 4096, }, - "sambanova": { - "model_name": "Meta-Llama-3.1-405B-Instruct", - "temperature": 0.1, - "max_tokens": 8192, + "fireworks": { + "model_name": "fireworks_ai/llama-v3p1-70b-instruct", + "temperature": 0.7, + "max_tokens": 4096, }, } @@ -402,6 +225,7 @@ def create_inference_server_from_presets(server_type: str = None, greedy_sample: bool = False, verbose: bool = False, time_generation: bool = False, + model_name: str = None, **kwargs, ) -> callable: """ @@ -409,15 +233,21 @@ def create_inference_server_from_presets(server_type: str = None, """ def _query_llm(prompt: str | list[dict]): server_args = SERVER_PRESETS[server_type].copy() - + + if model_name is not None and model_name != "None": + server_args["model_name"] = model_name + if kwargs: - server_args.update(kwargs) + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None and v != "None"} + server_args.update(filtered_kwargs) + if greedy_sample: server_args["temperature"] = 0.0 server_args["top_p"] = 1.0 server_args["top_k"] = 1 + if verbose: - print(f"Querying server {server_type} with args: {server_args}") + print(f"Querying server {server_type} with model {server_args['model_name']} and args: {server_args}") if time_generation: start_time = time.time() @@ -481,6 +311,9 @@ def extract_first_code(output_string: str, code_language_types: list[str]) -> st """ Extract first code block from model output, specified by code_language_type """ + if output_string is None: + return None + trimmed = output_string.strip() # Extracting the first occurrence of content between backticks