<a href="https://colab.research.google.com/github/RDGopal/IB9AU-2026/blob/main/MLM2_OCR_with_Gemini.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Why Modern VLMs Excel for Real-World Receipts

Modern Vision-Language Models (VLMs) like Google Gemini, GPT-4V, or Claude 3  are pre-trained on vast and diverse datasets encompassing both images and text, allowing them to develop a much broader understanding of visual and linguistic concepts. They can understand and reason about images and text in a more general, open-ended, and human-like way, rather than relying on rigid patterns learned from specific document types.

The key advantages of using VLMs for real-world receipts include:

1.  **Better Generalization**: VLMs are inherently more robust to variations in layout, font, language, and content. Their extensive pre-training enables them to process receipts from virtually any source without requiring explicit fine-tuning for each specific layout or design.
2.  **Improved Handling of Noisy or Unstructured Data**: Unlike template-based OCR systems or models heavily reliant on learned document structures, VLMs can often make sense of incomplete, messy, or semi-structured data, inferring meaning from context and visual cues.
3.  **Semantic Understanding**: VLMs move beyond mere character recognition or pattern matching. They understand the *semantics* of the information on a receipt. For instance, they can identify 'total amount' even if it's labeled differently (e.g., 'Grand Total', 'Amount Due', 'Sum') or placed in an unusual location, because they grasp the underlying financial concept.
4.  **Zero-Shot/Few-Shot Learning**: With strong prompting, VLMs can often perform extraction tasks without any specific training examples for a given receipt type, or with very few examples, making them highly adaptable to new data streams.

In essence, VLMs offer a flexible, intelligent, and scalable solution for the unpredictable nature of real-world financial documents, making them ideal for practical FinTech applications.

## Setup Gemini API Key

