# Vision QA Agent with Mistral and LlamaIndex

## Introduction 

This notebook provides an integrated environment for processing images and PDFs through advanced OCR and vision-based AI models. It primarily utilises the Mistral OCR model for extracting text from documents and Pixtral 12B (via AWS Bedrock Marketplace) for visual understanding and analysis. It also demonstrates the use of LlamaIndex framework and an agent system that intelligently selects the appropriate method for user queries, either performing OCR extraction or vision analysis tasks.

## Prerequisites:
1.  AWS Account Access:
- Active AWS account with permissions to use Amazon Bedrock.

2. Amazon Bedrock Service:

- Have model access to Mistral Large 2 (mistral.mistral-large-2407-v1:0)

- Permissions to invoke Bedrock models (Bedrock Converse API).

3. Mistral API Access:

- Valid API Key for the Mistral AI platform to utilise OCR services.

In [None]:
!pip install -U mistral llama-index llama-index-core llama-index-llms-bedrock-converse llama-index-llms-bedrock

In [None]:
!pip install mistralai

- Sets up the foundational environment with essential libraries for AWS, Mistral API integration, and document handling.

In [None]:

import glob
import base64
import mimetypes
from pathlib import Path
import boto3
import tempfile
from PIL import Image
from pdf2image import convert_from_path

from mistralai import Mistral

- Configures the AI models used, specifying the Mistral Large 2 (mistral-large-2407-v1) model with a defined token limit for interaction.

In [None]:
from llama_index.llms.bedrock_converse import BedrockConverse
from llama_index.core.agent import FunctionCallingAgent
from llama_index.core.tools import FunctionTool

from llama_index.core import Settings

llm = BedrockConverse(model="mistral.mistral-large-2407-v1:0", max_tokens = 2048)
Settings.llm = BedrockConverse(model="mistral.mistral-large-2407-v1:0", max_tokens = 2048)

In [None]:

from mistralai import DocumentURLChunk, ImageURLChunk, TextChunk
MISTRAL_API_KEY="<ADD_YOUR_MISTRAL_KEY>"
client = Mistral(api_key=MISTRAL_API_KEY)

- Defines the OCR extraction functionality. It identifies the file format and performs OCR using the Mistral OCR API, returning markdown-formatted text. It supports PDF, JPG, JPEG, and PNG formats.

In [None]:
def extract_info_use_OCR(file_prefix='uploaded_file', model="mistral-ocr-latest") -> str:
    """
    Extracts text information from an image or PDF file using OCR.
    
    Args:
        file_prefix (str): Prefix of the file without extension (e.g., 'uploaded_file').
        model (str): The OCR model to use for processing. Defaults to "mistral-ocr-latest".

    Returns: 
        str: A formatted string containing the extracted text in markdown format.
    
    Raises:
        ValueError: If the file is not found or has an unsupported format.
        AssertionError: If the file doesn't exist.
    """
    matching_files = glob.glob(f"{file_prefix}.*")
    
    if not matching_files:
        raise FileNotFoundError(f"No files found with prefix: {file_prefix}")
    
    if len(matching_files) > 1:
        raise ValueError(f"Multiple files found with prefix {file_prefix}: {matching_files}")
    
    file_path = matching_files[0]

    file = Path(file_path)
    assert file.is_file(), f"File not found: {file_path}"

    # Determine file type
    file_extension = file.suffix.lower()
    mime_type, _ = mimetypes.guess_type(file_path)
    
    # Process based on file type
    if file_extension in ['.jpg', '.jpeg', '.png'] or mime_type in ['image/jpeg', 'image/png']:
        return _process_image(file, model)
    elif file_extension == '.pdf' or mime_type == 'application/pdf':
        return _process_pdf(file, model)
    else:
        raise ValueError(f"Unsupported file format: {file_extension}. Supported formats: JPG, JPEG, PNG, PDF")

