## Import libraries

In [1]:
import os, sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [2]:
import json
import uuid
import base64
from hashlib import sha256
from tqdm.auto import tqdm
from typing import List, Union
from dotenv import load_dotenv
from IPython.display import Image

from langchain_google_genai import GoogleGenerativeAI
from langchain_cohere import CohereEmbeddings
from langchain_core.prompts import PromptTemplate
from langchain_core.documents import Document
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers.string import StrOutputParser
from langchain_chroma import Chroma
from langchain.indexes import SQLRecordManager, index
from langchain.storage import LocalFileStore
from langchain.retrievers.multi_vector import MultiVectorRetriever


from src.scraper.scrape import crawl
from src.prompts import TEXT_SUMMARIZE_PROMPT, IMAGE_SUMMARIZE_PROMPT
from src.image_extractor import (
    GeminiImageExtractor,
    get_image,
    encode_image
)
from src.agent_chunking import GeminiChunker
from src.embedding import GeminiEmbedding
from src.utils import load_json, write_json

In [3]:
load_dotenv()

True

## Crawl

This is the crawling data code, uncomment and run it if this is the first time running code. Or check the `data/tmp_docs.json` data

In [4]:
# data_dict = crawl()

In [5]:
# with open(os.path.join(os.getcwd(), '../data/tmp_data.json'), 'w', encoding='utf-8') as f:
#     json.dump(data_dict, f, ensure_ascii=False, indent=4)

In [6]:
# # convert raw text to Document
# texts = [Document(page_content=v['text'], metadata={'source': k, 'title': v['title']}) for k, v in data_dict.items()]

# # convert raw table to Document
# tables = []
# for k, v in data_dict.items():
#     for table in v['table']:
#         tables.append(Document(page_content=table, metadata={'source': {'url': k, 'title': v['title']}}))

Read the available crawled data

In [7]:
data_dict = load_json(os.path.join(os.getcwd(), '../data/tmp_data.json'))

texts = [Document(page_content=v['text'], metadata={'source': {'url': k, 'title': v['title']}}) for k, v in data_dict.items()]

tables = []
for k, v in data_dict.items():
    for table in v['table']:
        tables.append(Document(page_content=table, metadata={'source': {'url': k, 'title': v['title']}}))

## Partition content

In [8]:
# chunker = GeminiChunker()
# chunks = chunker.split_documents(texts)

### Summarize

### Summarize process

Uncomment to run it

In [9]:
# MODEL_NAME = 'gemini-1.5-flash-latest'

In [10]:
# def generate_text_summaries(texts: Union[List[str], List[Document]], tables: Union[List[str], List[Document]], summary_text: bool = True):
#     prompt = PromptTemplate.from_template(template=TEXT_SUMMARIZE_PROMPT)
#     empty_response = RunnableLambda(lambda x : AIMessage(content="Error processing document"))
#     model = GoogleGenerativeAI(model=MODEL_NAME, temperature=0.0, top_k=1, top_p=0.1).with_fallbacks([empty_response])
#     chain = prompt | model | StrOutputParser()

#     text_summaries = []
#     table_summaries = []
#     if texts and summary_text:
#         if isinstance(texts[0], Document):
#             texts = [t.page_content for t in texts]
        
#         text_summaries = chain.batch(texts, {'max_concurrency':1})
#     else:
#         text_summaries = texts

#     if tables:
#         if isinstance(tables[0], Document):
#             tables = [t.page_content for t in tables]
#         table_summaries = chain.batch(tables, {'max_concurrency':1})

#     return text_summaries, table_summaries

In [11]:
# text_summaries, table_summaries = generate_text_summaries(texts=texts, tables=tables, summary_text=True)

In [12]:
# def generate_image_summaries(image_paths):
#     model = GeminiImageExtractor(custom_prompt=IMAGE_SUMMARIZE_PROMPT)
#     image_summaries = []
#     b64_images = []
#     for image in tqdm(sorted(image_paths)):
#         image_summaries.append(model.invoke(image))
#         b64_images.append(encode_image(get_image(image)))

#     return b64_images, image_summaries

In [13]:
# images = []
# image_summaries = []
# for k, v in data_dict.items():
#     if v['image_path']:
#         b64_images, summaries = generate_image_summaries(v['image_path'])
#         images += [Document(page_content=image, metadata={'source': k, 'title': v['title']}) for image in b64_images]
#         image_summaries += summaries

In [14]:
# saved_docs = {
#     'texts': [json.loads(text.json(ensure_ascii=False)) for text in texts],
#     'tables': [json.loads(table.json(ensure_ascii=False)) for table in tables],
#     'images': [json.loads(image.json(ensure_ascii=False)) for image in images],
# }

# saved_summaries = {
#     'text_summaries': text_summaries,
#     'table_summaries': table_summaries,
#     'image_summaries': image_summaries
# }

# write_json(os.path.join(os.getcwd(), '../data/documents.json'), saved_docs)
# write_json(os.path.join(os.getcwd(), '../data/summaries.json'), saved_summaries)

### Use saved summaries

In [15]:
documents = load_json(os.path.join(os.getcwd(), '../data/documents.json'))
summaries = load_json(os.path.join(os.getcwd(), '../data/summaries.json'))

texts = [Document(**t) for t in documents['texts']]
tables = [Document(**t) for t in documents['tables']]
images = [Document(**i) for i in documents['images']]

text_summaries, table_summaries, image_summaries = summaries['text_summaries'], summaries['table_summaries'], summaries['image_summaries']

## Setup Multi-vector Retriever

### Ingest data

Vectorstore

In [16]:
vectorstore = Chroma(
    collection_name='multimodal',
    collection_metadata={'hnsw:space': 'cosine'},
    embedding_function=CohereEmbeddings(model='embed-multilingual-v3.0'),
    persist_directory=os.path.join(os.getcwd(), '../database')
)

Docstore

In [17]:
store = LocalFileStore(root_path=os.path.abspath(os.path.join(os.getcwd(), '../database/docstore')))

Ingest data

In [18]:
id_key = 'doc_id'

doc_ids = [str(uuid.uuid4()) for _ in range(len(texts) + len(tables) + len(images))]
doc_summaries = [
    Document(page_content=s, metadata={id_key: doc_ids[i]}) \
    for i, s in enumerate(text_summaries + table_summaries + image_summaries)
]

embedding_ids = vectorstore.add_documents(doc_summaries)

In [19]:
map_id = {} # it is used to map hashed url - embedding ids
doc_contents = []
for d, e_id in zip((texts + tables + images), embedding_ids):
    hash_source = sha256(d.metadata['source'].encode()).hexdigest()
    if hash_source not in map_id:
        map_id[hash_source] = []

    map_id[hash_source].append(e_id)
    doc_contents.append(d.json().encode())

In [20]:
store.mset(list(zip(doc_ids, doc_contents)))

In [21]:
write_json(os.path.join(os.getcwd(), '../database/id_map.json'), map_id)

### Load retriever

Multi-vector retriever

In [22]:
id_key = "doc_id"

retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
    search_type='similarity',
    search_kwargs={'k':15}
)