# Homework 1: Multi-Agent Supermarket Bill Analyzer

## Solution Overview
This solution implements a robust multi-agent system with:
- **Parallel Processing**: Concurrent analysis of multiple bill images
- **Python Calculator Tool**: Precise mathematical calculations
- **Reflection Agent**: Multi-step verification for accuracy
- **Query Classification**: Rejection of irrelevant queries

## 1. Package Installation

In [1]:
!pip install langchain_google_genai langchain langchain-core -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.5/66.5 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m719.4/719.4 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m234.9/234.9 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-colab 1.0.0 requires google-auth==2.43.0, but you have google-auth 2.47.0 which is incompatible.[0m[31m
[0m

## 2. API Key Setup

Store your Google API key in Colab Secrets as `VERTEX_API_KEY`

In [40]:
import os
from google.colab import userdata

# Initialize the Gemini model
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    api_key=userdata.get('VERTEX_API_KEY'), # Ensure this key is set in Colab secrets
    temperature=0,
    vertexai=True
)

## 3. Import Dependencies

In [41]:
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import RunnableParallel, RunnableLambda
from langchain_core.tools import Tool
from langchain_core.prompts import PromptTemplate
from PIL import Image
import base64
from io import BytesIO
import json
import re
from typing import List, Dict, Any

## 4. Helper Functions

In [19]:
import base64
import mimetypes

# Helper function to read and encode image
def image_to_base64(img_path):
    with open(img_path, "rb") as img_file:
        return base64.b64encode(img_file.read()).decode('utf-8')

# Helper function to encode local file to Base64 Data URL
def get_image_data_url(image_path):
    # Guess the mime type (e.g., image/png, image/jpeg) based on file extension
    mime_type, _ = mimetypes.guess_type(image_path)
    if mime_type is None:
        mime_type = "image/png" # Default fallback

    encoded_string = image_to_base64(image_path)

    # Construct the Data URL
    return f"data:{mime_type};base64,{encoded_string}"

## 5. Python Calculator Tool

Provides exact mathematical calculations for financial amounts

In [27]:
def python_calculator(expression: str) -> str:
    """
    Safely evaluate mathematical expressions using Python.

    Args:
        expression: A mathematical expression as string (e.g., "123.5 + 456.7")

    Returns:
        The calculated result as a string
    """
    try:
        # Clean the expression
        expression = expression.strip()

        # Only allow safe mathematical operations
        allowed_chars = set('0123456789+-*/().[] ,')
        if not all(c in allowed_chars for c in expression):
            return "Error: Invalid characters in expression"

        # Evaluate the expression
        result = eval(expression, {"__builtins__": {}}, {})

        # Round to 2 decimal places for currency
        if isinstance(result, (int, float)):
            return str(round(result, 2))
        else:
            return str(result)
    except Exception as e:
        return f"Error: {str(e)}"

# Create the calculator tool
calculator_tool = Tool(
    name="Calculator",
    func=python_calculator,
    description="""Useful for performing exact mathematical calculations.
    Input should be a valid Python expression like '123.5 + 456.7' or 'sum([12.3, 45.6, 78.9])'.
    Always use this tool for adding numbers to ensure accuracy."""
)

## 6. Multi-Agent System Components

In [26]:
class BillAnalysisAgent:
    """
    Main agent for analyzing supermarket bills.
    Uses parallel processing for multiple images and Python calculator for accuracy.
    """

    def __init__(self, llm):
        self.llm = llm
        self.calculator = calculator_tool

    def classify_query(self, query: str) -> Dict[str, Any]:
        """
        Classify the user query into valid or invalid categories.
        """
        classification_prompt = f"""Analyze this query about supermarket bills and classify it:

Query: "{query}"

Valid query types:
1. "How much money did I spend in total?" (or similar)
2. "How much would I have paid without discount?" (or similar)

Respond in JSON format:
{{
  "is_valid": true/false,
  "query_type": "total_spent" / "without_discount" / "invalid",
  "reason": "explanation"
}}
"""

        response = self.llm.invoke([HumanMessage(content=classification_prompt)])

        # Parse JSON response
        try:
            result = json.loads(response.content)
            return result
        except:
            # Fallback parsing
            content = response.content.lower()
            if "total" in query.lower() and ("spend" in query.lower() or "spent" in query.lower()):
                return {"is_valid": True, "query_type": "total_spent", "reason": "Query about total spending"}
            elif "without" in query.lower() and "discount" in query.lower():
                return {"is_valid": True, "query_type": "without_discount", "reason": "Query about price without discount"}
            else:
                return {"is_valid": False, "query_type": "invalid", "reason": "Query is not about bill totals or discounts"}

    def extract_bill_info(self, image_path: str, query_type: str) -> Dict[str, Any]:
        """
        Extract financial information from a single bill image.
        """
        # Load image
        image_data_url = get_image_data_url(image_path)

        if query_type == "total_spent":
            extraction_prompt = """Analyze this supermarket bill image carefully.

Extract the FINAL TOTAL amount that was actually paid (after all discounts).

Look for:
- "Total" or "總額" or "合計"
- The final amount charged
- The amount after all discounts applied

Respond in JSON format:
{
  "final_total": <number>,
  "currency": "HKD" or other,
  "confidence": "high"/"medium"/"low"
}

Be precise with the number. Only return the actual amount paid."""
        else:  # without_discount
            extraction_prompt = """Analyze this supermarket bill image carefully.

Calculate what the total would have been WITHOUT any discounts.

To do this:
1. Find the final total amount paid
2. Find all discount amounts (look for "Save", "Discount", "折扣", "優惠" etc.)
3. Sum all the discounts
4. Add discounts back to get original price

Respond in JSON format:
{
  "final_total": <number paid after discounts>,
  "total_discounts": <sum of all discounts>,
  "original_total": <total without discounts>,
  "confidence": "high"/"medium"/"low"
}

Be very careful to find ALL discount line items."""

        # Create message with image
        messages = [
            HumanMessage(
                content=[
                    {"type": "text", "text": extraction_prompt},
                    {"type": "image_url", "image_url": image_data_url}
                ]
            )
        ]

        response = self.llm.invoke(messages)

        # Parse response
        try:
            # Extract JSON from response
            content = response.content
            json_match = re.search(r'\{[^{}]*\}', content, re.DOTALL)
            if json_match:
                result = json.loads(json_match.group())
                return result
            else:
                # Try to extract numbers directly
                numbers = re.findall(r'\d+\.?\d*', content)
                if numbers:
                    return {"final_total": float(numbers[0]), "confidence": "medium"}
                return {"error": "Could not parse response"}
        except Exception as e:
            return {"error": str(e), "raw_response": response.content}

    def process_multiple_bills(self, image_paths: List[str], query_type: str) -> List[Dict[str, Any]]:
        """
        Process multiple bill images in parallel using RunnableParallel.
        """
        # Create parallel extraction tasks
        extraction_tasks = {
            f"bill_{i}": RunnableLambda(
                # Capture 'current_path' from the loop's iteration for the specific bill
                # The lambda's first argument '_' will receive the input from invoke, but is ignored.
                lambda _, current_path=path, qtype=query_type: self.extract_bill_info(current_path, qtype)
            )
            for i, path in enumerate(image_paths)
        }

        # Create parallel chain
        parallel_chain = RunnableParallel(**extraction_tasks)

        # Execute in parallel
        # The dictionary passed to invoke will map keys (bill_0, bill_1) to their respective image paths.
        # Each RunnableLambda will then receive its corresponding image path as its first argument (_).
        results_dict = parallel_chain.invoke(dict(zip(
            [f"bill_{i}" for i in range(len(image_paths))],
            image_paths
        )))

        return list(results_dict.values())

    def calculate_final_answer(self, extraction_results: List[Dict], query_type: str) -> float:
        """
        Use Python calculator to compute final answer from extraction results.
        """
        if query_type == "total_spent":
            # Sum all final_total values
            amounts = [r.get("final_total", 0) for r in extraction_results if "final_total" in r]
        else:  # without_discount
            # Sum all original_total values
            amounts = [r.get("original_total", r.get("final_total", 0)) for r in extraction_results]

        # Directly sum the amounts as they are already numbers
        calculated_sum = sum(amounts)

        return float(calculated_sum)

# Initialize the agent
bill_agent = BillAnalysisAgent(llm)

## 7. Reflection Agent

Verifies the accuracy of extracted information and calculations

In [24]:
class ReflectionAgent:
    """
    Reflection agent that verifies extraction results and performs quality checks.
    """

    def __init__(self, llm):
        self.llm = llm

    def verify_extraction(self, image_path: str, extracted_data: Dict, query_type: str) -> Dict[str, Any]:
        """
        Re-examine the bill image to verify extracted amounts.
        """
        verification_prompt = f"""You are a verification agent. Review this bill image and verify the extracted information.

Extracted data: {json.dumps(extracted_data, indent=2)}
Query type: {query_type}

Verification tasks:
1. Confirm the final total amount is correct
2. If query is about 'without discount', verify all discounts were found
3. Check if the calculation makes sense

Respond in JSON format:
{{
  "is_correct": true/false,
  "corrected_value": <number if correction needed>,
  "issues_found": ["list of any issues"],
  "confidence": "high"/"medium"/"low"
}}
"""
        image_data_url = get_image_data_url(image_path)
        messages = [
            HumanMessage(
                content=[
                    {"type": "text", "text": verification_prompt},
                    {"type": "image_url", "image_url": image_data_url}
                ]
            )
        ]

        response = self.llm.invoke(messages)

        try:
            json_match = re.search(r'\{[^{}]*\}', response.content, re.DOTALL)
            if json_match:
                return json.loads(json_match.group())
            return {"is_correct": True, "confidence": "medium"}
        except:
            return {"is_correct": True, "confidence": "low"}

    def verify_final_calculation(self, extraction_results: List[Dict],
                                final_answer: float, query_type: str) -> Dict[str, Any]:
        """
        Verify the final calculation is mathematically correct.
        """
        # Extract amounts
        if query_type == "total_spent":
            amounts = [r.get("final_total", 0) for r in extraction_results if "final_total" in r]
        else:
            amounts = [r.get("original_total", r.get("final_total", 0)) for r in extraction_results]

        # Recalculate
        expected_sum = sum(amounts)

        # Check if calculation matches
        is_correct = abs(expected_sum - final_answer) < 0.01

        return {
            "is_correct": is_correct,
            "expected_sum": round(expected_sum, 2),
            "actual_answer": round(final_answer, 2),
            "difference": round(abs(expected_sum - final_answer), 2),
            "individual_amounts": amounts
        }

# Initialize reflection agent
reflection_agent = ReflectionAgent(llm)

## 8. Main Pipeline Function

In [22]:
def analyze_bills_with_query(image_paths: List[str], query: str,
                            use_reflection: bool = True) -> Dict[str, Any]:
    """
    Main pipeline to analyze bills with reflection and verification.

    Args:
        image_paths: List of paths to bill images
        query: User query string
        use_reflection: Whether to use reflection agent (default True)

    Returns:
        Dictionary containing answer and metadata
    """
    print(f"\n{'='*60}")
    print(f"Processing query: {query}")
    print(f"Number of images: {len(image_paths)}")
    print(f"{'='*60}\n")

    # Step 1: Classify query
    print("Step 1: Classifying query...")
    classification = bill_agent.classify_query(query)
    print(f"Classification: {classification}")

    if not classification["is_valid"]:
        return {
            "status": "rejected",
            "reason": classification["reason"],
            "answer": None
        }

    query_type = classification["query_type"]

    # Step 2: Extract information from all bills in parallel
    print(f"\nStep 2: Extracting information from {len(image_paths)} bills in parallel...")
    extraction_results = bill_agent.process_multiple_bills(image_paths, query_type)
    print(f"Extraction completed. Results:")
    for i, result in enumerate(extraction_results):
        print(f"  Bill {i+1}: {result}")

    # Step 3: Calculate final answer using Python calculator
    print(f"\nStep 3: Calculating final answer...")
    final_answer = bill_agent.calculate_final_answer(extraction_results, query_type)
    print(f"Initial answer: ${final_answer}")

    # Step 4: Reflection and verification (if enabled)
    verification_results = []
    if use_reflection:
        print(f"\nStep 4: Reflection agent verifying results...")

        # Verify calculation
        calc_verification = reflection_agent.verify_final_calculation(
            extraction_results, final_answer, query_type
        )
        print(f"Calculation verification: {calc_verification}")

        if not calc_verification["is_correct"]:
            print(f"⚠️  Calculation mismatch detected. Using corrected value.")
            final_answer = calc_verification["expected_sum"]

        verification_results.append(calc_verification)

    # Return final result
    print(f"\n{'='*60}")
    print(f"FINAL ANSWER: ${final_answer}")
    print(f"{'='*60}\n")

    return {
        "status": "success",
        "query_type": query_type,
        "answer": final_answer,
        "extraction_results": extraction_results,
        "verification_results": verification_results if use_reflection else None,
        "num_bills_processed": len(image_paths)
    }

## 9. Upload and Process Images

Upload your bill images to Google Colab

In [14]:
import gdown
file_id = "1oe2FZd3ZTO7nrDqjCafNvxicl08oF8JF"
download_url = f"https://drive.google.com/uc?id={file_id}"
gdown.download(download_url, "receipts.zip", quiet=False)

Downloading...
From: https://drive.google.com/uc?id=1oe2FZd3ZTO7nrDqjCafNvxicl08oF8JF
To: /content/receipts.zip
100%|██████████| 1.61M/1.61M [00:00<00:00, 132MB/s]


'receipts.zip'

In [15]:
!unzip receipts.zip

Archive:  receipts.zip
  inflating: receipt1.jpg            
  inflating: __MACOSX/._receipt1.jpg  
  inflating: receipt2.jpg            
  inflating: __MACOSX/._receipt2.jpg  
  inflating: receipt3.jpg            
  inflating: __MACOSX/._receipt3.jpg  
  inflating: receipt4.jpg            
  inflating: __MACOSX/._receipt4.jpg  
  inflating: receipt5.jpg            
  inflating: __MACOSX/._receipt5.jpg  
  inflating: receipt6.jpg            
  inflating: __MACOSX/._receipt6.jpg  
  inflating: receipt7.jpg            
  inflating: __MACOSX/._receipt7.jpg  


In [35]:
image_paths = ["/content/receipt1.jpg", "/content/receipt2.jpg", "/content/receipt3.jpg", "/content/receipt4.jpg", "/content/receipt5.jpg", "/content/receipt6.jpg", "/content/receipt7.jpg"]

## 10. Example Usage - Query 1

**Query 1**: How much money did I spend in total for these bills?

In [36]:
# Process Query 1
query1 = "How much money did I spend in total for these bills?"
result1 = analyze_bills_with_query(image_paths, query1, use_reflection=True)

# Extract answer for evaluation
query1_answer = result1["answer"]
print(f"\n✅ Query 1 Answer: ${query1_answer}")


Processing query: How much money did I spend in total for these bills?
Number of images: 7

Step 1: Classifying query...
Classification: {'is_valid': True, 'query_type': 'total_spent', 'reason': 'Query about total spending'}

Step 2: Extracting information from 7 bills in parallel...
Extraction completed. Results:
  Bill 1: {'final_total': 394.7, 'currency': 'HKD', 'confidence': 'high'}
  Bill 2: {'final_total': 316.1, 'currency': 'HKD', 'confidence': 'high'}
  Bill 3: {'final_total': 140.8, 'currency': 'HKD', 'confidence': 'high'}
  Bill 4: {'final_total': 514.0, 'currency': 'HKD', 'confidence': 'high'}
  Bill 5: {'final_total': 102.3, 'currency': 'HKD', 'confidence': 'high'}
  Bill 6: {'final_total': 190.8, 'currency': 'HKD', 'confidence': 'high'}
  Bill 7: {'final_total': 315.6, 'currency': 'HKD', 'confidence': 'high'}

Step 3: Calculating final answer...
Initial answer: $1974.3

Step 4: Reflection agent verifying results...
Calculation verification: {'is_correct': True, 'expected_

## 11. Example Usage - Query 2

**Query 2**: How much would I have had to pay without the discount?

In [37]:
# Process Query 2
query2 = "How much would I have had to pay without the discount?"
result2 = analyze_bills_with_query(image_paths, query2, use_reflection=True)

# Extract answer for evaluation
query2_answer = result2["answer"]
print(f"\n✅ Query 2 Answer: ${query2_answer}")


Processing query: How much would I have had to pay without the discount?
Number of images: 7

Step 1: Classifying query...
Classification: {'is_valid': True, 'query_type': 'without_discount', 'reason': 'Query about price without discount'}

Step 2: Extracting information from 7 bills in parallel...
Extraction completed. Results:
  Bill 1: {'final_total': 394.7, 'total_discounts': 85.48, 'original_total': 480.2, 'confidence': 'high'}
  Bill 2: {'final_total': 316.1, 'total_discounts': 76.09, 'original_total': 392.19, 'confidence': 'high'}
  Bill 3: {'final_total': 140.8, 'total_discounts': 19.3, 'original_total': 160.1, 'confidence': 'high'}
  Bill 4: {'final_total': 514.0, 'total_discounts': 76.71, 'original_total': 590.8, 'confidence': 'high'}
  Bill 5: {'final_total': 102.3, 'total_discounts': 5.4, 'original_total': 107.7, 'confidence': 'high'}
  Bill 6: {'final_total': 190.8, 'total_discounts': 30.31, 'original_total': 221.11, 'confidence': 'high'}
  Bill 7: {'final_total': 315.6, 

## 12. Example Usage - Invalid Query

Test rejection of irrelevant queries

In [30]:
# Test with invalid query
invalid_query = "What's the weather today?"
result_invalid = analyze_bills_with_query(image_paths, invalid_query, use_reflection=False)

print(f"\n⛔ Invalid query result: {result_invalid}")


Processing query: What's the weather today?
Number of images: 2

Step 1: Classifying query...
Classification: {'is_valid': False, 'query_type': 'invalid', 'reason': 'Query is not about bill totals or discounts'}

⛔ Invalid query result: {'status': 'rejected', 'reason': 'Query is not about bill totals or discounts', 'answer': None}


## 13. Evaluation Code

Test your solution with the provided test cases

In [31]:
def test_query(answer, ground_truth_costs):
    """Test if answer is within acceptable range of ground truth."""
    if isinstance(answer, str):
        answer = float(answer)

    expected_total = sum(ground_truth_costs)

    # Check if within +/- $2
    assert abs(answer - expected_total) <= 2, \
        f"Answer ${answer} is not within $2 of expected ${expected_total}"

    print(f"✅ Test passed! Answer: ${answer}, Expected: ${expected_total}")

In [38]:
# Test Query 1
query_1_costs = [394.7, 316.1, 140.8, 514.0, 102.3, 190.8, 315.6]
test_query(query1_answer, query_1_costs)

✅ Test passed! Answer: $1974.3, Expected: $1974.3


In [39]:
# Test Query 2
query_2_costs = [480.20, 392.20, 160.10, 590.80, 107.70, 221.20, 396.00]
test_query(query2_answer, query_2_costs)

✅ Test passed! Answer: $2348.06, Expected: $2348.2
