In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/llama-3/transformers/8b-chat-hf/1/model.safetensors.index.json
/kaggle/input/llama-3/transformers/8b-chat-hf/1/model-00003-of-00004.safetensors
/kaggle/input/llama-3/transformers/8b-chat-hf/1/config.json
/kaggle/input/llama-3/transformers/8b-chat-hf/1/LICENSE
/kaggle/input/llama-3/transformers/8b-chat-hf/1/model-00001-of-00004.safetensors
/kaggle/input/llama-3/transformers/8b-chat-hf/1/model.py
/kaggle/input/llama-3/transformers/8b-chat-hf/1/USE_POLICY.md
/kaggle/input/llama-3/transformers/8b-chat-hf/1/tokenizer.json
/kaggle/input/llama-3/transformers/8b-chat-hf/1/tokenizer_config.json
/kaggle/input/llama-3/transformers/8b-chat-hf/1/example_text_completion.py
/kaggle/input/llama-3/transformers/8b-chat-hf/1/test_tokenizer.py
/kaggle/input/llama-3/transformers/8b-chat-hf/1/requirements.txt
/kaggle/input/llama-3/transformers/8b-chat-hf/1/tokenizer.py
/kaggle/input/llama-3/transformers/8b-chat-hf/1/model-00004-of-00004.safetensors
/kaggle/input/llama-3/transformers/8b-chat-hf

In [3]:
!pip install transformers bitsandbytes accelerate safetensors




In [2]:
import os
import json
import pandas as pd
import re
import logging
import time
from pypdf import PdfReader
from pdf2image import convert_from_bytes
import pytesseract
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# ===========================
# 1. Configuration and Setup
# ===========================

# Configure logging
logging.basicConfig(
    filename='invoice_extraction.log', 
    level=logging.INFO,
    format='%(asctime)s:%(levelname)s:%(message)s'
)

# Set up Llama model path and tokenizer/model
MODEL_PATH = "/kaggle/input/llama-3/transformers/8b-chat-hf/1"  # Adjust to a smaller model if possible
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load tokenizer and model
# Load tokenizer and model without explicitly moving the model to the device
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="auto",
    offload_buffers=True  # This allows automatic offloading when memory is constrained
)



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:

# ==============================
# 2. Helper Functions
# ==============================

def get_pdf_text(pdf_doc):
    """
    Extracts text from a PDF file, handling regular and scanned PDFs.
    
    Parameters:
        pdf_doc (UploadedFile): The uploaded PDF file.
    
    Returns:
        str: The extracted text from the PDF.
    """
    text = ""
    try:
        pdf_reader = PdfReader(pdf_doc)
        num_pages = len(pdf_reader.pages)
        for page_number, page in enumerate(pdf_reader.pages, start=1):
            extracted_text = page.extract_text()
            if extracted_text and len(extracted_text.strip()) > 50:  # Threshold for text extraction
                text += extracted_text + "\n"
                logging.info(f"Text extracted from page {page_number} using PdfReader.")
            else:
                # If text extraction is insufficient, use OCR
                logging.info(f"Insufficient text on page {page_number}. Applying OCR.")
                pdf_doc.seek(0)  # Reset file pointer to read bytes
                pdf_bytes = pdf_doc.read()
                images = convert_from_bytes(pdf_bytes, first_page=page_number, last_page=page_number)
                for image in images:
                    ocr_text = pytesseract.image_to_string(image, config='--psm 6')  # Assume a single uniform block of text
                    if ocr_text and len(ocr_text.strip()) > 10:  # Threshold for OCR text
                        text += ocr_text + "\n"
                        logging.info(f"OCR text extracted from page {page_number}.")
    except Exception as e:
        logging.error(f"Error extracting text from PDF: {e}")
        print(f"Error extracting text from PDF: {e}")
    return text

