# Invoke Fine-Tuned Document to JSON multi-modal model

This notebook demonstrates how to process documents using a deployed SageMaker endpoint. It shows how to:

1. Load and process images
2. Make requests to a SageMaker endpoint
3. Handle responses and extract information

## Prerequisites

Before running this notebook, ensure you have:

- AWS credentials configured
- Fine-tuned and deployed a model to a SageMaker endpoint [06_deploy_model_endpoint.ipynb](./06_deploy_model_endpoint.ipynb)
- Access to the document images

## Setup

First, let's install and import the required modules.

In [None]:
%pip install --quiet boto3 json2table

In [None]:
import json
from pathlib import Path
from IPython.display import display, JSON, Image as IPImage, HTML
from typing import Dict, Optional, Any, Union
import boto3
from PIL import Image
import io
import base64
from json2table import convert
from utils.docdiff import get_pil_image,image_formatter

## Helpers
The following helper functions will help to process the input, output and invoke the endpoint. 

In [None]:
def display_results(image_path: str, response: Dict) -> None:
    """
    Display input image and extracted information side by side using flexbox layout.
    
    Args:
        image_path: Path to the input image
        response: Processed response from the endpoint
    """
    # Create display objects
    json_content = response.get("choices", [{}])[0].get("message", {}).get("content", {})

    if json_content:
        json_content = json.loads(json_content)
   
        html_table = convert(
            json_content, 
            table_attributes={"class": "table table-striped"}
        )
    else:
        html_table = "No content found in response"
   
    img_html = image_formatter(get_pil_image(image_path))

    html = f"""
    <div style="font-size: 24px; font-weight: bold; text-align: center; margin-bottom: 20px;">Invoice Details</div>
    <div style="display: flex; gap: 0px; align-items: flex-start;">
        <!-- Image Section -->
        <div style="flex: 1; text-align: center;">
            {img_html}
        </div>
    
        <!-- Table Section -->
        <div style="flex: 1; overflow-x: auto;">
           {html_table}
        </div>
    </div>
    """

    display(HTML(html))

In [None]:
def encode_image(image_path: Union[str, Path]) -> str:
    """
    Convert image to base64 string with proper MIME type prefix.
    
    Args:
        image_path: Path to the image file
        
    Returns:
        Base64 encoded image string with data URI prefix
        
    """
    try:
        with Image.open(image_path) as image:
            buffered = io.BytesIO()
            image.save(buffered, format=image.format)
            mime_type = Image.MIME[image.format]
            img_str = f"data:{mime_type};base64,{base64.b64encode(buffered.getvalue()).decode()}"
            return img_str
    except Exception as e:
        print(f"Error loading image: {str(e)}")
        return None

In [None]:
def load_schema(schema_path: Optional[str]) -> Dict[str, Any]:
    """
    Load JSON schema from file.

    Args:
        schema_path: Path to the json schema file
    
    Returns:
        Loaded JSON schema
    """
    if schema_path:
        try:
            with open(schema_path, 'r') as f:
                json_schema = json.load(f)
                return json_schema
            print("Schema loaded successfully")
        except Exception as e:
            print(f"Error loading schema: {str(e)}")
    
    return None

In [None]:
def prepare_payload(model_name: str, image_base64: Optional[str], properties_to_extract: str = "", guided_decoding = None) -> Dict[str, Any]:
    """
    Prepare the request payload.
    
    Args:
        model_name: the model to invoke
        image_base64: Optional Base64 encoded image
        properties_to_extract: Optional JSON key names to extract. 
                                They will be added to prompt.
        schema: Optional vllm structured output configuration 
        
    Returns:
        Request payload dictionary
    """
    content = []

    content.append({
        "type": "text",
        "text": "Document pages: "
    })
    
    if image_base64:
        content.append({
            "type": "image_url",
            "image_url": {"url":image_base64}
        })

    
    content.append({
        "type": "text",
        "text": f"Process all document pages and extract the following information in JSON format: {properties_to_extract}"
    })

    
    
    payload = {
        "model": model_name,
        "messages": [
            {
                "role": "system", 
                "content": "You are a document processing expert and assistant."
            },
            {
                "role": "user",
                "content": content
            }
        ],
        "extra_body": guided_decoding
    }

    # if schema:
    #     payload["guided_json"] = schema

    return payload

    

In [None]:
def invoke_endpoint(endpoint_name: str, payload: Dict[str, Any]) -> Dict[str, Any]:
    """
    Process a document using the SageMaker endpoint.
    
    Args:
        endpoint_name: name of the SageMaker endpoint
        payload: the input to send to the endpoint
        
    Returns:
        Processed document data
    """
    runtime = boto3.client('sagemaker-runtime')
    

    print(f"Invoking {endpoint_name}")
    # Invoke endpoint
    response = runtime.invoke_endpoint(
        EndpointName=endpoint_name,
        ContentType='application/json',
        Body=json.dumps(payload)
    )
    
    # Parse response
    response_body = json.loads(response['Body'].read().decode())
    return response_body
            

## Configuration
Replace the values below.

In [None]:
ENDPOINT_NAME = "document-to-json" # Replace with your endpoint name
MODEL_NAME = ENDPOINT_NAME # Replace with vllm served model name if you changed it
IMAGE_PATH = "./data/Fatura2-invoices-original-strat2/images/Template1_Instance151.png" # Replace with image path

Optionally you can use structured output / constrained decoding to guide the models response structure, for example to only allow valid JSON:

In [None]:
SCHEMA_PATH = None

SCHEMA_PATH = "./data/Fatura2-invoices-original-strat2/groundtruth_schema.json" # optional replace with json schema file of expected output format

In [None]:
guided_decoding = None

In [None]:
properties_to_extract = ""
# properties_to_extract = "AMOUNT_DUE, BILL_TO, BUYER, CONDITIONS, DATE, DISCOUNT, DUE_DATE, GST(1%), GST(12%), GST(18%), GST(20%), GST(5%), GST(7%), GST(9%), GSTIN, GSTIN_BUYER, GSTIN_SELLER, INVOICE_INFO, LOGO, NOTE, NUMBER, PAYMENT_DETAILS, PO_NUMBER, SELLER_ADDRESS, SELLER_EMAIL, SELLER_NAME, SELLER_SITE, SEND_TO, SUB_TOTAL, TABLE, TAX, TITLE, TOTAL, TOTAL_WORDS"

## Prepare Input

In [None]:
if SCHEMA_PATH:
    schema = load_schema(SCHEMA_PATH)
    guided_decoding =  {"guided_json": schema}
    properties_to_extract = ", ".join(schema.get("required", []))

In [None]:
image_base64 = encode_image(IMAGE_PATH)
    
payload = prepare_payload(MODEL_NAME, image_base64, properties_to_extract, guided_decoding)

## Invoke Endpoint

In [None]:
response = invoke_endpoint(ENDPOINT_NAME, payload)

Next, let's look at the raw response.

## Inspect Response

In [None]:
JSON(response, expanded=True)

Next, let's show the image of the document and the extracted JSON side by side:

In [None]:
display_results(IMAGE_PATH, response)

## Conclusion

We received a structured JSON output from the fine-tuned and deployed endpoint by sending the image of a document and the desired target schema.