In [None]:
#| default_exp function_callbacks

In [None]:
#| export
from datetime import datetime
from typing import Any, Dict, List, Optional, Union, Callable, Awaitable
from uuid import UUID
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema.messages import BaseMessage
from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
from enum import Enum
import asyncio
from traceback import format_exception
from pino_inferior.core import OPENAI_API_KEY
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

In [None]:
#| export
class LLMEventType(Enum):
    LLM_START = "LLM_START"
    LLM_TOKEN = "TOKEN"
    LLM_ERROR = "LLM_ERROR"
    LLM_END = "LLM_END"


AsyncLLMCallback = Callable[[LLMEventType, datetime, str], Awaitable[None]]

In [None]:
#| export
class AsyncFunctionalStyleChatCompletionHandler(AsyncCallbackHandler):
    def __init__(self, callback: AsyncLLMCallback) -> None:
        super().__init__()
        self.llm_callback = callback
    
    async def on_llm_start(self,
                           serialized: Dict[str, Any],
                           prompts: List[str],
                           *,
                           run_id: UUID,
                           parent_run_id: UUID | None = None,
                           tags: List[str] | None = None,
                           metadata: Dict[str, Any] | None = None,
                           **kwargs: Any) -> None:
        assert len(prompts) == 1, "This agent structure works with 1 query each time"
        time = datetime.now()
        await asyncio.gather(*[
            self.llm_callback(
                LLMEventType.LLM_START,
                time,
                prompt
            )
            for prompt in prompts
        ])
    
    async def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, metadata: Dict[str, Any] | None = None, **kwargs: Any) -> Any:
        prompts = [
            "\n\n".join([
                f"{message.type}: {message.content}"
                for message in thread
            ])
            for thread in messages
        ]
        await self.on_llm_start(serialized, prompts, run_id=run_id, parent_run_id=parent_run_id, tags=tags, metadata=metadata, **kwargs)
    
    async def on_llm_new_token(self, token: str, *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
        await self.llm_callback(LLMEventType.LLM_TOKEN, datetime.now(), token)
    
    async def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
        assert len(response.generations) == 1, "This agent implementation works with 1 generated response"
        assert len(response.generations[0]) == 1, "This agent implementation works with 1 generated response"
        text = response.generations[0][0].text
        await self.llm_callback(
            LLMEventType.LLM_END,
            datetime.now(),
            text,
        )
    
    async def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
        await self.llm_callback(
            LLMEventType.LLM_ERROR,
            datetime.now(),
            "".join(format_exception(error)),
        )
        raise error

In [None]:
async def _callback(event: LLMEventType, time: datetime, text: str) -> None:
    print(f"{event.value} {time} {text}")


callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = ChatOpenAI(openai_api_key=OPENAI_API_KEY, callbacks=[callback], streaming=True)
for token in model.stream([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")]):
    pass

LLM_START 2023-11-08 00:41:02.474656 human: Tell me what is (2+2) * 2. Think step by step
_completion_with_retry {'messages': [{'role': 'user', 'content': 'Tell me what is (2+2) * 2. Think step by step'}], 'model': 'gpt-3.5-turbo', 'request_timeout': None, 'max_tokens': None, 'stream': True, 'n': 1, 'temperature': 0.7, 'api_key': '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', 'api_base': '', 'organization': ''}
TOKEN 2023-11-08 00:41:03.313383 
TOKEN 2023-11-08 00:41:03.315638 To
TOKEN 2023-11-08 00:41:03.322724  solve
TOKEN 2023-11-08 00:41:03.353727  (
TOKEN 2023-11-08 00:41:03.368723 2
TOKEN 2023-11-08 00:41:03.385724 +
TOKEN 2023-11-08 00:41:03.413130 2
TOKEN 2023-11-08 00:41:03.416131 )
TOKEN 2023-11-08 00:41:03.431058  *
TOKEN 2023-11-08 00:41:03.457078  
TOKEN 2023-11-08 00:41:03.479733 2
TOKEN 2023-11-08 00:41:03.491756  step
TOKEN 2023-11-08 00:41:03.510005  by
TOKEN 2023-11-08 00:41:03.527917  step
TOKEN 2023-11-08 00:41:03.540580 ,
TOKEN 2023-11-08 00:41:03.588207  f

In [None]:
model([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")])

LLM_START 2023-11-08 00:41:05.467323 human: Tell me what is (2+2) * 2. Think step by step
_generate True
_completion_with_retry {'messages': [{'role': 'user', 'content': 'Tell me what is (2+2) * 2. Think step by step'}], 'model': 'gpt-3.5-turbo', 'request_timeout': None, 'max_tokens': None, 'stream': True, 'n': 1, 'temperature': 0.7, 'api_key': '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', 'api_base': '', 'organization': ''}
TOKEN 2023-11-08 00:41:06.441656 
TOKEN 2023-11-08 00:41:06.443637 To
TOKEN 2023-11-08 00:41:06.462808  solve
TOKEN 2023-11-08 00:41:06.475679  the
TOKEN 2023-11-08 00:41:06.497199  expression
TOKEN 2023-11-08 00:41:06.510230  (
TOKEN 2023-11-08 00:41:06.531216 2
TOKEN 2023-11-08 00:41:06.551300 +
TOKEN 2023-11-08 00:41:06.568955 2
TOKEN 2023-11-08 00:41:06.588864 )
TOKEN 2023-11-08 00:41:06.606869  *
TOKEN 2023-11-08 00:41:06.625906  
TOKEN 2023-11-08 00:41:06.644949 2
TOKEN 2023-11-08 00:41:06.663189  step
TOKEN 2023-11-08 00:41:06.684226  by
TOKEN 2023-

AIMessageChunk(content='To solve the expression (2+2) * 2 step by step, we follow the order of operations, which is also known as PEMDAS (Parentheses, Exponents, Multiplication and Division from left to right, Addition and Subtraction from left to right).\n\n1. First, we evaluate the expression within the parentheses: (2+2) = 4.\n2. Now the expression becomes 4 * 2.\n3. Finally, we multiply 4 by 2: 4 * 2 = 8.\n\nTherefore, (2+2) * 2 equals 8.')

In [None]:
model.invoke([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")])

LLM_START 2023-11-08 00:41:09.003494 human: Tell me what is (2+2) * 2. Think step by step
_generate True
_completion_with_retry {'messages': [{'role': 'user', 'content': 'Tell me what is (2+2) * 2. Think step by step'}], 'model': 'gpt-3.5-turbo', 'request_timeout': None, 'max_tokens': None, 'stream': True, 'n': 1, 'temperature': 0.7, 'api_key': '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', 'api_base': '', 'organization': ''}
TOKEN 2023-11-08 00:41:09.959764 
TOKEN 2023-11-08 00:41:09.961765 To
TOKEN 2023-11-08 00:41:09.962765  solve
TOKEN 2023-11-08 00:41:09.970764  (
TOKEN 2023-11-08 00:41:09.988318 2
TOKEN 2023-11-08 00:41:10.007123 +
TOKEN 2023-11-08 00:41:10.045120 2
TOKEN 2023-11-08 00:41:10.088147 )
TOKEN 2023-11-08 00:41:10.101725  *
TOKEN 2023-11-08 00:41:10.115172  
TOKEN 2023-11-08 00:41:10.132421 2
TOKEN 2023-11-08 00:41:10.153958  step
TOKEN 2023-11-08 00:41:10.172337  by
TOKEN 2023-11-08 00:41:10.193357  step
TOKEN 2023-11-08 00:41:10.214199 ,
TOKEN 2023-11-08 00:

AIMessageChunk(content='To solve (2+2) * 2 step by step, follow the order of operations, which is commonly known as PEMDAS (Parentheses, Exponents, Multiplication and Division, and Addition and Subtraction):\n\nStep 1: Simplify inside the parentheses\n(2+2) = 4\n\nStep 2: Multiply\n4 * 2 = 8\n\nTherefore, (2+2) * 2 equals 8.')

In [None]:
await model.ainvoke([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")])

LLM_START 2023-11-08 00:41:11.758682 human: Tell me what is (2+2) * 2. Think step by step
_agenerate True
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:12.518859 
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:12.519859 Sure
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:12.538678 !
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:12.550300  Let
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:12.583776 's
run

AIMessageChunk(content="Sure! Let's break it down step by step:\n\n1. First, we evaluate the expression inside the parentheses: 2 + 2 = 4.\n2. Next, we multiply the result from step 1 by 2: 4 * 2 = 8.\n\nTherefore, (2+2) * 2 equals 8.")

In [None]:
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.chains import TransformChain

In [None]:
chat_prompt = ChatPromptTemplate.from_messages([
    HumanMessagePromptTemplate.from_template(template="Tell me what is (2+2) * 2. Think step by step")
])
callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True)
model.callbacks = [callback]

chain = chat_prompt | model
await chain.ainvoke({})

LLM_START 2023-11-08 00:41:14.396746 human: Tell me what is (2+2) * 2. Think step by step
_agenerate True
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:15.340708 
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:15.340708 Step
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:15.349718  
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:15.365845 1
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:15.382338 :
run_man

AIMessageChunk(content='Step 1: Begin by calculating the sum within the parentheses: 2 + 2 = 4.\nStep 2: Multiply the result from step 1 by 2: 4 * 2 = 8.\n\nTherefore, (2 + 2) * 2 equals 8.')

In [None]:
ct = TransformChain(
    transform=lambda row: {"datetime_str": str(row["datetime"])},
    input_variables=["datetime"],
    output_variables=["datetime_str"]
)
chat_prompt = ChatPromptTemplate.from_messages([
    HumanMessagePromptTemplate.from_template(template="Tell me what is (2+2) * 2. Think step by step")
])
callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True)
model.callbacks = [callback]

chain = ct | chat_prompt | model
await chain.ainvoke({"datetime": datetime.now()})

LLM_START 2023-11-08 00:50:39.622123 human: Tell me what is (2+2) * 2. Think step by step
_agenerate True
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:50:39.985545 
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:50:39.986545 Step
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:50:39.994259  
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:50:40.011337 1
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:50:40.028513 :
run_man

AIMessageChunk(content='Step 1: Evaluate the expression inside the parentheses: 2 + 2 = 4.\nStep 2: Multiply the result from step 1 by 2: 4 * 2 = 8.\nTherefore, (2 + 2) * 2 = 8.')

In [None]:
from langchain_openai_limiter import LimitAwaitChatOpenAI

In [None]:
ct = TransformChain(
    transform=lambda row: {"datetime_str": str(row["datetime"])},
    input_variables=["datetime"],
    output_variables=["datetime_str"]
)
chat_prompt = ChatPromptTemplate.from_messages([
    HumanMessagePromptTemplate.from_template(template="Tell me what is (2+2) * 2. Think step by step")
])
callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = LimitAwaitChatOpenAI(chat_openai=ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True))
model.callbacks = [callback]

chain = ct | chat_prompt | model
await chain.ainvoke({"datetime": datetime.now()})

LLM_START 2023-11-08 00:41:34.474421 human: Tell me what is (2+2) * 2. Think step by step
_agenerate True
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:35.703854 
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:35.703854 To
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:35.716244  solve
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:35.732884  (
run_manager= <langchain.callbacks.manager.AsyncCallbackManagerForLLMRun object> [<__main__.AsyncFunctionalStyleChatCompletionHandler object>]
TOKEN 2023-11-08 00:41:35.749904 2
run

AIMessageChunk(content='To solve (2+2) * 2 step by step, follow the order of operations, which is parentheses, exponents, multiplication and division (from left to right), and addition and subtraction (from left to right):\n\n1. Start with the parentheses: (2+2) = 4\n   Therefore, the expression becomes 4 * 2.\n\n2. Multiply: 4 * 2 = 8\n\nHence, (2+2) * 2 equals 8.')

In [None]:
async def _callback(event: LLMEventType, time: datetime, text: str) -> None:
    print(f"{event.value} {time} {text}")


callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True)
model.callbacks = [callback]
for token in model.stream([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")]):
    pass

LLM_START 2023-11-08 00:41:39.258279 human: Tell me what is (2+2) * 2. Think step by step
_completion_with_retry {'messages': [{'role': 'user', 'content': 'Tell me what is (2+2) * 2. Think step by step'}], 'model': 'gpt-3.5-turbo', 'request_timeout': None, 'max_tokens': None, 'stream': True, 'n': 1, 'temperature': 0.7, 'api_key': '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', 'api_base': '', 'organization': ''}
TOKEN 2023-11-08 00:41:39.562918 
TOKEN 2023-11-08 00:41:39.564295 To
TOKEN 2023-11-08 00:41:39.608966  solve
TOKEN 2023-11-08 00:41:39.627989  the
TOKEN 2023-11-08 00:41:39.662852  expression
TOKEN 2023-11-08 00:41:39.674771  (
TOKEN 2023-11-08 00:41:39.680761 2
TOKEN 2023-11-08 00:41:39.697838 +
TOKEN 2023-11-08 00:41:39.720067 2
TOKEN 2023-11-08 00:41:39.729823 )
TOKEN 2023-11-08 00:41:39.783780  *
TOKEN 2023-11-08 00:41:39.797906  
TOKEN 2023-11-08 00:41:39.817690 2
TOKEN 2023-11-08 00:41:39.839778  step
TOKEN 2023-11-08 00:41:39.856284  by
TOKEN 2023-11-08 00:41:39.

In [None]:
import nbdev; nbdev.nbdev_export()