def call_llama_model(pages_data, max_input_length=1000, timeout=300):
    """
    Calls the Llama 3 model to extract invoice data in JSON format.
    
    Parameters:
        pages_data (str): The extracted text from the PDF.
        max_input_length (int): Maximum number of tokens to send to the model.
        timeout (int): Maximum time (in seconds) to wait for model inference.
    
    Returns:
        str or None: The raw extracted data from the model if successful; otherwise, None.
    """
    prompt_template = '''Extract the following fields from the invoice data: 
- Invoice No.
- Date
- Amount
- Total
- Email
- Place of Origin
- Taxable Value
- SGST Amount
- CGST Amount
- IGST Amount
- SGST Rate
- CGST Rate
- IGST Rate
- Tax Amount
- Tax Rate
- Final Amount
- Invoice Date
- Place of Supply
- GSTIN Supplier

Provide the output strictly in valid JSON format with no additional text, explanations, or comments. 
Ensure all keys are correctly spelled and correspond to the field names above. 
Do not include any trailing commas or syntax errors.

Here is the invoice data:
{pages}
'''
    # Limit the input text
    pages_data = pages_data[:max_input_length]
    prompt = prompt_template.format(pages=pages_data)

    try:
        print("Tokenizing input...")
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_input_length).to(model.device)
        print(f"Input shape: {inputs.input_ids.shape}")

        print("Starting model inference...")
        start_time = time.time()
        with torch.no_grad():
            outputs = model.generate(
                **inputs, 
                max_new_tokens=500, 
                temperature=0.3,
                do_sample=True,
                num_return_sequences=1,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id,
            )
        inference_time = time.time() - start_time
        print(f"Model inference completed in {inference_time:.2f} seconds")

        llm_extracted_data = tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(f"Extracted data length: {len(llm_extracted_data)}")
        logging.info(f"Raw model response for data extraction: {llm_extracted_data}")
        return llm_extracted_data
    except Exception as e:
        logging.error(f"Exception during model call: {e}")
        print(f"An error occurred during model call: {e}")
        return None

def validate_data(field, value):
    """
    Validates the extracted data fields using regex patterns and assigns confidence levels.
    
    Parameters:
        field (str): The name of the field.
        value (str): The extracted value of the field.
    
    Returns:
        tuple: (is_valid (bool), confidence (str))
    """
    patterns  = {
    'Invoice No.': r'^[A-Za-z0-9\-]+$',                      # Alphanumeric characters and dashes
    'Date': r'^\d{2}/\d{2}/\d{4}$',                          # Format: DD/MM/YYYY
    'Amount': r'^\d+(\.\d{1,2})?$',                         # Decimal value (e.g., 100, 100.50)
    'Total': r'^\d+(\.\d{1,2})?$',                          # Decimal value (e.g., 100, 100.50)
    'Email': r'^[\w\.-]+@[\w\.-]+\.\w+$',                   # Valid email format
    'Place of Origin': r'^[A-Za-z\s\-]+$',                   # Alphabetic and spaces
    'Taxable Value': r'^\d+(\.\d{1,2})?$',                  # Decimal value (e.g., 100, 100.50)
    'SGST Amount': r'^\d+(\.\d{1,2})?$',                     # Decimal value (e.g., 100, 100.50)
    'CGST Amount': r'^\d+(\.\d{1,2})?$',                     # Decimal value (e.g., 100, 100.50)
    'IGST Amount': r'^\d+(\.\d{1,2})?$',                     # Decimal value (e.g., 100, 100.50)
    'SGST Rate': r'^\d+(\.\d{1,2})?$',                       # Percentage value
    'CGST Rate': r'^\d+(\.\d{1,2})?$',                       # Percentage value
    'IGST Rate': r'^\d+(\.\d{1,2})?$',                       # Percentage value
    'Tax Amount': r'^\d+(\.\d{1,2})?$',                      # Decimal value (e.g., 100, 100.50)
    'Tax Rate': r'^\d+(\.\d{1,2})?$',                        # Percentage value
    'Final Amount': r'^\d+(\.\d{1,2})?$',                   # Decimal value (e.g., 100, 100.50)
    'Invoice Date': r'^\d{2}/\d{2}/\d{4}$',                  # Format: DD/MM/YYYY
    'Place of Supply': r'^[A-Za-z\s\-]+$',                   # Alphabetic and spaces
    'GSTIN Supplier': r'^\d{2}[A-Z]{5}\d{4}[A-Z]{1}[A-Z\d]{1}[Z]{1}[A-Z\d]{1}$',  # GSTIN format
}


    if field in patterns:
        if re.match(patterns[field], str(value).strip()):
            return True, "High Confidence"
        else:
            return False, "Low Confidence"
    else:
        # For fields without specific patterns, basic non-empty check
        if str(value).strip():
            return True, "Medium Confidence"
        else:
            return False, "Low Confidence"

def extract_json(raw_text):
    """
    Extracts the first JSON object found in the raw text.
    
    Parameters:
        raw_text (str): The raw text containing JSON.
    
    Returns:
        str or None: The extracted JSON string if found; otherwise, None.
    """
    pattern = r'\{.*\}'  # Matches the first occurrence of {...}
    match = re.search(pattern, raw_text, re.DOTALL)
    if match:
        return match.group(0)
    else:
        return None

# ===========================================
# 3. Core Function to Process PDF Files
# ===========================================

