# Custom Chat Model

Let's see how to create a custom LangChain `ChatModel` implementation.

This allows you to wrap an LLM that may not be currently supported by `LangChain` for your own use and to maybe even contribute it back to LangChain!

Wrapping the LLM with the `ChatModel` interface makes it easy to swap your `LLM` into existing `LangChain` programs with minimal code modifications!

In addition, by wrapping the LLM with a `ChatModel` implementation, the implementation is automatically endowed 
with the standard `LangChain Runnable` interface and benefits from certain 
optimizations out of the box (e.g., batch) as well as async support (e.g., `ainvoke`, `abatch`).

There are 2 interface that you can inherit from to provide your implementation:

1. `SimpleChatModel` -- Meant for prototyping. This iswhen not all features are required (e.g., not need for function calling).
2. `BaseChatModel` -- Best suited for a full implementation that supports all features (e.g., streaming, function calling).

## Inputs and outputs

Before we start with implementing a chat model. Let's take a look at the inputs and outputs of chat models.

### Messages

Chat models take messages as inputs and return chat messages as outputs. 

Here are the messages:

- `SystemMessage`: Used for priming AI behavior, usually passed in as the first of a sequence of input messages.
- `HumanMessage`: Represents a message from a person interacting with the chat model.
- `AIMessage`: Represents a message from the chat model. This can be either text or a request to invoke a tool.
- `FunctionMessage` / `ToolMessage`: Message for passing the results of tool invocation back to the model.

::: {.callout-note}
`ToolMessage` and `FunctionMessage` closely follow OpenAIs `function` and `tool` arguments.

This is a rapidly developing field and as more models add function calling capabilities, expect that there will be additions to this schema.
:::

In [1]:
from langchain_core.messages import BaseMessage, SystemMessage, AIMessage, HumanMessage, FunctionMessage, ToolMessage

### Streaming Variant

All the chat messages have a streaming variant that contains `Chunk` in the name.

In [2]:
from langchain_core.messages import SystemMessageChunk, AIMessageChunk, HumanMessageChunk, FunctionMessageChunk, ToolMessageChunk

These chunks are used when streaming output from chat models, and they all define an additive property!

In [3]:
AIMessageChunk(content="Hello") + AIMessageChunk(content=" World!")

AIMessageChunk(content='Hello World!')

## Simple Chat Model

Inherting from `SimpleChatModel` is great for prototyping!

It won't allow you to implement all features that you might want out of a chat model, but it's quick to implement, and if you need more you can transition to `BaseChatModel` shown below.

Let's implement a chat model that echoes back the last `n` characters of the prompt!

You need to implement the following:

* The method `_call` - Use to generate a chat result from a prompt.
* The property `_llm_type` - Used to uniquely identify the type of the model. Used for logging.

In addition, you have the option to specify the following:

* The property `_identifying_params` - Represent model parameterization for logging purposes.

Optional:

* `_stream` - Use to implement streaming.
* `_agenerate` - Use to implement a native async method
* `_astream` - Use to implement async version of `_stream`

:::{.callout-caution}

If you're implementing streaming and want streaming to work in async, you must provide an async implementation of streaming (`_astream`).

If you want to replicate the logic in the sync variant of `_stream` you can use the trick below by running it in a separate executor.
:::


In [4]:
from typing import Any, List, Optional, Dict

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models import SimpleChatModel
from langchain_core.messages import BaseMessage, HumanMessage


