Skip to content

Commit

Permalink
feat: Removing new line cleanups from OpenAI Embeddings
Browse files Browse the repository at this point in the history
- For models other than first gen text embeddings (-001) new line removal is not necessary

openai/openai-python#418
  • Loading branch information
tazarov committed May 3, 2024
1 parent 93cc872 commit 4411197
Show file tree
Hide file tree
Showing 3 changed files with 4,523 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ class OpenAIEmbeddingModeModel(str, Enum):
def get_embedding(client: OpenAI, text: str, engine: str, **kwargs: Any) -> List[float]:
"""Get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
NOTE: New line removal is only necessary for non-code engines of gen 001.
See https://github.com/openai/openai-python/issues/418 for more details.
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
"""
text = text.replace("\n", " ")
if engine.endswith("001"):
text = text.replace("\n", " ")

return (
client.embeddings.create(input=[text], model=engine, **kwargs).data[0].embedding
Expand All @@ -144,14 +144,12 @@ async def aget_embedding(
) -> List[float]:
"""Asynchronously get embedding.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
NOTE: New line removal is only necessary for non-code engines of gen 001.
See https://github.com/openai/openai-python/issues/418 for more details.
"""
text = text.replace("\n", " ")
if engine.endswith("001"):
text = text.replace("\n", " ")

return (
(await aclient.embeddings.create(input=[text], model=engine, **kwargs))
Expand All @@ -166,16 +164,14 @@ def get_embeddings(
) -> List[List[float]]:
"""Get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
NOTE: New line removal is only necessary for non-code engines of gen 001.
See https://github.com/openai/openai-python/issues/418 for more details.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."

list_of_text = [text.replace("\n", " ") for text in list_of_text]
if engine.endswith("001"):
list_of_text = [text.replace("\n", " ") for text in list_of_text]

data = client.embeddings.create(input=list_of_text, model=engine, **kwargs).data
return [d.embedding for d in data]
Expand All @@ -190,16 +186,14 @@ async def aget_embeddings(
) -> List[List[float]]:
"""Asynchronously get embeddings.
NOTE: Copied from OpenAI's embedding utils:
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py
Copied here to avoid importing unnecessary dependencies
like matplotlib, plotly, scipy, sklearn.
NOTE: New line removal is only necessary for non-code engines of gen 001.
See https://github.com/openai/openai-python/issues/418 for more details.
"""
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."

list_of_text = [text.replace("\n", " ") for text in list_of_text]
if engine.endswith("001"):
list_of_text = [text.replace("\n", " ") for text in list_of_text]

data = (
await aclient.embeddings.create(input=list_of_text, model=engine, **kwargs)
Expand Down

0 comments on commit 4411197

Please sign in to comment.