Skip to content

Commit

Permalink
Add _type for all parsers (langchain-ai#4189)
Browse files Browse the repository at this point in the history
Used for serialization. Also add test that recurses through
our subclasses to check they have them implemented

Would fix langchain-ai#3217
Blocking: mlflow/mlflow#8297

---------

Signed-off-by: Sunish Sheth <sunishsheth2009@gmail.com>
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
  • Loading branch information
2 people authored and EandrewJones committed May 12, 2023
1 parent eed598b commit b53c8f5
Show file tree
Hide file tree
Showing 14 changed files with 106 additions and 8 deletions.
4 changes: 4 additions & 0 deletions langchain/agents/chat/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:

except Exception:
raise OutputParserException(f"Could not parse LLM output: {text}")

@property
def _type(self) -> str:
return "chat"
4 changes: 4 additions & 0 deletions langchain/agents/conversational/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
action = match.group(1)
action_input = match.group(2)
return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text)

@property
def _type(self) -> str:
return "conversational"
4 changes: 4 additions & 0 deletions langchain/agents/conversational_chat/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
return AgentFinish({"output": action_input}, text)
else:
return AgentAction(action, action_input, text)

@property
def _type(self) -> str:
return "conversational_chat"
4 changes: 4 additions & 0 deletions langchain/agents/mrkl/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,7 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
action = match.group(1).strip()
action_input = match.group(2)
return AgentAction(action, action_input.strip(" ").strip('"'), text)

@property
def _type(self) -> str:
return "mrkl"
4 changes: 4 additions & 0 deletions langchain/agents/react/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
return AgentFinish({"output": action_input}, text)
else:
return AgentAction(action, action_input, text)

@property
def _type(self) -> str:
return "react"
4 changes: 4 additions & 0 deletions langchain/agents/self_ask_with_search/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if " " == after_colon[0]:
after_colon = after_colon[1:]
return AgentAction("Intermediate Answer", after_colon, text)

@property
def _type(self) -> str:
return "self_ask"
8 changes: 8 additions & 0 deletions langchain/agents/structured_chat/output_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
except Exception as e:
raise OutputParserException(f"Could not parse LLM output: {text}") from e

@property
def _type(self) -> str:
return "structured_chat"


class StructuredChatOutputParserWithRetries(AgentOutputParser):
base_parser: AgentOutputParser = Field(default_factory=StructuredChatOutputParser)
Expand Down Expand Up @@ -76,3 +80,7 @@ def from_llm(
return cls(base_parser=base_parser)
else:
return cls()

@property
def _type(self) -> str:
return "structured_chat_with_retries"
4 changes: 4 additions & 0 deletions langchain/chains/api/openapi/requests_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def parse(self, llm_output: str) -> str:
return f"MESSAGE: {message_match.group(1).strip()}"
return "ERROR making request"

@property
def _type(self) -> str:
return "api_requester"


class APIRequesterChain(LLMChain):
"""Get the request parser."""
Expand Down
4 changes: 4 additions & 0 deletions langchain/chains/api/openapi/response_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def parse(self, llm_output: str) -> str:
else:
raise ValueError(f"No response found in output: {llm_output}.")

@property
def _type(self) -> str:
return "api_responder"


class APIResponderChain(LLMChain):
"""Get the response parser."""
Expand Down
4 changes: 4 additions & 0 deletions langchain/chains/llm_bash/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def get_code_blocks(t: str) -> List[str]:

return code_blocks

@property
def _type(self) -> str:
return "bash"


PROMPT = PromptTemplate(
input_variables=["question"],
Expand Down
2 changes: 1 addition & 1 deletion langchain/output_parsers/fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def get_format_instructions(self) -> str:

@property
def _type(self) -> str:
return self.parser._type
return "output_fixing"
6 changes: 5 additions & 1 deletion langchain/output_parsers/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def get_format_instructions(self) -> str:

@property
def _type(self) -> str:
return self.parser._type
return "retry"


class RetryWithErrorOutputParser(BaseOutputParser[T]):
Expand Down Expand Up @@ -122,3 +122,7 @@ def parse(self, completion: str) -> T:

def get_format_instructions(self) -> str:
return self.parser.get_format_instructions()

@property
def _type(self) -> str:
return "retry_with_error"
15 changes: 9 additions & 6 deletions langchain/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,25 +227,25 @@ class BaseChatMessageHistory(ABC):
class FileChatMessageHistory(BaseChatMessageHistory):
storage_path: str
session_id: str
@property
def messages(self):
with open(os.path.join(storage_path, session_id), 'r:utf-8') as f:
messages = json.loads(f.read())
return messages_from_dict(messages)
return messages_from_dict(messages)
def add_user_message(self, message: str):
message_ = HumanMessage(content=message)
messages = self.messages.append(_message_to_dict(_message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def add_ai_message(self, message: str):
message_ = AIMessage(content=message)
messages = self.messages.append(_message_to_dict(_message))
with open(os.path.join(storage_path, session_id), 'w') as f:
json.dump(f, messages)
def clear(self):
with open(os.path.join(storage_path, session_id), 'w') as f:
f.write("[]")
Expand Down Expand Up @@ -348,7 +348,10 @@ def get_format_instructions(self) -> str:
@property
def _type(self) -> str:
"""Return the type key."""
raise NotImplementedError
raise NotImplementedError(
f"_type property is not implemented in class {self.__class__.__name__}."
" This is required for serialization."
)

def dict(self, **kwargs: Any) -> Dict:
"""Return dictionary representation of output parser."""
Expand Down
47 changes: 47 additions & 0 deletions tests/unit_tests/output_parsers/test_base_output_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Test the BaseOutputParser class and its sub-classes."""
from abc import ABC
from typing import List, Optional, Set, Type

import pytest

from langchain.schema import BaseOutputParser


def non_abstract_subclasses(
cls: Type[ABC], to_skip: Optional[Set] = None
) -> List[Type]:
"""Recursively find all non-abstract subclasses of a class."""
_to_skip = to_skip or set()
subclasses = []
for subclass in cls.__subclasses__():
if not getattr(subclass, "__abstractmethods__", None):
if subclass.__name__ not in _to_skip:
subclasses.append(subclass)
subclasses.extend(non_abstract_subclasses(subclass, to_skip=_to_skip))
return subclasses


_PARSERS_TO_SKIP = {"FakeOutputParser", "BaseOutputParser"}
_NON_ABSTRACT_PARSERS = non_abstract_subclasses(
BaseOutputParser, to_skip=_PARSERS_TO_SKIP
)


@pytest.mark.parametrize("cls", _NON_ABSTRACT_PARSERS)
def test_subclass_implements_type(cls: Type[BaseOutputParser]) -> None:
try:
cls._type
except NotImplementedError:
pytest.fail(f"_type property is not implemented in class {cls.__name__}")


def test_all_subclasses_implement_unique_type() -> None:
types = []
for cls in _NON_ABSTRACT_PARSERS:
try:
types.append(cls._type)
except NotImplementedError:
# This is handled in the previous test
pass
dups = set([t for t in types if types.count(t) > 1])
assert not dups, f"Duplicate types: {dups}"

0 comments on commit b53c8f5

Please sign in to comment.