def _process_image(image_file, model):
    """Helper function to process image files with OCR."""
    # Encode image as base64 for API
    encoded = base64.b64encode(image_file.read_bytes()).decode()
    base64_data_url = f"data:image/jpeg;base64,{encoded}"

    # Process image with OCR
    image_response = client.ocr.process(
        document=ImageURLChunk(image_url=base64_data_url),
        model=model
    )

    image_ocr_markdown = image_response.pages[0].markdown

    response = f"""
    This is image's OCR in markdown:\n\n{image_ocr_markdown}\n.\n 
    """ 
    return response

def _process_pdf(pdf_file, model):
    """Helper function to process PDF files with OCR."""
    # Upload PDF file to Mistral's OCR service
    uploaded_file = client.files.upload(
        file={
            "file_name": pdf_file.stem,
            "content": pdf_file.read_bytes(),
        },
        purpose="ocr",
    )

    # Get URL for the uploaded file
    signed_url = client.files.get_signed_url(file_id=uploaded_file.id, expiry=1)

    # Process PDF with OCR, including embedded images
    pdf_response = client.ocr.process(
        document=DocumentURLChunk(document_url=signed_url.url),
        model=model,
        include_image_base64=True
    )
    
    # Collect all page content
    all_pages_markdown = []
    for page in pdf_response.pages:
        all_pages_markdown.append(page.markdown)
    
    combined_markdown = "\n\n".join(all_pages_markdown)
    
    response = f"""
    This is the information extracted from the file:\n\n{combined_markdown}\n.\n 
    """
    return response

extract_info_use_OCR_tool = FunctionTool.from_defaults(fn=extract_info_use_OCR)

- Handles image-based queries by integrating AWS Bedrock and the Pixtral 12B vision model. This includes converting PDFs to images and processing these to derive textual insights or analysis.

In [None]:
# Create bedrock runtime object
bedrock_runtime = boto3.client("bedrock-runtime")

def vision_unstanding_use_pixtral(prompt="", file_prefix='uploaded_file'):
    """
    Processes images or PDF documents with a vision AI model, Pixtral Model.
    
    Args:
        prompt (str): Text prompt to send to the vision model. Defaults to empty string.
        file_prefix (str): Prefix of the file to process (without extension). 
                           Defaults to 'uploaded_file'.
                          The function will search for any file matching this prefix.
                          
    Returns:
        str: The text response from the vision model based on the provided images and prompt.

    """
    # Configuration
    model_id = '<BEDROCK_MARKETPLACE_MODEL_ARN>'
    config = {
        "temperature": 0.6,
        "top_p": 0.9,
        "max_tokens": 5000
    }
    
    system_prompt = '''
    You are a helpful ai assistant for a vision related task. 
    You generate insights and answer questions based on provided images. 
    '''
    
    # Find files matching the prefix
    matching_files = glob.glob(f"{file_prefix}.*")
    
    if not matching_files:
        raise FileNotFoundError(f"No files found with prefix: {file_prefix}")
    
    if len(matching_files) > 1:
        raise ValueError(f"Multiple files found with prefix {file_prefix}: {matching_files}. Please specify a unique prefix.")
    
    file_path = matching_files[0]
    file = Path(file_path)


    # Process file based on type
    try:
        if file.suffix.lower() == '.pdf':
            image_paths = _convert_pdf_to_images(file)
            temp_dir_created = True
        elif file.suffix.lower() in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
            image_paths = [file_path]
            temp_dir_created = False
        else:
            raise ValueError(f"Unsupported file format: {file.suffix}. Supported formats: PDF, JPG, JPEG, PNG, GIF, BMP")
            
        # Get model response from images
        response = _get_vision_model_response(prompt, image_paths, model_id, system_prompt, config)
        
        # Clean up temporary files if needed
        if temp_dir_created:
            _cleanup_temp_files(image_paths[0].parent)
            
        return response
        
    except Exception as e:
        # Clean up on error if needed
        if 'temp_dir_created' in locals() and temp_dir_created and 'image_paths' in locals() and image_paths:
            _cleanup_temp_files(image_paths[0].parent)
        raise

