Skip to content

Commit

Permalink
feat(langchain): improve generation_info and llm outputs
Browse files Browse the repository at this point in the history
Signed-off-by: Tomas Dvorak <toomas2d@gmail.com>
  • Loading branch information
Tomas2D committed Nov 10, 2023
1 parent 4178613 commit 6bf56fa
Show file tree
Hide file tree
Showing 12 changed files with 341 additions and 118 deletions.
13 changes: 11 additions & 2 deletions examples/user/langchain_chat_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from genai.credentials import Credentials
from genai.extensions.langchain.chat_llm import LangChainChatInterface
from genai.schemas import GenerateParams
from genai.schemas.generate_params import ChatOptions
from genai.schemas import ChatOptions, GenerateParams, ReturnOptions
from genai.schemas.generate_params import HAPOptions, ModerationsOptions

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
Expand All @@ -25,6 +25,13 @@
temperature=0.5,
top_k=50,
top_p=1,
stream=True,
return_options=ReturnOptions(input_text=False, input_tokens=True),
moderations=ModerationsOptions(
# Threshold is set to very low level to flag everything (testing purposes)
# or set to True to enable HAP with default settings
hap=HAPOptions(input=True, output=False, threshold=0.01)
),
),
)

Expand All @@ -50,6 +57,8 @@
conversation_id = result.generations[0][0].generation_info["meta"]["conversation_id"]
print(f"New conversation with ID '{conversation_id}' has been created!")
print(f"Response: {result.generations[0][0].text}")
print(result.llm_output)
print(result.generations[0][0].generation_info)

prompt = "Show me some simple code example."
print(f"Request: {prompt}")
Expand Down
34 changes: 34 additions & 0 deletions examples/user/langchain_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

from dotenv import load_dotenv
from langchain.evaluation import EvaluatorType, load_evaluator

from genai.credentials import Credentials
from genai.extensions.langchain import LangChainChatInterface
from genai.schemas import GenerateParams

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
load_dotenv()
api_key = os.getenv("GENAI_KEY", None)
api_endpoint = os.getenv("GENAI_API", None)
credentials = Credentials(api_key, api_endpoint=api_endpoint)

