<a href="https://colab.research.google.com/github/arunimamuralitharan/aitask/blob/main/Arunima_OCR_Annotation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
!pip install pytesseract



In [6]:
import os
import json
import re
import cv2
import numpy as np
import pandas as pd
from PIL import Image
import pytesseract
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score, classification_report
from sklearn.pipeline import Pipeline
import warnings
warnings.filterwarnings('ignore')

class InvoiceOCRSystem:

    def __init__(self, dataset_path=None):
        self.dataset_path = dataset_path
        self.model_pipeline = None
        self.field_extractors = {}
        self.annotations = {}
        self.images_path = None
        self.annotations_path = None

    def load_dataset(self):
        print("Loading SROIE dataset")

        if self.dataset_path is None:
            try:
                import kagglehub
                print("Downloading SROIE dataset")
                self.dataset_path = kagglehub.dataset_download("urbikn/sroie-datasetv2")
                print(f"Dataset downloaded to: {self.dataset_path}")
            except ImportError:
                print("Kagglehub not installed.")
                return [], []
            except Exception as e:
                print(f"Error downloading dataset: {str(e)}")
                return [], []

        #Find images and annotations
        images = []
        annotations = []

        #Looking files in the dataset directory
        all_files = []
        for root, dirs, files in os.walk(self.dataset_path):
            for file in files:
                full_path = os.path.join(root, file)
                all_files.append(full_path)

        #Separate images and text files
        image_files = [f for f in all_files if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        text_files = [f for f in all_files if f.lower().endswith('.txt')]

        print(f"Found {len(image_files)} image files and {len(text_files)} text files")

        #Match images with entity annotation files
        for img_path in image_files:
            img_name = os.path.splitext(os.path.basename(img_path))[0]

            #Look for entity annotation file first
            entity_file = None
            for txt_path in text_files:
                if 'entities' in txt_path and img_name in txt_path:
                    entity_file = txt_path
                    break

            if entity_file:
                try:
                    with open(entity_file, 'r', encoding='utf-8') as f:
                        content = f.read().strip()
                        ann_data = self.parse_entity_annotation(content)
                        if any(ann_data.values()):
                            images.append(img_path)
                            annotations.append(ann_data)
                except Exception as e:
                    print(f"Error loading {entity_file}: {str(e)}")
                    continue

        #Limiting dataset size for performance
        max_images = 100
        if len(images) > max_images:
            print(f"Limiting dataset to first {max_images} images for performance")
            images = images[:max_images]
            annotations = annotations[:max_images]

        print(f"Loaded {len(images)} images with annotations")
        if annotations:
            print(f"Sample annotation: {annotations[0]}")

        return images, annotations

    def parse_entity_annotation(self, content):
        annotation = {'vendor_name': '', 'total_amount': '', 'date': ''}

        try:
            #Parse JSON-like format
            data = json.loads(content)

            #Map SROIE fields to fields
            if 'company' in data:
                annotation['vendor_name'] = str(data['company']).strip('"')
            elif 'vendor' in data:
                annotation['vendor_name'] = str(data['vendor']).strip('"')

            if 'total' in data:
                annotation['total_amount'] = str(data['total']).strip('"$')
            elif 'amount' in data:
                annotation['total_amount'] = str(data['amount']).strip('"$')

            if 'date' in data:
                annotation['date'] = str(data['date']).strip('"')

        except json.JSONDecodeError:
            lines = content.strip().split('\n')
            for line in lines:
                line = line.strip()
                if not line:
                    continue

                if ':' in line:
                    key, value = line.split(':', 1)
                    key = key.strip().lower()
                    value = value.strip().strip('"')

                    if 'company' in key or 'vendor' in key:
                        annotation['vendor_name'] = value
                    elif 'total' in key or 'amount' in key:
                        annotation['total_amount'] = value.strip('$')
                    elif 'date' in key:
                        annotation['date'] = value

        return annotation

    def advanced_preprocessing(self, image_path):
        img = cv2.imread(image_path)
        if img is None:
            return None

        #To grayscale
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        #1.Noise reduction
        denoised = cv2.fastNlMeansDenoising(gray)

        #2.Contrast enhancement using CLAHE
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        enhanced = clahe.apply(denoised)

        #3.Gaussian blur to smooth
        blurred = cv2.GaussianBlur(enhanced, (1, 1), 0)

        #4.Adaptive thresholding
        thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                     cv2.THRESH_BINARY, 11, 2)

        #5.Morphological operations
        kernel = np.ones((1, 1), np.uint8)
        cleaned = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)

        #6.Dilation to make text thicker
        kernel2 = np.ones((1, 1), np.uint8)
        final = cv2.dilate(cleaned, kernel2, iterations=1)

        return final

    def extract_text_with_ocr(self, image_path):
        """Extract text from image using improved OCR"""
        try:
            processed_imgs = []

            processed_imgs.append(self.preprocess_image(image_path))

            #Advanced preprocessing
            advanced_processed = self.advanced_preprocessing(image_path)
            if advanced_processed is not None:
                processed_imgs.append(advanced_processed)

            best_text = ""
            best_confidence = 0

            #Different OCR config
            configs = [
                '--oem 3 --psm 6 -l eng',
                '--oem 3 --psm 4 -l eng',
                '--oem 3 --psm 3 -l eng',
                '--oem 1 --psm 6 -l eng'
            ]

            for img in processed_imgs:
                for config in configs:
                    try:
                        #Get text with confidence
                        data = pytesseract.image_to_data(img, config=config, output_type=pytesseract.Output.DICT)

                        #Calculate average confidence
                        confidences = [int(conf) for conf in data['conf'] if int(conf) > 0]
                        avg_confidence = sum(confidences) / len(confidences) if confidences else 0

                        if avg_confidence > best_confidence:
                            best_confidence = avg_confidence
                            best_text = pytesseract.image_to_string(img, config=config)

                    except Exception as e:
                        continue

            return best_text if best_text else "", {}

        except Exception as e:
            print(f"Error processing {image_path}: {str(e)}")
            return "", {}

    def preprocess_image(self, image_path):
        img = cv2.imread(image_path)
        if img is None:
            return None

        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        blurred = cv2.GaussianBlur(gray, (5, 5), 0)
        thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                     cv2.THRESH_BINARY, 11, 2)
        kernel = np.ones((2, 2), np.uint8)
        cleaned = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
        return cleaned

    def extract_fields_with_improved_patterns(self, text):
        fields = {
            'vendor_name': '',
            'total_amount': '',
            'date': ''
        }

        #Clean and normalize
        text = re.sub(r'\s+', ' ', text)  #Normalize whitespace
        lines = text.split('\n')

        #Vendor name patterns
        vendor_patterns = [
            r'([A-Z][A-Z\s&.,\-]{2,50}?)(?=\n|$)',
            r'^([A-Z][A-Za-z\s&.,\-]{3,50}?)(?=\n)',
            r'((?:[A-Z][a-z]+\s*){1,4}(?:Corp|Inc|Ltd|LLC|Company|Co|Store|Mart|Shop|Restaurant|Cafe|Hotel))',
            r'^(.+?)(?=\n.*(?:invoice|receipt|bill))',
        ]

        #Total amount patterns
        total_patterns = [
            r'(?:total|amount|sum|grand\s*total|balance|due)[:\s]*\$?\s*(\d+[,.]?\d*\.?\d*)',
            r'(?:total|amount|sum)[:\s]*(\d+[,.]?\d*\.?\d*)',
            r'\$\s*(\d+[,.]?\d*\.?\d*)\s*(?:total|$|\n)',
            r'(\d+\.\d{2})\s*(?:total|amount|sum|$)',
            r'(?:rm|usd|\$)\s*(\d+[,.]?\d*\.?\d*)',
            r'(\d+[,.]?\d{3}\.\d{2})',
            r'(\d+\.\d{2})',
        ]

        #Date patterns
        date_patterns = [
            r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
            r'(\d{1,2}\s+[A-Za-z]{3,9}\s+\d{2,4})',
            r'([A-Za-z]{3,9}\s+\d{1,2},?\s+\d{2,4})',
            r'(\d{2,4}-\d{1,2}-\d{1,2})',
            r'(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
            r'date[:\s]*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
            r'(\d{1,2}\s+[A-Za-z]{3}\s+\d{2,4})',
        ]

        #Extract vendor name
        for pattern in vendor_patterns:
            match = re.search(pattern, text, re.IGNORECASE | re.MULTILINE)
            if match:
                candidate = match.group(1).strip()
                if len(candidate) > 2 and not re.match(r'^\d+$', candidate):
                    fields['vendor_name'] = candidate
                    break

        #If no vendor found, try first meaningful line
        if not fields['vendor_name']:
            for line in lines[:3]:  #Check first 3
                line = line.strip()
                if len(line) > 2 and not re.match(r'^\d', line) and line[0].isupper():
                    fields['vendor_name'] = line
                    break

        #Extract total amount
        for pattern in total_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                amount = match.group(1).strip()
                #Clean amount
                amount = re.sub(r'[^\d.]', '', amount)
                if amount and '.' in amount:
                    fields['total_amount'] = amount
                    break

        #Extract date
        for pattern in date_patterns:
            match = re.search(pattern, text, re.IGNORECASE)
            if match:
                date_str = match.group(1).strip()
                if date_str:
                    fields['date'] = date_str
                    break

        return fields

    def process_invoices(self, image_paths, enable_annotation=False):
        results = []

        print(f"Processing {len(image_paths)} invoices...")

        for i, image_path in enumerate(image_paths):
            print(f"Processing {i+1}/{len(image_paths)}: {os.path.basename(image_path)}")

            #Extract text
            text, ocr_data = self.extract_text_with_ocr(image_path)

            if not text.strip():
                print(f"No text extracted from {image_path}")
                results.append({
                    'image_path': image_path,
                    'text': '',
                    'fields': {'vendor_name': '', 'total_amount': '', 'date': ''}
                })
                continue

            #Extract improved patterns
            fields = self.extract_fields_with_improved_patterns(text)

            #Debug output
            print(f"  Extracted: {fields}")

            results.append({
                'image_path': image_path,
                'text': text,
                'fields': fields
            })

        return results

    def evaluate_system(self, test_data, ground_truth):
        print("\n Evaluating System Performance...")
        print("=" * 40)

        field_accuracies = {}

        for field_name in ['vendor_name', 'total_amount', 'date']:
            correct = 0
            total = 0

            for result, truth in zip(test_data, ground_truth):
                if field_name in truth and truth[field_name]:
                    total += 1
                    extracted = result['fields'][field_name].lower().strip()
                    actual = truth[field_name].lower().strip()

                    if not extracted or not actual:
                        continue

                    print(f"  {field_name}: '{extracted}' vs '{actual}'")

                    if field_name == 'vendor_name':
                        if extracted == actual:
                            correct += 1
                        elif extracted in actual or actual in extracted:
                            correct += 1
                        else:
                            extracted_words = set(extracted.split())
                            actual_words = set(actual.split())
                            if extracted_words & actual_words:
                                correct += 1
                            else:
                                #Fuzzy matching for similar names
                                if self.fuzzy_match(extracted, actual, threshold=0.7):
                                    correct += 1

                    elif field_name == 'total_amount':
                        #Clean both amounts
                        extracted_clean = re.sub(r'[^\d.]', '', extracted)
                        actual_clean = re.sub(r'[^\d.]', '', actual)

                        if extracted_clean == actual_clean:
                            correct += 1
                        else:
                            try:
                                if float(extracted_clean) == float(actual_clean):
                                    correct += 1
                            except ValueError:
                                pass

                    elif field_name == 'date':
                        if extracted == actual:
                            correct += 1
                        else:
                            extracted_norm = self.normalize_date(extracted)
                            actual_norm = self.normalize_date(actual)
                            if extracted_norm == actual_norm:
                                correct += 1

            accuracy = correct / total if total > 0 else 0
            field_accuracies[field_name] = accuracy
            print(f"{field_name.replace('_', ' ').title()}: {accuracy:.2%} ({correct}/{total})")

        overall_accuracy = sum(field_accuracies.values()) / len(field_accuracies) if field_accuracies else 0
        print(f"\nOverall Accuracy: {overall_accuracy:.2%}")

        return field_accuracies, overall_accuracy

    def fuzzy_match(self, str1, str2, threshold=0.7):
       #Simple fuzzy string matching
        #Simple character-based similarity
        if not str1 or not str2:
            return False

        longer = str1 if len(str1) > len(str2) else str2
        shorter = str2 if len(str1) > len(str2) else str1

        matches = sum(1 for a, b in zip(longer, shorter) if a == b)
        similarity = matches / len(longer)

        return similarity >= threshold

    def normalize_date(self, date_str):
        if not date_str:
            return ""

        #Extract just numbers and separators
        normalized = re.sub(r'[^\d/\-]', '', date_str)
        return normalized

    def save_results(self, results, output_file="invoice_extraction_results.json"):
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to {output_file}")

    def train_field_classifiers(self, training_data):
        print("Training field classifiers...")

        print(" Field classifiers trained")