def _convert_pdf_to_images(pdf_path):
    """Convert PDF pages to optimized images."""
    # Create temporary directory
    temp_dir = Path(tempfile.mkdtemp())
    
    # Convert PDF pages with optimized settings
    pages = convert_from_path(str(pdf_path), dpi=100)
    
    image_paths = []
    for i, page in enumerate(pages):
        # Resize to 50% of original size
        width, height = page.size
        new_width = width // 2
        new_height = height // 2
        resized_page = page.resize((new_width, new_height))
        
        # Save as optimized PNG
        image_path = temp_dir / f'page_{i+1}.png'
        resized_page.save(
            image_path, 
            'PNG',
            optimize=True,
            quality=70
        )
        image_paths.append(image_path)
    
    return image_paths

def _get_image_format(image_path):
    """Determine the format of an image file."""
    with Image.open(image_path) as img:
        fmt = img.format.lower() if img.format else 'jpeg'
        if fmt == 'jpg':
            fmt = 'jpeg'
    return fmt

def _get_vision_model_response(prompt, image_paths, model_id, system_prompt, config):
    """Send images and prompt to the vision model and get response."""
    # Build content blocks
    content_blocks = []
    
    # Add text prompt if provided
    if prompt.strip():
        content_blocks.append({"text": prompt})
    
    # Add images
    for img_path in image_paths:
        fmt = _get_image_format(img_path)
        with open(img_path, 'rb') as f:
            image_raw_bytes = f.read()
            
        content_blocks.append({
            "image": {
                "format": fmt,
                "source": {
                    "bytes": image_raw_bytes
                }
            }
        })
    
    # Construct message payload
    messages = [{"role": "user", "content": content_blocks}]
    
    # Create request payload
    request_payload = {
        "messages": messages,
        "inferenceConfig": {
            "maxTokens": config["max_tokens"],
            "temperature": config["temperature"],
            "topP": config["top_p"]
        },
        "system": [{"text": system_prompt}],
        "modelId": model_id
    }
    
    # Call the model API
    response = bedrock_runtime.converse(**request_payload)
    
    # Extract text from response
    assistant_message = response.get('output', {}).get('message', {})
    assistant_content = assistant_message.get('content', [])
    result_text = "".join(block.get('text', '') for block in assistant_content)
    
    return result_text

def _cleanup_temp_files(temp_dir):
    """Clean up temporary files and directories."""
    import shutil
    try:
        shutil.rmtree(temp_dir)
    except Exception:
        pass

vision_unstanding_use_pixtral_tool = FunctionTool.from_defaults(fn=vision_unstanding_use_pixtral)

- Implements an agent using Llamaindex that decides which tool to invoke (OCR extraction or visual understanding) based on user queries, enhancing the responsiveness and usability of the system.

In [None]:
from llama_index.core.agent import FunctionCallingAgentWorker
from llama_index.core.agent import AgentRunner

import time
current_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))

system_prompt = f"""
You are a Vision QA assistant that extracts information from uploaded images or PDF files.

When responding to user queries:
- Use extract_info_use_OCR_tool when asked to extract or list information from image/file
- Use vision_understanding_use_pixtral_tool when asked for general understanding or questions about image/file
  - Pass the user's original query as input to this tool

If you don't know the answer, respond only with: "Sorry, I don't know." Never fabricate information.

Current time is: {current_time}
"""

agent_worker = FunctionCallingAgentWorker.from_tools(
    [extract_info_use_OCR_tool, vision_unstanding_use_pixtral_tool], 
    llm=llm, 
    verbose=False, # Set verbose=True to display the full trace of steps. 
    system_prompt = system_prompt,
    # allow_parallel_tool_calls = True # Uncomment this line to allow multiple tool invocations
)
agent = AgentRunner(agent_worker)

- Demonstrates real-time interaction where the agent processes user input and chooses appropriate tools to extract information or analyse visual content.

