Skip to content

Commit

Permalink
feat: Add a new Orchestrator "prompt_flow" (#1026)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Dougal <adamdougal@microsoft.com>
  • Loading branch information
tectonia and adamdougal committed Jun 14, 2024
1 parent a186884 commit ce9281a
Show file tree
Hide file tree
Showing 16 changed files with 1,533 additions and 1,119 deletions.
8 changes: 6 additions & 2 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ AZURE_FORM_RECOGNIZER_KEY=
# Azure AI Content Safety for filtering out the inappropriate questions or answers
AZURE_CONTENT_SAFETY_ENDPOINT=
AZURE_CONTENT_SAFETY_KEY=
# Orchestration strategy. Use Azure OpenAI Functions (openai_function), Semantic Kernel (semantic_kernel) or LangChain (langchain) for messages orchestration. If you are using a new model version 0613 select any strategy, if you are using a 0314 model version select "langchain". Note that both `openai_function` and `semantic_kernel` use OpenAI function calling.
# Orchestration strategy. Use Azure OpenAI Functions (openai_function), Semantic Kernel (semantic_kernel), LangChain (langchain) or Prompt Flow (prompt_flow) for messages orchestration. If you are using a new model version 0613 select any strategy, if you are using a 0314 model version select "langchain". Note that both `openai_function` and `semantic_kernel` use OpenAI function calling.
ORCHESTRATION_STRATEGY=openai_function
# If selected Prompt Flow as orchestration strategy, please provide the following environment variables. Note that Prompt Flow does not support RBAC authentication currently.
AZURE_ML_WORKSPACE_NAME=
PROMPT_FLOW_DEPLOYMENT_NAME=
PROMPT_FLOW_ENDPOINT_NAME=
#Speech-to-text feature
AZURE_SPEECH_SERVICE_KEY=
AZURE_SPEECH_SERVICE_REGION=
Expand All @@ -58,4 +62,4 @@ AZURE_AUTH_TYPE=keys
USE_KEY_VAULT=true
AZURE_KEY_VAULT_ENDPOINT=
# Chat conversation type to decide between custom or byod (bring your own data) conversation type
CONVERSATION_FLOW=
CONVERSATION_FLOW=
8 changes: 8 additions & 0 deletions code/backend/batch/utilities/helpers/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ def __load_config(self, **kwargs) -> None:
"LOAD_CONFIG_FROM_BLOB_STORAGE"
)

self.AZURE_ML_WORKSPACE_NAME = os.getenv(
"AZURE_ML_WORKSPACE_NAME", ""
)

self.PROMPT_FLOW_ENDPOINT_NAME = os.getenv("PROMPT_FLOW_ENDPOINT_NAME", "")

self.PROMPT_FLOW_DEPLOYMENT_NAME = os.getenv("PROMPT_FLOW_DEPLOYMENT_NAME", "")

def should_use_data(self) -> bool:
if (
self.AZURE_SEARCH_SERVICE
Expand Down
12 changes: 12 additions & 0 deletions code/backend/batch/utilities/helpers/llm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from semantic_kernel.connectors.ai.open_ai.prompt_execution_settings.azure_chat_prompt_execution_settings import (
AzureChatPromptExecutionSettings,
)
from azure.ai.ml import MLClient
from azure.identity import DefaultAzureCredential
from .env_helper import EnvHelper


Expand Down Expand Up @@ -154,3 +156,13 @@ def get_sk_service_settings(self, service: AzureChatCompletion):
max_tokens=self.llm_max_tokens,
),
)

def get_ml_client(self):
if not hasattr(self, "_ml_client"):
self._ml_client = MLClient(
DefaultAzureCredential(),
self.env_helper.AZURE_SUBSCRIPTION_ID,
self.env_helper.AZURE_RESOURCE_GROUP,
self.env_helper.AZURE_ML_WORKSPACE_NAME,
)
return self._ml_client
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ class OrchestrationStrategy(Enum):
OPENAI_FUNCTION = "openai_function"
LANGCHAIN = "langchain"
SEMANTIC_KERNEL = "semantic_kernel"
PROMPT_FLOW = "prompt_flow"
93 changes: 93 additions & 0 deletions code/backend/batch/utilities/orchestrator/prompt_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
from typing import List
import json
import tempfile

from .orchestrator_base import OrchestratorBase
from ..common.answer import Answer
from ..helpers.llm_helper import LLMHelper
from ..helpers.env_helper import EnvHelper

logger = logging.getLogger(__name__)


class PromptFlowOrchestrator(OrchestratorBase):
def __init__(self) -> None:
super().__init__()
self.llm_helper = LLMHelper()
self.env_helper = EnvHelper()

# Get the ML client, endpoint and deployment names
self.ml_client = self.llm_helper.get_ml_client()
self.enpoint_name = self.env_helper.PROMPT_FLOW_ENDPOINT_NAME
self.deployment_name = self.env_helper.PROMPT_FLOW_DEPLOYMENT_NAME

async def orchestrate(
self, user_message: str, chat_history: List[dict], **kwargs: dict
) -> list[dict]:
# Call Content Safety tool on question
if self.config.prompts.enable_content_safety:
if response := self.call_content_safety_input(user_message):
return response

transformed_chat_history = self.transform_chat_history(chat_history)

file_name = self.transform_data_into_file(
user_message, transformed_chat_history
)

