Skip to content

Commit

Permalink
fireworks[patch]: remove custom async and stream implementations (lan…
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored and gkorland committed Mar 30, 2024
1 parent d516a31 commit c3f6aa9
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 103 deletions.
102 changes: 1 addition & 101 deletions libs/partners/fireworks/langchain_fireworks/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from operator import itemgetter
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Literal,
Mapping,
Expand All @@ -26,13 +24,11 @@
from fireworks.client import AsyncFireworks, Fireworks # type: ignore
from langchain_core._api import beta
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
Expand All @@ -57,7 +53,7 @@
JsonOutputKeyToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
from langchain_core.tools import BaseTool
Expand Down Expand Up @@ -348,40 +344,6 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
combined["system_fingerprint"] = system_fingerprint
return combined

def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}

default_chunk_class = AIMessageChunk
for chunk in self.client.create(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk

def _generate(
self,
messages: List[BaseMessage],
Expand Down Expand Up @@ -438,68 +400,6 @@ def _create_chat_result(self, response: Union[dict, BaseModel]) -> ChatResult:
}
return ChatResult(generations=generations, llm_output=llm_output)

async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}

default_chunk_class = AIMessageChunk
async for chunk in await self.async_client.create(
messages=message_dicts, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk

async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)

message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": stream} if stream is not None else {}),
**kwargs,
}
response = await self.async_client.create(messages=message_dicts, **params)
return self._create_chat_result(response)

@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
Expand Down
2 changes: 1 addition & 1 deletion libs/partners/fireworks/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/partners/fireworks/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-fireworks"
version = "0.1.0"
version = "0.1.1"
description = "An integration package connecting Fireworks and LangChain"
authors = []
readme = "README.md"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,64 @@ class MyTool(BaseModel):
"name": "Erick",
}
assert tool_call["type"] == "function"


def test_stream() -> None:
"""Test streaming tokens from ChatFireworks."""
llm = ChatFireworks()

for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token.content, str)


async def test_astream() -> None:
"""Test streaming tokens from ChatFireworks."""
llm = ChatFireworks()

async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token.content, str)


async def test_abatch() -> None:
"""Test abatch tokens from ChatFireworks."""
llm = ChatFireworks()

result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)


async def test_abatch_tags() -> None:
"""Test batch tokens from ChatFireworks."""
llm = ChatFireworks()

result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token.content, str)


def test_batch() -> None:
"""Test batch tokens from ChatFireworks."""
llm = ChatFireworks()

result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token.content, str)


async def test_ainvoke() -> None:
"""Test invoke tokens from ChatFireworks."""
llm = ChatFireworks()

result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result.content, str)


def test_invoke() -> None:
"""Test invoke tokens from ChatFireworks."""
llm = ChatFireworks()

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str)
61 changes: 61 additions & 0 deletions libs/partners/fireworks/tests/integration_tests/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,64 @@ async def test_fireworks_acall() -> None:
output_text = output.generations[0][0].text
assert isinstance(output_text, str)
assert output_text.count("bar") <= 1


def test_stream() -> None:
"""Test streaming tokens from OpenAI."""
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")

for token in llm.stream("I'm Pickle Rick"):
assert isinstance(token, str)


async def test_astream() -> None:
"""Test streaming tokens from OpenAI."""
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")

async for token in llm.astream("I'm Pickle Rick"):
assert isinstance(token, str)


async def test_abatch() -> None:
"""Test streaming tokens from Fireworks."""
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")

result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)


async def test_abatch_tags() -> None:
"""Test batch tokens from Fireworks."""
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")

result = await llm.abatch(
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
)
for token in result:
assert isinstance(token, str)


def test_batch() -> None:
"""Test batch tokens from Fireworks."""
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")

result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
for token in result:
assert isinstance(token, str)


async def test_ainvoke() -> None:
"""Test invoke tokens from Fireworks."""
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")

result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
assert isinstance(result, str)


def test_invoke() -> None:
"""Test invoke tokens from Fireworks."""
llm = Fireworks(model="accounts/fireworks/models/mixtral-8x7b-instruct")

result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result, str)

0 comments on commit c3f6aa9

Please sign in to comment.