Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

langchain llamaindex callback handler #127

Merged
merged 33 commits into from
Jun 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
ae286fb
feat: logger function for langchain
csgulati09 Apr 1, 2024
25dcec9
feat: basic llama index callbackHandler template
csgulati09 Apr 4, 2024
2d83b4f
feat: basic llama index callbackHandler template
csgulati09 Apr 4, 2024
cf33db1
feat: updated req res object for logging
csgulati09 Apr 8, 2024
cd543c3
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 Apr 8, 2024
4fe6eaf
feat: clean up for langchain and logger
csgulati09 Apr 13, 2024
77f0340
feat: llama index callback handler
csgulati09 Apr 18, 2024
c361a54
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 May 4, 2024
e2d39ea
feat: update langchain and llamaindex handlers + logger file
csgulati09 May 6, 2024
2a8256a
fix: linting issues + code clean up
csgulati09 May 9, 2024
fbbfdac
fix: llamaindex init file
csgulati09 May 18, 2024
bab40e3
fix: base url for langchain
csgulati09 Jun 1, 2024
1b3f158
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 Jun 1, 2024
2904274
fix:linting issues
csgulati09 Jun 1, 2024
dcc6b6e
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 Jun 1, 2024
627e754
feat: logger for langchain and llamaindex
csgulati09 Jun 4, 2024
4ff1b08
fix: linitng issues
csgulati09 Jun 4, 2024
ded6324
fix: file structure for callbackhanders
csgulati09 Jun 4, 2024
deadd8b
feat: test cases for langchain and llamaindex
csgulati09 Jun 8, 2024
7a35942
fix: linting issues
csgulati09 Jun 8, 2024
046b69a
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 Jun 11, 2024
57e3755
fix: models.json for tests and base url to prod
csgulati09 Jun 11, 2024
a11eced
fix: linting issues + conditional import
csgulati09 Jun 11, 2024
983b646
fix: extra dependency for conditional import
csgulati09 Jun 11, 2024
614cd58
fix: tested conditional import + init files fixed
csgulati09 Jun 12, 2024
8864bac
fix: token count for llamaindex
csgulati09 Jun 13, 2024
57da8e2
fix: restructuring setup.cfg
csgulati09 Jun 18, 2024
3940880
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 Jun 19, 2024
5d8552e
fix: import statement for llm test cases
csgulati09 Jun 19, 2024
244b7fe
feat: prompt tokens for llamaindex
csgulati09 Jun 19, 2024
7ebd68b
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 Jun 22, 2024
84483b1
fix: type + make file command
csgulati09 Jun 22, 2024
9d973b7
Merge branch 'main' into feat/langchainCallbackHandler
csgulati09 Jun 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,10 @@ upload:
rm -rf dist

dev:
pip install -e ".[dev]"
pip install -e ".[dev]"

langchain_callback:
pip install -e ".[langchain_callback]"

llama_index_callback:
pip install -e ".[llama_index_callback]"
33 changes: 33 additions & 0 deletions portkey_ai/api_resources/apis/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
import os
from typing import Optional
import requests

from portkey_ai.api_resources.global_constants import PORTKEY_BASE_URL


class Logger:
def __init__(
self,
api_key: Optional[str] = None,
) -> None:
api_key = api_key or os.getenv("PORTKEY_API_KEY")
if api_key is None:
raise ValueError("API key is required to use the Logger API")

self.headers = {
"Content-Type": "application/json",
"x-portkey-api-key": api_key,
}

self.url = PORTKEY_BASE_URL + "/logs"

def log(
self,
log_object: dict,
):
response = requests.post(
url=self.url, data=json.dumps(log_object), headers=self.headers
)

return response
3 changes: 2 additions & 1 deletion portkey_ai/llms/langchain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .chat import ChatPortkey
from .completion import PortkeyLLM
from .portkey_langchain_callback import PortkeyLangchain

__all__ = ["ChatPortkey", "PortkeyLLM"]
__all__ = ["ChatPortkey", "PortkeyLLM", "PortkeyLangchain"]
170 changes: 170 additions & 0 deletions portkey_ai/llms/langchain/portkey_langchain_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from datetime import datetime
import time
from typing import Any, Dict, List, Optional
from portkey_ai.api_resources.apis.logger import Logger

