In [None]:
import zipfile
with zipfile.ZipFile('/content/drive/MyDrive/invoice_cf_data.zip', 'r') as zip_ref:
    zip_ref.extractall('/content/drive/MyDrive/invoice_cf_data')

In [1]:
!pip install paddlepaddle-gpu==2.6.2 paddleocr==2.10.0 seqeval

Collecting paddlepaddle-gpu==2.6.2
  Downloading paddlepaddle_gpu-2.6.2-cp312-cp312-manylinux1_x86_64.whl.metadata (8.6 kB)
Collecting paddleocr==2.10.0
  Downloading paddleocr-2.10.0-py3-none-any.whl.metadata (12 kB)
Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting astor (from paddlepaddle-gpu==2.6.2)
  Downloading astor-0.8.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting opt-einsum==3.3.0 (from paddlepaddle-gpu==2.6.2)
  Downloading opt_einsum-3.3.0-py3-none-any.whl.metadata (6.5 kB)
Collecting pyclipper (from paddleocr==2.10.0)
  Downloading pyclipper-1.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (8.6 kB)
Collecting lmdb (from paddleocr==2.10.0)
  Downloading lmdb-1.7.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.me

In [2]:
import json
from pathlib import Path
from sklearn.model_selection import train_test_split


MASTER_JSONL_FILE = "/content/drive/MyDrive/CF_Inv_Train/improved_ner_CF2K.jsonl"

VALIDATION_SET_SIZE = 0.20


RANDOM_SEED = 42

master_path = Path(MASTER_JSONL_FILE)
output_dir = master_path.parent
train_path = output_dir / "train.jsonl"
val_path = output_dir / "val.jsonl"

print(f"Reading master file: {master_path}")

try:

    with open(master_path, 'r', encoding='utf-8') as f:
        all_lines = f.readlines()

    if not all_lines:
        raise ValueError("Input file is empty.")

    print(f"Found {len(all_lines)} total samples.")


    train_lines, val_lines = train_test_split(
        all_lines,
        test_size=VALIDATION_SET_SIZE,
        random_state=RANDOM_SEED,
        shuffle=True
    )


    with open(train_path, 'w', encoding='utf-8') as f:
        f.writelines(train_lines)


    with open(val_path, 'w', encoding='utf-8') as f:
        f.writelines(val_lines)

    print("-" * 50)
    print("Dataset splitting complete!")
    print(f"✓ Training data ({len(train_lines)} samples) saved to: {train_path}")
    print(f"✓ Validation data ({len(val_lines)} samples) saved to: {val_path}")
    print("-" * 50)

except FileNotFoundError:
    print(f"✗ ERROR: The file was not found at the specified path.")
    print(f"  Please double-check the path: {master_path}")
except Exception as e:
    print(f"An error occurred: {e}")

Reading master file: /content/drive/MyDrive/CF_Inv_Train/improved_ner_CF2K.jsonl
Found 1896 total samples.
--------------------------------------------------
Dataset splitting complete!
✓ Training data (1516 samples) saved to: /content/drive/MyDrive/CF_Inv_Train/train.jsonl
✓ Validation data (380 samples) saved to: /content/drive/MyDrive/CF_Inv_Train/val.jsonl
--------------------------------------------------


In [3]:
"""
LiLT-based Invoice Extraction System - Training Script
Production-Grade Implementation
License: Apache 2.0
"""

import json
import os
from typing import List, Dict, Tuple
from pathlib import Path
from dataclasses import dataclass

