In [6]:
import os
import io
import re
import uuid
import tempfile
import gradio as gr
import pandas as pd
from datetime import datetime
from azure.ai.formrecognizer import DocumentAnalysisClient
from azure.core.credentials import AzureKeyCredential
import openpyxl
from openpyxl.styles import PatternFill

# Azure setup
AZURE_ENDPOINT = ""
AZURE_KEY = ""

def get_azure_client():
    return DocumentAnalysisClient(
        endpoint=AZURE_ENDPOINT,
        credential=AzureKeyCredential(AZURE_KEY)
    )

def safe_float(value):
    try:
        if hasattr(value, "amount"):
            return round(float(value.amount), 2)
        return round(float(value), 2)
    except:
        return None

def format_currency(value):
    """Format currency values to 2 decimal places"""
    if value is None:
        return None
    try:
        return round(float(value), 2)
    except:
        return None

def extract_vendor_name(fields, result):
    raw_name = fields.get("VendorName")
    if raw_name and raw_name.value and len(raw_name.value.strip()) > 4:
        return raw_name.value.strip(), raw_name.confidence or 1.0

    all_lines = []
    for page in result.pages:
        all_lines.extend(page.lines)

    all_lines.sort(key=lambda line: max(pt.y for pt in line.polygon), reverse=True)
    bottom_lines_count = max(5, len(all_lines) // 3)

    for line in all_lines[:bottom_lines_count]:
        content = line.content.strip()
        for_match = re.search(r"\bfor\s+([A-Za-z][A-Za-z\s&.,\-]+)", content, re.IGNORECASE)
        if for_match:
            company_name = for_match.group(1).strip()
            stop_words = ["authorised", "authorized", "signature", "sign", "date", "customer", "seal", "stamp"]
            for delimiter in ["\n", "  ", "\t"]:
                if delimiter in company_name:
                    company_name = company_name.split(delimiter)[0].strip()
            words = company_name.split()
            filtered_words = []
            for word in words:
                if word.lower() in stop_words:
                    break
                filtered_words.append(word)
            if filtered_words:
                company_name = " ".join(filtered_words).strip()
                if (
                    3 <= len(company_name) <= 60
                    and company_name.lower() not in ["the", "and", "or", "of", "in", "on", "at"]
                    and not re.match(r"^\d+[-\s]\d+", company_name)
                ):
                    return company_name, 0.7

    return "Unknown", 0.0

def format_with_confidence_indicator(value, confidence, discrepancy):
    """Add confidence indicators to cell values"""
    if value is None or value == "":
        return ""
    
    indicators = []
    
    # Add confidence indicator
    if confidence == -1:
        indicators.append("📊")  # Calculated
    elif discrepancy:
        indicators.append("⚠️")  # Warning/Discrepancy
    elif confidence < 0.6:
        indicators.append("🔴")  # Low confidence
    elif confidence < 0.8:
        indicators.append("🟡")  # Medium confidence
    else:
        indicators.append("🟢")  # High confidence
    
    # Format the value with indicators
    indicator_str = "".join(indicators)
    return f"{value} {indicator_str}"

def get_confidence_summary(metadata):
    """Generate a summary of confidence levels for the user"""
    if not metadata:
        return ""
    
    low_conf_fields = []
    discrepancy_fields = []
    calculated_fields = []
    
    for row_idx, row_data in enumerate(metadata):
        for field, (value, conf, disc) in row_data.items():
            if field == "Status":
                continue
            if disc:
                discrepancy_fields.append(f"Row {row_idx}: {field}")
            elif conf == -1:
                calculated_fields.append(f"Row {row_idx}: {field}")
            elif conf < 0.6:
                low_conf_fields.append(f"Row {row_idx}: {field} ({conf:.1%})")
    
    summary = []
    if discrepancy_fields:
        summary.append(f"⚠️ Calculation Discrepancies: {', '.join(discrepancy_fields[:5])}")
    if low_conf_fields:
        summary.append(f"🔴 Low Confidence Fields: {', '.join(low_conf_fields[:5])}")
    if calculated_fields:
        summary.append(f"📊 Calculated Values: {', '.join(calculated_fields[:5])}")
    
    return "\n".join(summary)

def process_invoice(file_content, file_name):
    client = get_azure_client()
    file_obj = io.BytesIO(file_content)

    try:
        poller = client.begin_analyze_document("prebuilt-invoice", document=file_obj)
        result = poller.result()
    except Exception as e:
        error_msg = str(e).lower()
        if "size" in error_msg and ("large" in error_msg or "limit" in error_msg or "maximum" in error_msg):
            raise gr.Error(f"File '{file_name}' is too large. Please reduce the file size and try again.")
        elif "quota" in error_msg or "rate" in error_msg:
            raise gr.Error(f"API quota exceeded. Please try again later.")
        else:
            raise gr.Error(f"Error processing '{file_name}': {str(e)}")

    rows = []
    for doc in result.documents:
        fields = doc.fields

        vendor, vendor_conf = extract_vendor_name(fields, result)

        def get_conf(field):
            return (field.confidence if field and field.confidence is not None else 0.0)

        total_field = fields.get("InvoiceTotal")
        subtotal_field = fields.get("SubTotal")
        tax_field = fields.get("TotalTax")
        items_field = fields.get("Items")

        total_val = safe_float(total_field.value) if total_field else None
        subtotal_val = safe_float(subtotal_field.value) if subtotal_field else None
        tax_val = safe_float(tax_field.value) if tax_field else None

        subtotal_calc = False
        tax_calc = False
        total_calc = False

        if subtotal_val is None and total_val is not None and tax_val is not None:
            subtotal_val = round(total_val - tax_val, 2)
            subtotal_calc = True

        if tax_val is None and subtotal_val is not None and total_val is not None:
            tax_val = round(total_val - subtotal_val, 2)
            tax_calc = True

        if total_val is None and subtotal_val is not None and tax_val is not None:
            total_val = round(subtotal_val + tax_val, 2)
            total_calc = True

        items_desc = []
        items_amts = []
        amt_confs = []

        if items_field and items_field.value:
            for item in items_field.value:
                desc = item.value.get("Description").value if item.value.get("Description") else ""
                amt = item.value.get("Amount").value if item.value.get("Amount") else ""
                conf = get_conf(item.value.get("Amount"))
                if hasattr(amt, "amount"):
                    amt = amt.amount

                desc_clean = str(desc).strip().replace('\r\n', ' ').replace('\r', ' ').replace('\n', ' ')
                desc_clean = re.sub(r'\s+', ' ', desc_clean)

                items_desc.append(desc_clean)
                safe_amt = safe_float(amt)
                if safe_amt is not None:
                    formatted_amt = f"{safe_amt:.2f}"
                else:
                    formatted_amt = "0.00"
                items_amts.append(formatted_amt)
                amt_confs.append(conf)
        
        item_sum = sum(safe_float(a) for a in items_amts if safe_float(a) is not None)
        item_sum = round(item_sum, 2)

        ROUND_TOLERANCE = 0.01
        subtotal_discrepancy = False
        total_discrepancy = False

        item_sum_matches_subtotal = subtotal_val is not None and abs(item_sum - subtotal_val) <= ROUND_TOLERANCE
        item_sum_matches_total = total_val is not None and abs(item_sum - total_val) <= ROUND_TOLERANCE
        
        if not (item_sum_matches_subtotal or item_sum_matches_total):
            if subtotal_val is not None:
                if abs(item_sum - subtotal_val) > ROUND_TOLERANCE:
                    subtotal_discrepancy = True

        if total_val is not None and subtotal_val is not None and tax_val is not None:
            if abs((subtotal_val + tax_val) - total_val) > ROUND_TOLERANCE:
                total_discrepancy = True

        if total_val is not None and (subtotal_val is None or tax_val is None):
            if not item_sum_matches_total and total_val + ROUND_TOLERANCE < item_sum:
                total_discrepancy = True

        row = {
            "Company Name": (vendor, vendor_conf, False),
            "Items Descriptions": ("; ".join(items_desc), 1.0, False),
            "Items Amounts": ("; ".join(items_amts), min(amt_confs) if amt_confs else 0.0, subtotal_discrepancy),
            "Total Tax": (
                format_currency(tax_val),
                get_conf(tax_field) if not tax_calc else -1,
                False
            ),
            "Subtotal": (
                format_currency(subtotal_val),
                get_conf(subtotal_field) if not subtotal_calc else -1,
                subtotal_discrepancy
            ),
            "Total Amount": (
                format_currency(total_val),
                get_conf(total_field) if not total_calc else -1,
                total_discrepancy
            ),
        }
        rows.append(row)
    return rows

def process_files(files):
    if not files:
        return None, "No files uploaded.", ""
    
    all_data = []
    max_file_size = 50 * 1024 * 1024
    
    for file in files:
        try:
            # Read file content
            with open(file, 'rb') as f:
                content = f.read()
            
            # Check file size
            if len(content) > max_file_size:
                return None, f"File '{os.path.basename(file)}' is too large.", ""
            
            file_data = process_invoice(content, os.path.basename(file))
            if file_data:
                all_data.extend(file_data)
        except Exception as e:
            return None, f"Error processing {os.path.basename(file)}: {str(e)}", ""
    
    if not all_data:
        return None, "No data extracted from files.", ""
    
    # Create DataFrame with visual confidence indicators
    df_data = {}
    for col in all_data[0].keys():
        df_data[col] = []
        for row in all_data:
            value, conf, disc = row[col]
            # Add confidence indicators to the display value
            display_value = format_with_confidence_indicator(value, conf, disc)
            df_data[col].append(display_value)
    
    # Add approval column
    df_data["Status"] = ["Pending"] * len(all_data)
    
    df = pd.DataFrame(df_data)
    
    # Store metadata for Excel export and reference
    global extraction_metadata
    extraction_metadata = all_data
    
    # Generate confidence summary
    confidence_summary = get_confidence_summary(all_data)
    
    status_message = f"Successfully processed {len(files)} files. {len(all_data)} bills of materials extracted."
    
    return df, status_message, confidence_summary

def clean_value_for_excel(value_str):
    """Remove emoji indicators from values for Excel export"""
    if not isinstance(value_str, str):
        return value_str
    
    # Remove all emoji indicators
    cleaned = re.sub(r'[🔴🟡🟢⚠️📊]\s*', '', value_str)
    return cleaned.strip()

def approve_row(df, row_index):
    """Approve a specific row"""
    if df is None:
        return df

    row_idx = int(row_index) - 1
    if row_idx < 0 or row_idx >= len(df):
        return df
    
    df.iloc[row_idx, df.columns.get_loc("Status")] = "Approved ✅"
    return df

def reject_row(df, row_index):
    """Reject a specific row"""
    if df is None:
        return df

    row_idx = int(row_index) - 1
    if row_idx < 0 or row_idx >= len(df):
        return df
    
    df.iloc[row_idx, df.columns.get_loc("Status")] = "Rejected ❌"
    return df


def create_colored_excel(df):
    """Create Excel file with colored cells based on confidence and approval status"""
    if df is None or df.empty:
        return None
    
    # Create a clean version of the dataframe for Excel (without emoji indicators)
    clean_df = df.copy()
    for col in clean_df.columns:
        if col != "Status":
            clean_df[col] = clean_df[col].apply(clean_value_for_excel)
    
    # Create Excel file in memory
    output = io.BytesIO()
    
    with pd.ExcelWriter(output, engine='openpyxl') as writer:
        clean_df.to_excel(writer, sheet_name='Invoice_Data', index=False)
        
        # Get the workbook and worksheet
        workbook = writer.book
        worksheet = writer.sheets['Invoice_Data']
        
        # Define colors
        colors = {
            'high_conf': PatternFill(start_color='90EE90', end_color='90EE90', fill_type='solid'),  # Light green
            'medium_conf': PatternFill(start_color='DAA520', end_color='DAA520', fill_type='solid'),  # Goldenrod
            'low_conf': PatternFill(start_color='FA8072', end_color='FA8072', fill_type='solid'),  # Salmon
            'calculated': PatternFill(start_color='ADD8E6', end_color='ADD8E6', fill_type='solid'),  # Light blue
            'discrepancy': PatternFill(start_color='DC143C', end_color='DC143C', fill_type='solid'),  # Crimson
            'approved': PatternFill(start_color='00FF00', end_color='00FF00', fill_type='solid'),  # Bright green
            'rejected': PatternFill(start_color='FF0000', end_color='FF0000', fill_type='solid'),  # Red
        }
        
        # Apply colors based on confidence and approval status
        for row_idx in range(len(df)):
            # Check approval status first
            status = df.iloc[row_idx]['Status']
            
            for col_idx, col_name in enumerate(df.columns):
                cell = worksheet.cell(row=row_idx + 2, column=col_idx + 1)  # +2 for header row, +1 for 1-indexed
                
                if "Approved" in str(status):
                    cell.fill = colors['approved']
                elif "Rejected" in str(status):
                    cell.fill = colors['rejected']
                elif col_name != "Status" and col_name in extraction_metadata[0]:
                    # Apply confidence-based coloring
                    conf = extraction_metadata[row_idx][col_name][1]
                    discrepancy = extraction_metadata[row_idx][col_name][2]
                    
                    if discrepancy:
                        cell.fill = colors['discrepancy']
                    elif conf == -1:
                        cell.fill = colors['calculated']
                    elif conf >= 0.8:
                        cell.fill = colors['high_conf']
                    elif conf >= 0.6:
                        cell.fill = colors['medium_conf']
                    else:
                        cell.fill = colors['low_conf']
    
    output.seek(0)
    return output.getvalue()

# Global variable to store extraction metadata
extraction_metadata = []

def create_interface():
    with gr.Blocks(title="Invoice Processor", theme=gr.themes.Soft()) as demo:
        gr.Markdown("# Bill of Materials Extraction")
        gr.Markdown("Upload bills, review extracted data with confidence indicators, edit if needed, and approve/reject each bill of materials.")
        
        with gr.Row():
            file_upload = gr.File(
                label="Upload Bill of Materials",
                file_count="multiple",
                file_types=[".pdf", ".jpg", ".png", ".jpeg"]
            )
        
        process_btn = gr.Button("Process Bills of materials", variant="primary")
        status_msg = gr.Textbox(label="Status", interactive=False)
        
        # Confidence summary
        confidence_summary = gr.Textbox(
            label="Confidence Summary - Fields that need attention:",
            interactive=False,
            lines=3
        )
        
        with gr.Row():
            # Enhanced Legend with emojis
            gr.HTML("""
            <div style="border: 1px solid #ddd; padding: 15px; border-radius: 5px; background-color: #f9f9f9;">
                <h3>Confidence Indicators Legend</h3>
                <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 10px; margin: 10px 0;">
                    <div style="display: flex; align-items: center; gap: 8px;">
                        <span style="font-size: 18px;">🟢</span>
                        <span>High Confidence (≥80%)</span>
                    </div>
                    <div style="display: flex; align-items: center; gap: 8px;">
                        <span style="font-size: 18px;">🟡</span>
                        <span>Medium Confidence (60-79%)</span>
                    </div>
                    <div style="display: flex; align-items: center; gap: 8px;">
                        <span style="font-size: 18px;">🔴</span>
                        <span>Low Confidence (<60%)</span>
                    </div>
                    <div style="display: flex; align-items: center; gap: 8px;">
                        <span style="font-size: 18px;">📊</span>
                        <span>Calculated Value</span>
                    </div>
                    <div style="display: flex; align-items: center; gap: 8px;">
                        <span style="font-size: 18px;">⚠️</span>
                        <span>Calculation Discrepancy</span>
                    </div>
                    <div style="display: flex; align-items: center; gap: 8px;">
                        <span style="font-size: 18px;">✅</span>
                        <span>Approved</span>
                    </div>
                    <div style="display: flex; align-items: center; gap: 8px;">
                        <span style="font-size: 18px;">❌</span>
                        <span>Rejected</span>
                    </div>
                </div>
                <p style="margin-top: 10px; font-style: italic; color: #666;">
                    💡 Tip: Focus on editing cells with 🔴 (low confidence) and ⚠️ (discrepancy) indicators
                </p>
            </div>
            """)
        
        # Data table
        data_table = gr.Dataframe(
            label="Bill of Materials Data - Edit cells with low confidence indicators (🔴, ⚠️)",
            interactive=True,
            wrap=True
        )
        
        with gr.Row():
            with gr.Column():
                row_number = gr.Number(label="Row Number (1-indexed)", value=1, precision=0)
            with gr.Column():
                approve_btn = gr.Button("Approve Row ✅", variant="secondary")
                reject_btn = gr.Button("Reject Row ❌", variant="stop")
        
        download_btn = gr.Button("Download Excel File", variant="primary")
        download_file = gr.File(label="Download", visible=True, file_count="single")
        
        # Event handlers
        process_btn.click(
            fn=process_files,
            inputs=[file_upload],
            outputs=[data_table, status_msg, confidence_summary]
        )
        
        approve_btn.click(
            fn=approve_row,
            inputs=[data_table, row_number],
            outputs=[data_table]
        )
        
        reject_btn.click(
            fn=reject_row,
            inputs=[data_table, row_number],
            outputs=[data_table]
        )
        
        def download_excel(df):
            if df is None or df.empty:
                return None
        
            excel_data = create_colored_excel(df)
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            filename = f"Bills_{timestamp}.xlsx"
        
            # Save the bytes to a temporary file
            temp_path = os.path.join(tempfile.gettempdir(), filename)
            with open(temp_path, "wb") as f:
                f.write(excel_data)
        
            return temp_path

        
        download_btn.click(
            fn=download_excel,
            inputs=[data_table],
            outputs=[download_file]
        )

    
    return demo

if __name__ == "__main__":
    demo = create_interface()
    demo.launch()

* Running on local URL:  http://127.0.0.1:7865
* To create a public link, set `share=True` in `launch()`.
