In [2]:
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter

In [3]:
help(Document)

Help on class Document in module langchain.schema:

class Document(langchain.load.serializable.Serializable)
 |  Document(*, lc_kwargs: Dict[str, Any] = None, page_content: str, metadata: dict = None) -> None
 |  
 |  Interface for interacting with a document.
 |  
 |  Method resolution order:
 |      Document
 |      langchain.load.serializable.Serializable
 |      pydantic.main.BaseModel
 |      pydantic.utils.Representation
 |      abc.ABC
 |      builtins.object
 |  
 |  Static methods defined here:
 |  
 |  __json_encoder__ = pydantic_encoder(obj: Any) -> Any
 |  
 |  ----------------------------------------------------------------------
 |  Data and other attributes defined here:
 |  
 |  __abstractmethods__ = frozenset()
 |  
 |  __annotations__ = {'metadata': 'dict', 'page_content': 'str'}
 |  
 |  __class_vars__ = set()
 |  
 |  __config__ = <class 'langchain.schema.Config'>
 |  
 |  __custom_root_type__ = False
 |  
 |  __exclude_fields__ = {'lc_kwargs': True}
 |  
 |  __fiel

In [None]:
import numpy as np
defaut_search_config = {
    'k': 1,
}
class Retriever:
    def __init__(self, text_list, embedder, main_tokenizer, search_config=defaut_search_config):
        self.text_list = text_list
        self.embedder = embedder
        self.main_tokenizer = main_tokenizer
        self.search_config = search_config

    def _build_corpus(self, num_token):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=num_token,
            length_function=lambda x: len(self.main_tokenizer.tokenize(x))
        )
        self.corpus = []
        for i, text in enumerate(self.text_list):
            text_chunks = text_splitter.split_text(text)
            text_docs = [Document(text_chunk, {'id': i}) for text_chunk in text_chunks]
            self.corpus.extend(text_docs)
    
    def _build(self):
        self._build_corpus()
        self.db = FAISS.from_documents(self.corpus, self.embedder)
        self.retriever = self.db.as_retriever(search_kwargs=self.search_config)

    def save_local(self, path):
        self.db.save_local(path)

    def load_local(self, path):
        pass

    def search_main_document(self, query):
        candidate_doc = self.retriever.get_relevant_documents(query)[0]
        id_ = candidate_doc.metadata['id']
        candidate_chunks = [doc.page_content for doc in self.corpus if doc.metadata['id'] == id_]
        temp_embeddings = self.embedder.embed_documents([doc for doc in candidate_chunks])    
        return {
            'id': id_,
            'chunk_texts': candidate_chunks,
            'chunk_embeddings': temp_embeddings
        }
    
    def search_chunks(self, main_doc, query, k=2):
        q_embedding = self.embedder.embed_query(query)
        chunk_embeddings = main_doc['chunk_embeddings']
        chunk_texts = main_doc['chunk_texts']
        scores = np.dot(chunk_embeddings, q_embedding)
        top_k = np.argsort(scores)[::-1][:k]
        return [chunk_texts[i] for i in top_k]
    

    