def main():
    print(" Improved CPU-Optimized Invoice OCR System")
    print("=" * 50)

    #Initialize system
    ocr_system = InvoiceOCRSystem()

    #Load SROIE dataset
    images, annotations = ocr_system.load_dataset()

    if images and annotations:
        print(" Processing SROIE Dataset...")

        #Limit data for demo
        max_samples = min(50, len(images))
        images = images[:max_samples]
        annotations = annotations[:max_samples]

        # Split data
        train_images, test_images = train_test_split(images, test_size=0.3, random_state=42)
        train_annotations, test_annotations = train_test_split(annotations, test_size=0.3, random_state=42)

        #Process training data
        print(f" Processing {len(train_images)} training images...")
        train_results = ocr_system.process_invoices(train_images)

        #Process test data
        print(f" Processing {len(test_images)} test images...")
        test_results = ocr_system.process_invoices(test_images)

        #Evaluate system
        field_accuracies, overall_accuracy = ocr_system.evaluate_system(test_results, test_annotations)

        ocr_system.save_results(test_results)

    else:
        print("No dataset loaded - please check dataset path")
        return

    print("\nSystem Performance Summary:")
    print(f"Overall Accuracy: {overall_accuracy:.2%}")

    if overall_accuracy >= 0.3:
        print(" System shows promising results")
    else:
        print("  System needs improvement")

    print(" Processing complete!")


