In [1]:
import pickle

In [4]:
import re
import faiss
import pickle
import tiktoken
from pydantic import BaseModel
from typing import Any, Dict, List
from langchain.chains import LLMChain
from langchain.vectorstores import FAISS
from langchain.schema import BaseRetriever
from langchain.chat_models import ChatOpenAI
from langchain.docstore.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.callbacks.manager import AsyncCallbackManager
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

import logging

logger = logging.getLogger(__name__)
# Formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# stream handler
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)


In [5]:
class CustomeSplitter:
    def __init__(self, chunk_threshold=6000, chunk_size=6000, chunk_overlap=50):
        self.chunk_threshold = chunk_threshold
        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.enc = tiktoken.get_encoding("cl100k_base")
        self.splitter = TokenTextSplitter(
            chunk_size=chunk_size, 
            chunk_overlap=chunk_overlap
        )

    def token_counter(self, document):
        tokens = self.enc.encode(document.page_content)
        return len(tokens)

    def split(self, documents):
        chunked_documents = []
        for i, doc in enumerate(documents):
            try:
                if self.token_counter(doc) > self.chunk_threshold:
                    chunks = self.splitter.split_documents([doc])
                    chunks = [
                        Document(
                            page_content=chunk.page_content,
                            metadata={
                                "source": f"{chunk.metadata['source']} chunk {i}"
                            },
                        )
                        for i, chunk in enumerate(chunks)
                    ]
                    chunked_documents.extend(chunks)
                else:
                    chunked_documents.append(doc)
            except Exception as e:
                chunked_documents.append(doc)
                print(f"Error on document {i}")
                print(e)
                print(doc.metadata["source"])

        return chunked_documents


class CustomRetriever(BaseRetriever, BaseModel):
    full_docs: List[Document]
    base_retriever_all: BaseRetriever = None
    base_retriever_data: BaseRetriever = None
    k_initial: int = 10
    k_final: int = 4

    logger: Any = None

    class Config:
        """Configuration for this pydantic object."""

        arbitrary_types_allowed = True

    @classmethod
    def from_documents(
        cls,
        full_docs: List[Document],
        vectorstore_all: FAISS,
        vectorstore_data: FAISS,
        search_kwargs: Dict[str, Any] = {},
        k_initial: int = 10,
        k_final: int = 4,
        logger: Any = None,
        **kwargs: Any,
    ):
        # splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=50)
        # split_docs = splitter.split_documents(full_docs)
        # vector_store = FAISS.from_documents(split_docs, embedding=OpenAIEmbeddings())

        return cls(
            full_docs=full_docs,
            base_retriever_all=vectorstore_all.as_retriever(search_kwargs={"k": k_initial}),
            base_retriever_data=vectorstore_data.as_retriever(search_kwargs={"k": k_initial}),
            logger=logger,
        )

    def get_relevant_documents(self, query: str, workflow:int=1) -> List[Document]:
        self.logger.info(f"Worflow: {workflow}")

        if workflow == 2:
            results = self.base_retriever_data.get_relevant_documents(query=query)
            self.logger.info(f"Retrieved {len(results)} documents")
            return results[:self.k_final]

        else:
            results =  self.base_retriever_all.get_relevant_documents(query=query)
            self.logger.info(f"Retrieved {len(results)} documents")
            if workflow == 1:
                doc_ids = [doc.metadata["source"] for doc in results]

                # make it a set but keep the order
                doc_ids = list(dict.fromkeys(doc_ids))[:self.k_final]

                # log to the logger
                self.logger.info(f"Retrieved {len(doc_ids)} unique documents")

                # get upto 4 documents
                full_retrieved_docs = [d for d in self.full_docs if d.metadata["source"] in doc_ids]

                return self.prepare_source(full_retrieved_docs)

            full_retrieved_docs = results[:self.k_final]
            return self.prepare_source(full_retrieved_docs)
        
    async def aget_relevant_documents(self, query: str) -> List[Document]:
        raise NotImplementedError

    def prepare_source(self, documents: List[Document]) -> List[Document]:
        
        for doc in documents:
            source = doc.metadata["source"]
            if "chunk" in source:
                source = source.split("chunk")[0].strip()
                doc.metadata["source"] = source

        return documents

In [6]:
with open('../data/blog_2023-08-18.pkl', 'rb') as f:
    blog_docs = pickle.load(f)

with open('/Users/arshath/play/chainlink-assistant/data/chain_link_main_docs_2023-08-18.pkl', 'rb') as f:
    chainlink_docs = pickle.load(f)

with open('/Users/arshath/play/chainlink-assistant/data/chain_link_you_tube_docs_2023-08-18.pkl', 'rb') as f:
    chainlink_youtube_docs = pickle.load(f)

with open('/Users/arshath/play/chainlink-assistant/data/stackoverflow_documents.pkl', 'rb') as f:
    stackoverflow_docs = pickle.load(f)

with open('/Users/arshath/play/chainlink-assistant/data/techdocs_2023-08-18.pkl', 'rb') as f:
    tech_docs = pickle.load(f)

with open('/Users/arshath/play/chainlink-assistant/data/education_2023-08-14.pkl', 'rb') as f:
    education_docs = pickle.load(f)
    
with open('/Users/arshath/play/chainlink-assistant/data/datadocs_2023-08-18.pkl', 'rb') as f:
    data_docs = pickle.load(f)

In [7]:
all_docs = blog_docs + chainlink_docs + chainlink_youtube_docs + stackoverflow_docs + tech_docs + education_docs

with open('/Users/arshath/play/chainlink-assistant/data/documents.pkl', 'wb') as f:
    pickle.dump(all_docs, f)

In [8]:

# Split documents into chunks for 16k model
full_doc_splitter = CustomeSplitter()
chunked_full_documents = full_doc_splitter.split(all_docs)

splitter = RecursiveCharacterTextSplitter(chunk_size=1200, chunk_overlap=50)
split_docs = splitter.split_documents(all_docs)

# Create vectorstore for all documents
vectorstore_all = FAISS.from_documents(split_docs, embedding=OpenAIEmbeddings())

# Split documents into chunks using datadocs
split_docs_data = splitter.split_documents(data_docs)

# Create vectorstore for datadocs
vectorstore_data = FAISS.from_documents(split_docs_data, embedding=OpenAIEmbeddings())



In [9]:
# Save vectorstore_all
faiss.write_index(vectorstore_all.index, "docs_all.index")
vectorstore_all.index = None
with open("faiss_store_all.pkl", "wb") as f:
    pickle.dump(vectorstore_all, f)

# Save vectorstore_data
faiss.write_index(vectorstore_data.index, "docs_data.index")
vectorstore_data.index = None
with open("faiss_store_data.pkl", "wb") as f:
    pickle.dump(vectorstore_data, f)