In [1]:
import Levenshtein
import re

def compute_nld(s1, s2):
    """Compute Normalized Levenshtein Distance (NLD)."""
    s1 = re.sub(r'\s+', ' ', s1.strip().lower())
    s2 = re.sub(r'\s+', ' ', s2.strip().lower())
    distance = Levenshtein.distance(s1, s2)
    return distance / max(len(s1), len(s2))

def word_level_accuracy(gt, pred):
    """Compute word-level accuracy and F1 score."""
    gt_words = gt.split()
    pred_words = pred.split()
    
    common = set(gt_words) & set(pred_words)
    accuracy = len(common) / len(gt_words) if gt_words else 0.0
    precision = len(common) / len(pred_words) if pred_words else 0.0
    recall = accuracy
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    return accuracy, f1


In [16]:
dataset = [
    ("Incorrect_Images/Incorrect_SOTP_sign.jpg", "STOP"),
    # ("GenAI_Dataset/Imagen(Gemini)/Gemini_Generated_Image_1mqqb41mqqb41mqq.jpg", "LIMITLESS POSSIBBITIES"),
    ("Incorrect_Images/Incorrect_Happy_Birthday.png.webp", "Happy Birthday"),
    # ("GenAI_Dataset/Imagen(Gemini)/Gemini_Generated_Image_r8x05sr8x05sr8x0.jpg", "nice to met you"),
    # ("Incorrect_Images/incorrect_parking.jpg", "No UNORTHERISED PARKING THE COMMITTEE"),

]


In [17]:
import torch
import gc
from transformers import BlipProcessor, BlipForConditionalGeneration
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
import numpy as np
import easyocr
import networkx as nx
from scipy.optimize import linear_sum_assignment
from accelerate import infer_auto_device_map, dispatch_model

