In [2]:
import pdfplumber
import re
import json
from datetime import datetime
import uuid
from tqdm import tqdm

# Avoid importing transformers at startup to reduce initial load time
classifier = None

def load_classifier():
    """Lazily load the DistilBERT pipeline with a progress indicator."""
    global classifier
    if classifier is None:
        print("Loading DistilBERT model...")
        from transformers import pipeline
        with tqdm(total=1, desc="Initializing model") as pbar:
            classifier = pipeline("zero-shot-classification", model="distilbert-base-uncased")
            pbar.update(1)
    return classifier

def extract_text_from_pdf(pdf_path, max_pages=1):
    """Extract raw text from a PDF invoice, limiting to max_pages."""
    try:
        with pdfplumber.open(pdf_path) as pdf:
            text = ""
            for i, page in enumerate(pdf.pages):
                if i >= max_pages:
                    break
                text += page.extract_text() or ""
        return text
    except Exception as e:
        print(f"Error reading PDF: {e}")
        return ""

def parse_invoice_fields(text):
    """Extract key fields using regex patterns."""
    fields = {
        "invoice_number": r"Invoice\s*(?:No|Number)\s*[:#]?\s*([\w-]+)",
        "vendor": r"(?:From|Vendor|Supplier)\s*[:\n]?\s*([A-Za-z0-9\s]+?)(?:\n|$)",
        "total": r"Total\s*(?:Amount)?\s*[:$]?\s*(\d+\.\d{2})",
        "date": r"(?:Invoice\s*Date|Date)\s*[:\n]?\s*(\d{1,2}/\d{1,2}/\d{2,4})",
        "due_date": r"Due\s*Date\s*[:\n]?\s*(\d{1,2}/\d{1,2}/\d{2,4})",
        "line_items": r"(?:\n|^)\s*(\d+)\s+([A-Za-z\s-]+?)\s+(\d+\.\d{2})\s*(?:\n|$)"
    }
    
    extracted = {}
    
    # Extract single-match fields
    for field, pattern in fields.items():
        if field == "line_items":
            continue
        match = re.search(pattern, text, re.IGNORECASE)
        extracted[field] = match.group(1).strip() if match else None
    
    # Extract multiple line items
    line_items = []
    for match in re.finditer(fields["line_items"], text, re.IGNORECASE):
        line_items.append({
            "quantity": match.group(1),
            "description": match.group(2).strip(),
            "price": match.group(3)
        })
    extracted["line_items"] = line_items if line_items else None
    
    return extracted

def validate_fields(extracted, use_llm=True):
    """Validate extracted fields, optionally using DistilBERT."""
    validated = extracted.copy()
    
    if use_llm:
        classifier = load_classifier()
        # Define expected formats for validation
        field_expectations = {
            "invoice_number": ["alphanumeric code", "random text"],
            "vendor": ["company name", "random text"],
            "total": ["decimal number", "random text"],
            "date": ["MM/DD/YYYY", "random text"],
            "due_date": ["MM/DD/YYYY", "random text"]
        }
        
        for field, value in extracted.items():
            if field == "line_items" or not value:
                continue
            # Classify if the extracted value matches the expected format
            result = classifier(value, field_expectations[field], multi_label=False)
            if result["labels"][0] != field_expectations[field][0]:
                validated[field] = None  # Invalidate if it doesn't match
    
    # Validate date formats specifically
    for date_field in ["date", "due_date"]:
        if validated.get(date_field):
            try:
                datetime.strptime(validated[date_field], "%m/%d/%Y")
            except ValueError:
                validated[date_field] = None
    
    return validated

def save_to_json(data, output_path):
    """Save extracted and validated fields to JSON."""
    with open(output_path, "w") as f:
        json.dump(data, f, indent=4)

def main(pdf_path, output_json_path, use_llm=True, max_pages=1):
    """Main function to process invoice and output JSON."""
    # Extract text from PDF
    text = extract_text_from_pdf(pdf_path, max_pages)
    if not text:
        print("No text extracted from PDF.")
        return
    
    # Parse fields using regex
    extracted_fields = parse_invoice_fields(text)
    
    # Validate fields (with or without LLM)
    validated_fields = validate_fields(extracted_fields, use_llm)
    
    # Add artifact ID for tracking
    validated_fields["artifact_id"] = str(uuid.uuid4())
    
    # Save to JSON
    save_to_json(validated_fields, output_json_path)
    print(f"Parsed invoice saved to {output_json_path}")
    
    # Print result
    print(json.dumps(validated_fields, indent=4))

if __name__ == "__main__":
    # Sample run (replace with actual PDF path)
    sample_pdf = "data/sample_invoice.pdf"
    output_json = "data/parsed_invoice.json"
    main(sample_pdf, output_json, use_llm=True, max_pages=1)

CropBox missing from /Page, defaulting to MediaBox
CropBox missing from /Page, defaulting to MediaBox


Loading DistilBERT model...


Initializing model:   0%|                                                                        | 0/1 [00:00<?, ?it/s]Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Device set to use cpu
Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.
Initializing model: 100%|████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.51it/s]


Parsed invoice saved to data/parsed_invoice.json
{
    "invoice_number": null,
    "vendor": null,
    "total": null,
    "date": "05/01/2025",
    "due_date": "05/15/2025",
    "line_items": null,
    "artifact_id": "02024bc0-991e-4a63-bd27-8c2ca7a88138"
}