class CustomChatModel(SimpleChatModel):
    """A custom chat model that echoes the first `n` characters of the input.
    
    When contributing an implementation to LangChain, carefully document
    the model including the initialization parameters, include
    an example of how to initialize the model and include any relevant
    links to the underlying models documentation or API.
    
    Example:
        
        .. code-block:: python
        
            model = CustomChatModel(n=2)
            result = model.invoke([HumanMessage(content="hello")])
            result = model.batch([[HumanMessage(content="hello")],
                                 [HumanMessage(content="world")]])
    """
    n: int
    """The number of characters from the last message of the prompt to be echoed."""

    # This is the core logic. It accepts a prompt composed of a list of messages
    # and returns a response which is a string.
    def _call(
            self,
            messages: List[BaseMessage],
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> str:
        """Implementation of the chat model logic.

        Args:
            messages: the prompt composed of a list of messages.
            stop: a list of strings on which the model should stop generating.
                  If generation stops due to a stop token, the stop token itself
                  SHOULD BE INCLUDED as part of the output. This is not enforced
                  across models right now, but it's a good practice to follow since
                  it makes it much easier to parse the output of the model
                  downstream and understand why generation stopped.
            run_manager: A run manager that contains callbacks for the LLM.
                on_chat_start and on_llm_end callbacks are automatically called
                by wrapping code, so you don't need to invoke them when 
                using SimpleChatModel.
                Please refer to the callbacks section in the documentation for
                more details about callbacks.
        """
        return messages[-1].content[:self.n]

    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model.

        This property must return a string. It's used for logging purposes, and
        will be accessible from callbacks. It should identify the type
        of the model uniquely.
        """
        return "echoing_chat_model"

    # **Optional** override the identifying parameters to include additional
    # information about the model parameterization for logging purposes
    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Return a dictionary of identifying parameters."""
        return {"n": self.n}


### Let's test it 🧪

The chat model will implement the standard `Runnable` interface of LangChain which many of the LangChain abstractions support!

In [5]:
model = CustomChatModel(n=7)

In [6]:
model.invoke([HumanMessage(content='hello world!')])

AIMessage(content='hello w')

In addition, it supoprts the standard type conversions associated with chat models to make it more convenient to work with it! 

An input of a string, gets interpreted as a `Human Message`!

In [7]:
model.invoke("Hello World!")

AIMessage(content='Hello W')

And look async works as well 😘

In [8]:
await model.ainvoke("Hello World!")

AIMessage(content='Hello W')

By adding a LangChain interface to your LLM, you get a bunch of optimizations out of the box!

Batch by default executes in a threadpool which means that operations that block on IO (e.g., an API call to another service), will automatically be run in parallel!

:::{.callout-note}
The default `batch` implementation will not help to speed up CPU / GPU bound operations, unless those operations are delegated to lower level code that releases the GIL.

If this doesn't mean much to you, then just ignore it. 😵‍💫
:::

In [9]:
model.batch(['hello world', 'goodbye moon'])

[AIMessage(content='hello w'), AIMessage(content='goodbye')]

:::{.callout-important}
The `astream` interface will work as well, but it won't actually stream since we haven't provided a streaming implementation!
:::

In [10]:
async for chunk in model.astream('hello world'):
    print(chunk)

content='hello w'


## BaseChatModel

You may want to add additional features like function calling or running the model in JSON mode or streaming.

To do so inherit from `BaseChatModel` which is a lower level class and implement the methods:

* `_generate` - Use to generate a chat result from a prompt

Optional:

* `_stream` - Use to implement streaming
* `_agenerate` - Use to implement a native async method 

In [11]:
from typing import Any, AsyncIterator, List, Optional, Dict, Iterator

from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun
from langchain_core.language_models import SimpleChatModel, BaseChatModel
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.outputs import ChatResult, ChatGeneration, ChatGenerationChunk
from langchain_core.messages import AIMessageChunk
from langchain_core.runnables import run_in_executor




class CustomChatModelAdvanced(BaseChatModel):
    """A custom chat model that echoes the first `n` characters of the input.

    When contributing an implementation to LangChain, carefully document
    the model including the initialization parameters, include
    an example of how to initialize the model and include any relevant
    links to the underlying models documentation or API.

    Example:

        .. code-block:: python

            model = CustomChatModel(n=2)
            result = model.invoke([HumanMessage(content="hello")])
            result = model.batch([[HumanMessage(content="hello")],
                                 [HumanMessage(content="world")]])
    """
    n: int
    """The number of characters from the last message of the prompt to be echoed."""


    def _generate(
            self,
            messages: List[BaseMessage],
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> ChatResult:
        """Override the _generate method to implement the chat model logic.

        This can be a call to an API, a call to a local model, or any other
        implementation that generates a response to the input prompt.

        Args:
            messages: the prompt composed of a list of messages.
            stop: a list of strings on which the model should stop generating.
                  If generation stops due to a stop token, the stop token itself
                  SHOULD BE INCLUDED as part of the output. This is not enforced
                  across models right now, but it's a good practice to follow since
                  it makes it much easier to parse the output of the model
                  downstream and understand why generation stopped.
            run_manager: A run manager with callbacks for the LLM.
        """
        last_message = messages[-1]
        tokens = last_message.content[:self.n]    
        message = AIMessage(content=tokens)
        generation = ChatGeneration(message=message)
        return ChatResult(generations=[generation])

    def _stream(
            self,
            messages: List[BaseMessage],
            stop: Optional[List[str]] = None,
            run_manager: Optional[CallbackManagerForLLMRun] = None,
            **kwargs: Any,
    ) -> Iterator[ChatGenerationChunk]:
        """Stream the output of the model.

        This method should be implemented if the model can generate output
        in a streaming fashion. If the model does not support streaming,
        do not implement it. In that case streaming requests will be automatically
        handled by the _generate method.

        Args:
            messages: the prompt composed of a list of messages.
            stop: a list of strings on which the model should stop generating.
                  If generation stops due to a stop token, the stop token itself
                  SHOULD BE INCLUDED as part of the output. This is not enforced
                  across models right now, but it's a good practice to follow since
                  it makes it much easier to parse the output of the model
                  downstream and understand why generation stopped.
            run_manager: A run manager with callbacks for the LLM.
        """
        last_message = messages[-1]
        tokens = last_message.content[:self.n]

        for token in tokens:
            chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))

            if run_manager:
                run_manager.on_llm_new_token(token, chunk=chunk)

            yield chunk

    async def _astream(
        self,
        messages: List[BaseMessage],
        stop: Optional[List[str]] = None,
        run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> AsyncIterator[ChatGenerationChunk]:
        """An async variant of astream.

        If not provided, the default behavior is to delegate to the _generate method.

        The implementation below instead will delegate to `_stream` and will
        kick it off in a separate thread.

        If you're able to natively support async, then by all means do so!
        """
        result = await run_in_executor(
            None,
            self._stream,
            messages,
            stop=stop,
            run_manager=run_manager.get_sync() if run_manager else None,
            **kwargs,
        )
        for chunk in result:
            yield chunk


    @property
    def _llm_type(self) -> str:
        """Get the type of language model used by this chat model."""
        return "echoing-chat-model-advanced"

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Return a dictionary of identifying parameters."""
        return {"n": self.n}

### Let's test it 🧪

The chat model will implement the standard `Runnable` interface of LangChain which many of the LangChain abstractions support!

In [12]:
model = CustomChatModelAdvanced(n=3)

In [13]:
model.invoke([HumanMessage(content='hello!'), AIMessage(content="Hi there human!"), HumanMessage(content="Meow!")])

AIMessage(content='Meo')

In [14]:
model.invoke('hello')

AIMessage(content='hel')

In [15]:
model.batch(['hello', 'goodbye'])

[AIMessage(content='hel'), AIMessage(content='goo')]

In [16]:
for chunk in model.stream('cat'):
    print(chunk.content, end='|')

c|a|t|

Please see the implementation of `_astream` in the model! If you do not implement it, then no output will stream.!

In [17]:
async for chunk in model.astream('cat'):
    print(chunk.content, end='|')

c|a|t|

Let's try to use the astream events API which will also help double check that all the callbacks were implemented!

In [18]:
async for event in model.astream_events('cat', version='v1'):
    print(event)

{'event': 'on_chat_model_start', 'run_id': 'fe9141d7-78a0-4a5e-a267-44194f3a1d39', 'name': 'CustomChatModelAdvanced', 'tags': [], 'metadata': {}, 'data': {'input': 'cat'}}
{'event': 'on_chat_model_stream', 'run_id': 'fe9141d7-78a0-4a5e-a267-44194f3a1d39', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='c')}}
{'event': 'on_chat_model_stream', 'run_id': 'fe9141d7-78a0-4a5e-a267-44194f3a1d39', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='a')}}
{'event': 'on_chat_model_stream', 'run_id': 'fe9141d7-78a0-4a5e-a267-44194f3a1d39', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='t')}}
{'event': 'on_chat_model_end', 'name': 'CustomChatModelAdvanced', 'run_id': 'fe9141d7-78a0-4a5e-a267-44194f3a1d39', 'tags': [], 'metadata': {}, 'data': {'output': AIMessageChunk(content='cat')}}


  warn_beta(


## Contributing

We would very much appreciate contributions for the chat model.

Here's a checklist to help you validate your implementation of the chat model.

### Checklist

An overview of things to verify to make sure that the implementation is done correctly.

Documentation:
- [ ] The model contains doc-strings for all initialization arguments, as these will be surfaced in the [APIReference](https://api.python.langchain.com/en/stable/langchain_api_reference.html).
- [ ] The class doc-string for the model contains a link to the model API if the model is powered by a service.

Tests:
- [ ] Add unit or integration tests to the overridden methods. Verify that `invoke`, `ainvoke`, `batch`, `stream` work if you've over-ridden the corresponding code.

Streaming (if you're implementing it):
- [ ] Provided an async implementation via `_astream`
- [ ] Make sure to invoke the `on_llm_new_token` callback
- [ ] `on_llm_new_token` is invoked BEFORE yielding the chunk

Stop Token Behavior:
- [ ] Stop token should be respected
- [ ] Stop token should be INCLUDED as part of the response

Secret API Keys:
- [ ] If your model connects to an API it will likely accept API keys as part of its initialization. Use Pydantic's `SecretStr` type for secrets, so they don't get accidentally printed out when folks print the model.