Skip to content

Commit

Permalink
Bump up openai version to >=1.0 & use get_conn (#36014)
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Dec 6, 2023
1 parent 58e264c commit d2514b4
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 46 deletions.
42 changes: 23 additions & 19 deletions airflow/providers/openai/hooks/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

from __future__ import annotations

from functools import cached_property
from typing import Any

import openai
from openai import OpenAI

from airflow.hooks.base import BaseHook

Expand All @@ -41,37 +42,40 @@ class OpenAIHook(BaseHook):
def __init__(self, conn_id: str = default_conn_name, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.conn_id = conn_id
openai.api_key = self._get_api_key()
api_base = self._get_api_base()
if api_base:
openai.api_base = api_base

@staticmethod
def get_ui_field_behaviour() -> dict[str, Any]:
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Return custom field behaviour."""
return {
"hidden_fields": ["schema", "port", "login", "extra"],
"hidden_fields": ["schema", "port", "login"],
"relabeling": {"password": "API Key"},
"placeholders": {},
}

def test_connection(self) -> tuple[bool, str]:
try:
openai.Model.list()
self.conn.models.list()
return True, "Connection established!"
except Exception as e:
return False, str(e)

def _get_api_key(self) -> str:
"""Get the OpenAI API key from the connection."""
conn = self.get_connection(self.conn_id)
if not conn.password:
raise ValueError("OpenAI API key not found in connection")
return str(conn.password)
@cached_property
def conn(self) -> OpenAI:
"""Return an OpenAI connection object."""
return self.get_conn()

def _get_api_base(self) -> None | str:
def get_conn(self) -> OpenAI:
"""Return an OpenAI connection object."""
conn = self.get_connection(self.conn_id)
return conn.host
extras = conn.extra_dejson
openai_client_kwargs = extras.get("openai_client_kwargs", {})
api_key = openai_client_kwargs.pop("api_key", None) or conn.password
base_url = openai_client_kwargs.pop("base_url", None) or conn.host or None
return OpenAI(
api_key=api_key,
base_url=base_url,
**openai_client_kwargs,
)

def create_embeddings(
self,
Expand All @@ -84,6 +88,6 @@ def create_embeddings(
:param text: The text to generate embeddings for.
:param model: The model to use for generating embeddings.
"""
response = openai.Embedding.create(model=model, input=text, **kwargs)
embeddings: list[float] = response["data"][0]["embedding"]
response = self.conn.embeddings.create(model=model, input=text, **kwargs)
embeddings: list[float] = response.data[0].embedding
return embeddings
2 changes: 1 addition & 1 deletion airflow/providers/openai/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ integrations:

dependencies:
- apache-airflow>=2.5.0
- openai[datalib]>=0.28.1,<1.0
- openai[datalib]>=1.0

hooks:
- integration-name: OpenAI
Expand Down
17 changes: 17 additions & 0 deletions docs/apache-airflow-providers-openai/connections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,20 @@ API Key (required)

Host (optional)
The host address of the OpenAI instance.

Extra (optional)
Specify the extra parameters (as json dictionary) that can be used in the
connection. All parameters are optional.
This ``extra`` field accepts a nested dictionary with key ``openai_client_kwargs`` as key-value pairs that
are passed to the `OpenAI client <https://github.com/openai/openai-python/blob/main/src/openai/_client.py>`__
on instantiation. For example, to set the timeout for the client, you can pass the following dictionary
as the ``extra`` field:

.. code-block:: json
{
"openai_client_kwargs": {
"timeout": 10,
"api_key": "YOUR_API_KEY"
}
}
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@
"openai": {
"deps": [
"apache-airflow>=2.5.0",
"openai[datalib]>=0.28.1,<1.0"
"openai[datalib]>=1.0"
],
"cross-providers-deps": [],
"excluded-python-versions": []
Expand Down
127 changes: 102 additions & 25 deletions tests/providers/openai/hooks/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,48 +16,125 @@
# under the License.
from __future__ import annotations

from unittest.mock import Mock, patch
import os
from unittest.mock import patch

import pytest
from openai.types import CreateEmbeddingResponse, Embedding

from airflow.models import Connection
from airflow.providers.openai.hooks.openai import OpenAIHook


@pytest.fixture
def openai_hook():
with patch("airflow.providers.openai.hooks.openai.OpenAIHook._get_api_key"), patch(
"airflow.providers.openai.hooks.openai.OpenAIHook._get_api_base"
) as _:
yield OpenAIHook(conn_id="test_conn_id")
def mock_openai_connection():
conn_id = "openai_conn"
conn = Connection(
conn_id=conn_id,
conn_type="openai",
)
os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
yield conn


@pytest.fixture
def mock_embeddings_response():
return {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
def mock_openai_hook(mock_openai_connection):
with patch("airflow.providers.openai.hooks.openai.OpenAI"):
yield OpenAIHook(conn_id=mock_openai_connection.conn_id)


@pytest.fixture
def mock_completions_response():
return Mock(
id="completion-id",
object="completion",
created=1234567890,
model="text-davinci-002",
usage={"prompt_tokens": 15, "completion_tokens": 32, "total_tokens": 47},
choices=[Mock(text="the quick brown fox", finish_reason="stop", index=0)],
def mock_embeddings_response():
return CreateEmbeddingResponse(
data=[Embedding(embedding=[0.1, 0.2, 0.3], index=0, object="embedding")],
model="text-embedding-ada-002-v2",
object="list",
usage={"prompt_tokens": 4, "total_tokens": 4},
)


def test_create_embeddings(openai_hook, mock_embeddings_response):
def test_create_embeddings(mock_openai_hook, mock_embeddings_response):
text = "Sample text"
with patch("openai.Embedding.create", return_value=mock_embeddings_response):
embeddings = openai_hook.create_embeddings(text)
mock_openai_hook.conn.embeddings.create.return_value = mock_embeddings_response
embeddings = mock_openai_hook.create_embeddings(text)
assert embeddings == [0.1, 0.2, 0.3]


def test_get_api_key():
mock_connection = Mock()
mock_connection.password = "your_api_key"
OpenAIHook.get_connection = Mock(return_value=mock_connection)
api_key = OpenAIHook()._get_api_key()
assert api_key == "your_api_key"
def test_openai_hook_test_connection(mock_openai_hook):
result, message = mock_openai_hook.test_connection()
assert result is True
assert message == "Connection established!"


@patch("airflow.providers.openai.hooks.openai.OpenAI")
def test_get_conn_with_api_key_in_extra(mock_client):
conn_id = "api_key_in_extra"
conn = Connection(
conn_id=conn_id,
conn_type="openai",
extra={"openai_client_kwargs": {"api_key": "api_key_in_extra"}},
)
os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
hook = OpenAIHook(conn_id=conn_id)
hook.get_conn()
mock_client.assert_called_once_with(
api_key="api_key_in_extra",
base_url=None,
)


@patch("airflow.providers.openai.hooks.openai.OpenAI")
def test_get_conn_with_api_key_in_password(mock_client):
conn_id = "api_key_in_password"
conn = Connection(
conn_id=conn_id,
conn_type="openai",
password="api_key_in_password",
)
os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
hook = OpenAIHook(conn_id=conn_id)
hook.get_conn()
mock_client.assert_called_once_with(
api_key="api_key_in_password",
base_url=None,
)


@patch("airflow.providers.openai.hooks.openai.OpenAI")
def test_get_conn_with_base_url_in_extra(mock_client):
conn_id = "base_url_in_extra"
conn = Connection(
conn_id=conn_id,
conn_type="openai",
extra={"openai_client_kwargs": {"base_url": "base_url_in_extra", "api_key": "api_key_in_extra"}},
)
os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
hook = OpenAIHook(conn_id=conn_id)
hook.get_conn()
mock_client.assert_called_once_with(
api_key="api_key_in_extra",
base_url="base_url_in_extra",
)


@patch("airflow.providers.openai.hooks.openai.OpenAI")
def test_get_conn_with_openai_client_kwargs(mock_client):
conn_id = "openai_client_kwargs"
conn = Connection(
conn_id=conn_id,
conn_type="openai",
extra={
"openai_client_kwargs": {
"api_key": "api_key_in_extra",
"organization": "organization_in_extra",
}
},
)
os.environ[f"AIRFLOW_CONN_{conn.conn_id.upper()}"] = conn.get_uri()
hook = OpenAIHook(conn_id=conn_id)
hook.get_conn()
mock_client.assert_called_once_with(
api_key="api_key_in_extra",
base_url=None,
organization="organization_in_extra",
)

0 comments on commit d2514b4

Please sign in to comment.