In [0]:
class Endpoints():
    def __init__(self):
        from mlflow.deployments import get_deploy_client
        self.chat = "llm/v1/chat"
        self.embeddings = "llm/v1/embeddings"
        self.dp_client = get_deploy_client('databricks')
    
    def get_endpoint(self):
      endpoint_list = self.dp_client.list_endpoints()
      endpoint_names = [ep["name"] for ep in endpoint_list]
      return endpoint_names

    api_key = dbutils.secrets.get(scope = "genai-key-vault-secret", key = "openaiapikey")
    def create(self, name, model, provider="openai", task="chat"):
        print(f"Creating Model Endpoint...", end='')
        endpoint_names = self.get_endpoint()
        if name in endpoint_names:
            print(f"Endpoint {name} already exists")
        else:
            endpoint = self.dp_client.create_endpoint(
                config={
                    "name": name,
                    "config": {
                        "served_entities": [{
                            "external_model": {
                                "name": model,                            
                                "provider": provider,
                                "task": self.chat if task=="chat" else self.embeddings,
                                "openai_config": {
                                    "openai_api_type": "openai",
                                    "openai_api_key_plaintext": api_key,
                                    },                            
                                },
                            }]                    
                        },
                    "ai_gateway":{"usage_tracking_config": {"enabled": True}}
                    }
                )
            print("Done")
        
    def delete(self, name):
        endpoint_names = self.get_endpoint()
        if name not in endpoint_names:
            print(f"Endpoint {name} does not exist")
        else:
            print(f"Deleting Model Endpoint...", end='')
            self.dp_client.delete_endpoint(endpoint=name)
            print("Done")

#DP = Deployments()