import torch
from torch.utils.data import Dataset
from transformers import (
    AutoProcessor,
    LiltForTokenClassification,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from seqeval.metrics import (
    f1_score,
    precision_score,
    recall_score
)
from seqeval.scheme import IOB2
import numpy as np


# ============================================================================
# CONFIGURATION
# ============================================================================

@dataclass
class ModelConfig:
    """Configuration for the LiLT model"""
    model_name: str = "SCUT-DLVCLab/lilt-roberta-en-base"
    license: str = "Apache 2.0"
    max_length: int = 512
    ignore_label_id: int = -100


# ============================================================================
# DATASET
# ============================================================================

class InvoiceDataset(Dataset):
    """Custom Dataset for loading and preprocessing invoice data."""

    def __init__(
        self,
        jsonl_path: str,
        processor: AutoProcessor,
        label2id: Dict[str, int],
        max_length: int = 512
    ):
        self.processor = processor
        self.label2id = label2id
        self.max_length = max_length
        self.samples = []

        # Load all samples from JSONL file
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    self.samples.append(json.loads(line))

        print(f"Loaded {len(self.samples)} samples from {jsonl_path}")

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """Process a single sample."""
        sample = self.samples[idx]
        words = sample['words']
        bboxes = sample['bboxes']
        ner_tags = sample['ner_tags']

        # Convert string labels to integer IDs
        label_ids = [self.label2id[tag] for tag in ner_tags]

        # The processor handles everything automatically
        encoding = self.processor(
            words,
            boxes=bboxes,
            word_labels=label_ids,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )

        # Remove the extra batch dimension
        item = {key: val.squeeze(0) for key, val in encoding.items()}

        return item


def load_labels_from_jsonl(jsonl_path: str) -> Tuple[Dict[str, int], Dict[int, str]]:
    """
    Automatically discover all unique NER tags from a JSONL file.
    Returns label2id and id2label mappings.
    """
    unique_labels = set()

    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                sample = json.loads(line)
                unique_labels.update(sample['ner_tags'])

    # Sort labels for consistency (O first, then alphabetically)
    sorted_labels = sorted(unique_labels)
    if 'O' in sorted_labels:
        sorted_labels.remove('O')
        sorted_labels = ['O'] + sorted_labels

    label2id = {label: idx for idx, label in enumerate(sorted_labels)}
    id2label = {idx: label for label, idx in label2id.items()}

    print(f"Discovered {len(label2id)} unique labels: {sorted_labels}")
    return label2id, id2label


# ============================================================================
# TRAINING PIPELINE
# ============================================================================

class InvoiceTrainer:
    """Main training pipeline orchestrator."""

    def __init__(
        self,
        train_jsonl: str,
        val_jsonl: str,
        output_dir: str = './lilt_invoice_model',
        model_name: str = "SCUT-DLVCLab/lilt-roberta-en-base"
    ):
        self.train_jsonl = train_jsonl
        self.val_jsonl = val_jsonl
        self.output_dir = output_dir
        self.model_name = model_name

        # Create output directory
        Path(output_dir).mkdir(parents=True, exist_ok=True)

        # Load labels
        self.label2id, self.id2label = load_labels_from_jsonl(train_jsonl)

        # Save label mappings
        with open(f"{output_dir}/label2id.json", 'w') as f:
            json.dump(self.label2id, f, indent=2)
        with open(f"{output_dir}/id2label.json", 'w') as f:
            json.dump(self.id2label, f, indent=2)

        # Initialize processor
        self.processor = AutoProcessor.from_pretrained(model_name)

        # Initialize model
        self.model = LiltForTokenClassification.from_pretrained(
            model_name,
            num_labels=len(self.label2id),
            id2label=self.id2label,
            label2id=self.label2id
        )

        # Create datasets
        self.train_dataset = InvoiceDataset(
            train_jsonl, self.processor, self.label2id
        )
        self.val_dataset = InvoiceDataset(
            val_jsonl, self.processor, self.label2id
        )

        print(f"Model initialized with {len(self.label2id)} labels")
        print(f"Training samples: {len(self.train_dataset)}")
        print(f"Validation samples: {len(self.val_dataset)}")

    def train(
        self,
        num_epochs: int = 10,
        batch_size: int = 8,
        learning_rate: float = 5e-5,
        weight_decay: float = 0.01,
        warmup_ratio: float = 0.1,
        use_fp16: bool = True
    ):
        """Execute the training pipeline."""

        training_args = TrainingArguments(
            output_dir=self.output_dir,
            num_train_epochs=num_epochs,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            learning_rate=learning_rate,
            weight_decay=weight_decay,
            warmup_ratio=warmup_ratio,
            eval_strategy='epoch',
            save_strategy='epoch',
            save_total_limit=1,  # Keep only the best checkpoint
            load_best_model_at_end=True,
            metric_for_best_model='f1',
            greater_is_better=True,
            logging_dir=f'{self.output_dir}/logs',
            logging_steps=50,
            fp16=use_fp16 and torch.cuda.is_available(),
            dataloader_num_workers=4,
            remove_unused_columns=False,
            push_to_hub=False,
            report_to=['tensorboard'],
            save_safetensors=True,
            # Optimize storage
            save_only_model=True,  # Don't save optimizer states
            logging_first_step=False,
            save_on_each_node=False,
        )

        # Custom compute_metrics
        def compute_metrics_with_labels(pred):
            predictions, labels = pred
            predictions = np.argmax(predictions, axis=2)

            true_labels = []
            true_predictions = []

            for prediction, label in zip(predictions, labels):
                true_label_seq = []
                true_pred_seq = []

                for pred_id, label_id in zip(prediction, label):
                    if label_id != -100:
                        true_label_seq.append(self.id2label[label_id])
                        true_pred_seq.append(self.id2label[pred_id])

                if true_label_seq:
                    true_labels.append(true_label_seq)
                    true_predictions.append(true_pred_seq)

            results = {
                'precision': precision_score(true_labels, true_predictions, mode='strict', scheme=IOB2),
                'recall': recall_score(true_labels, true_predictions, mode='strict', scheme=IOB2),
                'f1': f1_score(true_labels, true_predictions, mode='strict', scheme=IOB2)
            }

            return results

        # Initialize trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=self.train_dataset,
            eval_dataset=self.val_dataset,
            compute_metrics=compute_metrics_with_labels,
            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
        )

        # Start training
        print("\n" + "="*80)
        print("Starting Training Pipeline")
        print("="*80 + "\n")

        trainer.train()

        # Save the best model (final)
        print("\n" + "="*80)
        print("Training Complete - Saving Best Model")
        print("="*80 + "\n")

        trainer.save_model(self.output_dir)
        self.processor.save_pretrained(self.output_dir)

        # Clean up intermediate checkpoints
        self._cleanup_checkpoints()

        # Evaluate final performance
        eval_results = trainer.evaluate()
        print("\nFinal Validation Metrics:")
        print(f"  Precision: {eval_results['eval_precision']:.4f}")
        print(f"  Recall: {eval_results['eval_recall']:.4f}")
        print(f"  F1-Score: {eval_results['eval_f1']:.4f}")

        # Save metrics
        with open(f"{self.output_dir}/final_metrics.json", 'w') as f:
            json.dump({
                'precision': eval_results['eval_precision'],
                'recall': eval_results['eval_recall'],
                'f1': eval_results['eval_f1']
            }, f, indent=2)

        return trainer

    def _cleanup_checkpoints(self):
        """Remove all checkpoint directories to save space."""
        import shutil

        for item in Path(self.output_dir).iterdir():
            if item.is_dir() and item.name.startswith('checkpoint-'):
                print(f"Removing checkpoint: {item.name}")
                shutil.rmtree(item)


# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    """Training script entry point."""

    print("="*80)
    print("LILT INVOICE EXTRACTION - TRAINING PIPELINE")
    print("="*80 + "\n")

    # Configuration
    TRAIN_JSONL = '/content/drive/MyDrive/CF_Inv_Train/train.jsonl'
    VAL_JSONL = '/content/drive/MyDrive/CF_Inv_Train/val.jsonl'
    OUTPUT_DIR = '/content/drive/MyDrive/CF_Inv_Train/models/lilt_invoice_final'

    # Verify input files exist
    if not os.path.exists(TRAIN_JSONL):
        raise FileNotFoundError(f"Training file not found: {TRAIN_JSONL}")
    if not os.path.exists(VAL_JSONL):
        raise FileNotFoundError(f"Validation file not found: {VAL_JSONL}")

    # Initialize trainer
    trainer = InvoiceTrainer(
        train_jsonl=TRAIN_JSONL,
        val_jsonl=VAL_JSONL,
        output_dir=OUTPUT_DIR
    )

    # Start training
    trainer.train(
        num_epochs=15,
        batch_size=8,
        learning_rate=5e-5,
        use_fp16=True
    )

    print(f"\n✓ Training complete! Model saved to: {OUTPUT_DIR}")
    print(f"✓ Use this model path for inference: {OUTPUT_DIR}")


if __name__ == '__main__':
    main()

LILT INVOICE EXTRACTION - TRAINING PIPELINE

Discovered 11 unique labels: ['O', 'B-INVOICE_AMOUNT', 'B-INVOICE_DATE', 'B-INVOICE_NUMBER', 'B-INVOICE_RAISED_BY', 'B-INVOICE_RAISED_TO', 'I-INVOICE_AMOUNT', 'I-INVOICE_DATE', 'I-INVOICE_NUMBER', 'I-INVOICE_RAISED_BY', 'I-INVOICE_RAISED_TO']


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/697 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/957 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/523M [00:00<?, ?B/s]