In [None]:
while True:
    text_input = input("User: ")
    if text_input == "exit":
        break
    response = agent.chat(text_input)
    print(f"Agent: {response}")
    print("-" * 120)
    print(" New question: ")

## Gradio App 

In [None]:
!pip install gradio

In [None]:
# Define custom CSS for a box-like appearance
custom_css = """
.my-box {
    border: 2px solid #ccc;
    padding: 16px;
    border-radius: 8px;
    background-color: #f9f9f9;
    max-width: 300px;  /* sets a maximum width */
}
"""

In [None]:
import gradio as gr
import os
import shutil

# Function to handle file upload and renaming
def upload_file(file):
    if not file:
        return None, "⚠️ No file uploaded yet. Please upload a file first.",gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)

    # Get the file extension
    file_extension = os.path.splitext(file.name)[1].lower()

    # List of allowed extensions
    allowed_extensions = [".pdf", ".jpg", ".jpeg", ".png"]

    # Check if the uploaded file has a valid extension
    if file_extension not in allowed_extensions:
        return None, f"⚠️ Invalid file type. Please upload a file with one of the following extensions: {', '.join(allowed_extensions)}"

    # Create a new filename with the same extension
    new_filename = f"uploaded_file{file_extension}"
    new_path = os.path.join("./", new_filename)

    # Copy the file to the new location
    shutil.copy(file.name, new_path)

    # Return the path and a success message

    if file_extension in [".jpg", ".jpeg", ".png"]:
        return new_path, f"✅ File uploaded and renamed to {new_filename}", gr.update(value=file, visible=True), gr.update(visible=False), gr.update(visible=False)  # Show Image
    elif file_extension == ".pdf":
        return new_path, f"✅ File uploaded and renamed to {new_filename}", gr.update(visible=False), gr.update(value=file, visible=False), gr.update(visible=False) # Show PDF
    else:
        return new_path, f"✅ File uploaded and renamed to {new_filename}", gr.update(visible=False), gr.update(visible=False), gr.update(value="⚠️ Unsupported file type. Please upload an image or PDF.", visible=True)

def display_uploaded_file(file):
    # Check the file type and display the appropriate component
    file_extension = os.path.splitext(file.name)[1].lower()
    
    if file_extension in [".jpg", ".jpeg", ".png"]:  # Image file types
        return file  # Return image for display
    elif file_extension == ".pdf":  # PDF file type
        return file  # Return PDF file for download
    else:
        return "⚠️ Unsupported file type. Please upload an image or PDF."

# Function to clear uploaded file, chat history, trace output, and reset messages
def reset_all():
    # Remove the uploaded file if it exists
    uploaded_file_name = "uploaded_file"  # Set the filename you want to delete
    for filename in os.listdir("."):
        # Check if the file starts with "uploaded_file" (ignoring the suffix)
        if filename.startswith(uploaded_file_name):
            # Construct the full file path
            file_path = os.path.join(".", filename)
            
            # Check if it's a file (not a directory) before removing
            if os.path.isfile(file_path):
                os.remove(file_path)
                print(f"✅ {filename} has been deleted.")

    # Return the reset states
    return None, [], "⚠️ No file uploaded yet. Please upload a file first.", "", "", None


def add_user_message(chat_history, user_input):
    chat_history = chat_history or []
    chat_history.append((user_input, None))  # Append user's message with no response yet
    return chat_history

