In [1]:
%load_ext autoreload
%autoreload 2

## Load Libraries

In [45]:
import os
import fitz 
import pandas as pd
import io
import requests
import base64
from langchain_core.messages import HumanMessage
from PIL import Image,ImageFile
from DocumentLoader import DocumentChunk,DocumentExtract
ImageFile.LOAD_TRUNCATED_IMAGES = True

## Class to Extract Images 

In [6]:
class ExtractionImage:
    def __init__(self,filename):
        self.output_dir = filename.split('.',1)[0]
        self.output_format = 'png'
        self.min_width = 100
        self.min_height = 100
        self.filename = filename
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir)
            
    def _open_pdf(self):
        pdf_file = fitz.open(self.filename)
        return pdf_file
    
    def generate_images(self):
        pdf_file = self._open_pdf()
        for page_index in range(len(pdf_file)):
            # Get the page itself
            page = pdf_file[page_index]
            # Get image list
            image_list = page.get_images(full=True)
            # Print the number of images found on this page
            if image_list:
                print(f"[+] Found a total of {len(image_list)} images in page {page_index}")
            else:
                print(f"[!] No images found on page {page_index}")
            # Iterate over the images on the page
            for image_index, img in enumerate(image_list, start=1):
                # Get the XREF of the image
                xref = img[0]
                # Extract the image bytes
                base_image = pdf_file.extract_image(xref)
                image_bytes = base_image["image"]
                # Get the image extension
                image_ext = base_image["ext"]
                # Load it to PIL
                image = Image.open(io.BytesIO(image_bytes))
                # Check if the image meets the minimum dimensions and save it
                if image.width >= self.min_width and image.height >= self.min_height:
                    image.save(
                        open(os.path.join(self.output_dir, f"image{page_index + 1}_{image_index}.{self.output_format}"), "wb"),
                        format=self.output_format.upper())
                else:
                    print(f"[-] Skipping image {image_index} on page {page_index} due to its small size.")

In [3]:
GPT4V_keys = "<gpt4-v api keys>"
GPT4V_ENDPOINT = "<gpt4-v api endpoint>"

In [4]:
endpoint = "<azure form recognizer endpoint>"
keys = "<azure form recognizer keys>"
model_id = "prebuilt-layout"

In [7]:
filename = '<fileName>'

In [8]:
docChunk = DocumentChunk.load_keys(endpoint,keys,model_id)

In [9]:
threshold = 300
data = docChunk.get_dataframe(filename,threshold)

In [19]:
data['is_table'] = data['content'].str.contains('<table-start>')

In [22]:
data[data['is_table']==True].page_number.unique()

array([ 2,  7,  8,  9, 10, 11, 13, 14, 16, 17, 18, 22, 23, 24, 25, 26, 27,
       28, 29, 30, 31, 32, 33, 34, 39, 40], dtype=int64)

## Generating Summary

In [40]:
class GPTVisionCall:
    def __init__(self,GPT4V_KEY,GPT4V_ENDPOINT):
        self.endpoint = GPT4V_ENDPOINT
        self.headers = {
            "Content-Type": "application/json",
            "api-key": GPT4V_KEY,
        }
        
    def _get_encoded_image(self,image):
        encoded_image = base64.b64encode(open(image, 'rb').read()).decode('ascii')
        return encoded_image
    
    def _call_api(self,messages):
        # Payload for the request
        payload = {
           "enhancements": {
            "ocr": {
              "enabled": True
            },
            "grounding": {
              "enabled": True
            }},
            "dataSources": [
            {
                "type": "AzureComputerVision",
                "parameters": {
                    "endpoint": "https://imageanalysisvision.cognitiveservices.azure.com/",
                    "key": "b60781f6d99c4869886792006560b097"
                }
            }],
          "messages":messages,
          "temperature": 0,
           "max_tokens": 4096,
        }
        # Send request
        try:
            response = requests.post(GPT4V_ENDPOINT, headers=self.headers, json=payload)
            response.raise_for_status()  # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
            return response.json()
        except requests.RequestException as e:
            raise SystemExit(f"Failed to make the request. Error: {e}")
            return None

In [41]:
def process_single_image(image_path):
    gptv_obj = GPTVisionCall(GPT4V_keys,GPT4V_ENDPOINT)
    encoded_image = gptv_obj._get_encoded_image(image_path)
    #sys_prompt = 'Your tasks is to generate detailed summary of the chart'
    sys_prompt = """You are an assistant tasked with summarizing images for retrieval. \
        These summaries will be embedded and used to retrieve the raw image. \
        Give a detailed summary of the image that is well optimized for retrieval."""
    user_txt_prompt = """Let's think step by step, and Provide detailed summary for the image. Make sure to follow all instructions"""
    user_img_prompt = f"data:image/jpeg;base64,{encoded_image}"
    messages = [{'role':'system','content':sys_prompt},
           {'role':'user','content':[{"type": "text","text": user_txt_prompt},
            {"type": "image_url","image_url": {"url": user_img_prompt}}]}]
    res_json = gptv_obj._call_api(messages)
    return (encoded_image,res_json['choices'][0]['message']['content']) if res_json else (encoded_image,None)