Some weights of LiltForTokenClassification were not initialized from the model checkpoint at SCUT-DLVCLab/lilt-roberta-en-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded 1516 samples from /content/drive/MyDrive/CF_Inv_Train/train.jsonl
Loaded 380 samples from /content/drive/MyDrive/CF_Inv_Train/val.jsonl
Model initialized with 11 labels
Training samples: 1516
Validation samples: 380

Starting Training Pipeline





Epoch,Training Loss,Validation Loss,Precision,Recall,F1
1,0.0502,0.01951,0.925249,0.95268,0.938764
2,0.0141,0.011134,0.964062,0.963512,0.963787
3,0.0059,0.011822,0.964567,0.977765,0.971121
4,0.0046,0.011489,0.967287,0.977765,0.972498
5,0.0038,0.010323,0.969612,0.982326,0.975927
6,0.0024,0.013171,0.969101,0.983466,0.976231
7,0.0021,0.012238,0.962606,0.968643,0.965615
8,0.0015,0.012945,0.973982,0.981756,0.977853
9,0.0026,0.010241,0.978953,0.981186,0.980068
10,0.0006,0.014547,0.975734,0.985747,0.980715



Training Complete - Saving Best Model

Removing checkpoint: checkpoint-2470



Final Validation Metrics:
  Precision: 0.9780
  Recall: 0.9875
  F1-Score: 0.9827

✓ Training complete! Model saved to: /content/drive/MyDrive/CF_Inv_Train/models/lilt_invoice_final
✓ Use this model path for inference: /content/drive/MyDrive/CF_Inv_Train/models/lilt_invoice_final


In [7]:
"""
LiLT-based Invoice Extraction System - Inference Script
Production-Grade Implementation with Train-Test Symmetry
License: Apache 2.0
"""

import json
from typing import List, Dict, Tuple, Optional

import torch
from transformers import AutoProcessor, LiltForTokenClassification
from paddleocr import PaddleOCR
from PIL import Image


# ============================================================================
# OCR SORTING ALGORITHMS (IDENTICAL TO TRAINING)
# ============================================================================

def _count_aligned_words(line: List, columns: List[Dict], tolerance: int = 10) -> int:
    """Counts how many words in a line align with the given column boundaries."""
    aligned_count = 0
    for word, bbox in line:
        word_x_start = bbox[0]
        for col in columns:
            if abs(word_x_start - col['x_start']) < tolerance:
                aligned_count += 1
                break
    return aligned_count


def find_best_table_candidate(
    lines: List,
    min_cols: int = 3,
    min_rows: int = 1
) -> Optional[Dict]:
    """Finds the best table candidate by checking for vertical alignment."""
    best_candidate = None
    max_score = 0

    for i, header_candidate in enumerate(lines):
        if len(header_candidate) < min_cols:
            continue

        columns = sorted([
            {'x_start': bbox[0], 'x_center': (bbox[0] + bbox[2]) / 2}
            for _, bbox in header_candidate
        ], key=lambda c: c['x_start'])

        conforming_rows_count = 0
        table_end_idx = i + 1

        for j in range(i + 1, len(lines)):
            line = lines[j]
            if not line:
                continue

            aligned_words = _count_aligned_words(line, columns)
            first_word_x = line[0][1][0]
            is_indented_description = aligned_words < 2 and first_word_x > (columns[0]['x_start'] + 20)

            if aligned_words >= min(len(line), len(columns)) / 2 or is_indented_description:
                conforming_rows_count += 1
                table_end_idx = j + 1
            else:
                break

        score = conforming_rows_count * len(columns)
        if conforming_rows_count >= min_rows and score > max_score:
            max_score = score
            best_candidate = {
                'start_idx': i,
                'end_idx': table_end_idx,
                'columns': columns
            }

    return best_candidate


def sort_table_row(row_lines: List, column_info: List[Dict]) -> List:
    """Sorts a table row using column start positions for accurate assignment."""
    all_words = [word for line in row_lines for word in line]

    cells = [[] for _ in column_info]

    for word, bbox in all_words:
        word_x_start = bbox[0]

        assigned_col_idx = -1
        for i, col in enumerate(column_info):
            if word_x_start >= col['x_start'] - 10:
                assigned_col_idx = i
            else:
                break

        if assigned_col_idx != -1:
            cells[assigned_col_idx].append((word, bbox))
        else:
            cells[0].append((word, bbox))

    for cell in cells:
        cell.sort(key=lambda item: ((item[1][1] + item[1][3]) / 2, item[1][0]))

    sorted_row = [word for cell in cells for word in cell]

    return sorted_row


