## Define the Proxy Wrapper

In [0]:
import boto3
import mlflow

mlflow.set_registry_uri("databricks-uc")
uc_model_location = "adrian_test.genai.azure_openai_proxy"

In [0]:
import mlflow
import openai
import httpx
import os
import socket
from urllib.parse import urlparse

class AzureOpenAPIProxy(mlflow.pyfunc.PythonModel):
    def __init__(self):
        self.override_config = None
    
    def _get_config_value(self, context, key, default_value=None):
        value = context.model_config.get(key)
        if value in (None, ""):
            value = os.getenv(key, default_value)
        assert value not in (None, ""), f"Missing value for key {key}" 
        return str(value)

    @staticmethod
    def get_config_properties():
        return [
            "AZURE_OPENAI_API_KEY",
            "AZURE_OPENAI_ENDPOINT",
            "AZURE_OPENAI_API_VERSION",
            "AZURE_OPENAI_MODEL",
            "AZURE_OPENAI_VERIFY_SSL"
        ]

    def load_context(self, context):
        self.api_key = self._get_config_value(context, "AZURE_OPENAI_API_KEY")
        self.azure_endpoint = self._get_config_value(context, "AZURE_OPENAI_ENDPOINT")
        self.api_version = self._get_config_value(context, "AZURE_OPENAI_API_VERSION")
        self.model = self._get_config_value(context, "AZURE_OPENAI_MODEL")
        self.verify_ssl = self._get_config_value(context, "AZURE_OPENAI_VERIFY_SSL", "true")
        assert self.verify_ssl in ("True", "true", "False", "false"), f"AZURE_OPENAI_VERIFY_SSL must be either True or False, got {self.verify_ssl}"
        self.verify_ssl = self.verify_ssl in ("True", "true")
        
        # Check if this is being resolved privately
        domain = urlparse(self.azure_endpoint).netloc
        addr1 = socket.gethostbyname_ex(domain)
        print(addr1)

        self.client = openai.AzureOpenAI(
            api_key=self.api_key,  
            api_version=self.api_version,
            azure_endpoint=self.azure_endpoint,
            http_client=httpx.Client(verify=self.verify_ssl)
        ) 


    def predict(self, context, messages, params):
        response = self.client.chat.completions.create(
            model="gpt-4",
            messages=messages.to_dict(orient="records")[0]["messages"]
        )
        return response.dict()


In [0]:
conda_env = mlflow.pyfunc.get_default_conda_env()
for dep in conda_env["dependencies"]:
    if type(dep) == type({}) and "pip" in dep: 
        dep["pip"] += [
        f"openai=={openai.__version__}"
    ]
conda_env

In [0]:
with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=AzureOpenAPIProxy(),
        signature=mlflow.models.ModelSignature(
                mlflow.types.llm.CHAT_MODEL_INPUT_SCHEMA,
                mlflow.types.llm.CHAT_MODEL_OUTPUT_SCHEMA
            ),
        conda_env=conda_env,
        model_config={key:None for key in AzureOpenAPIProxy.get_config_properties()},
        registered_model_name=uc_model_location
    )

## Test the Proxy

In [0]:
import mlflow
mlflow_client = mlflow.MlflowClient()
model_versions = mlflow_client.search_model_versions(f"name='{uc_model_location}'")
max_version = max([mv.version for mv in model_versions])
logged_model = f'models:/{uc_model_location}/{max_version}'

# Load model as a PyFuncModel.
loaded_model = mlflow.pyfunc.load_model(logged_model, model_config={
    "AZURE_OPENAI_API_KEY": dbutils.secrets.get("<scope>", "<key>"),
    "AZURE_OPENAI_ENDPOINT": "https://<endpoint>.openai.azure.com",
    "AZURE_OPENAI_API_VERSION": "2023-05-15",
    "AZURE_OPENAI_MODEL": "gpt-4",
    "AZURE_OPENAI_VERIFY_SSL": "True"
})

response = loaded_model.predict({"messages": [{"role": "user", "content": "Hello, how are you?"}]})

In [0]:
response

## Deploy the Proxy

In [0]:
from requests.exceptions import HTTPError
from mlflow.deployments import get_deploy_client


def create_or_update_depyloyment(name, api_key_secret_scope, api_key_secret_key, api_version, api_endpoint, verify_ssl="True", model="gpt-4"):
    deploy_client = get_deploy_client("databricks")
    mlflow_client = mlflow.MlflowClient()
    model_versions = mlflow_client.search_model_versions(f"name='{uc_model_location}'")
    max_version = max([mv.version for mv in model_versions])
    model_config={
        "served_entities": [
            {
                "entity_name": uc_model_location,
                "entity_version": max_version,
                "workload_size": "Small",
                "scale_to_zero_enabled": True,
                'environment_vars': {
                    'AZURE_OPENAI_API_KEY': f'{{{{secrets/{api_key_secret_scope}/{api_key_secret_key}}}}}',
                    'AZURE_OPENAI_API_VERSION': api_version,
                    'AZURE_OPENAI_ENDPOINT': api_endpoint,
                    'AZURE_OPENAI_MODEL': model,
                    'AZURE_OPENAI_VERIFY_SSL': verify_ssl
                },
            }
        ],
        "traffic_config": {
            "routes": [
                {
                    "served_model_name": f"{uc_model_location.split('.')[-1]}-{max_version}",
                    "traffic_percentage": 100
                }
            ]
        }
    }
    create_or_update_func = None
    try:
        deploy_client.get_endpoint(name)
        create_or_update_func = deploy_client.update_endpoint
    except HTTPError as e:
        if e.response.status_code == 404:
            create_or_update_func = deploy_client.create_endpoint
        else:
            raise
    create_or_update_func(name, model_config)

In [0]:
create_or_update_depyloyment(
    "azure_openai_proxy_public",
    api_endpoint="https://<endpoint>.openai.azure.com",
    api_key_secret_scope="<scope>",
    api_key_secret_key="<key>", 
    api_version="2023-05-15"
)