In [1]:
image_path = "Path to image directory"

In [53]:
%%time
images_summaries = [] 
img_base64_list = []
dataframe = pd.DataFrame(columns=['image_name','summary','encoding'])
for path in os.listdir(image_path):
    img_path = os.path.join(image_path,path)
    encoded_img,summary = process_single_image(img_path)
    img_base64_list.append(encoded_img)
    images_summaries.append(summary)
    dataframe.loc[len(dataframe)] = [path,summary,encoded_img]
    import time
    time.sleep(15)

Wall time: 10min 15s


In [61]:
from openai import AzureOpenAI
 
client = AzureOpenAI(
  api_key = "<api-keys>",  
  api_version = "2023-05-15",
  azure_endpoint = "<azure endpoint>"
)
 

def _gpt_text_summary(client,sys_prompt,user_prompt):
    completion = client.chat.completions.create(
      model='gpt-4',
      messages=[
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": user_prompt}
      ]
    )
    return completion

In [63]:
def process_single_text_chunk(element):
    sys_prompt = """You are an assistant tasked with summarizing tables and text for retrieval. \
    These summaries will be embedded and used to retrieve the raw text or table elements.\n
    Table part will start from '<table-start>' and ends with '<table-end>'"""
    user_prompt = f"""Provide detailed and concise summary of the table or text that is well optimized for retrieval.\n\nTable or text: {element} """
    text_summary = _gpt_text_summary(client,sys_prompt,user_prompt)
    return text_summary.choices[0].message.content

In [64]:
%%time
text_summary = []
for content in data['content'].tolist():
    text_summary.append(process_single_text_chunk(content))

Wall time: 6min 33s


In [66]:
data['summary'] = text_summary

In [67]:
data.to_excel('table_text.xlsx',index=False)

In [68]:
from langchain.vectorstores import FAISS
from langchain.chat_models import AzureChatOpenAI
from langchain.vectorstores import Chroma
import uuid
from langchain.embeddings import AzureOpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_core.documents import Document

In [75]:
def extract_number_from_filename(filename):
    # Define the regex pattern
    import re
    pattern = re.compile(r'image(\d+)_\d+\.png')
    
    # Use the regex pattern to find a match in the filename
    match = pattern.match(filename)
    
    # If there's a match, return the extracted number as an integer
    if match:
        return int(match.group(1))
    
    # If no match is found, return None
    return None

In [78]:
dataframe['page_number'] = dataframe['image_name'].apply(extract_number_from_filename)

In [111]:
def create_multi_vector_retriever(
    vectorstore, text_df, image_df
):
    store = InMemoryStore()
    id_key = "doc_id"

    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=store,
        id_key=id_key,
    )

    def add_documents(retriever, doc_summaries, doc_contents,page_number):
        doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
        summary_docs = [
            Document(page_content=s, metadata={id_key: doc_ids[i],'page_number':page_number[i]})
            for i, s in enumerate(doc_summaries)
        ]
        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, doc_contents)))

    if isinstance(text_df, pd.DataFrame):
        add_documents(retriever, text_df['summary'].tolist(), text_df['content'].tolist(),text_df['page_number'].tolist())
    if isinstance(image_df, pd.DataFrame):
        add_documents(retriever, image_df['summary'].tolist(), image_df['encoding'].tolist(),image_df['page_number'].tolist())

    return retriever

In [112]:
%%time
embed_model = AzureOpenAIEmbeddings(
        azure_deployment="text-embedding-ada-002",
        openai_api_version="2023-05-15",
    )
vectorstore = Chroma(
    collection_name='<name of collection>', embedding_function=embed_model,persist_directory="./chroma_db"
)

retriever = create_multi_vector_retriever(vectorstore,data,dataframe)

INFO:backoff:Backing off send_request(...) for 0.1s (requests.exceptions.SSLError: HTTPSConnectionPool(host='app.posthog.com', port=443): Max retries exceeded with url: /batch/ (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1129)'))))
INFO:backoff:Backing off send_request(...) for 0.7s (requests.exceptions.SSLError: HTTPSConnectionPool(host='app.posthog.com', port=443): Max retries exceeded with url: /batch/ (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1129)'))))
INFO:backoff:Backing off send_request(...) for 3.2s (requests.exceptions.SSLError: HTTPSConnectionPool(host='app.posthog.com', port=443): Max retries exceeded with url: /batch/ (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer

Wall time: 4.62 s


ERROR:backoff:Giving up send_request(...) after 4 tries (requests.exceptions.SSLError: HTTPSConnectionPool(host='app.posthog.com', port=443): Max retries exceeded with url: /batch/ (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1129)'))))


In [115]:
import json
## store mapping of summary and actual chunk
with open('chroma_db/<filename>.json', "w") as json_file:
    json.dump(retriever.docstore.store, json_file)