try:
from langchain_core.callbacks import BaseCallbackHandler
except ImportError:
raise ImportError("Please pip install langchain-core to use PortkeyLangchain")


class PortkeyLangchain(BaseCallbackHandler):
def __init__(
self,
api_key: str,
) -> None:
super().__init__()
self.startTimestamp: float = 0
self.endTimestamp: float = 0

self.api_key = api_key

self.portkey_logger = Logger(api_key=api_key)

self.log_object: Dict[str, Any] = {}
self.prompt_records: Any = []

self.request: Any = {}
self.response: Any = {}

# self.responseHeaders: Dict[str, Any] = {}
self.responseBody: Any = None
self.responseStatus: int = 0

self.streamingMode: bool = False

if not api_key:
raise ValueError("Please provide an API key to use PortkeyCallbackHandler")

def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
) -> None:
for prompt in prompts:
messages = prompt.split("\n")
for message in messages:
role, content = message.split(":", 1)
self.prompt_records.append(
{"role": role.lower(), "content": content.strip()}
)

self.startTimestamp = float(datetime.now().timestamp())

self.streamingMode = kwargs.get("invocation_params", False).get("stream", False)

self.request["method"] = "POST"
self.request["url"] = serialized.get("kwargs", "").get(
"base_url", "chat/completions"
)
self.request["provider"] = serialized["id"][2]
self.request["headers"] = serialized.get("kwargs", {}).get(
"default_headers", {}
)
self.request["headers"].update({"provider": serialized["id"][2]})
self.request["body"] = {"messages": self.prompt_records}
self.request["body"].update({**kwargs.get("invocation_params", {})})

def on_chain_start(
self,
serialized: Dict[str, Any],
inputs: Dict[str, Any],
**kwargs: Any,
) -> None:
"""Run when chain starts running."""

def on_llm_end(self, response: Any, **kwargs: Any) -> None:
self.endTimestamp = float(datetime.now().timestamp())
responseTime = self.endTimestamp - self.startTimestamp

usage = (response.llm_output or {}).get("token_usage", "") # type: ignore[union-attr]

self.response["status"] = (
200 if self.responseStatus == 0 else self.responseStatus
)
self.response["body"] = {
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": response.generations[0][0].text,
},
"logprobs": response.generations[0][0].generation_info.get("logprobs", ""), # type: ignore[union-attr] # noqa: E501
"finish_reason": response.generations[0][0].generation_info.get("finish_reason", ""), # type: ignore[union-attr] # noqa: E501
}
]
}
self.response["body"].update({"usage": usage})
self.response["body"].update({"id": str(kwargs.get("run_id", ""))})
self.response["body"].update({"created": int(time.time())})
self.response["body"].update({"model": (response.llm_output or {}).get("model_name", "")}) # type: ignore[union-attr] # noqa: E501
self.response["body"].update({"system_fingerprint": (response.llm_output or {}).get("system_fingerprint", "")}) # type: ignore[union-attr] # noqa: E501
self.response["time"] = int(responseTime * 1000)
self.response["headers"] = {}
self.response["streamingMode"] = self.streamingMode

self.log_object.update(
{
"request": self.request,
"response": self.response,
}
)

self.portkey_logger.log(log_object=self.log_object)

def on_chain_end(
self,
outputs: Dict[str, Any],
**kwargs: Any,
) -> None:
"""Run when chain ends running."""
pass

def on_chain_error(self, error: BaseException, **kwargs: Any) -> None:
self.responseBody = error
self.responseStatus = error.status_code # type: ignore[attr-defined]
"""Do nothing."""
pass

def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
self.responseBody = error
self.responseStatus = error.status_code # type: ignore[attr-defined]
"""Do nothing."""
pass

def on_tool_error(self, error: BaseException, **kwargs: Any) -> None:
self.responseBody = error
self.responseStatus = error.status_code # type: ignore[attr-defined]
pass

def on_text(self, text: str, **kwargs: Any) -> None:
pass

def on_agent_finish(self, finish: Any, **kwargs: Any) -> None:
pass

def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
self.streamingMode = True
"""Do nothing."""
pass