# Load a trajectory (conversation) evaluator
llm = LangChainChatInterface(
model="meta-llama/llama-2-70b-chat",
credentials=credentials,
params=GenerateParams(
decoding_method="sample",
min_new_tokens=1,
max_new_tokens=100,
length_penalty={
"decay_factor": 1.5,
"start_index": 50,
},
temperature=1.2,
stop_sequences=["<|endoftext|>", "}]"],
),
)
evaluator = load_evaluator(evaluator=EvaluatorType.AGENT_TRAJECTORY, llm=llm)
print(evaluator)
15 changes: 11 additions & 4 deletions examples/user/langchain_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from genai.credentials import Credentials
from genai.extensions.langchain import LangChainInterface
from genai.schemas import GenerateParams, ReturnOptions
from genai.schemas import GenerateParams
from genai.schemas.generate_params import HAPOptions, ModerationsOptions

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
Expand All @@ -29,7 +30,7 @@ def on_llm_new_token(


llm = LangChainInterface(
model="google/flan-ul2",
model="google/flan-t5-xl",
credentials=Credentials(api_key, api_endpoint),
params=GenerateParams(
decoding_method="sample",
Expand All @@ -39,12 +40,18 @@ def on_llm_new_token(
temperature=0.5,
top_k=50,
top_p=1,
return_options=ReturnOptions(generated_tokens=True, token_logprobs=True, input_tokens=True),
moderations=ModerationsOptions(
# Threshold is set to very low level to flag everything (testing purposes)
# or set to True to enable HAP with default settings
hap=HAPOptions(input=True, output=True, threshold=0.01)
),
),
)

result = llm.generate(
prompts=["Tell me about IBM."],
callbacks=[Callback()],
)
print(result)
print(f"Response: {result.generations[0][0].text}")
print(result.llm_output)
print(result.generations[0][0].generation_info)
26 changes: 26 additions & 0 deletions examples/user/llama_index_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os

from dotenv import load_dotenv
from llama_index.llms import LangChainLLM

from genai import Credentials
from genai.extensions.langchain import LangChainInterface
from genai.schemas import GenerateParams

load_dotenv()
api_key = os.environ.get("GENAI_KEY")
api_url = os.environ.get("GENAI_API")
langchain_model = LangChainInterface(
model="meta-llama/llama-2-70b-chat",
credentials=Credentials(api_key, api_endpoint=api_url),
params=GenerateParams(
decoding_method="sample",
min_new_tokens=1,
max_new_tokens=10,
),
)

llm = LangChainLLM(llm=langchain_model)
response_gen = llm.stream_chat("What is a molecule?")
for delta in response_gen:
print(delta.delta)
105 changes: 74 additions & 31 deletions src/genai/extensions/langchain/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pydantic import ConfigDict

from genai import Credentials, Model
from genai.exceptions import GenAiException
from genai.schemas import GenerateParams
from genai.schemas.chat import AIMessage, BaseMessage, HumanMessage, SystemMessage
from genai.schemas.generate_params import ChatOptions
Expand All @@ -27,16 +26,14 @@
from .utils import (
create_generation_info_from_response,
create_llm_output,
extract_token_usage,
load_config,
update_token_usage,
update_token_usage_stream,
)
except ImportError:
raise ImportError("Could not import langchain: Please install ibm-generative-ai[langchain] extension.")

__all__ = ["LangChainChatInterface"]


logger = logging.getLogger(__name__)

Message = Union[LCBaseMessage, BaseMessage]
Expand Down Expand Up @@ -77,6 +74,7 @@ class LangChainChatInterface(BaseChatModel):
model: str
params: Optional[GenerateParams] = None
model_config = ConfigDict(extra="forbid", protected_namespaces=())
streaming: Optional[bool] = None

@classmethod
def is_lc_serializable(cls) -> bool:
Expand Down Expand Up @@ -123,20 +121,33 @@ def _stream(
model = Model(self.model, params=params, credentials=self.credentials)

stream = model.chat_stream(messages=convert_messages_to_genai(messages), options=options, **kwargs)
conversation_id: Optional[str] = None
for response in stream:
result = response.results[0] if response else None
if not result:
if not response:
continue

generated_text = result.generated_text or ""
generation_info = create_generation_info_from_response(response)
chunk = ChatGenerationChunk(
message=LCAIMessageChunk(content=generated_text, generation_info=generation_info),
generation_info=generation_info,
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(token=generated_text, chunk=chunk, response=response)
def send_chunk(*, text: str = "", generation_info: dict):
logger.info("Chunk received: {}".format(text))
chunk = ChatGenerationChunk(
message=LCAIMessageChunk(content=text, generation_info=generation_info),
generation_info=generation_info,
)
yield chunk
if run_manager:
run_manager.on_llm_new_token(token=text, chunk=chunk, response=response)

if not conversation_id:
conversation_id = response.conversation_id
else:
response.conversation_id = conversation_id

if response.moderation:
generation_info = create_generation_info_from_response(response, result=response.moderation)
yield from send_chunk(generation_info=generation_info)

for result in response.results or []:
generation_info = create_generation_info_from_response(response, result=result)
yield from send_chunk(text=result.generated_text or "", generation_info=generation_info)

def _generate(
self,
Expand All @@ -149,28 +160,63 @@ def _generate(
) -> ChatResult:
params = to_model_instance(self.params, GenerateParams)
params.stop_sequences = stop or params.stop_sequences

model = Model(self.model, params=params, credentials=self.credentials)
response = model.chat(messages=convert_messages_to_genai(messages), options=options, **kwargs)
result = response.results[0]
assert result

message = LCAIMessage(content=result.generated_text or "")
params.stream = params.stream or self.streaming

def handle_stream():
final_generation: Optional[ChatGenerationChunk] = None
for result in self._stream(
messages=messages,
stop=stop,
run_manager=run_manager,
options=options,
**kwargs,
):
if final_generation:
token_usage = result.generation_info.pop("token_usage")
final_generation += result
update_token_usage_stream(
target=final_generation.generation_info["token_usage"],
source=token_usage,
)
else:
final_generation = result

assert final_generation and final_generation.generation_info
return {
"text": final_generation.text,
"generation_info": final_generation.generation_info.copy(),
}

def handle_non_stream():
model = Model(self.model, params=params, credentials=self.credentials)
response = model.chat(messages=convert_messages_to_genai(messages), options=options, **kwargs)

assert response.results
result = response.results[0]

return {
"text": result.generated_text or "",
"generation_info": create_generation_info_from_response(response, result=result),
}

result = handle_stream() if params.stream else handle_non_stream()
return ChatResult(
generations=[
ChatGeneration(message=message, generation_info=create_generation_info_from_response(response))
ChatGeneration(
message=LCAIMessage(content=result["text"]),
generation_info=result["generation_info"].copy(),
)
],
llm_output=create_llm_output(
model=self.model,
token_usage=extract_token_usage(result.model_dump()),
token_usages=[result["generation_info"]["token_usage"]],
),
)

def get_num_tokens(self, text: str) -> int:
model = Model(self.model, params=self.params, credentials=self.credentials)
response = model.tokenize([text], return_tokens=False)[0]
if response.token_count is None:
raise GenAiException("Invalid tokenize result!")
assert response.token_count is not None
return response.token_count

def get_num_tokens_from_messages(self, messages: list[LCBaseMessage]) -> int:
Expand All @@ -179,8 +225,5 @@ def get_num_tokens_from_messages(self, messages: list[LCBaseMessage]) -> int:
return sum([response.token_count for response in responses if response.token_count])

def _combine_llm_outputs(self, llm_outputs: list[Optional[dict]]) -> dict:
overall_token_usage: dict = extract_token_usage({})
update_token_usage(
target=overall_token_usage, sources=[output.get("token_usage") for output in llm_outputs if output]
)
return {"model_name": self.model, "token_usage": overall_token_usage}
token_usages = [output.get("token_usage") for output in llm_outputs if output]
return create_llm_output(model=self.model, token_usages=token_usages)
Loading

0 comments on commit 6bf56fa

Please sign in to comment.