## Import libraries

In [1]:
import os
import json
import uuid
import base64
from hashlib import sha256
from tqdm.auto import tqdm
from typing import List, Union
from dotenv import load_dotenv
from IPython.display import Image

from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAI
from langchain_cohere import CohereRerank, CohereEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_core.documents import Document
from langchain_core.messages import HumanMessage, AIMessage
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.output_parsers.string import StrOutputParser
from langchain_chroma import Chroma
from langchain.indexes import SQLRecordManager, index
from langchain.storage import LocalFileStore
from langchain.retrievers.multi_vector import MultiVectorRetriever


from src.scraper.scrape import crawl, crawl_webpage
from src.prompts import TEXT_SUMMARIZE_PROMPT, IMAGE_SUMMARIZE_PROMPT
from src.image_extractor import (
    GeminiImageExtractor,
    get_image,
    encode_image
)
from src.agent_chunking import GeminiChunker
from src.embedding import GeminiEmbedding
from src.utils import load_json, write_json

In [2]:
load_dotenv()

## Crawl

This is the crawling data code, uncomment and run it if this is the first time running code. Or check the `data/tmp_docs.json` data

In [3]:
# data_dict = crawl()

In [4]:
# with open(os.path.join(os.getcwd(), 'data/tmp_data.json'), 'w', encoding='utf-8') as f:
#     json.dump(data_dict, f, ensure_ascii=False, indent=4)

In [5]:
# # convert raw text to Document
# texts = [Document(page_content=v['text'], metadata={'source': {'url': k, 'title': v['title']}}) for k, v in data_dict.items()]

# # convert raw table to Document
# tables = []
# for k, v in data_dict.items():
#     for table in v['table']:
#         tables.append(Document(page_content=table, metadata={'source': {'url': k, 'title': v['title']}}))

Read the available crawled data

In [6]:
data_dict = load_json(os.path.join(os.getcwd(), 'data/tmp_data.json'))

texts = [Document(page_content=v['text'], metadata={'source': {'url': k, 'title': v['title']}}) for k, v in data_dict.items()]

tables = []
for k, v in data_dict.items():
    for table in v['table']:
        tables.append(Document(page_content=table, metadata={'source': {'url': k, 'title': v['title']}}))

## Partition content

In [7]:
# chunker = GeminiChunker()
# chunks = chunker.split_documents(texts)

### Summarize

### Summarize process

Uncomment to run it

In [8]:
# MODEL_NAME = 'gemini-1.5-flash-latest'

In [9]:
# def generate_text_summaries(texts: Union[List[str], List[Document]], tables: Union[List[str], List[Document]], summary_text: bool = True):
#     prompt = PromptTemplate.from_template(template=TEXT_SUMMARIZE_PROMPT)
#     empty_response = RunnableLambda(lambda x : AIMessage(content="Error processing document"))
#     model = GoogleGenerativeAI(model=MODEL_NAME, temperature=0.0, top_k=1, top_p=0.1).with_fallbacks([empty_response])
#     chain = prompt | model | StrOutputParser()

#     text_summaries = []
#     table_summaries = []
#     if texts and summary_text:
#         if isinstance(texts[0], Document):
#             texts = [t.page_content for t in texts]
        
#         text_summaries = chain.batch(texts, {'max_concurrency':1})
#     else:
#         text_summaries = texts

#     if tables:
#         if isinstance(tables[0], Document):
#             tables = [t.page_content for t in tables]
#         table_summaries = chain.batch(tables, {'max_concurrency':1})

#     return text_summaries, table_summaries

In [10]:
# text_summaries, table_summaries = generate_text_summaries(texts=texts, tables=tables, summary_text=True)

In [11]:
# def generate_image_summaries(image_paths):
#     model = GeminiImageExtractor(custom_prompt=IMAGE_SUMMARIZE_PROMPT)
#     image_summaries = []
#     b64_images = []
#     for image in tqdm(sorted(image_paths)):
#         image_summaries.append(model.invoke(image))
#         b64_images.append(encode_image(get_image(image)))

