In [None]:
from langchain.vectorstores import FAISS
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
import pandas as pd
import numpy as np
import os
import json
defaut_search_config = {
    'k': 1,
}
class Retriever:
    def __init__(self, text_list, embedder, main_tokenizer, search_config=defaut_search_config):
        self.text_list = text_list
        self.embedder = embedder
        self.main_tokenizer = main_tokenizer
        self.search_config = search_config
        self.corpus = None
        self.db = None
        self.retriever = None
        if self.corpus is None:
            self._build()

    def _build_corpus(self, num_token):
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=num_token,
            length_function=lambda x: len(self.main_tokenizer.tokenize(x))
        )
        self.corpus = []
        for i, text in enumerate(self.text_list):
            text_chunks = text_splitter.split_text(text)
            text_docs = [Document(text_chunk, {'id': i}) for text_chunk in text_chunks]
            self.corpus.extend(text_docs)
    
    def _build(self):
        self._build_corpus()
        self.db = FAISS.from_documents(self.corpus, self.embedder)
        self.retriever = self.db.as_retriever(search_kwargs=self.search_config)

    def save_local(self, path):
        os.makedirs(path, exist_ok=True)
        text_list_df = pd.DataFrame(self.text_list)
        corpus_list = [
            {
                'page_content': doc.page_content,
                'metadata': doc.metadata
            } for doc in self.corpus
        ]
        corpus_df = pd.DataFrame(corpus_list)
        text_list_df.to_csv(os.path.join(path, 'text_list.csv'))
        corpus_df.to_csv(os.path.join(path, 'corpus.csv'))
        self.db.save_local(os.path.join(path, 'db'))
        with open(os.path.join(path, 'search_config.json'), 'w') as f:
            json.dump(self.search_config, f)


    @staticmethod
    def load_local(self, path, embedder, main_tokenizer):
        # check all paths exist
        paths = [
            os.path.join(path, 'text_list.csv'),
            os.path.join(path, 'corpus.csv'),
            os.path.join(path, 'db'),
            os.path.join(path, 'search_config.json'),
        ]
        for path in paths:
            if not os.path.exists(path):
                raise ValueError(f'Path {path} does not exist')
        with open(os.path.join(path, 'search_config.json'), 'r') as f:
            search_config = json.load(f)
        text_list_df = pd.read_csv(os.path.join(path, 'text_list.csv'))
        corpus_df = pd.read_csv(os.path.join(path, 'corpus.csv'))
        text_list = text_list_df['0'].tolist()
        corpus = [
            Document(row['page_content'], eval(row['metadata']))
            for _, row in corpus_df.iterrows()
        ]
        db = FAISS.load_local(os.path.join(path, 'db'), embedder)
        retriever =  Retriever(text_list, embedder, main_tokenizer, search_config=search_config)
        retriever.corpus = corpus
        retriever.db = db
        retriever.retriever = db.as_retriever(search_kwargs=retriever.search_config)
        return retriever

    def search_main_document(self, query):
        candidate_doc = self.retriever.get_relevant_documents(query)[0]
        id_ = candidate_doc.metadata['id']
        candidate_chunks = [doc.page_content for doc in self.corpus if doc.metadata['id'] == id_]
        temp_embeddings = self.embedder.embed_documents([doc for doc in candidate_chunks])    
        return {
            'id': id_,
            'chunk_texts': candidate_chunks,
            'chunk_embeddings': temp_embeddings
        }
    
    def search_chunks(self, main_doc, query, k=2):
        q_embedding = self.embedder.embed_query(query)
        chunk_embeddings = main_doc['chunk_embeddings']
        chunk_texts = main_doc['chunk_texts']
        scores = np.dot(chunk_embeddings, q_embedding) / (np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(q_embedding))
        top_k = np.argsort(scores)[::-1][:k]
        return [chunk_texts[i] for i in top_k]

In [5]:
import os
os.makedirs('hello_dir', exist_ok=True)