class AITextCorrector:
    def __init__(self, blip_model="Salesforce/blip-image-captioning-base",
                 model_name="stabilityai/stable-diffusion-2-inpainting"):
        """
        Initialize models: BLIP for context-aware correction, and TextDiffuser for inpainting.
        """
        # if torch.backends.mps.is_available():
        #     self.device = torch.device("mps")  # Use Apple MPS
        if torch.cuda.is_available():
            self.device = torch.device("cuda")  # Use CUDA if available (not applicable for Macs)
        else:
            self.device = torch.device("cpu")  # Default to CPU

        print("Using device:", self.device)

        # Captioning - BLIP-2
        self.blip_processor = BlipProcessor.from_pretrained(blip_model)
        self.blip_model = BlipForConditionalGeneration.from_pretrained(blip_model).to(self.device)

        # Text Inpainting - TextDiffuser
        self.model = StableDiffusionInpaintPipeline.from_pretrained(model_name).to(self.device)

        # Traditional OCR for Bounding Box Detection
        self.easyocr_model = easyocr.Reader(['en'])

    def detect_text_boxes(self, image):
        """
        Detects text regions using EasyOCR and extracts bounding boxes.
        """
        image_np = np.array(image)
        ocr_results = self.easyocr_model.readtext(image_np)
        return [{"coordinates": result[0], "text": result[1]} for result in ocr_results]

    def recognize_text(self, image):
        """
        Recognizes text in the image using EasyOCR.
        """
        image_np = np.array(image)
        ocr_results = self.easyocr_model.readtext(image_np)

        recognized_text = " ".join([result[1] for result in ocr_results])  # Join detected text pieces
        print(f"OCR Output: {recognized_text}")
        return recognized_text

    def generate_caption(self, image):
        """
        Generates a descriptive caption for the image using BLIP-2.
        """
        inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
        pixel_values = inputs["pixel_values"]  # Extract pixel values
        with torch.no_grad():
            outputs = self.blip_model.generate(pixel_values=pixel_values)  # Pass pixel_values explicitly
        return self.blip_processor.decode(outputs[0], skip_special_tokens=True)

    def correct_text(self, extracted_text, caption, image):
        """
        Uses BLIP-2 to refine extracted text based on image caption context.
        """
        inputs = self.blip_processor(images=image, text=f"Correct this text: {extracted_text} in context: {caption}", return_tensors="pt").to(self.device)
        pixel_values = inputs["pixel_values"]  # Extract pixel values
        with torch.no_grad():
            outputs = self.blip_model.generate(pixel_values=pixel_values)  # ✅ Pass explicitly
        return self.blip_processor.decode(outputs[0], skip_special_tokens=True)

    def create_mask(self, image_size, coordinates):
        """
        Creates a binary mask for the text regions.
        """
        mask = Image.new('L', image_size, 0)
        draw = ImageDraw.Draw(mask)
        draw.polygon([tuple(point) for point in coordinates], outline=255, fill=255)
        return mask

    def graph_based_text_alignment(self, detected_boxes):
        """
        Uses a graph-based Hungarian Matching algorithm to align detected text positions.
        """
        num_boxes = len(detected_boxes)
        cost_matrix = np.zeros((num_boxes, num_boxes))

        for i in range(num_boxes):
            for j in range(num_boxes):
                if i != j:
                    # Distance-based cost function
                    x1, y1 = np.mean(detected_boxes[i]['coordinates'], axis=0)
                    x2, y2 = np.mean(detected_boxes[j]['coordinates'], axis=0)
                    cost_matrix[i, j] = np.linalg.norm(np.array([x1, y1]) - np.array([x2, y2]))

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        aligned_boxes = [detected_boxes[i] for i in row_ind]
        return aligned_boxes

    def inpaint_text(self, image, mask, corrected_text):
        """
        Inpaints the corrected text using TextDiffuser.
        """
        return self.model(prompt=f"Generate text '{corrected_text}' in a matching style", image=image, mask_image=mask, num_inference_steps=50, guidance_scale=7.5).images[0]

    def run_pipeline(self, image):
        """
        Runs the complete text correction pipeline.
        """
        text_boxes = self.detect_text_boxes(image)
        caption = self.generate_caption(image)
        aligned_boxes = self.graph_based_text_alignment(text_boxes)

        corrected_image = image.copy()

        for box in aligned_boxes:
            original_text = box["text"]
            corrected_text = self.correct_text(original_text, caption, image)

            if corrected_text.strip() == original_text.strip():
                continue  # Skip if no correction needed

            mask = self.create_mask(image.size, box["coordinates"])
            mask = mask.resize(image.size)  # Ensure mask is the same size as image
            inpainted_region = self.inpaint_text(corrected_image, mask, corrected_text)

            # Blend the corrected text back into the image
            full_mask = Image.new('L', corrected_image.size, 0)
            full_mask.paste(mask, (0, 0))
            mask = mask.resize(image.size, Image.LANCZOS)

            # Ensure the inpainted region is the same size as the mask
            inpainted_region = inpainted_region.resize(image.size, Image.LANCZOS)

            # Debugging: Print sizes before pasting
            print("Original image size:", image.size)
            print("Mask size:", mask.size)
            print("Inpainted region size:", inpainted_region.size)

            # Paste the inpainted region back into the corrected image
            corrected_image.paste(inpainted_region, (0, 0), mask)

        return corrected_image


In [18]:
class AITextCorrectorAblation(AITextCorrector):
    def run_pipeline_ablation(self, image, use_simulated_annealing=True, use_ocr_inpainting=True):
        """
        Run the pipeline with options to disable simulated annealing or OCR inpainting.
        """
        text_boxes = self.detect_text_boxes(image)
        caption = self.generate_caption(image)

        # If Simulated Annealing is disabled, skip the alignment step
        if not use_simulated_annealing:
            aligned_boxes = text_boxes  # Don't perform alignment if Simulated Annealing is disabled
        else:
            aligned_boxes = self.graph_based_text_alignment(text_boxes)

        corrected_image = image.copy()

        for box in aligned_boxes:
            original_text = box["text"]
            corrected_text = self.correct_text(original_text, caption, image)

            if corrected_text.strip() == original_text.strip():
                continue  # Skip if no correction needed

            if use_ocr_inpainting:
                # Create the mask and ensure it matches the size of the image
                mask = self.create_mask(image.size, box["coordinates"])
                mask = mask.resize(image.size, Image.LANCZOS)  # Resize mask to match image size

                # Inpaint the corrected text
                inpainted_region = self.inpaint_text(corrected_image, mask, corrected_text)

                # Resize the inpainted region to match the image size
                inpainted_region = inpainted_region.resize(image.size, Image.LANCZOS)

                # Ensure the mask is resized to the image size as well
                full_mask = Image.new('L', corrected_image.size, 0)
                full_mask.paste(mask, (0, 0))

                # Debugging: Print sizes before pasting
                print("Original image size:", image.size)
                print("Mask size:", mask.size)
                print("Inpainted region size:", inpainted_region.size)

                # Paste the inpainted region back into the image using the mask
                corrected_image.paste(inpainted_region, (0, 0), full_mask)  # Ensure mask is applied correctly

        return corrected_image