def create_docs(pdf_file_paths):
    """
    Processes multiple PDF files to extract invoice data and compile it into a DataFrame.
    
    Parameters:
        pdf_file_paths (list): List of paths to PDF files.
    
    Returns:
        pd.DataFrame: DataFrame containing all extracted invoice data.
    """
    # Initialize DataFrame with additional columns for confidence and trust
    df = pd.DataFrame(columns=[
        'Invoice No.', 'Date', 'Amount', 'Total',
        'Email', 'Place of Origin', 'Taxable Value', 'SGST Amount',
        'CGST Amount', 'IGST Amount', 'SGST Rate', 'CGST Rate',
        'IGST Rate', 'Tax Amount', 'Tax Rate', 'Final Amount',
        'Invoice Date', 'Place of Supply',
        'GSTIN Supplier',
        'Confidence', 'Trust'
    ])
    
    # Initialize a list to hold rows
    rows = []

    # Metrics tracking
    metrics = {
        'total_files': 0,
        'successful_extractions': 0,
        'field_accuracy': {field: {'correct': 0, 'total': 0} for field in df.columns if field not in ['Confidence', 'Trust']}
    }
    
    for file_path in pdf_file_paths:
        metrics['total_files'] += 1
        file_name = os.path.basename(file_path)
        print(f"### Processing {file_name}...")

        # Read and extract text from PDF
        pdf_doc = open(file_path, 'rb')
        pdf_text = get_pdf_text(pdf_doc)
        pdf_doc.close()

        # Call the Llama model for data extraction
        llm_extracted_data = call_llama_model(pdf_text)
        
        if llm_extracted_data:
            # Extract JSON from model output
            json_data = extract_json(llm_extracted_data)
            if json_data:
                try:
                    extracted_fields = json.loads(json_data)
                    row = {field: extracted_fields.get(field, "") for field in df.columns[:-2]}  # All fields except Confidence and Trust
                    # Validate each field
                    for field, value in row.items():
                        is_valid, confidence = validate_data(field, value)
                        metrics['field_accuracy'][field]['total'] += 1
                        if is_valid:
                            metrics['field_accuracy'][field]['correct'] += 1
                    
                    row['Confidence'] = confidence
                    row['Trust'] = "High" if all([validate_data(f, row[f])[0] for f in row]) else "Low"
                    
                    # Add the row to the list
                    rows.append(row)
                    metrics['successful_extractions'] += 1
                except json.JSONDecodeError as e:
                    logging.error(f"JSON decode error for {file_name}: {e}")
                    print(f"JSON decode error for {file_name}: {e}")
            else:
                logging.warning(f"No valid JSON found in model output for {file_name}.")
                print(f"No valid JSON found in model output for {file_name}.")
        else:
            logging.error(f"Failed to extract data for {file_name}.")

    # Convert the list of rows to a DataFrame
    df = pd.DataFrame(rows)

    # Log metrics
    logging.info(f"Processed {metrics['total_files']} files with {metrics['successful_extractions']} successful extractions.")
    for field, accuracy in metrics['field_accuracy'].items():
        logging.info(f"Field: {field}, Accuracy: {accuracy['correct']}/{accuracy['total']}")

    return df

# ==============================
# 4. Running the Invoice Extraction
# ==============================

# Example: Process files from a given directory
pdf_directory = "/kaggle/input/test-data"
pdf_file_paths = [os.path.join(pdf_directory, file) for file in os.listdir(pdf_directory) if file.endswith('.pdf')]
print(f"Found {len(pdf_file_paths)} PDF files. Processing...")

# Process and extract data into DataFrame
extracted_data_df = create_docs(pdf_file_paths)

# Optionally, save the extracted DataFrame to CSV
output_csv_path = "extracted_invoice_data.csv"
extracted_data_df.to_csv(output_csv_path, index=False)
print(f"Extraction completed. Data saved to {output_csv_path}.")


Found 24 PDF files. Processing...
### Processing INV-145_Indraja Mohite.pdf...
Tokenizing input...
Input shape: torch.Size([1, 545])
Starting model inference...


Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Model inference completed in 315.12 seconds
Extracted data length: 3013
### Processing INV-142_Urmila Jangam.pdf...
Tokenizing input...
Input shape: torch.Size([1, 532])
Starting model inference...
Model inference completed in 303.44 seconds
Extracted data length: 2981
### Processing INV-123_Asit.pdf...
Tokenizing input...
Input shape: torch.Size([1, 564])
Starting model inference...
Model inference completed in 303.27 seconds
Extracted data length: 2802
JSON decode error for INV-123_Asit.pdf: Extra data: line 22 column 1 (char 557)
### Processing INV-128_Atia Latif.pdf...
Tokenizing input...
Input shape: torch.Size([1, 530])
Starting model inference...
Model inference completed in 303.45 seconds
Extracted data length: 2562
### Processing INV-144_Atia Latif.pdf...
Tokenizing input...
Input shape: torch.Size([1, 597])
Starting model inference...
Model inference completed in 302.65 seconds
Extracted data length: 2709
### Processing INV-143_Prashant.pdf...
Tokenizing input...
Input shape: