Skip to content

Commit

Permalink
Merge pull request #41 from artitw/langchain
Browse files Browse the repository at this point in the history
LangChain
  • Loading branch information
artitw committed Jul 9, 2023
2 parents 56e1b35 + d13b284 commit 37ccfad
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 1 deletion.
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.2.4",
version="1.2.5",
author="Artit Wangperawong",
author_email="artitw@gmail.com",
description="Text2Text: Crosslingual NLP/G toolkit",
Expand All @@ -26,6 +26,7 @@
'peft',
'faiss-cpu',
'flask',
'langchain',
'googledrivedownloader',
'numpy',
'pandas',
Expand Down
75 changes: 75 additions & 0 deletions text2text/langchain/stfidf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""STF-IDF Retriever.
Based on https://github.com/artitw/text2text"""

from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional

from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document


class STFIDFRetriever(BaseRetriever):
index: Any
docs: List[Document]
k: int = 4

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True

@classmethod
def from_texts(
cls,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
**kwargs: Any,
) -> STFIDFRetriever:
try:
import text2text as t2t
except ImportError:
raise ImportError(
"Could not import text2text, please install with `pip install "
"text2text`."
)

index = t2t.Indexer().transform(texts)
metadatas = metadatas or ({} for _ in texts)
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
return cls(index=index, docs=docs, **kwargs)

@classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
tfidf_params: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> STFIDFRetriever:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts, metadatas=metadatas, **kwargs
)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
distances, pred_ids = self.index.search([query], k=self.k)
return [self.docs[i] for i in pred_ids[0] if i >= 0]

async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
self.docs += documents
self.index.add(texts)
31 changes: 31 additions & 0 deletions text2text/langchain/test_stfidf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest

from text2text.langchain.stfidf import STFIDFRetriever
from langchain.schema import Document


@pytest.mark.requires("langchain")
def test_from_texts() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
stfidf_retriever = STFIDFRetriever.from_texts(texts=input_texts)
assert len(stfidf_retriever.docs) == 3


@pytest.mark.requires("langchain")
def test_retrieval() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
stfidf_retriever = STFIDFRetriever.from_texts(
texts=input_texts
)
assert len(stfidf_retriever.index.retrieve(["pen"], k=2)[0]) == 2


@pytest.mark.requires("langchain")
def test_from_documents() -> None:
input_docs = [
Document(page_content="I have a pen."),
Document(page_content="Do you have a pen?"),
Document(page_content="I have a bag."),
]
tfidf_retriever = STFIDFRetriever.from_documents(documents=input_docs)
assert len(tfidf_retriever.docs) == 3
10 changes: 10 additions & 0 deletions text2text/langchain/test_text2text_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import pytest

from text2text.langchain.text2text_assistant import Text2TextAssistant

@pytest.mark.requires("langchain")
def test_llm_inference() -> None:
input_text = 'Say "hello, world" back to me'
llm = Text2TextAssistant()
result = llm(input_text)
assert "hello" in result.lower()
28 changes: 28 additions & 0 deletions text2text/langchain/text2text_assistant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Any, List, Mapping, Optional

import text2text as t2t
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM

class Text2TextAssistant(LLM):
model: t2t.Assistant = t2t.Assistant()

@property
def _llm_type(self) -> str:
return "Text2Text"

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return self.model.transform([prompt], **kwargs)[0]

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {"type": self._llm_type}

0 comments on commit 37ccfad

Please sign in to comment.