forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
chat_model_facade.py
33 lines (27 loc) · 1.12 KB
/
chat_model_facade.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from __future__ import annotations
from typing import List, Optional
from langchain.chat_models.base import BaseChatModel, SimpleChatModel
from langchain.schema import BaseMessage
from langchain.llms.base import BaseLanguageModel
from langchain.utils import serialize_msgs
class ChatModelFacade(SimpleChatModel):
llm: BaseLanguageModel
def _call(self, messages: List[BaseMessage], stop: Optional[List[str]] = None) -> str:
if isinstance(self.llm, BaseChatModel):
return self.llm(messages, stop=stop).content
elif isinstance(self.llm, BaseLanguageModel):
return self.llm(serialize_msgs(messages), stop=stop)
else:
raise ValueError(
f"Invalid llm type: {type(self.llm)}. Must be a chat model or language model."
)
@classmethod
def of(cls, llm):
if isinstance(llm, BaseChatModel):
return llm
elif isinstance(llm, BaseLanguageModel):
return cls(llm)
else:
raise ValueError(
f"Invalid llm type: {type(llm)}. Must be a chat model or language model."
)