In [1]:
!pip install torch torchvision transformers diffusers numpy easyocr scipy networkx pillow

Collecting numpy
  Downloading numpy-2.2.3-cp310-cp310-macosx_14_0_arm64.whl.metadata (62 kB)
Downloading numpy-2.2.3-cp310-cp310-macosx_14_0_arm64.whl (5.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.23.5
    Uninstalling numpy-1.23.5:
      Successfully uninstalled numpy-1.23.5
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albucore 0.0.13 requires numpy<2,>=1.24.4, but you have numpy 2.2.3 which is incompatible.
albumentations 1.4.10 requires numpy<2,>=1.24.4, but you have numpy 2.2.3 which is incompatible.
paddleocr 2.9.1 requires numpy<2.0, but you have numpy 2.2.3 which is incompatible.[0m[31m
[0mSuccessfully installed numpy-2.2.3


In [2]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import torch
import gc
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, BlipProcessor, BlipForConditionalGeneration
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
import numpy as np
import easyocr
import networkx as nx
from scipy.optimize import linear_sum_assignment
from accelerate import infer_auto_device_map, dispatch_model

class AITextCorrector:
    def __init__(self, trocr_model="microsoft/trocr-large-handwritten",
                 blip_model="Salesforce/blip-image-captioning-base",
                 model_name="stabilityai/stable-diffusion-2-inpainting"):
        """
        Initialize models: TrOCR for OCR, BLIP for context-aware correction, and TextDiffuser for inpainting.
        """
        # if torch.backends.mps.is_available():
        #     self.device = torch.device("mps")  # Use Apple MPS
        if torch.cuda.is_available():
            self.device = torch.device("cuda")  # Use CUDA if available (not applicable for Macs)
        else:
            self.device = torch.device("cpu")  # Default to CPU

        print("Using device:", self.device)

        # OCR - TrOCR
        self.ocr_processor = TrOCRProcessor.from_pretrained(trocr_model)
        # self.ocr_model = VisionEncoderDecoderModel.from_pretrained(trocr_model).to(self.device)

        # First, load the model
        self.ocr_model = VisionEncoderDecoderModel.from_pretrained(trocr_model)

        # Then, infer the device map
        device_map = infer_auto_device_map(self.ocr_model)

        # Finally, move the model to the computed device
        self.ocr_model = dispatch_model(self.ocr_model, device_map=device_map)

        # Captioning - BLIP-2
        self.blip_processor = BlipProcessor.from_pretrained(blip_model)
        self.blip_model = BlipForConditionalGeneration.from_pretrained(blip_model).to(self.device)

        # Text Inpainting - TextDiffuser
        self.model = StableDiffusionInpaintPipeline.from_pretrained(model_name).to(self.device)

        # Traditional OCR for Bounding Box Detection
        self.easyocr_model = easyocr.Reader(['en'])

    def detect_text_boxes(self, image):
        """
        Detects text regions using EasyOCR and extracts bounding boxes.
        """
        image_np = np.array(image)
        ocr_results = self.easyocr_model.readtext(image_np)
        return [{"coordinates": result[0], "text": result[1]} for result in ocr_results]

    def recognize_text(self, image):
        """
        Recognizes text in the image using TrOCR.
        """
        pixel_values = self.ocr_processor(images=image, return_tensors="pt").pixel_values.to(self.device)
        generated_ids = self.ocr_model.generate(pixel_values)
        return self.ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    def generate_caption(self, image):
        """
        Generates a descriptive caption for the image using BLIP-2.
        """
        inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
        pixel_values = inputs["pixel_values"]  # Extract pixel values
        with torch.no_grad():
            outputs = self.blip_model.generate(pixel_values=pixel_values)  # Pass pixel_values explicitly
        return self.blip_processor.decode(outputs[0], skip_special_tokens=True)

    def correct_text(self, extracted_text, caption, image):
        """
        Uses BLIP-2 to refine extracted text based on image caption context.
        """
        inputs = self.blip_processor(images=image, text=f"Correct this text: {extracted_text} in context: {caption}", return_tensors="pt").to(self.device)
        pixel_values = inputs["pixel_values"]  # Extract pixel values
        with torch.no_grad():
            outputs = self.blip_model.generate(pixel_values=pixel_values)  # ✅ Pass explicitly
        return self.blip_processor.decode(outputs[0], skip_special_tokens=True)


    def create_mask(self, image_size, coordinates):
        """
        Creates a binary mask for the text regions.
        """
        mask = Image.new('L', image_size, 0)
        draw = ImageDraw.Draw(mask)
        draw.polygon([tuple(point) for point in coordinates], outline=255, fill=255)
        return mask

    def graph_based_text_alignment(self, detected_boxes):
        """
        Uses a graph-based Hungarian Matching algorithm to align detected text positions.
        """
        num_boxes = len(detected_boxes)
        cost_matrix = np.zeros((num_boxes, num_boxes))

        for i in range(num_boxes):
            for j in range(num_boxes):
                if i != j:
                    # Distance-based cost function
                    x1, y1 = np.mean(detected_boxes[i]['coordinates'], axis=0)
                    x2, y2 = np.mean(detected_boxes[j]['coordinates'], axis=0)
                    cost_matrix[i, j] = np.linalg.norm(np.array([x1, y1]) - np.array([x2, y2]))

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        aligned_boxes = [detected_boxes[i] for i in row_ind]
        return aligned_boxes

    def inpaint_text(self, image, mask, corrected_text):
        """
        Inpaints the corrected text using TextDiffuser.
        """
    

        return self.model(prompt=f"Generate text '{corrected_text}' in a matching style", image=image, mask_image=mask, num_inference_steps=50, guidance_scale=7.5).images[0]

    def run_pipeline(self, image):
        """
        Runs the complete text correction pipeline.
        """
        text_boxes = self.detect_text_boxes(image)
        caption = self.generate_caption(image)
        aligned_boxes = self.graph_based_text_alignment(text_boxes)

        corrected_image = image.copy()

        for box in aligned_boxes:
            original_text = box["text"]
            corrected_text = self.correct_text(original_text, caption, image)

            if corrected_text.strip() == original_text.strip():
                continue  # Skip if no correction needed
            
            # mask = self.create_mask(image.size, box["coordinates"])
            
            mask = self.create_mask(image.size, box["coordinates"])
            mask = mask.resize(image.size)  # Ensure mask is the same size as image
            inpainted_region = self.inpaint_text(corrected_image, mask, corrected_text)

            # Blend the corrected text back into the image
            full_mask = Image.new('L', corrected_image.size, 0)
            full_mask.paste(mask, (0, 0))
            # corrected_image.paste(inpainted_region, (0, 0), full_mask)
            # Ensure the mask is resized to match the original image
            mask = mask.resize(image.size, Image.LANCZOS)

            # Ensure the inpainted region is the same size as the mask
            inpainted_region = inpainted_region.resize(image.size, Image.LANCZOS)

            # Debugging: Print sizes before pasting
            print("Original image size:", image.size)
            print("Mask size:", mask.size)
            print("Inpainted region size:", inpainted_region.size)

            # Paste the inpainted region back into the corrected image
            corrected_image.paste(inpainted_region, (0, 0), mask)

        return corrected_image

# Example usage
if __name__ == "__main__":
    corrector = AITextCorrector()

    input_image = Image.open("GenAI_Dataset/Dall-E3(ChatGpt)/DALL·E 2025-02-15 00.29.46 - A breathtaking night sky filled with countless stars, stretching across the horizon. The Milky Way is visible, creating a stunning cosmic backdrop. In.webp")  # Replace with your test image
    output_image = corrector.run_pipeline(input_image)
    print("Corrected Image Generated")
    output_image.save("test.png")

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.48.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
# Free unused memory on MPS
import torch
import gc

# Run garbage collection to free memory
gc.collect()

# Free unused memory on MPS
torch.mps.empty_cache()

In [None]:
# For Levenshtein distance
!pip install python-Levenshtein


In [17]:
import torch
import numpy as np
from scipy.optimize import minimize, linear_sum_assignment
from scipy.spatial.distance import cdist
from PIL import Image, ImageDraw
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from diffusers import StableDiffusionInpaintPipeline
import easyocr

class EnhancedTextCorrector:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Initializing models on {self.device}")
        
        # Initialize OCR reader with GPU acceleration
        self.easyocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
        
        # Initialize BLIP-2 for contextual understanding
        self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
        self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-opt-2.7b", 
            torch_dtype=torch.float16 if 'cuda' in str(self.device) else torch.float32
        ).to(self.device)
        
        # Initialize inpainting pipeline with memory optimization
        self.inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-inpainting",
            torch_dtype=torch.float16 if 'cuda' in str(self.device) else torch.float32
        ).to(self.device)
        self.inpaint_pipe.enable_attention_slicing()

    def detect_text_regions(self, image):
        """Robust text detection with adaptive confidence thresholding"""
        results = self.easyocr_reader.readtext(np.array(image), paragraph=True)
        regions = []
        
        for item in results:
            try:
                coords = np.array(item[0]).astype(int)
                text = str(item[1])
                confidence = float(item[2]) if len(item) > 2 else 0.0
                
                if confidence > 0.3:  # Adjusted confidence threshold
                    regions.append({
                        'coordinates': coords,
                        'text': text,
                        'confidence': confidence,
                        'original_bbox': self.get_axis_aligned_bbox(coords)
                    })
            except Exception as e:
                print(f"Skipping invalid OCR result: {e}")
        
        print(f"Detected {len(regions)} text regions")
        return regions

    def get_axis_aligned_bbox(self, polygon):
        """Convert rotated polygon to axis-aligned rectangle"""
        x_coords = polygon[:, 0]
        y_coords = polygon[:, 1]
        return np.array([
            [min(x_coords), min(y_coords)],
            [max(x_coords), min(y_coords)],
            [max(x_coords), max(y_coords)],
            [min(x_coords), max(y_coords)]
        ])

    def generate_contextual_caption(self, image):
        """Generate image description with enhanced prompting"""
        inputs = self.blip_processor(
            images=image, 
            text="a high quality photo of",
            return_tensors="pt"
        ).to(self.device, torch.float16)
        
        generated_ids = self.blip_model.generate(
            **inputs, 
            max_new_tokens=100,
            num_beams=5,
            early_stopping=True
        )
        return self.blip_processor.decode(generated_ids[0], skip_special_tokens=True)

    def correct_text_with_context(self, text, caption, image):
        """Context-aware text correction with validation"""
        prompt = f"Correct this text in image context: '{text}'. Image shows: {caption}. Correction must be:"
        
        inputs = self.blip_processor(
            images=image, 
            text=prompt,
            return_tensors="pt"
        ).to(self.device, torch.float16)
        
        generated_ids = self.blip_model.generate(
            **inputs,
            max_new_tokens=50,
            repetition_penalty=1.5,
            temperature=0.7
        )
        
        corrected = self.blip_processor.decode(generated_ids[0], skip_special_tokens=True)
        return corrected.strip().replace('"', '').replace("'", "")

    def optimize_text_layout(self, regions):
        """Enhanced geometric optimization with visual consistency"""
        if len(regions) < 1:
            return regions

        # Calculate initial layout parameters
        layout_params = []
        for region in regions:
            bbox = region['original_bbox']
            layout_params.append([
                (bbox[0][0] + bbox[2][0]) / 2,  # x_center
                (bbox[0][1] + bbox[2][1]) / 2,  # y_center
                bbox[2][1] - bbox[0][1]          # height
            ])

        initial_params = np.array(layout_params)
        bounds = [
            (x-50, x+50) for x in initial_params[:, 0]
        ] + [
            (y-50, y+50) for y in initial_params[:, 1]
        ] + [
            (h*0.8, h*1.2) for h in initial_params[:, 2]
        ]

        # Run constrained optimization
        result = minimize(
            self.layout_energy,
            initial_params.flatten(),
            args=(initial_params),
            method='L-BFGS-B',
            bounds=bounds,
            options={'maxiter': 200}
        )

        # Update regions with optimized layout
        optimized = result.x.reshape(-1, 3)
        for i, region in enumerate(regions):
            region['optimized_bbox'] = self.create_optimized_bbox(
                optimized[i], 
                region['original_bbox']
            )
        
        return regions

    def layout_energy(self, params, original):
        """Physics-inspired layout energy function"""
        params = params.reshape(-1, 3)
        energy = 0.0
        
        # Positional fidelity
        energy += 0.5 * np.sum((params[:, :2] - original[:, :2])**2)
        
        # Size consistency
        energy += 0.3 * np.sum((params[:, 2] - original[:, 2])**2)
        
        # Inter-element spacing
        if len(params) > 1:
            dx = np.diff(params[:, 0])
            energy += 0.2 * np.sum((dx - np.mean(dx))**2)
            
        return energy

    def create_optimized_bbox(self, params, original_bbox):
        """Create optimized bounding box from parameters"""
        x_center, y_center, height = params
        width = original_bbox[2][0] - original_bbox[0][0]
        return np.array([
            [x_center - width/2, y_center - height/2],
            [x_center + width/2, y_center - height/2],
            [x_center + width/2, y_center + height/2],
            [x_center - width/2, y_center + height/2]
        ]).astype(int)

    def semantic_text_matching(self, sources, targets):
        """Robust text matching using combined metrics"""
        # Create feature vectors
        source_vecs = np.array([self.text_features(t) for t in sources])
        target_vecs = np.array([self.text_features(t) for t in targets])
        
        # Combined similarity matrix
        char_sim = cdist(source_vecs, target_vecs, 'cosine')
        len_sim = np.abs(np.array([len(s)]*len(targets)) - np.array([len(t) for t in targets]))[:, None]
        combined_sim = 0.7*char_sim + 0.3*len_sim/20
        
        # Optimal assignment
        row_ind, col_ind = linear_sum_assignment(combined_sim)
        return [targets[i] for i in col_ind]

    def text_features(self, text):
        """Feature vector combining multiple text properties"""
        text = str(text).lower()
        return [
            len(text),
            *[ord(c) for c in text[:10]],
            *[text.count(vowel) for vowel in 'aeiou']
        ]

    def generate_text_mask(self, image_size, bbox):
        """Precision mask generation with anti-aliasing"""
        mask = Image.new('L', image_size, 0)
        draw = ImageDraw.Draw(mask)
        draw.polygon([tuple(p) for p in bbox], fill=255)
        return mask.resize(image_size, Image.LANCZOS)

    def inpaint_text_region(self, image, mask, text):
        """Style-preserving inpainting with text focus"""
        return self.inpaint_pipe(
            prompt=f"Professional sign with text: '{text}', perfect spelling, crisp edges, matching style",
            negative_prompt="deformed, blurry, disfigured, incorrect text, watermark",
            image=image,
            mask_image=mask,
            num_inference_steps=100,
            guidance_scale=12.0,
            height=image.height,
            width=image.width
        ).images[0]

    def process_image(self, image_path):
        """Complete processing pipeline with validation"""
        # Load and validate input
        orig_image = Image.open(image_path).convert("RGB")
        print("\nProcessing image:", image_path)
        
        # Stage 1: Text detection
        regions = self.detect_text_regions(orig_image)
        if not regions:
            print("No text regions found")
            return orig_image
        
        # Stage 2: Context understanding
        caption = self.generate_contextual_caption(orig_image)
        print(f"Image context: {caption}")
        
        # Stage 3: Text correction
        corrected_texts = []
        for i, region in enumerate(regions):
            original = region['text']
            corrected = self.correct_text_with_context(original, caption, orig_image)
            print(f"Region {i+1}: '{original}' → '{corrected}'")
            corrected_texts.append(corrected)
        
        # Stage 4: Text matching
        aligned_texts = self.semantic_text_matching(
            [r['text'] for r in regions],
            corrected_texts
        )
        
        # Stage 5: Layout optimization
        optimized_regions = self.optimize_text_layout(regions)
        
        # Stage 6: Inpainting
        result_image = orig_image.copy()
        for region, new_text in zip(optimized_regions, aligned_texts):
            if new_text.lower() == region['text'].lower():
                continue
            
            print(f"Processing: {region['text']} → {new_text}")
            try:
                # Generate mask for current region
                bbox = region.get('optimized_bbox', region['original_bbox'])
                mask = self.generate_text_mask(orig_image.size, bbox)
                
                # Inpaint with corrected text
                inpainted = self.inpaint_text_region(orig_image, mask, new_text)
                result_image.paste(inpainted, (0, 0), mask)
            except Exception as e:
                print(f"Error processing region: {e}")
                continue
        
        return result_image

# Usage example
if __name__ == "__main__":
    corrector = EnhancedTextCorrector()
    result = corrector.process_image("Incorrect_Images/Incorrect_SOTP_sign.jpg")
    result.save("corrected_image.jpg")
    result.show()

Using CPU. Note: This module is much faster with a GPU.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