def flatten_lines_and_normalize(
    lines: List,
    image_width: int,
    image_height: int
) -> Dict[str, List]:
    """
    Flattens lines and normalizes bboxes to [0, 1000] scale.
    CRITICAL: This normalization must match the training data format.
    """
    sorted_words = []
    sorted_bboxes = []

    for line in lines:
        for word, bbox in line:
            sorted_words.append(word)

            # Normalize to [0, 1000] range as required by LiLT
            x0, y0, x1, y1 = bbox[0], bbox[1], bbox[2], bbox[3]

            normalized_bbox = [
                int(1000 * (x0 / image_width)),
                int(1000 * (y0 / image_height)),
                int(1000 * (x1 / image_width)),
                int(1000 * (y1 / image_height))
            ]
            sorted_bboxes.append(normalized_bbox)

    return {'words': sorted_words, 'bboxes': sorted_bboxes}


def sort_ocr_words(
    ocr_result: List,
    image_width: int,
    image_height: int,
    line_threshold: int = 15
) -> Tuple[List[str], List[List[int]]]:
    """
    Two-stage OCR word sorting algorithm - IDENTICAL to training data generation.
    CRITICAL: This function must produce the same word order as the training data.

    Stage 1: Stable top-to-bottom, left-to-right sort using line anchors.
    Stage 2: Robust table detection and column-aware re-sorting.

    Args:
        ocr_result: PaddleOCR result
        image_width: Image width in pixels
        image_height: Image height in pixels
        line_threshold: Maximum vertical distance for same line grouping

    Returns:
        Tuple of (sorted_words, normalized_bboxes)
    """
    if not ocr_result or not ocr_result[0]:
        return [], []

    # Extract words and bboxes from OCR result
    word_data = []
    for line in ocr_result[0]:
        bbox_points = line[0]  # [[x1,y1], [x2,y2], [x3,y3], [x4,y4]]
        text = line[1][0]
        confidence = line[1][1]

        if confidence < 0.5:
            continue

        # Get bounding box coordinates
        xs = [point[0] for point in bbox_points]
        ys = [point[1] for point in bbox_points]
        x0, x1 = min(xs), max(xs)
        y0, y1 = min(ys), max(ys)

        # Split multi-word text
        text_words = text.split()
        if not text_words:
            continue

        # Distribute bbox across words
        word_width = (x1 - x0) / len(text_words)

        for idx, word in enumerate(text_words):
            word_x0 = x0 + (idx * word_width)
            word_x1 = word_x0 + word_width

            word_data.append((
                word,
                [int(word_x0), int(y0), int(word_x1), int(y1)]
            ))

    if not word_data:
        return [], []

    # STAGE 1: Sort by Y then X
    word_data.sort(key=lambda item: (item[1][1], item[1][0]))

    # Group into lines
    lines = []
    current_line = [word_data[0]]

    for word, bbox in word_data[1:]:
        line_anchor_y = (current_line[0][1][1] + current_line[0][1][3]) / 2
        word_y = (bbox[1] + bbox[3]) / 2

        if abs(word_y - line_anchor_y) < line_threshold:
            current_line.append((word, bbox))
        else:
            current_line.sort(key=lambda item: item[1][0])
            lines.append(current_line)
            current_line = [(word, bbox)]

    if current_line:
        current_line.sort(key=lambda item: item[1][0])
        lines.append(current_line)

    # STAGE 2: Table detection and column-aware sorting
    table_candidate = find_best_table_candidate(lines)

    if not table_candidate:
        # No table found - return flattened lines with normalization
        result = flatten_lines_and_normalize(lines, image_width, image_height)
        return result['words'], result['bboxes']

    # Process table
    table_header_idx = table_candidate['start_idx']
    table_end_idx = table_candidate['end_idx']
    column_info = table_candidate['columns']
    header_line = lines[table_header_idx]

    processed_table_lines = []
    i = table_header_idx + 1

    while i < table_end_idx:
        current_logical_row = [lines[i]]

        j = i + 1
        while j < table_end_idx:
            next_line = lines[j]
            if not next_line:
                j += 1
                continue

            first_word_bbox = next_line[0][1]
            is_indented = first_word_bbox[0] > (column_info[0]['x_start'] + 20)

            if is_indented and _count_aligned_words(next_line, column_info) < 2:
                current_logical_row.append(next_line)
                j += 1
            else:
                break

        sorted_row = sort_table_row(current_logical_row, column_info)
        processed_table_lines.append(sorted_row)
        i = j

    final_lines = (
        lines[:table_header_idx] +
        [header_line] +
        processed_table_lines +
        lines[table_end_idx:]
    )

    # Flatten and normalize
    result = flatten_lines_and_normalize(final_lines, image_width, image_height)
    return result['words'], result['bboxes']