# Call the Prompt Flow service
try:
response = self.ml_client.online_endpoints.invoke(
endpoint_name=self.enpoint_name,
request_file=file_name,
deployment_name=self.deployment_name,
)
result = json.loads(response)
logger.debug(result)
except Exception as error:
logger.error("The request failed: %s", error)
raise RuntimeError(f"The request failed: {error}") from error

# Transform response into answer for further processing
answer = Answer(question=user_message, answer=result["chat_output"])

# Call Content Safety tool on answer
if self.config.prompts.enable_content_safety:
if response := self.call_content_safety_output(user_message, answer.answer):
return response

# Format the output for the UI
messages = self.output_parser.parse(
question=answer.question,
answer=answer.answer,
source_documents=answer.source_documents,
)
return messages

def transform_chat_history(self, chat_history):
transformed_chat_history = []
for i, message in enumerate(chat_history):
if message["role"] == "user":
user_message = message["content"]
assistant_message = ""
if (
i + 1 < len(chat_history)
and chat_history[i + 1]["role"] == "assistant"
):
assistant_message = chat_history[i + 1]["content"]
transformed_chat_history.append(
{
"inputs": {"chat_input": user_message},
"outputs": {"chat_output": assistant_message},
}
)
return transformed_chat_history

def transform_data_into_file(self, user_message, chat_history):
# Transform data input into a file for the Prompt Flow service
data = {"chat_input": user_message, "chat_history": chat_history}
body = str.encode(json.dumps(data))
with tempfile.NamedTemporaryFile(delete=False) as file:
file.write(body)
return file.name
3 changes: 3 additions & 0 deletions code/backend/batch/utilities/orchestrator/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .open_ai_functions import OpenAIFunctionsOrchestrator
from .lang_chain_agent import LangChainAgent
from .semantic_kernel import SemanticKernelOrchestrator
from .prompt_flow import PromptFlowOrchestrator


def get_orchestrator(orchestration_strategy: str):
Expand All @@ -11,5 +12,7 @@ def get_orchestrator(orchestration_strategy: str):
return LangChainAgent()
elif orchestration_strategy == OrchestrationStrategy.SEMANTIC_KERNEL.value:
return SemanticKernelOrchestrator()
elif orchestration_strategy == OrchestrationStrategy.PROMPT_FLOW.value:
return PromptFlowOrchestrator()
else:
raise Exception(f"Unknown orchestration strategy: {orchestration_strategy}")
1 change: 1 addition & 0 deletions code/frontend/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export default defineConfig({
sourcemap: true
},
server: {
host: true,
proxy: {
"/api": {
target: "http://127.0.0.1:5050",
Expand Down
2 changes: 1 addition & 1 deletion code/tests/utilities/helpers/test_config_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def test_get_available_orchestration_strategies(config: Config):

# then
assert sorted(orchestration_strategies) == sorted(
["openai_function", "langchain", "semantic_kernel"]
["openai_function", "langchain", "prompt_flow", "semantic_kernel"]
)


Expand Down
28 changes: 28 additions & 0 deletions code/tests/utilities/helpers/test_llm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
AZURE_OPENAI_MODEL = "mock-model"
AZURE_OPENAI_MAX_TOKENS = "100"
AZURE_OPENAI_EMBEDDING_MODEL = "mock-embedding-model"
AZURE_SUBSCRIPTION_ID = "mock-subscription-id"
AZURE_RESOURCE_GROUP = "mock-resource-group"
AZURE_ML_WORKSPACE_NAME = "mock-ml-workspace"
PROMPT_FLOW_ENDPOINT_NAME = "mock-endpoint-name"
PROMPT_FLOW_DEPLOYMENT_NAME = "mock-deployment-name"


@pytest.fixture(autouse=True)
Expand All @@ -26,6 +31,11 @@ def env_helper_mock():
env_helper.AZURE_OPENAI_MODEL = AZURE_OPENAI_MODEL
env_helper.AZURE_OPENAI_MAX_TOKENS = AZURE_OPENAI_MAX_TOKENS
env_helper.AZURE_OPENAI_EMBEDDING_MODEL = AZURE_OPENAI_EMBEDDING_MODEL
env_helper.AZURE_SUBSCRIPTION_ID = AZURE_SUBSCRIPTION_ID
env_helper.AZURE_RESOURCE_GROUP = AZURE_RESOURCE_GROUP
env_helper.AZURE_ML_WORKSPACE_NAME = AZURE_ML_WORKSPACE_NAME
env_helper.PROMPT_FLOW_ENDPOINT_NAME = PROMPT_FLOW_ENDPOINT_NAME
env_helper.PROMPT_FLOW_DEPLOYMENT_NAME = PROMPT_FLOW_DEPLOYMENT_NAME

yield env_helper

Expand Down Expand Up @@ -127,3 +137,21 @@ def test_generate_embeddings_returns_embeddings(azure_openai_mock):

# then
assert actual_embeddings == expected_embeddings

@patch('backend.batch.utilities.helpers.llm_helper.DefaultAzureCredential')
@patch('backend.batch.utilities.helpers.llm_helper.MLClient')
def test_get_ml_client_initializes_with_expected_parameters(
mock_ml_client, mock_default_credential, env_helper_mock):
# given
llm_helper = LLMHelper()

# when
llm_helper.get_ml_client()

# then
mock_ml_client.assert_called_once_with(
mock_default_credential.return_value,
env_helper_mock.AZURE_SUBSCRIPTION_ID,
env_helper_mock.AZURE_RESOURCE_GROUP,
env_helper_mock.AZURE_ML_WORKSPACE_NAME
)

0 comments on commit ce9281a

Please sign in to comment.