In [1]:
pip install torch torchvision transformers diffusers numpy easyocr scipy networkx pillow

Note: you may need to restart the kernel to use updated packages.


In [2]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import torch
import gc
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, BlipProcessor, BlipForConditionalGeneration
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
import numpy as np
import easyocr
import networkx as nx
from scipy.optimize import linear_sum_assignment
from accelerate import infer_auto_device_map, dispatch_model

class AITextCorrector:
    def __init__(self, trocr_model="microsoft/trocr-large-handwritten",
                 blip_model="Salesforce/blip-image-captioning-base",
                 model_name="stabilityai/stable-diffusion-2-inpainting"):
        """
        Initialize models: TrOCR for OCR, BLIP for context-aware correction, and TextDiffuser for inpainting.
        """
        # if torch.backends.mps.is_available():
        #     self.device = torch.device("mps")  # Use Apple MPS
        if torch.cuda.is_available():
            self.device = torch.device("cuda")  # Use CUDA if available (not applicable for Macs)
        else:
            self.device = torch.device("cpu")  # Default to CPU

        print("Using device:", self.device)

        # OCR - TrOCR
        self.ocr_processor = TrOCRProcessor.from_pretrained(trocr_model)
        # self.ocr_model = VisionEncoderDecoderModel.from_pretrained(trocr_model).to(self.device)

        # First, load the model
        self.ocr_model = VisionEncoderDecoderModel.from_pretrained(trocr_model)

        # Then, infer the device map
        device_map = infer_auto_device_map(self.ocr_model)

        # Finally, move the model to the computed device
        self.ocr_model = dispatch_model(self.ocr_model, device_map=device_map)

        # Captioning - BLIP-2
        self.blip_processor = BlipProcessor.from_pretrained(blip_model)
        self.blip_model = BlipForConditionalGeneration.from_pretrained(blip_model).to(self.device)

        # Text Inpainting - TextDiffuser
        self.model = StableDiffusionInpaintPipeline.from_pretrained(model_name).to(self.device)

        # Traditional OCR for Bounding Box Detection
        self.easyocr_model = easyocr.Reader(['en'])

    def detect_text_boxes(self, image):
        """
        Detects text regions using EasyOCR and extracts bounding boxes.
        """
        image_np = np.array(image)
        ocr_results = self.easyocr_model.readtext(image_np)
        return [{"coordinates": result[0], "text": result[1]} for result in ocr_results]

    def recognize_text(self, image):
        """
        Recognizes text in the image using TrOCR.
        """
        pixel_values = self.ocr_processor(images=image, return_tensors="pt").pixel_values.to(self.device)
        generated_ids = self.ocr_model.generate(pixel_values)
        return self.ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    def generate_caption(self, image):
        """
        Generates a descriptive caption for the image using BLIP-2.
        """
        inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
        pixel_values = inputs["pixel_values"]  # Extract pixel values
        with torch.no_grad():
            outputs = self.blip_model.generate(pixel_values=pixel_values)  # Pass pixel_values explicitly
        return self.blip_processor.decode(outputs[0], skip_special_tokens=True)

    def correct_text(self, extracted_text, caption, image):
        """
        Uses BLIP-2 to refine extracted text based on image caption context.
        """
        inputs = self.blip_processor(images=image, text=f"Correct this text: {extracted_text} in context: {caption}", return_tensors="pt").to(self.device)
        pixel_values = inputs["pixel_values"]  # Extract pixel values
        with torch.no_grad():
            outputs = self.blip_model.generate(pixel_values=pixel_values)  # ✅ Pass explicitly
        return self.blip_processor.decode(outputs[0], skip_special_tokens=True)


    def create_mask(self, image_size, coordinates):
        """
        Creates a binary mask for the text regions.
        """
        mask = Image.new('L', image_size, 0)
        draw = ImageDraw.Draw(mask)
        draw.polygon([tuple(point) for point in coordinates], outline=255, fill=255)
        return mask

    def graph_based_text_alignment(self, detected_boxes):
        """
        Uses a graph-based Hungarian Matching algorithm to align detected text positions.
        """
        num_boxes = len(detected_boxes)
        cost_matrix = np.zeros((num_boxes, num_boxes))

        for i in range(num_boxes):
            for j in range(num_boxes):
                if i != j:
                    # Distance-based cost function
                    x1, y1 = np.mean(detected_boxes[i]['coordinates'], axis=0)
                    x2, y2 = np.mean(detected_boxes[j]['coordinates'], axis=0)
                    cost_matrix[i, j] = np.linalg.norm(np.array([x1, y1]) - np.array([x2, y2]))

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        aligned_boxes = [detected_boxes[i] for i in row_ind]
        return aligned_boxes

    def inpaint_text(self, image, mask, corrected_text):
        """
        Inpaints the corrected text using TextDiffuser.
        """
    

        return self.model(prompt=f"Generate text '{corrected_text}' in a matching style", image=image, mask_image=mask, num_inference_steps=50, guidance_scale=7.5).images[0]

    def run_pipeline(self, image):
        """
        Runs the complete text correction pipeline.
        """
        text_boxes = self.detect_text_boxes(image)
        caption = self.generate_caption(image)
        aligned_boxes = self.graph_based_text_alignment(text_boxes)

        corrected_image = image.copy()

        for box in aligned_boxes:
            original_text = box["text"]
            corrected_text = self.correct_text(original_text, caption, image)

            if corrected_text.strip() == original_text.strip():
                continue  # Skip if no correction needed
            
            # mask = self.create_mask(image.size, box["coordinates"])
            
            mask = self.create_mask(image.size, box["coordinates"])
            mask = mask.resize(image.size)  # Ensure mask is the same size as image
            inpainted_region = self.inpaint_text(corrected_image, mask, corrected_text)

            # Blend the corrected text back into the image
            full_mask = Image.new('L', corrected_image.size, 0)
            full_mask.paste(mask, (0, 0))
            # corrected_image.paste(inpainted_region, (0, 0), full_mask)
            # Ensure the mask is resized to match the original image
            mask = mask.resize(image.size, Image.LANCZOS)

            # Ensure the inpainted region is the same size as the mask
            inpainted_region = inpainted_region.resize(image.size, Image.LANCZOS)

            # Debugging: Print sizes before pasting
            print("Original image size:", image.size)
            print("Mask size:", mask.size)
            print("Inpainted region size:", inpainted_region.size)

            # Paste the inpainted region back into the corrected image
            corrected_image.paste(inpainted_region, (0, 0), mask)

        return corrected_image

# Example usage
if __name__ == "__main__":
    corrector = AITextCorrector()

    input_image = Image.open("Incorrect_Images/spelling errors.jpg")  # Replace with your test image
    output_image = corrector.run_pipeline(input_image)
    print("Corrected Image Generated")
    output_image.save("Corrected_spelling_error.png")


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.48.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared dec

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
# Free unused memory on MPS
import torch
import gc

# Run garbage collection to free memory
gc.collect()

# Free unused memory on MPS
torch.mps.empty_cache()