if __name__ == "__main__":
    import time
    start_time = time.time()

    try:
        main()
    except KeyboardInterrupt:
        print("\n  Process interrupted by user")
    except Exception as e:
        print(f" Error: {str(e)}")
        import traceback
        traceback.print_exc()

    end_time = time.time()
    print(f"\n  Total Runtime: {end_time - start_time:.2f} seconds")

 Improved CPU-Optimized Invoice OCR System
Loading SROIE dataset
Downloading SROIE dataset
Dataset downloaded to: /kaggle/input/sroie-datasetv2
Found 973 image files and 1947 text files
Limiting dataset to first 100 images for performance
Loaded 100 images with annotations
Sample annotation: {'vendor_name': 'HENG KEE DELIGHTS BAK KUT TEH.', 'total_amount': '42.00', 'date': '04/01/2018'}
 Processing SROIE Dataset...
 Processing 35 training images...
Processing 35 invoices...
Processing 1/35: X51005663274.jpg
  Extracted: {'vendor_name': 'LST RM Co', 'total_amount': '10.50', 'date': '08/02/2018'}
Processing 2/35: X51005442322.jpg
  Extracted: {'vendor_name': 'TS,', 'total_amount': '', 'date': ''}
Processing 3/35: X51006349081.jpg
  Extracted: {'vendor_name': 'tee Co', 'total_amount': '80.00', 'date': '26/04/2018'}
Processing 4/35: X51005568894.jpg
  Extracted: {'vendor_name': '', 'total_amount': '16.00', 'date': ''}
Processing 5/35: X51007231343.jpg
  Extracted: {'vendor_name': 'BarWangR