#### Instructions
1. Go to the Google AI Studio website (https://aistudio.google.com/app/apikey) to generate an API key. Make sure you are logged in with your Google account.
2. Click 'Create API key in new project' or 'Get API Key' to generate a new key.
3. Once generated, copy the API key.
4. In Google Colab, click on the 'Secrets' tab (ðŸ”‘ icon on the left sidebar).
5. Click 'Add new secret'.
6. For the 'Name' field, type `GOOGLE_API_KEY`.
7. For the 'Value' field, paste the API key you copied from Google AI Studio.
8. Ensure the 'Notebook access' checkbox is enabled for this secret.
9. In a new code cell, add the following Python code to load the API key from Colab secrets: `from google.colab import userdata` followed by `GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')`.
10. Finally, set the environment variable for the Gemini API by running `os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY`.

This code cell loads the API key from Colab secrets and sets it as an environment variable for the Gemini API.

In [None]:
from google.colab import userdata
import os

# Load the API key from Colab secrets, assuming it's stored as 'GEMINI_API_KEY'
GOOGLE_API_KEY = userdata.get('GEMINI_API_KEY')

# Set the environment variable for the Gemini API
os.environ["GOOGLE_API_KEY"] = GOOGLE_API_KEY

print("API key loaded and environment variable set.")

This code block imports the `google.generativeai` library and then lists all available Gemini models, filtering for those that support content generation. This helps identify which models can be used for tasks like receipt analysis.

In [None]:
import google.generativeai as genai

print("Listing available Gemini models:")
for m in genai.list_models():
    if "generateContent" in m.supported_generation_methods:
        print(f"Name: {m.name}, Description: {m.description}")

This cell initializes the `gemini_pro_vision` model using the `gemini-2.5-flash` model, preparing it for image and text content generation tasks.

In [None]:
# Initialize the Gemini Pro Vision model with the updated name
gemini_pro_vision = genai.GenerativeModel('gemini-2.5-flash')

print("Google Gemini 2.5 Flash model initialized.")

This code loads an image from the specified path `/content/receipt1.jpg` into a PIL Image object and then displays it.

In [None]:
from PIL import Image
# Load the image from the specified path
receipt_image_path = '/content/receipt1.jpg'
receipt_image = Image.open(receipt_image_path)
print(f"Image '{receipt_image_path}' loaded successfully.")
display(receipt_image)

This code block defines a detailed prompt for the Gemini VLM, instructing it to extract specific information from a receipt image into a structured JSON format. It then prepares the image and sends the request to the Gemini API, finally parsing and printing the structured response.

In [None]:
import io
from PIL import Image
import json # Import json at the top for clarity and consistency

# Define a detailed prompt string for the Gemini VLM
prompt = """
Analyze this receipt image and extract the following information in a structured JSON format.
Ensure that all numerical values are returned as numbers (float or int), not strings.

Extract:
- "store_name": The name of the store.
- "date": The date of the transaction in YYYY-MM-DD format.
- "time": The time of the transaction in HH:MM format (24-hour).
- "items": A list of individual items purchased. Each item should be an object with:
    - "description": Name of the item.
    - "quantity": Number of units purchased (integer).
    - "unit_price": Price per unit (float).
    - "total_price": Total price for that item (float).
- "subtotal": The subtotal amount before tax (float).
- "tax": The tax amount (float).
- "total_amount": The final total amount paid (float).
- "currency": The currency symbol or code (e.g., "USD", "$").

If any information is not present, use null for its value.

Example JSON structure:
{
  "store_name": "Example Store",
  "date": "2023-10-26",
  "time": "14:30",
  "items": [
    {
      "description": "Item A",
      "quantity": 1,
      "unit_price": 10.50,
      "total_price": 10.50
    },
    {
      "description": "Item B",
      "quantity": 2,
      "unit_price": 5.25,
      "total_price": 10.50
    }
  ],
  "subtotal": 21.00,
  "tax": 1.50,
  "total_amount": 22.50,
  "currency": "$"
}
"""

# Convert receipt_image (which is an MPOImageFile) to JPEG bytes
img_byte_arr = io.BytesIO()
receipt_image.save(img_byte_arr, format='JPEG')
img_byte_arr = img_byte_arr.getvalue()

# 3. Create a content list containing the prompt and the converted image
contents = [prompt, {"mime_type": "image/jpeg", "data": img_byte_arr}]

# 4. Call the gemini_pro_vision.generate_content() method
print("Sending request to Gemini Pro Vision model...")
response = gemini_pro_vision.generate_content(contents)

# 5. Access the text part of the response
response_text = response.text

# 6. Print the raw text response
print("\n--- Raw Gemini Response ---")
print(response_text)
print("---------------------------")

# 7. Attempt to parse the response.text as a JSON object
try:
    # Clean the response_text to remove markdown code block fences if present
    if response_text.startswith('```json') and response_text.endswith('```'):
        cleaned_response_text = response_text.lstrip('```json').rstrip('```').strip()
    else:
        cleaned_response_text = response_text.strip()

    structured_data = json.loads(cleaned_response_text)
    print("\n--- Parsed Structured Data (JSON) ---")
    print(json.dumps(structured_data, indent=2))
    print("-------------------------------------")
except json.JSONDecodeError as e:
    print(f"\nError decoding JSON from Gemini output: {e}")
    print("Raw response could not be parsed as valid JSON. Check the output above for format issues.")

This cell loads the `naver-clova-ix/cord-v2` dataset using the `datasets` library.

In [None]:
from datasets import load_dataset

ds = load_dataset("naver-clova-ix/cord-v2")

This cell displays the dataset object, showing its structure and the available splits (train, validation, test) along with their features and number of rows.

In [None]:
ds

This code cell simply prints the total number of records available in the training split of the dataset, providing a quick overview of its size.

In [None]:
len(ds['train'])

This cell converts the first 5 records of the training dataset into a Pandas DataFrame and displays it, allowing for a quick inspection of the data structure and content, including the image and ground truth JSON.

In [None]:
import pandas as pd
pd.DataFrame(ds['train'][:5])

This extensive code block is responsible for displaying the first 5 images and their corresponding ground truth data from the training set. It dynamically generates an HTML table with base64 encoded images and pretty-printed JSON ground truth for better visualization.

In [None]:
from IPython.display import display, HTML
import pandas as pd
import json
import io
import base64

# Assuming the DataFrame with the first 5 rows is named 'df_first_5'
# from the previous execution of cell NypN9z20Mm4E.
# If not, let's re-create it here for clarity:
df_first_5 = pd.DataFrame(ds['train'][:5])

print("Displaying the first 5 images and their ground truth in a table:")

html_output = """
<style>
  table {
    border-collapse: collapse;
    width: 100%;
  }
  th, td {
    border: 1px solid #ddd;
    padding: 8px;
    text-align: left;
    vertical-align: top;
  }
  th {
    background-color: #f2f2f2;
  }
  img {
    max-width: 300px; /* Limit image width */
    height: auto;
    display: block; /* Remove extra space below image */
  }
  pre {
    white-space: pre-wrap; /* Preserve whitespace and wrap text */
    word-wrap: break-word;
  }
</style>
<table>
  <tr>
    <th>Image</th>
    <th>Ground Truth</th>
  </tr>
"""

for i, row in df_first_5.iterrows():
    html_output += "<tr>"
    # Image column
    img_bytes = io.BytesIO()
    row['image'].save(img_bytes, format='PNG') # Save the PIL image to bytes
    img_base64 = base64.b64encode(img_bytes.getvalue()).decode('utf-8')
    html_output += f"<td><img src='data:image/png;base64,{img_base64}' width='200'></td>"

    # Ground Truth column
    ground_truth_json = json.loads(row['ground_truth'])
    html_output += f"<td><pre>{json.dumps(ground_truth_json, indent=2)}</pre></td>"
    html_output += "</tr>"
html_output += "</table>"

display(HTML(html_output))

This code block defines the process of extracting the image, store name, and overall total price for each receipt in `ds['train']`. It cleans the total price string to ensure it can be converted to a float and then creates a DataFrame named `df_receipt_summary` with the extracted information.

In [None]:
import pandas as pd
import json
import re # Import regular expression module

# Initialize an empty list to store the extracted data for each receipt
receipt_data_df = []

# Iterate through each record in ds['train']
for record in ds['train']:
    image = record['image']
    ground_truth_str = record['ground_truth']

    store_name = None
    total_price = None

    try:
        ground_truth_json = json.loads(ground_truth_str)
        gt_parse = ground_truth_json.get('gt_parse', {})

        # 5. Extract the store_name
        company_name = gt_parse.get('company', {}).get('name')
        seller_name = gt_parse.get('seller', {}).get('name')

        if company_name:
            store_name = company_name
        elif seller_name:
            store_name = seller_name

        # 6. Extract the total_price string
        total_info = gt_parse.get('total', {})
        total_price_str = total_info.get('total_price')

        # 7. Clean and convert total_price to float
        if total_price_str is not None and isinstance(total_price_str, str):
            # Remove any non-digit, non-decimal characters (e.g., currency symbols, spaces, thousands separators)
            # First, handle comma as decimal if present, convert to dot, then remove other non-digits except first dot
            cleaned_price = total_price_str.replace(',', '.')

            # If there are multiple dots (e.g., "1.234.567.89"), assume all but the last are thousands separators
            # and consolidate them into a single decimal point or remove.
            # A more robust way: keep digits and only one decimal point.
            parts = re.findall(r'\d+', cleaned_price) # Extract all number sequences
            if parts:
                numeric_string = ''.join(parts)
                # Find the last potential decimal separator
                last_dot_index = cleaned_price.rfind('.')
                if last_dot_index != -1 and last_dot_index > cleaned_price.rfind(parts[-1]): # Check if it's actually a decimal after the last number
                    # Reconstruct with the decimal point
                    integer_part = ''.join(re.findall(r'\d', cleaned_price[:last_dot_index]))
                    decimal_part = ''.join(re.findall(r'\d', cleaned_price[last_dot_index+1:]))
                    cleaned_price = f"{integer_part}.{decimal_part}"
                else:
                    cleaned_price = numeric_string
            else:
                cleaned_price = ""

            try:
                if cleaned_price:
                    total_price = float(cleaned_price)
                else:
                    total_price = None
            except ValueError:
                total_price = None # Conversion failed

        # 8. Append extracted data to the list
        receipt_data_df.append({
            'image': image,
            'store_name': store_name,
            'total_price': total_price
        })

    except json.JSONDecodeError as e:
        print(f"Error decoding JSON for a record: {e}")
        continue # Skip to the next record if JSON is invalid
    except Exception as e:
        print(f"An unexpected error occurred for a record: {e}")
        continue # Skip to the next record if an error occurs

# Create the new DataFrame
df_receipt_summary = pd.DataFrame(receipt_data_df)

print(f"Created df_receipt_summary with {len(df_receipt_summary)} records.")
print("First 5 rows of df_receipt_summary:")
display(df_receipt_summary.head())

This cell assigns the `df_receipt_summary` DataFrame (created in the previous data extraction step) to `df_receipt`, as per the task instructions. It then prints the number of records and displays the head of `df_receipt` to verify its content and structure.

In [None]:
df_receipt = df_receipt_summary

print(f"Number of records in df_receipt: {len(df_receipt)}")
print("First 5 rows of df_receipt:")
display(df_receipt.head())

This cell calculates and prints the number of records in the `df_receipt` DataFrame where the 'store_name' column has a `None` value. This is useful for understanding the completeness of the extracted store name data.

In [None]:
num_none_store_name = df_receipt['store_name'].isnull().sum()
print(f"Number of records with None in 'store_name': {num_none_store_name}")

This cell removes the 'store_name' column from the `df_receipt` DataFrame, as it was found to be entirely `None`. It then displays the head of the DataFrame to confirm the column's removal.

In [None]:
df_receipt = df_receipt.drop(columns=['store_name'])
print(" 'store_name' column removed from df_receipt.")
print("First 5 rows of df_receipt after column removal:")
display(df_receipt.head())

#Required Task 10
Your task is to utilize the Gemini VLM to predict the `total_price` for a subset of receipts and then evaluate the model's performance against the ground truth `total_price` already present in `df_receipt`.

#### Instructions:

1.  **Randomly Select 15 Records**:
    *   From the `df_receipt` DataFrame, randomly select 100 receipts. This will be your test set for Gemini's prediction.
    *   Store these 15 records in a new DataFrame, say `df_test_receipts`.

2.  **Define a Prompt for Gemini**:
    *   Create a clear and concise prompt that instructs Gemini to extract only the `total_price` from a given receipt image. Emphasize that the output should be a single numerical value (float).
    *   Example prompt: `"Extract the total amount from this receipt. Provide only the numerical value as a float."`

3.  **Process `df_test_receipts` with Gemini**:
    *   Iterate through each row in `df_test_receipts`.
    *   For each receipt's image, call the Gemini VLM with your defined prompt.
    *   Parse Gemini's response to extract the predicted `total_price`. Handle potential errors (e.g., non-numeric responses, API issues) by assigning `None` or `NaN` if a valid price cannot be extracted.
    *   Add the extracted prediction as a new column, `predicted_total_price`, to `df_test_receipts`.

4.  **Evaluate Predictions**:
    *   Compare the `predicted_total_price` with the `total_price` (ground truth) in `df_test_receipts`.
    *   Calculate appropriate evaluation metrics. Consider the following:
        *   **Mean Absolute Error (MAE)**: Average of the absolute differences between predicted and actual values.
        *   **Number of successful extractions**: Count how many predictions were successfully extracted (not `None` or `NaN`).
        *   **Accuracy within a threshold**: Calculate the percentage of predictions that are within a certain percentage (e.g., 5% or 10%) of the ground truth.

5.  **Display Results**:
    *   Print the calculated evaluation metrics.
    *   Display `df_test_receipts` with `total_price` and `predicted_total_price` columns for a few sample rows to show the comparison.