# ============================================================================
# ENTITY GROUPING
# ============================================================================

def group_entities(words: List[str], labels: List[str]) -> Dict[str, str]:
    """
    Group consecutive B- and I- tags into complete entities.
    """
    entities = {
        'INVOICE_AMOUNT': '',
        'INVOICE_DATE': '',
        'INVOICE_NUMBER': '',
        'INVOICE_RAISED_BY': '',
        'INVOICE_RAISED_TO': ''
    }

    current_entity = None
    current_tokens = []

    for word, label in zip(words, labels):
        if label.startswith('B-'):
            # Save previous entity
            if current_entity and current_tokens:
                entities[current_entity] = ' '.join(current_tokens)

            # Start new entity
            current_entity = label[2:]
            current_tokens = [word]

        elif label.startswith('I-') and current_entity:
            entity_type = label[2:]
            if entity_type == current_entity:
                current_tokens.append(word)

        else:
            # End current entity
            if current_entity and current_tokens:
                entities[current_entity] = ' '.join(current_tokens)
            current_entity = None
            current_tokens = []

    # Save final entity
    if current_entity and current_tokens:
        entities[current_entity] = ' '.join(current_tokens)

    return entities


# ============================================================================
# INFERENCE CLASS
# ============================================================================

class InvoiceExtractor:
    """
    Production inference pipeline with perfect train-test symmetry.
    """

    def __init__(self, model_path: str, device: str = None):
        """
        Initialize the extractor with a trained model.

        Args:
            model_path: Path to the trained model directory
            device: Device to run inference on ('cuda', 'cpu', or None for auto)
        """
        self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')

        # Load model and processor
        self.model = LiltForTokenClassification.from_pretrained(model_path)
        self.processor = AutoProcessor.from_pretrained(model_path)
        self.model.to(self.device)
        self.model.eval()

        # Load label mappings
        with open(f"{model_path}/id2label.json", 'r') as f:
            self.id2label = {int(k): v for k, v in json.load(f).items()}

        # Initialize OCR
        self.ocr = PaddleOCR(use_angle_cls=True, lang='en', show_log=False)

        print(f"✓ Model loaded successfully on {self.device}")
        print(f"✓ Configured for {len(self.id2label)} entity types")

    def extract_from_image(self, image_path: str, debug: bool = False) -> Dict[str, str]:
        """
        Complete extraction pipeline for a single invoice image.
        CRITICAL: Uses IDENTICAL sorting logic as training data generation.

        Args:
            image_path: Path to the invoice image
            debug: If True, print intermediate results

        Returns:
            Dictionary with extracted invoice fields
        """
        # Load image to get dimensions
        image = Image.open(image_path)
        image_width, image_height = image.size

        if debug:
            print(f"\nProcessing: {image_path}")
            print(f"Image size: {image_width}x{image_height}")

        # Stage 1: OCR
        ocr_result = self.ocr.ocr(image_path, cls=True)

        # Stage 2: Sort with IDENTICAL algorithm as training
        words, bboxes = sort_ocr_words(ocr_result, image_width, image_height)

        if debug:
            print(f"Detected {len(words)} words")

        if not words:
            return {
                'INVOICE_AMOUNT': '',
                'INVOICE_DATE': '',
                'INVOICE_NUMBER': '',
                'INVOICE_RAISED_BY': '',
                'INVOICE_RAISED_TO': '',
                'error': 'No text detected in image'
            }

        # Stage 3: Preprocess
        encoding = self.processor(
            text=words,
            boxes=bboxes,
            return_tensors='pt',
            padding='max_length',
            truncation=True,
            max_length=512
        )

        # Move to device
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        bbox = encoding['bbox'].to(self.device)

        # Stage 4: Inference
        with torch.no_grad():
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                bbox=bbox
            )

        # Stage 5: Post-processing
        predictions = torch.argmax(outputs.logits, dim=2)
        predictions = predictions.squeeze(0).cpu().numpy()

        # Convert predictions to labels
        word_ids = encoding.word_ids(batch_index=0)
        predicted_labels = []
        previous_word_idx = None

        for idx, word_idx in enumerate(word_ids):
            if word_idx is not None and word_idx != previous_word_idx:
                if idx < len(predictions):
                    predicted_labels.append(self.id2label[predictions[idx]])
                previous_word_idx = word_idx

        # Ensure alignment
        predicted_labels = predicted_labels[:len(words)]

        if debug:
            print("\nWord-Label pairs:")
            for word, label in zip(words[:20], predicted_labels[:20]):
                print(f"  {word:20s} -> {label}")

        # Stage 6: Entity grouping
        entities = group_entities(words, predicted_labels)

        return entities

    def extract_batch(self, image_paths: List[str], debug: bool = False) -> List[Dict[str, str]]:
        """
        Process multiple invoices in batch.

        Args:
            image_paths: List of paths to invoice images
            debug: If True, print progress

        Returns:
            List of dictionaries with extracted fields
        """
        results = []
        total = len(image_paths)

        for idx, image_path in enumerate(image_paths, 1):
            if debug:
                print(f"\n[{idx}/{total}] Processing: {image_path}")

            try:
                result = self.extract_from_image(image_path, debug=False)
                results.append(result)

                if debug:
                    print(f"✓ Success")
            except Exception as e:
                error_result = {
                    'error': str(e),
                    'image_path': image_path,
                    'INVOICE_AMOUNT': '',
                    'INVOICE_DATE': '',
                    'INVOICE_NUMBER': '',
                    'INVOICE_RAISED_BY': '',
                    'INVOICE_RAISED_TO': ''
                }
                results.append(error_result)

                if debug:
                    print(f"✗ Error: {str(e)}")

        return results

    def extract_to_json(self, image_path: str, output_path: str):
        """
        Extract invoice data and save to JSON file.

        Args:
            image_path: Path to the invoice image
            output_path: Path to save the JSON output
        """
        result = self.extract_from_image(image_path)

        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(result, f, indent=2, ensure_ascii=False)

        print(f"✓ Results saved to: {output_path}")