In [19]:
ablation_conditions = [
    {"name": "Baseline (No Correction)", "simulated": False, "ocr": False},
    {"name": "With Simulated Annealing", "simulated": True, "ocr": False},
    {"name": "With OCR In-painting", "simulated": False, "ocr": True},
    {"name": "Full Model", "simulated": True, "ocr": True}
]

results_ablation = []
corrector_ablation = AITextCorrectorAblation()

for cond in ablation_conditions:
    metrics = []
    for img_path, gt_text in dataset:  # 'dataset' should be your test dataset with images and ground truth
        image = Image.open(img_path)
        corrected_image = corrector_ablation.run_pipeline_ablation(
            image, use_simulated_annealing=cond["simulated"], use_ocr_inpainting=cond["ocr"]
        )
        pred_text = corrector_ablation.recognize_text(corrected_image)
        _, f1 = word_level_accuracy(gt_text, pred_text)  # You may need to adjust the word_level_accuracy method
        metrics.append(f1)
    avg_f1 = sum(metrics) / len(metrics)
    results_ablation.append({"Method": cond["name"], "Avg F1 Score": round(avg_f1, 2)})

df_ablation = pd.DataFrame(results_ablation)
print(df_ablation.to_latex(index=False))


Using device: cpu


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

OCR Output: SOTP
OCR Output: HAPPP  Hanpdday Birthday
OCR Output: SOTP
OCR Output: HAPPP  Hanpdday Birthday


  0%|          | 0/50 [00:00<?, ?it/s]

Original image size: (2816, 2112)
Mask size: (2816, 2112)
Inpainted region size: (2816, 2112)
OCR Output: STOP


  0%|          | 0/50 [00:00<?, ?it/s]

Original image size: (512, 512)
Mask size: (512, 512)
Inpainted region size: (512, 512)


  0%|          | 0/50 [00:00<?, ?it/s]

Original image size: (512, 512)
Mask size: (512, 512)
Inpainted region size: (512, 512)


  0%|          | 0/50 [00:00<?, ?it/s]

Original image size: (512, 512)
Mask size: (512, 512)
Inpainted region size: (512, 512)
OCR Output: Happpy BitR Hdday


  0%|          | 0/50 [00:00<?, ?it/s]

Original image size: (2816, 2112)
Mask size: (2816, 2112)
Inpainted region size: (2816, 2112)
OCR Output: S7ITOP


  0%|          | 0/50 [00:00<?, ?it/s]

Original image size: (512, 512)
Mask size: (512, 512)
Inpainted region size: (512, 512)


  0%|          | 0/50 [00:00<?, ?it/s]

Original image size: (512, 512)
Mask size: (512, 512)
Inpainted region size: (512, 512)


  0%|          | 0/50 [00:00<?, ?it/s]

Original image size: (512, 512)
Mask size: (512, 512)
Inpainted region size: (512, 512)
OCR Output: Hopppv HAppPY BIRTHDDAY
\begin{tabular}{lr}
\toprule
Method & Avg F1 Score \\
\midrule
Baseline (No Correction) & 0.200000 \\
With Simulated Annealing & 0.200000 \\
With OCR In-painting & 0.500000 \\
Full Model & 0.000000 \\
\bottomrule
\end{tabular}



In [20]:
import pandas as pd
df_ablation = pd.DataFrame(results_ablation)
print(df_ablation.to_latex(index=False))

\begin{tabular}{lr}
\toprule
Method & Avg F1 Score \\
\midrule
Baseline (No Correction) & 0.200000 \\
With Simulated Annealing & 0.200000 \\
With OCR In-painting & 0.500000 \\
Full Model & 0.000000 \\
\bottomrule
\end{tabular}

