In [None]:
#| default_exp function_callbacks

In [None]:
#| export
from datetime import datetime
from typing import Any, Dict, List, Optional, Union, Callable, Awaitable
from uuid import UUID
from langchain.callbacks.base import AsyncCallbackHandler
from langchain.schema.messages import BaseMessage
from langchain.schema.output import ChatGenerationChunk, GenerationChunk, LLMResult
from enum import Enum
import asyncio
from traceback import format_exception
from pino_inferior.core import OPENAI_API_KEY
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage

In [None]:
#| export
class LLMEventType(Enum):
    LLM_START = "LLM_START"
    LLM_TOKEN = "TOKEN"
    LLM_ERROR = "LLM_ERROR"
    LLM_END = "LLM_END"


AsyncLLMCallback = Callable[[LLMEventType, datetime, str], Awaitable[None]]

In [None]:
#| export
class AsyncFunctionalStyleChatCompletionHandler(AsyncCallbackHandler):
    def __init__(self, callback: AsyncLLMCallback) -> None:
        super().__init__()
        self.llm_callback = callback
    
    async def on_llm_start(self,
                           serialized: Dict[str, Any],
                           prompts: List[str],
                           *,
                           run_id: UUID,
                           parent_run_id: UUID | None = None,
                           tags: List[str] | None = None,
                           metadata: Dict[str, Any] | None = None,
                           **kwargs: Any) -> None:
        assert len(prompts) == 1, "This agent structure works with 1 query each time"
        time = datetime.now()
        await asyncio.gather(*[
            self.llm_callback(
                LLMEventType.LLM_START,
                time,
                prompt
            )
            for prompt in prompts
        ])
    
    async def on_chat_model_start(self, serialized: Dict[str, Any], messages: List[List[BaseMessage]], *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, metadata: Dict[str, Any] | None = None, **kwargs: Any) -> Any:
        prompts = [
            "\n\n".join([
                f"{message.type}: {message.content}"
                for message in thread
            ])
            for thread in messages
        ]
        await self.on_llm_start(serialized, prompts, run_id=run_id, parent_run_id=parent_run_id, tags=tags, metadata=metadata, **kwargs)
    
    async def on_llm_new_token(self, token: str, *, chunk: GenerationChunk | ChatGenerationChunk | None = None, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
        await self.llm_callback(LLMEventType.LLM_TOKEN, datetime.now(), token)
    
    async def on_llm_end(self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
        assert len(response.generations) == 1, "This agent implementation works with 1 generated response"
        assert len(response.generations[0]) == 1, "This agent implementation works with 1 generated response"
        text = response.generations[0][0].text
        await self.llm_callback(
            LLMEventType.LLM_END,
            datetime.now(),
            text,
        )
    
    async def on_llm_error(self, error: BaseException, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
        await self.llm_callback(
            LLMEventType.LLM_ERROR,
            datetime.now(),
            "".join(format_exception(error)),
        )
        raise error

In [None]:
async def _callback(event: LLMEventType, time: datetime, text: str) -> None:
    print(f"{event.value} {time} {text}")


callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = ChatOpenAI(openai_api_key=OPENAI_API_KEY, callbacks=[callback], streaming=True)
for token in model.stream([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")]):
    pass

LLM_START 2023-11-08 05:04:25.353040 human: Tell me what is (2+2) * 2. Think step by step
TOKEN 2023-11-08 05:04:26.901479 
TOKEN 2023-11-08 05:04:26.904479 Step
TOKEN 2023-11-08 05:04:26.919481  
TOKEN 2023-11-08 05:04:26.943478 1
TOKEN 2023-11-08 05:04:26.946478 :
TOKEN 2023-11-08 05:04:26.978479  Start
TOKEN 2023-11-08 05:04:26.983479  with
TOKEN 2023-11-08 05:04:26.996479  the
TOKEN 2023-11-08 05:04:27.013479  inner
TOKEN 2023-11-08 05:04:27.176480 most
TOKEN 2023-11-08 05:04:27.178480  parentheses
TOKEN 2023-11-08 05:04:27.179481 ,
TOKEN 2023-11-08 05:04:27.179481  which
TOKEN 2023-11-08 05:04:27.180482  is
TOKEN 2023-11-08 05:04:27.181478  
TOKEN 2023-11-08 05:04:27.182480 2
TOKEN 2023-11-08 05:04:27.183478 +
TOKEN 2023-11-08 05:04:27.184479 2
TOKEN 2023-11-08 05:04:27.197479 .

TOKEN 2023-11-08 05:04:27.216479   
TOKEN 2023-11-08 05:04:27.231480  
TOKEN 2023-11-08 05:04:27.250482 2
TOKEN 2023-11-08 05:04:27.262479  +
TOKEN 2023-11-08 05:04:27.282480  
TOKEN 2023-11-08 05:04:27.2

In [None]:
model([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")])

LLM_START 2023-11-08 05:04:28.120732 human: Tell me what is (2+2) * 2. Think step by step
TOKEN 2023-11-08 05:04:28.627197 
TOKEN 2023-11-08 05:04:28.629196 Step
TOKEN 2023-11-08 05:04:28.634955  
TOKEN 2023-11-08 05:04:28.651970 1
TOKEN 2023-11-08 05:04:28.672966 :
TOKEN 2023-11-08 05:04:28.688969  Start
TOKEN 2023-11-08 05:04:28.701971  by
TOKEN 2023-11-08 05:04:28.722967  evaluating
TOKEN 2023-11-08 05:04:28.739966  the
TOKEN 2023-11-08 05:04:28.750966  expression
TOKEN 2023-11-08 05:04:28.787981  inside
TOKEN 2023-11-08 05:04:28.837887  the
TOKEN 2023-11-08 05:04:28.841886  parentheses
TOKEN 2023-11-08 05:04:28.855157 :
TOKEN 2023-11-08 05:04:28.876139  
TOKEN 2023-11-08 05:04:28.892330 2
TOKEN 2023-11-08 05:04:28.916245  +
TOKEN 2023-11-08 05:04:28.934263  
TOKEN 2023-11-08 05:04:28.946271 2
TOKEN 2023-11-08 05:04:28.967284  =
TOKEN 2023-11-08 05:04:28.982316  
TOKEN 2023-11-08 05:04:29.002392 4
TOKEN 2023-11-08 05:04:29.021559 .

TOKEN 2023-11-08 05:04:29.038310 Step
TOKEN 2023-1

AIMessageChunk(content='Step 1: Start by evaluating the expression inside the parentheses: 2 + 2 = 4.\nStep 2: Now, the expression becomes 4 * 2.\nStep 3: Multiply 4 by 2, which gives us the final answer: 8.')

In [None]:
model.invoke([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")])

LLM_START 2023-11-08 05:04:29.662832 human: Tell me what is (2+2) * 2. Think step by step
TOKEN 2023-11-08 05:04:30.114008 
TOKEN 2023-11-08 05:04:30.116008 To
TOKEN 2023-11-08 05:04:30.127898  solve
TOKEN 2023-11-08 05:04:30.154553  (
TOKEN 2023-11-08 05:04:30.176467 2
TOKEN 2023-11-08 05:04:30.196675 +
TOKEN 2023-11-08 05:04:30.218362 2
TOKEN 2023-11-08 05:04:30.228967 )
TOKEN 2023-11-08 05:04:30.243976  *
TOKEN 2023-11-08 05:04:30.266546  
TOKEN 2023-11-08 05:04:30.314875 2
TOKEN 2023-11-08 05:04:30.318914  step
TOKEN 2023-11-08 05:04:30.322914  by
TOKEN 2023-11-08 05:04:30.338927  step
TOKEN 2023-11-08 05:04:30.356929 ,
TOKEN 2023-11-08 05:04:30.372931  we
TOKEN 2023-11-08 05:04:30.388381  follow
TOKEN 2023-11-08 05:04:30.407599  the
TOKEN 2023-11-08 05:04:30.422340  order
TOKEN 2023-11-08 05:04:30.441617  of
TOKEN 2023-11-08 05:04:30.458621  operations
TOKEN 2023-11-08 05:04:30.477495 ,
TOKEN 2023-11-08 05:04:30.505949  which
TOKEN 2023-11-08 05:04:30.506949  is
TOKEN 2023-11-08 0

AIMessageChunk(content='To solve (2+2) * 2 step by step, we follow the order of operations, which is also known as PEMDAS (Parentheses, Exponents, Multiplication and Division from left to right, and Addition and Subtraction from left to right).\n\n1. Start with the parentheses: (2+2) equals 4.\n   Equation becomes: 4 * 2.\n\n2. Next, perform the multiplication: 4 * 2 equals 8.\n\nTherefore, (2+2) * 2 equals 8.')

In [None]:
await model.ainvoke([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")])

LLM_START 2023-11-08 05:04:32.126167 human: Tell me what is (2+2) * 2. Think step by step
TOKEN 2023-11-08 05:04:32.917700 
TOKEN 2023-11-08 05:04:32.917700 To
TOKEN 2023-11-08 05:04:32.963724  solve
TOKEN 2023-11-08 05:04:32.986533  the
TOKEN 2023-11-08 05:04:33.000902  expression
TOKEN 2023-11-08 05:04:33.022670  (
TOKEN 2023-11-08 05:04:33.034542 2
TOKEN 2023-11-08 05:04:33.050699 +
TOKEN 2023-11-08 05:04:33.066788 2
TOKEN 2023-11-08 05:04:33.083711 )
TOKEN 2023-11-08 05:04:33.101487  *
TOKEN 2023-11-08 05:04:33.119696  
TOKEN 2023-11-08 05:04:33.138037 2
TOKEN 2023-11-08 05:04:33.156495  step
TOKEN 2023-11-08 05:04:33.173545  by
TOKEN 2023-11-08 05:04:33.192975  step
TOKEN 2023-11-08 05:04:33.208660 ,
TOKEN 2023-11-08 05:04:33.225243  we
TOKEN 2023-11-08 05:04:33.255022  follow
TOKEN 2023-11-08 05:04:33.268023  the
TOKEN 2023-11-08 05:04:33.302188  order
TOKEN 2023-11-08 05:04:33.317953  of
TOKEN 2023-11-08 05:04:33.326111  operations
TOKEN 2023-11-08 05:04:33.329961 ,
TOKEN 2023-1

AIMessageChunk(content='To solve the expression (2+2) * 2 step by step, we follow the order of operations, which is known as PEMDAS (Parentheses, Exponents, Multiplication and Division from left to right, and Addition and Subtraction from left to right):\n\n1. First, we evaluate the expression inside the parentheses: (2+2) = 4.\n   So, we have 4 * 2.\n\n2. Next, we perform the multiplication: 4 * 2 = 8.\n\nTherefore, the answer to (2+2) * 2 is 8.')

In [None]:
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.chains import TransformChain

In [None]:
chat_prompt = ChatPromptTemplate.from_messages([
    HumanMessagePromptTemplate.from_template(template="Tell me what is (2+2) * 2. Think step by step")
])
callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True)
model.callbacks = [callback]

chain = chat_prompt | model
await chain.ainvoke({})

LLM_START 2023-11-08 05:04:35.629458 human: Tell me what is (2+2) * 2. Think step by step
TOKEN 2023-11-08 05:04:37.133022 
TOKEN 2023-11-08 05:04:37.133022 Step
TOKEN 2023-11-08 05:04:37.144130  
TOKEN 2023-11-08 05:04:37.158694 1
TOKEN 2023-11-08 05:04:37.173002 :
TOKEN 2023-11-08 05:04:37.187072  Evaluate
TOKEN 2023-11-08 05:04:37.203108  the
TOKEN 2023-11-08 05:04:37.219882  expression
TOKEN 2023-11-08 05:04:37.234562  inside
TOKEN 2023-11-08 05:04:37.252563  the
TOKEN 2023-11-08 05:04:37.290855  parentheses
TOKEN 2023-11-08 05:04:37.290855 .

TOKEN 2023-11-08 05:04:37.298888     
TOKEN 2023-11-08 05:04:37.334915  (
TOKEN 2023-11-08 05:04:37.338915 2
TOKEN 2023-11-08 05:04:37.340915  +
TOKEN 2023-11-08 05:04:37.363905  
TOKEN 2023-11-08 05:04:37.370905 2
TOKEN 2023-11-08 05:04:37.384906 )
TOKEN 2023-11-08 05:04:37.402076  =
TOKEN 2023-11-08 05:04:37.414899  
TOKEN 2023-11-08 05:04:37.429560 4
TOKEN 2023-11-08 05:04:37.442913 


TOKEN 2023-11-08 05:04:37.456924 Step
TOKEN 2023-11-08

AIMessageChunk(content='Step 1: Evaluate the expression inside the parentheses.\n     (2 + 2) = 4\n\nStep 2: Multiply the result from step 1 by 2.\n     4 * 2 = 8\n\nTherefore, (2+2) * 2 equals 8.')

In [None]:
ct = TransformChain(
    transform=lambda row: {"datetime_str": str(row["datetime"])},
    input_variables=["datetime"],
    output_variables=["datetime_str"]
)
chat_prompt = ChatPromptTemplate.from_messages([
    HumanMessagePromptTemplate.from_template(template="Tell me what is (2+2) * 2. Think step by step")
])
callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True)
model.callbacks = [callback]

chain = ct | chat_prompt | model
await chain.ainvoke({"datetime": datetime.now()})

TransformChain's atransform is not provided, falling back to synchronous transform


LLM_START 2023-11-08 05:04:38.091777 human: Tell me what is (2+2) * 2. Think step by step
TOKEN 2023-11-08 05:04:38.411720 
TOKEN 2023-11-08 05:04:38.411720 To
TOKEN 2023-11-08 05:04:38.448719  find
TOKEN 2023-11-08 05:04:38.460718  the
TOKEN 2023-11-08 05:04:38.481721  value
TOKEN 2023-11-08 05:04:38.492718  of
TOKEN 2023-11-08 05:04:38.514718  (
TOKEN 2023-11-08 05:04:38.531718 2
TOKEN 2023-11-08 05:04:38.546718 +
TOKEN 2023-11-08 05:04:38.564718 2
TOKEN 2023-11-08 05:04:38.580718 )
TOKEN 2023-11-08 05:04:38.597720  *
TOKEN 2023-11-08 05:04:38.612719  
TOKEN 2023-11-08 05:04:38.629718 2
TOKEN 2023-11-08 05:04:38.647719 ,
TOKEN 2023-11-08 05:04:38.663719  we
TOKEN 2023-11-08 05:04:38.676718  need
TOKEN 2023-11-08 05:04:38.694790  to
TOKEN 2023-11-08 05:04:38.717971  follow
TOKEN 2023-11-08 05:04:38.748071  the
TOKEN 2023-11-08 05:04:38.749072  order
TOKEN 2023-11-08 05:04:38.749072  of
TOKEN 2023-11-08 05:04:38.761093  operations
TOKEN 2023-11-08 05:04:38.776096 ,
TOKEN 2023-11-08 05:

AIMessageChunk(content='To find the value of (2+2) * 2, we need to follow the order of operations, also known as PEMDAS (Parentheses, Exponents, Multiplication and Division from left to right, and Addition and Subtraction from left to right). \n\n1. Start by evaluating the expression inside the parentheses: 2 + 2 = 4.\n2. Now, we have (4) * 2. \n3. Multiply 4 by 2: 4 * 2 = 8.\n\nTherefore, the value of (2+2) * 2 is 8.')

In [None]:
from langchain_openai_limiter import LimitAwaitChatOpenAI

In [None]:
ct = TransformChain(
    transform=lambda row: {"datetime_str": str(row["datetime"])},
    input_variables=["datetime"],
    output_variables=["datetime_str"]
)
chat_prompt = ChatPromptTemplate.from_messages([
    HumanMessagePromptTemplate.from_template(template="Tell me what is (2+2) * 2. Think step by step")
])
callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = LimitAwaitChatOpenAI(chat_openai=ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True))
model.callbacks = [callback]

chain = ct | chat_prompt | model
await chain.ainvoke({"datetime": datetime.now()})

LLM_START 2023-11-08 05:04:40.409566 human: Tell me what is (2+2) * 2. Think step by step
TOKEN 2023-11-08 05:04:41.576429 
TOKEN 2023-11-08 05:04:41.576429 To
TOKEN 2023-11-08 05:04:41.587455  solve
TOKEN 2023-11-08 05:04:41.604731  the
TOKEN 2023-11-08 05:04:41.618554  expression
TOKEN 2023-11-08 05:04:41.632565  (
TOKEN 2023-11-08 05:04:41.645804 2
TOKEN 2023-11-08 05:04:41.664138 +
TOKEN 2023-11-08 05:04:41.675138 2
TOKEN 2023-11-08 05:04:41.687154 )
TOKEN 2023-11-08 05:04:41.739949  *
TOKEN 2023-11-08 05:04:41.740950  
TOKEN 2023-11-08 05:04:41.743749 2
TOKEN 2023-11-08 05:04:41.750191  step
TOKEN 2023-11-08 05:04:41.755991  by
TOKEN 2023-11-08 05:04:41.770992  step
TOKEN 2023-11-08 05:04:41.784030 ,
TOKEN 2023-11-08 05:04:41.799046  follow
TOKEN 2023-11-08 05:04:41.812004  the
TOKEN 2023-11-08 05:04:41.826565  order
TOKEN 2023-11-08 05:04:41.852312  of
TOKEN 2023-11-08 05:04:41.856316  operations
TOKEN 2023-11-08 05:04:41.869354 ,
TOKEN 2023-11-08 05:04:41.898187  which
TOKEN 202

AIMessageChunk(content='To solve the expression (2+2) * 2 step by step, follow the order of operations, which is usually remembered using the acronym PEMDAS:\n\n1. Parentheses: Within parentheses, evaluate any operations first.\n   (2+2) = 4\n\nNow the expression becomes 4 * 2.\n\n2. Multiplication: Multiply the numbers from left to right.\n   4 * 2 = 8\n\nTherefore, the final answer is 8.')

In [None]:
async def _callback(event: LLMEventType, time: datetime, text: str) -> None:
    print(f"{event.value} {time} {text}")


callback = AsyncFunctionalStyleChatCompletionHandler(_callback)
model = ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True)
model.callbacks = [callback]
for token in model.stream([HumanMessage(content="Tell me what is (2+2) * 2. Think step by step")]):
    pass

LLM_START 2023-11-08 05:04:43.005402 human: Tell me what is (2+2) * 2. Think step by step
TOKEN 2023-11-08 05:04:43.910937 
TOKEN 2023-11-08 05:04:43.912937 To
TOKEN 2023-11-08 05:04:43.925837  solve
TOKEN 2023-11-08 05:04:43.939851  the
TOKEN 2023-11-08 05:04:43.960631  expression
TOKEN 2023-11-08 05:04:43.984927  (
TOKEN 2023-11-08 05:04:43.991530 2
TOKEN 2023-11-08 05:04:44.010567 +
TOKEN 2023-11-08 05:04:44.024982 2
TOKEN 2023-11-08 05:04:44.045305 )
TOKEN 2023-11-08 05:04:44.081523  *
TOKEN 2023-11-08 05:04:44.088523  
TOKEN 2023-11-08 05:04:44.094918 2
TOKEN 2023-11-08 05:04:44.112918  step
TOKEN 2023-11-08 05:04:44.128742  by
TOKEN 2023-11-08 05:04:44.146691  step
TOKEN 2023-11-08 05:04:44.161704 ,
TOKEN 2023-11-08 05:04:44.183332  we
TOKEN 2023-11-08 05:04:44.195333  follow
TOKEN 2023-11-08 05:04:44.210333  the
TOKEN 2023-11-08 05:04:44.236044  order
TOKEN 2023-11-08 05:04:44.251517  of
TOKEN 2023-11-08 05:04:44.272545  operations
TOKEN 2023-11-08 05:04:44.287132 ,
TOKEN 2023-1

In [None]:
import nbdev; nbdev.nbdev_export()