# ============================================================================
# MAIN EXECUTION
# ============================================================================

def main():
    """Inference script entry point with example usage."""

    print("="*80)
    print("LILT INVOICE EXTRACTION - INFERENCE PIPELINE")
    print("="*80 + "\n")

    # Configuration
    MODEL_PATH = '/content/drive/MyDrive/CF_Inv_Train/models/lilt_invoice_final'

    # Initialize extractor
    extractor = InvoiceExtractor(model_path=MODEL_PATH)

    # Example 1: Single image inference
    print("\n--- Single Image Inference ---")
    result = extractor.extract_from_image('/content/drive/MyDrive/8e147617-4c63-4cdd-a6db-f70288d617b3.jpg', debug=True)

    print("\nExtracted Fields:")
    print(json.dumps(result, indent=2))

if __name__ == '__main__':
    main()

LILT INVOICE EXTRACTION - INFERENCE PIPELINE

✓ Model loaded successfully on cuda
✓ Configured for 11 entity types

--- Single Image Inference ---

Processing: /content/drive/MyDrive/8e147617-4c63-4cdd-a6db-f70288d617b3.jpg
Image size: 1200x1600
Detected 128 words

Word-Label pairs:
  SHREE                -> B-INVOICE_RAISED_BY
  GOVIND               -> I-INVOICE_RAISED_BY
  AUTO                 -> I-INVOICE_RAISED_BY
  TVS                  -> I-INVOICE_RAISED_BY
  OPP.SATAE            -> O
  BANK,CHH.SAMBHAJINAGAR -> O
  ROAD.VERUL(ELLORA)DIST.CHH.SAMBHAJINAGAR -> O
  Tel-94227052189403004981 -> O
  Tax                  -> O
  Invoice              -> O
  Original/Duplicate/Triplicate -> O
  Name                 -> O
  of                   -> O
  Buyer                -> O
  MR.SANTOSH           -> B-INVOICE_RAISED_TO
  KHANHU               -> I-INVOICE_RAISED_TO
  THAPE                -> I-INVOICE_RAISED_TO
  27AHGPA4782L1Z3      -> O
  Address              -> O
  AT.KASABKHEDA        