def on_tool_start(
self,
serialized: Dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
pass

def on_agent_action(self, action: Any, **kwargs: Any) -> Any:
"""Do nothing."""
pass

def on_tool_end(
self,
output: Any,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
pass
4 changes: 2 additions & 2 deletions portkey_ai/llms/llama_index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .completions import PortkeyLLM
from .portkey_llama_callback import PortkeyLlamaindex

__all__ = ["PortkeyLLM"]
__all__ = ["PortkeyLlamaindex"]
160 changes: 160 additions & 0 deletions portkey_ai/llms/llama_index/portkey_llama_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import time
from typing import Any, Dict, List, Optional
from portkey_ai.api_resources.apis.logger import Logger
from datetime import datetime

try:
from llama_index.core.callbacks.base_handler import (
BaseCallbackHandler as LlamaIndexBaseCallbackHandler,
)
from llama_index.core.utilities.token_counting import TokenCounter
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Please install llama-index to use Portkey Callback Handler"
)
except ImportError:
raise ImportError("Please pip install llama-index to use Portkey Callback Handler")


class PortkeyLlamaindex(LlamaIndexBaseCallbackHandler):
startTimestamp: int = 0
endTimestamp: float = 0

def __init__(
self,
api_key: str,
) -> None:
super().__init__(
event_starts_to_ignore=[],
event_ends_to_ignore=[],
)

self.api_key = api_key

self.portkey_logger = Logger(api_key=api_key)

self._token_counter = TokenCounter()
self.completion_tokens = 0
self.prompt_tokens = 0
self.token_llm = 0

self.log_object: Dict[str, Any] = {}
self.prompt_records: Any = []

self.request: Any = {}
self.response: Any = {}

self.responseTime: int = 0
self.streamingMode: bool = False

if not api_key:
raise ValueError("Please provide an API key to use PortkeyCallbackHandler")

def on_event_start( # type: ignore[return]
self,
event_type: Any,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
parent_id: str = "",
**kwargs: Any,
) -> str:
"""Run when an event starts and return id of event."""

if event_type == "llm":
self.llm_event_start(payload)

def on_event_end(
self,
event_type: Any,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
"""Run when an event ends."""

if event_type == "llm":
self.llm_event_stop(payload, event_id)

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""Run when an overall trace is launched."""
self.startTimestamp = int(datetime.now().timestamp())

def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""Run when an overall trace is exited."""

def llm_event_start(self, payload: Any) -> None:
if "messages" in payload:
chunks = payload.get("messages", {})
self.prompt_tokens = self._token_counter.estimate_tokens_in_messages(chunks)
messages = payload.get("messages", {})
self.prompt_records = [
{"role": m.role.value, "content": m.content} for m in messages
]
self.request["method"] = "POST"
self.request["url"] = payload.get("serialized", {}).get(
"api_base", "chat/completions"
)
self.request["provider"] = payload.get("serialized", {}).get("class_name", "")
self.request["headers"] = {}
self.request["body"] = {"messages": self.prompt_records}
self.request["body"].update(
{"model": payload.get("serialized", {}).get("model", "")}
)
self.request["body"].update(
{"temperature": payload.get("serialized", {}).get("temperature", "")}
)

return None

def llm_event_stop(self, payload: Any, event_id) -> None:
self.endTimestamp = float(datetime.now().timestamp())
responseTime = self.endTimestamp - self.startTimestamp

data = payload.get("response", {})

chunks = payload.get("messages", {})
self.completion_tokens = self._token_counter.estimate_tokens_in_messages(chunks)
self.token_llm = self.prompt_tokens + self.completion_tokens
self.response["status"] = 200
self.response["body"] = {
"choices": [
{
"index": 0,
"message": {
"role": data.message.role.value,
"content": data.message.content,
},
"logprobs": data.logprobs,
"finish_reason": "done",
}
]
}
self.response["body"].update(
{
"usage": {
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.token_llm,
}
}
)
self.response["body"].update({"id": event_id})
self.response["body"].update({"created": int(time.time())})
self.response["body"].update({"model": data.raw.get("model", "")})
self.response["time"] = int(responseTime * 1000)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the unit of responseTime here? seconds or ms?

self.response["headers"] = {}
self.response["streamingMode"] = self.streamingMode

self.log_object.update(
{
"request": self.request,
"response": self.response,
}
)
self.portkey_logger.log(log_object=self.log_object)

return None