diff --git a/nemoguardrails/llm/prompts.py b/nemoguardrails/llm/prompts.py index 73e5dba8f..2c64f70da 100644 --- a/nemoguardrails/llm/prompts.py +++ b/nemoguardrails/llm/prompts.py @@ -114,18 +114,20 @@ def _get_prompt( def get_prompt(config: RailsConfig, task: Union[str, Task]) -> TaskPrompt: """Return the prompt for the given task.""" - # Currently, we use the main model for all tasks - # TODO: add support to use different models for different tasks # Fetch current task parameters like name, models to use, and the prompting mode task_name = str(task.value) if isinstance(task, Task) else task task_model = "unknown" if config.models: - task_model = config.models[0].engine - if config.models[0].model: - task_model += "/" + config.models[0].model - + _models = [model for model in config.models if model.type == task_name] + if not _models: + _models = [model for model in config.models if model.type == "main"] + + task_model = _models[0].engine + if _models[0].model: + task_model += "/" + _models[0].model + task_prompting_mode = "standard" if config.prompting_mode: # if exists in config, overwrite, else, default to "standard"