In [1]:
!pip install --upgrade byaldi
!sudo apt-get install -y poppler-utils
!pip install -q pdf2image  qwen-vl-utils flash-attn 


Collecting pyarrow>=15.0.0 (from datasets>=2.2.0->mteb==1.6.35->byaldi)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m172.7 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: pyarrow
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 14.0.1
    Uninstalling pyarrow-14.0.1:
      Successfully uninstalled pyarrow-14.0.1
Successfully installed pyarrow-17.0.0
Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following additional packages will be installed:
  liblcms2-2 libnspr4 libnss3 libpoppler97 poppler-data
Suggested packages:
  liblcms2-utils ghostscript fonts-japanese-mincho | fonts-ipafont-mincho
  fonts-japanese-gothic | fonts-ipafont-gothic fonts-arphic-ukai
  fonts-arphic-uming fonts-nanum


In [2]:
pip install git+https://github.com/huggingface/transformers@21fac7abba2a37fae86106f87fcf9974fd1e3830 accelerate

Collecting git+https://github.com/huggingface/transformers@21fac7abba2a37fae86106f87fcf9974fd1e3830
  Cloning https://github.com/huggingface/transformers (to revision 21fac7abba2a37fae86106f87fcf9974fd1e3830) to /tmp/pip-req-build-xnglbd1w
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-xnglbd1w

  Running command git rev-parse -q --verify 'sha^21fac7abba2a37fae86106f87fcf9974fd1e3830'
  Running command git fetch -q https://github.com/huggingface/transformers 21fac7abba2a37fae86106f87fcf9974fd1e3830
  Running command git checkout -q 21fac7abba2a37fae86106f87fcf9974fd1e3830
  Resolved https://github.com/huggingface/transformers to commit 21fac7abba2a37fae86106f87fcf9974fd1e3830
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Note: you may need to restart the kernel to use updated packages.


In [3]:
pip install numpy==1.23.5 pyarrow==14.0.1 fsspec==2024.6.1


Collecting pyarrow==14.0.1
  Downloading pyarrow-14.0.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.0 kB)
Downloading pyarrow-14.0.1-cp310-cp310-manylinux_2_28_x86_64.whl (38.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m38.0/38.0 MB[0m [31m167.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: pyarrow
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 17.0.0
    Uninstalling pyarrow-17.0.0:
      Successfully uninstalled pyarrow-17.0.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
datasets 3.0.0 requires pyarrow>=15.0.0, but you have pyarrow 14.0.1 which is incompatible.[0m[31m
[0mSuccessfully installed pyarrow-14.0.1
Note: you may need to restart the kernel to use updated packages.


STREAMLIT

In [13]:
%%writefile app.py

import streamlit as st
from byaldi import RAGMultiModalModel
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch
from PIL import Image
import re

def highlight_text(text, term):
    highlighted_text = re.sub(f"({term})", r'<mark>\1</mark>', text, flags=re.IGNORECASE)
    return highlighted_text

@st.cache_resource
def load_models():
    RAG = RAGMultiModalModel.from_pretrained("vidore/colpali")
    
    model = Qwen2VLForConditionalGeneration.from_pretrained(
        "Qwen/Qwen2-VL-2B-Instruct",
        trust_remote_code=True, 
        torch_dtype=torch.bfloat16).cuda().eval()
    
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
    
    return model, processor, RAG

if 'is_indexed' not in st.session_state:
    st.session_state['is_indexed'] = False

st.title("Image to Text Extraction and Search with Highlighting")

uploaded_file = st.file_uploader("Upload an image", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
    # Save the uploaded image to a temporary file
    temp_file_path = f"temp_{uploaded_file.name}"
    with open(temp_file_path, "wb") as f:
        f.write(uploaded_file.getbuffer())

    image = Image.open(uploaded_file)
    images = [image]
    st.image(image, caption='Uploaded Image', use_column_width=True)

    model, processor, RAG = load_models()

    # Text Extraction from Image
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": "Extract the text from this image."},
            ],
        }
    ]

    # Process the image and text for input
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to("cuda")

    # Generate the text from the image using the model
    generated_ids = model.generate(**inputs, max_new_tokens=5000)

    generated_ids_trimmed = [
        out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    extracted_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    extracted_text = "\n".join(extracted_text)  # Convert list to a single string
    
    st.subheader("Extracted Text:")
    st.write(extracted_text)

    # Save the extracted text to a file
    with open("extracted_text.txt", "w", encoding="utf-8") as f:
        f.write(extracted_text)

    #  Search Query
    query = st.text_input("Search in Extracted Text", "")
    
    if query:
        # If the query is a single word, highlight its occurrences
        if len(query.split()) == 1:
            # Highlight the search term in the extracted text
            highlighted_text = highlight_text(extracted_text, query)
            st.subheader("Search Result (Word Occurrences):")
            st.markdown(highlighted_text, unsafe_allow_html=True)
        
        # If the query is more than one word, use RAG for Intelli search
        else:
            # Only index the image once
            if not st.session_state['is_indexed']:
                try:
                    RAG.index(
                        input_path=temp_file_path,  # Use the local file path for indexing
                        index_name="image_index",   # index will be saved at index_root/index_name/
                        store_collection_with_index=False,
                        overwrite=True
                    )
                    st.session_state['is_indexed'] = True  # Mark document as indexed
                except Exception as e:
                    st.error(f"")
            
            # Perform search using the query
            try:
                results = RAG.search(query, k=1)
                query_image_index = results[0]["page_num"] - 1
                
                # Get the result text related to the query
                query_messages = [
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "image",
                                "image": images[query_image_index],
                            },
                            {"type": "text", "text": query},
                        ],
                    }
                ]
                
                # Generate the answer using the RAG model
                text = processor.apply_chat_template(
                    query_messages, tokenize=False, add_generation_prompt=True
                )
                image_inputs, video_inputs = process_vision_info(messages)
                inputs = processor(
                    text=[text],
                    images=image_inputs,
                    videos=video_inputs,
                    padding=True,
                    return_tensors="pt",
                )
                inputs = inputs.to("cuda")
                
                generated_ids_query = model.generate(**inputs, max_new_tokens=1000)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids_query)
                ]
                query_result = processor.batch_decode(
                    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
                )
                
                # Highlight the query within the result
                highlighted_result = highlight_text("\n".join(query_result), query)
                
                # Display the query result
                st.subheader("Search Result (Intelli Answer):")
                st.markdown(highlighted_result, unsafe_allow_html=True)
            
            except Exception as e:
                st.error(f"Error during search: {str(e)}")



Overwriting app.py


In [None]:
! streamlit run app.py


Collecting usage statistics. To deactivate, set browser.gatherUsageStats to False.
[0m
[0m
[34m[1m  You can now view your Streamlit app in your browser.[0m
[0m
[34m  Network URL: [0m[1mhttp://10.192.11.106:8501[0m
[34m  External URL: [0m[1mhttp://54.210.241.79:8501[0m
[0m
Verbosity is set to 1 (active). Pass verbose=0 to make quieter.
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████████████| 2/2 [00:01<00:00,  1.01it/s]
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████████████| 2/2 [00:00<00:00,  5.83it/s]
Starting from v4.46, the `logits` model output will have the same type as the 