In [None]:
import os
import base64
import requests
from io import BytesIO
from typing import List, Literal, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from pdf2image import convert_from_bytes
from openai import OpenAI
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Initialize FastAPI
app = FastAPI(title="Bill Extraction API", description="AI-powered invoice line item extractor")

# Initialize OpenAI Client
# The API key must be set in the environment variables (e.g., via Railway/Render dashboard)
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)

# --- Pydantic Models (Strictly matching the provided JSON Schema) ---

class BillItem(BaseModel):
    item_name: str = Field(description="Name of the item exactly as mentioned in the bill")
    item_amount: float = Field(description="Net Amount of the item post discounts")
    item_rate: float = Field(description="Unit rate of the item")
    item_quantity: float = Field(description="Quantity of the item")

class PageData(BaseModel):
    page_no: str = Field(description="The page number of the document")
    page_type: Literal["Bill Detail", "Final Bill", "Pharmacy"] = Field(
        description="Classify the page. 'Bill Detail' for itemized lists, 'Pharmacy' for medical lists, 'Final Bill' for summary pages."
    )
    bill_items: List[BillItem]

class ExtractionResult(BaseModel):
    """Container for the structured output from the LLM"""
    pages: List[PageData]

# --- API Request/Response Schemas ---

class ExtractionRequest(BaseModel):
    document: str  # The URL of the file

class TokenUsage(BaseModel):
    total_tokens: int
    input_tokens: int
    output_tokens: int

class ResponseData(BaseModel):
    pagewise_line_items: List[PageData]
    total_item_count: int

class APIResponse(BaseModel):
    is_success: bool
    token_usage: TokenUsage
    data: Optional[ResponseData] = None
    error: Optional[str] = None

# --- Helper Functions ---

def download_file(url: str) -> tuple[bytes, str]:
    """Downloads file from URL and returns bytes and content-type."""
    try:
        response = requests.get(url, timeout=30)
        response.raise_for_status()
        return response.content, response.headers.get("Content-Type", "")
    except requests.RequestException as e:
        raise HTTPException(status_code=400, detail=f"Failed to download document: {str(e)}")

def process_document_to_images(file_content: bytes, content_type: str) -> List[str]:
    """Converts PDF or Image bytes to a list of base64 encoded strings."""
    base64_images = []

    # Handle PDF
    if "pdf" in content_type.lower() or file_content.startswith(b"%PDF"):
        try:
            # Convert PDF to images (requires poppler-utils installed on system)
            images = convert_from_bytes(file_content)
            for img in images:
                buffered = BytesIO()
                img.save(buffered, format="JPEG", quality=85)
                img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
                base64_images.append(img_str)
        except Exception as e:
            raise HTTPException(status_code=500, detail=f"PDF processing failed: {str(e)}")
    # Handle Images (PNG/JPG/JPEG)
    else:
        try:
            # We assume it's an image. You could add PIL validation here if strictly needed.
            img_str = base64.b64encode(file_content).decode("utf-8")
            base64_images.append(img_str)
        except Exception as e:
             raise HTTPException(status_code=400, detail=f"Image processing failed: {str(e)}")

    return base64_images

# --- Core Endpoint ---

@app.post("/extract-bill-data", response_model=APIResponse)
async def extract_bill_data(request: ExtractionRequest):
    if not api_key:
        raise HTTPException(status_code=500, detail="Server Configuration Error: OPENAI_API_KEY is missing.")

    try:
        # 1. Download File
        file_content, content_type = download_file(request.document)

        # 2. Convert to Images (Base64)
        base64_images = process_document_to_images(file_content, content_type)

        # 3. Prepare Prompt for GPT-4o
        system_prompt = """
        You are an expert invoice data extraction AI. Your goal is to digitize bills accurately.

        CRITICAL EXTRACTION RULES:
        1. **Goal**: Extract specific line items (products, services, medicines).
        2. **Values**:
           - 'item_amount': This MUST be the total cost for that line (Rate * Quantity).
           - 'item_rate': Unit price. If missing, calculate as Amount / Quantity.
           - 'item_quantity': If missing, assume 1.
        3. **Double Counting Prevention (The "Final Total" Rule)**:
           - Do NOT extract "Subtotal", "Total", "Balance Due", or "Grand Total" as line items. These are aggregations.
           - If a page is a 'Final Bill' or 'Summary' that re-lists items from previous pages, return an EMPTY list for that page to avoid double counting.
           - ONLY extract items from a 'Final Bill' page if they are NEW charges (e.g., 'Delivery Fee', 'Service Tax', 'Discount') that were not listed on the detail pages.
        4. **Validation**: Ensure the sum of all 'item_amount' values roughly equals the invoice grand total.
        """

        user_content = [
            {"type": "text", "text": "Analyze these invoice pages. Extract the line items into the structured JSON format."}
        ]

        # Add all pages to the payload
        for img in base64_images:
            user_content.append({
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{img}", "detail": "high"}
            })

        # 4. Call OpenAI with Structured Outputs
        completion = client.beta.chat.completions.parse(
            model="gpt-4o-2024-08-06", # Using the latest model for best Vision performance
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_content},
            ],
            response_format=ExtractionResult,
            temperature=0.0, # Zero temperature for maximum determinism
        )

        result_data = completion.choices[0].message.parsed
        usage = completion.usage

        # 5. Post-Processing: Calculate Total Item Count
        total_count = sum(len(page.bill_items) for page in result_data.pages)

        # 6. Return Formatted Response
        return APIResponse(
            is_success=True,
            token_usage=TokenUsage(
                total_tokens=usage.total_tokens,
                input_tokens=usage.prompt_tokens,
                output_tokens=usage.completion_tokens
            ),
            data=ResponseData(
                pagewise_line_items=result_data.pages,
                total_item_count=total_count
            )
        )

    except Exception as e:
        # Return a failed response structure rather than a 500 error if possible
        return APIResponse(
            is_success=False,
            token_usage=TokenUsage(total_tokens=0, input_tokens=0, output_tokens=0),
            data=ResponseData(pagewise_line_items=[], total_item_count=0), # Empty data on failure
            error=str(e)
        )

# For local testing: uvicorn main:app --reload