# Function to generate responses and log the LLM reasoning process
def chat_with_trace(user_input, uploaded_file_path, file_status, chat_history=[]):
    # Check if a file has been uploaded
    if not uploaded_file_path:
        chat_history.append((user_input, "⚠️ Please upload a file first before we continue the conversation."))
        return chat_history, "⚠️ No file has been uploaded yet."
    
    # Store the reasoning process
    trace_steps = []
    trace_steps.append("🔄 Thinking...")
    trace_steps.append(f"📁 Using uploaded file: {uploaded_file_path}")
    
    # Get the AI response
    response = agent.chat(user_input)
    
    for tool_output in response.sources:
        tool_name = tool_output.tool_name  # Name of the tool invoked
        raw_output = tool_output.raw_output  # The raw output from the tool
        trace_steps.append("🔧 Calling Function...")
        trace_steps.append(tool_name)
        trace_steps.append("🔧 Function Output...")
        trace_steps.append(raw_output)

    ai_response = response.response

    # Append to chat history
    chat_history.append((None, ai_response))

    trace_steps.append("✅ Response Generated.")

    # Format trace output as a bullet list
    trace_text = "\n".join(trace_steps)
    
    return chat_history, trace_text, ""


In [None]:
# Create a Gradio interface
with gr.Blocks(css=custom_css) as chatbot_ui:

    gr.Markdown("")
    gr.Markdown("# 🤖 Chatbot with Mistral Models and Agent")

    with gr.Column(elem_classes="my-box"):
        gr.Markdown("""
        ### **🔥 Model and Agent Framework**\n\n
        - Agent Orchastration Model: **Mistral Large 2** \n\n
        - OCR Model: **Mistral OCR** \n\n
        - Vision Language Model (VLM): **Pixtral 12B** \n\n
        - Agent Framework: **LlamaIndex**\n\n
        """)
    

    uploaded_file_path = gr.State(None)
    chat_history = gr.State([])
    with gr.Row():
        with gr.Column(scale=2):  # Left: Chat Interface
            
            chatbox = gr.Chatbot(label="Chat Window")
            user_input = gr.Textbox(label="Your Message", placeholder="Type a message...")
            send_button = gr.Button("Send")

            # button to clear the file and chat history
            reset_button = gr.Button("Clear All")

        with gr.Column(scale=1):  # Right: Upload File and  LLM Trace
            # Add file upload component
            file_upload = gr.File(
                label="Upload File (PDF, JPG, PNG, JPEG)",
                file_types=[".pdf", ".jpg", ".jpeg", ".png"],
                type="filepath"
            )
            image_display = gr.Image(label="Uploaded Image", visible=False)
            pdf_display = gr.File(label="Uploaded PDF File", visible=False)
            error_message = gr.Textbox(label="Error", interactive=False, visible=False)
            
            # Status message for file upload
            file_status = gr.Textbox(
                label="File Status", 
                value="⚠️ No file uploaded yet. Please upload a file first.",
                interactive=False
            )
            
            trace_output = gr.Textbox(label="Agent Traces", interactive=False, lines=6)

    # Trigger the upload function when a file is uploaded
    file_upload.change(
        fn=upload_file,  # The function to run when a file is uploaded
        inputs=[file_upload],  # The input (the uploaded file)
        outputs=[uploaded_file_path, file_status, image_display, pdf_display, error_message]  # Outputs: the file path and the status message
    )

     # Button click event 
    send_button.click(
        fn=add_user_message,
        inputs=[chatbox, user_input],
        outputs=[chatbox]
    ).then(
        fn=chat_with_trace,
        inputs=[user_input, uploaded_file_path, file_status, chatbox],
        outputs=[chatbox, trace_output, user_input]
    )

    # Button click event to clear the file, chat history, and reset the status
    reset_button.click(
        fn=reset_all,  # Function to reset everything
        inputs=[],  # No inputs needed for the reset
        outputs=[file_upload, chat_history, file_status, trace_output, user_input, chatbox]  # Outputs: reset values for file, chat history, file status, and trace output
    )

# Launch the Gradio app
chatbot_ui.launch(share=True)

In [None]:
agent.memory.reset() # clear the chat memory

## Conclusion 

The notebook effectively integrates multiple AI-based technologies for comprehensive document and image analysis, demonstrating robust OCR capabilities combined with advanced visual understanding. It illustrates the practical application of AI agents in automating decision-making for document processing tasks, providing versatile interaction and clear results.