Skip to content

Commit

Permalink
STF-IDF for LangChain
Browse files Browse the repository at this point in the history
  • Loading branch information
artitw committed Jul 9, 2023
1 parent 56e1b35 commit c976ae4
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 1 deletion.
75 changes: 75 additions & 0 deletions langchain/stfidf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""STF-IDF Retriever.
Based on https://github.com/artitw/text2text"""

from __future__ import annotations

from typing import Any, Dict, Iterable, List, Optional

from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.schema import BaseRetriever, Document


class STFIDFRetriever(BaseRetriever):
index: Any
docs: List[Document]
k: int = 4

class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True

@classmethod
def from_texts(
cls,
texts: Iterable[str],
metadatas: Optional[Iterable[dict]] = None,
**kwargs: Any,
) -> STFIDFRetriever:
try:
import text2text as t2t
except ImportError:
raise ImportError(
"Could not import text2text, please install with `pip install "
"text2text`."
)

index = t2t.Indexer().transform(texts)
metadatas = metadatas or ({} for _ in texts)
docs = [Document(page_content=t, metadata=m) for t, m in zip(texts, metadatas)]
return cls(index=index, docs=docs, **kwargs)

@classmethod
def from_documents(
cls,
documents: Iterable[Document],
*,
tfidf_params: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> STFIDFRetriever:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
return cls.from_texts(
texts=texts, metadatas=metadatas, **kwargs
)

def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
distances, pred_ids = self.index.search([query], k=self.k)
return [self.docs[i] for i in pred_ids[0] if i >= 0]

async def _aget_relevant_documents(
self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
raise NotImplementedError

async def aadd_documents(
self, documents: List[Document], **kwargs: Any
) -> List[str]:
texts, metadatas = zip(*((d.page_content, d.metadata) for d in documents))
self.docs += documents
self.index.add(texts)
30 changes: 30 additions & 0 deletions langchain/test_stfidf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import pytest

from text2text.langchain.stfidf import STFIDFRetriever
from langchain.schema import Document


@pytest.mark.requires("langchain")
def test_from_texts() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
stfidf_retriever = STFIDFRetriever.from_texts(texts=input_texts)
assert len(stfidf_retriever.docs) == 3


@pytest.mark.requires("langchain")
def test_retrieval_with_stfidf_params() -> None:
input_texts = ["I have a pen.", "Do you have a pen?", "I have a bag."]
stfidf_retriever = STFIDFRetriever.from_texts(
texts=input_texts, k=2
)
assert len(stfidf_retriever._get_relevant_documents("pen")) == 2

@pytest.mark.requires("langchain")
def test_from_documents() -> None:
input_docs = [
Document(page_content="I have a pen."),
Document(page_content="Do you have a pen?"),
Document(page_content="I have a bag."),
]
tfidf_retriever = STFIDFRetriever.from_documents(documents=input_docs)
assert len(tfidf_retriever.docs) == 3
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="text2text",
version="1.2.4",
version="1.2.5",
author="Artit Wangperawong",
author_email="artitw@gmail.com",
description="Text2Text: Crosslingual NLP/G toolkit",
Expand All @@ -26,6 +26,7 @@
'peft',
'faiss-cpu',
'flask',
'langchain',
'googledrivedownloader',
'numpy',
'pandas',
Expand Down

0 comments on commit c976ae4

Please sign in to comment.