## Import libraries

In [None]:
import os, sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [1]:
from dotenv import load_dotenv
from typing import List

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents.compressor import BaseDocumentCompressor

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_cohere import CohereRerank, CohereEmbeddings
from langchain_core.documents import Document
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers.string import StrOutputParser
from langchain_chroma import Chroma
from langchain.storage import LocalFileStore
from langchain.retrievers.multi_vector import MultiVectorRetriever

from src.prompts import ADMISSION_CONSULTANT_PROMPT
from src.utils import get_image_type

In [2]:
load_dotenv()

True

## Setup chains

Multi-vector retriever

In [3]:
vectorstore = Chroma(
    collection_name='multimodal',
    collection_metadata={'hnsw:space': 'cosine'},
    embedding_function=CohereEmbeddings(model='embed-multilingual-v3.0'),
    persist_directory=os.path.join(os.getcwd(), '../database')
)

store = LocalFileStore(root_path=os.path.abspath(os.path.join(os.getcwd(), '../database/docstore')))

id_key = "doc_id"

multivector_retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
    search_type='similarity',
    search_kwargs={'k':15}
)

Reranker

In [4]:
reranker = CohereRerank(model='rerank-multilingual-v3.0')

RAG class

In [None]:
class RAG:
    def __init__(self,
                 chat_model: BaseChatModel,
                 multivector_retriever: BaseRetriever,
                 reranker: BaseDocumentCompressor,
                 num_retrieved_docs: int = 5,
                 compress_docs: bool = False):
        
        self.chat_model = chat_model
        self.multivector_retriever = multivector_retriever
        self.reranker = reranker
        self.num_retrieved_docs = num_retrieved_docs
        self.reorder_method = self._compress_documents if compress_docs else self._rerank_documents

    def get_chain(self):
        chain = (
            {
                'context': 
                    self.multivector_retriever 
                    | RunnableLambda(self._parse_retrieved_context) 
                    | RunnableLambda(self._split_text_image_content),
                'question': RunnablePassthrough()
            }
            | RunnableLambda(self.reorder_method)
            | RunnableLambda(self._create_prompt)
            | self.chat_model
            | StrOutputParser()
        )

        return chain
    
    def _parse_retrieved_context(self, context: List[bytes]):
        return [Document.parse_raw(b) for b in context]

    def _split_text_image_content(self, retrieved_docs: List[Document]):
        texts = []
        b64_images = []
        for i, doc in enumerate(retrieved_docs):
            if doc.metadata.get('type') == 'image':
                b64_images.append(doc.page_content)
            else:
                texts.append(doc)

        return {'texts': texts, 'images': b64_images}
    
    def _compress_documents(self, data_dict: dict):
        compressed_texts = self.reranker.compress_documents(documents=data_dict['context']['texts'], query=data_dict['question'])
        data_dict['context']['texts'] = compressed_texts[:self.num_retrieved_docs]

        return data_dict

    def _rerank_documents(self, data_dict: dict):
        try:
            rank = self.reranker.rerank(documents=data_dict['context']['texts'], query=data_dict['question'])
            rank = [r['index'] for r in rank[:self.num_retrieved_docs]]
            data_dict['context']['texts'] = [data_dict['context']['text'][i] for i in rank]

            return data_dict

        except:
            return self._compress_documents(data_dict)

    def _create_prompt(self, data_dict: dict):
        context = '\n\n'.join([text.page_content for text in data_dict['context']['texts']])
        message = [
            {
                'type': 'text',
                'text': ADMISSION_CONSULTANT_PROMPT.format(context, data_dict['question'])
            }
        ]

        if data_dict['context']['images']:
            for image in data_dict['context']['images']:
                message.append(
                    {
                        'type': 'image_url',
                        'image_url': {'url': f'data:image/{get_image_type(image)};base64,{image}'}
                    }
                )
                
        return [HumanMessage(content=message)]

In [5]:
def parse_retrieved_context(context: List[bytes]):
    return [Document.parse_raw(b) for b in context]

def split_text_image_content(retrieved_docs: List[Document]):
    texts = []
    b64_images = []
    for i, doc in enumerate(retrieved_docs):
        if doc.metadata.get('type') == 'image':
            b64_images.append(doc.page_content)
        else:
            texts.append(doc)

    return {'texts': texts, 'images': b64_images}

def compress_documents(data_dict: dict):
    reranker = CohereRerank(model='rerank-multilingual-v3.0')
    compressed_texts = reranker.compress_documents(documents=data_dict['context']['texts'], query=data_dict['question'])
    data_dict['context']['texts'] = compressed_texts[:5]

    return data_dict

Chain

In [7]:
rag = RAG(chat_model=ChatGoogleGenerativeAI(model='gemini-1.5-flash-latest'),
          multivector_retriever=multivector_retriever,
          reranker=reranker,
          compress_docs=False)

chain = rag.get_chain()

In [8]:
chain.invoke('Trường có bao nhiêu cơ sở?')

'Trường có hai cơ sở: \n- Cơ sở chính: Khu đô thị ĐH Quốc gia, Khu phố 6, phường Linh Trung, TP. Thủ Đức, TP.HCM.\n- Cơ sở 2: 227 Nguyễn Văn Cừ, Q5, Tp.HCM.'