#     return b64_images, image_summaries

In [12]:
# images = []
# image_summaries = []
# for k, v in data_dict.items():
#     if v['image_path']:
#         b64_images, summaries = generate_image_summaries(v['image_path'])
#         images += [Document(page_content=image, metadata={'source': {'url': k, 'title': v['title']}}) for image in b64_images]
#         image_summaries += summaries

In [13]:
# saved_docs = {
#     'texts': [json.loads(text.json(ensure_ascii=False)) for text in texts],
#     'tables': [json.loads(table.json(ensure_ascii=False)) for table in tables],
#     'images': [json.loads(image.json(ensure_ascii=False)) for image in images],
# }

# saved_summaries = {
#     'text_summaries': text_summaries,
#     'table_summaries': table_summaries,
#     'image_summaries': image_summaries
# }

# write_json(os.path.join(os.getcwd(), 'data/documents.json'), saved_docs)
# write_json(os.path.join(os.getcwd(), 'data/summaries.json'), saved_summaries)

### Use saved summaries

In [14]:
documents = load_json(os.path.join(os.getcwd(), 'data/documents.json'))
summaries = load_json(os.path.join(os.getcwd(), 'data/summaries.json'))

texts = [Document(**t) for t in documents['texts']]
tables = [Document(**t) for t in documents['tables']]
images = [Document(**i) for i in documents['images']]

text_summaries, table_summaries, image_summaries = summaries['text_summaries'], summaries['table_summaries'], summaries['image_summaries']

## Setup Multi-vector Retriever

### Ingest data

Record Manager

In [15]:
collection_name = 'multimodal'
namespace = f'chroma/{collection_name}'
db_path = os.path.join(os.getcwd(), 'record_manager_cache.sql')
record_manager = SQLRecordManager(
    namespace=namespace,
    db_url=f'sqlite:///{os.path.join(os.getcwd(), 'database/record_manager_cache.sql')}'
)

if not os.path.exists(db_path):
    record_manager.create_schema()

Vectorstore

In [16]:
vectorstore = Chroma(
    collection_name='test',
    collection_metadata={'hnsw:space': 'cosine'},
    embedding_function=CohereEmbeddings(model='embed-multilingual-v3.0'),
    persist_directory=os.path.join(os.getcwd(), 'database')
)

Docstore

In [17]:
store = LocalFileStore(root_path=os.path.join(os.getcwd(), 'database/docstore'))

Ingest data

In [18]:
doc_contents = [t.page_content.encode() for t in texts] + [t.page_content.encode() for t in tables] + [i.page_content.encode() for i in images]

id_key = 'doc_id'
id_map = {}
doc_ids = []
for doc in (texts + tables + images):
    id = str(uuid.uuid4())

    hash_source = sha256(doc.metadata['source']['url'].encode('utf-8')).hexdigest()
    if hash_source not in id_map:
        id_map[hash_source] = doc.metadata['source']
        id_map[hash_source]['doc_ids'] = []
    id_map[hash_source]['doc_ids'].append(id)
    
    doc_ids.append(id)

doc_summaries = [
    Document(page_content=s, metadata={id_key: doc_ids[i]}) \
    for i, s in enumerate(text_summaries + table_summaries + image_summaries)
]

In [19]:
index(
    docs_source=doc_summaries,
    record_manager=record_manager,
    vector_store=vectorstore,
    cleanup='full',
    source_id_key=id_key
)

In [20]:
store.mset(list(zip(doc_ids, doc_contents)))

In [21]:
write_json(os.path.join(os.getcwd(), 'database/id_map.json'), id_map)

### Load retriever

Multi-vector retriever

In [None]:
id_key = "doc_id"

retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
    search_type='similarity',
    search_kwargs={'k':15}
)