In [1]:
!pip install torch torchvision transformers diffusers numpy easyocr scipy networkx pillow

Collecting numpy
  Downloading numpy-2.2.3-cp310-cp310-macosx_14_0_arm64.whl.metadata (62 kB)
Downloading numpy-2.2.3-cp310-cp310-macosx_14_0_arm64.whl (5.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.4/5.4 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.23.5
    Uninstalling numpy-1.23.5:
      Successfully uninstalled numpy-1.23.5
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albucore 0.0.13 requires numpy<2,>=1.24.4, but you have numpy 2.2.3 which is incompatible.
albumentations 1.4.10 requires numpy<2,>=1.24.4, but you have numpy 2.2.3 which is incompatible.
paddleocr 2.9.1 requires numpy<2.0, but you have numpy 2.2.3 which is incompatible.[0m[31m
[0mSuccessfully installed numpy-2.2.3


In [2]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [6]:
import torch
import gc
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, BlipProcessor, BlipForConditionalGeneration
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image, ImageDraw
import numpy as np
import easyocr
import networkx as nx
from scipy.optimize import linear_sum_assignment
from accelerate import infer_auto_device_map, dispatch_model

class AITextCorrector:
    def __init__(self, trocr_model="microsoft/trocr-large-handwritten",
                 blip_model="Salesforce/blip-image-captioning-base",
                 model_name="stabilityai/stable-diffusion-2-inpainting"):
        """
        Initialize models: TrOCR for OCR, BLIP for context-aware correction, and TextDiffuser for inpainting.
        """
        # if torch.backends.mps.is_available():
        #     self.device = torch.device("mps")  # Use Apple MPS
        if torch.cuda.is_available():
            self.device = torch.device("cuda")  # Use CUDA if available (not applicable for Macs)
        else:
            self.device = torch.device("cpu")  # Default to CPU

        print("Using device:", self.device)

        # OCR - TrOCR
        self.ocr_processor = TrOCRProcessor.from_pretrained(trocr_model)
        # self.ocr_model = VisionEncoderDecoderModel.from_pretrained(trocr_model).to(self.device)

        # First, load the model
        self.ocr_model = VisionEncoderDecoderModel.from_pretrained(trocr_model)

        # Then, infer the device map
        device_map = infer_auto_device_map(self.ocr_model)

        # Finally, move the model to the computed device
        self.ocr_model = dispatch_model(self.ocr_model, device_map=device_map)

        # Captioning - BLIP-2
        self.blip_processor = BlipProcessor.from_pretrained(blip_model)
        self.blip_model = BlipForConditionalGeneration.from_pretrained(blip_model).to(self.device)

        # Text Inpainting - TextDiffuser
        self.model = StableDiffusionInpaintPipeline.from_pretrained(model_name).to(self.device)

        # Traditional OCR for Bounding Box Detection
        self.easyocr_model = easyocr.Reader(['en'])

    def detect_text_boxes(self, image):
        """
        Detects text regions using EasyOCR and extracts bounding boxes.
        """
        image_np = np.array(image)
        ocr_results = self.easyocr_model.readtext(image_np)
        return [{"coordinates": result[0], "text": result[1]} for result in ocr_results]

    def recognize_text(self, image):
        """
        Recognizes text in the image using TrOCR.
        """
        pixel_values = self.ocr_processor(images=image, return_tensors="pt").pixel_values.to(self.device)
        generated_ids = self.ocr_model.generate(pixel_values)
        return self.ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    def generate_caption(self, image):
        """
        Generates a descriptive caption for the image using BLIP-2.
        """
        inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device)
        pixel_values = inputs["pixel_values"]  # Extract pixel values
        with torch.no_grad():
            outputs = self.blip_model.generate(pixel_values=pixel_values)  # Pass pixel_values explicitly
        return self.blip_processor.decode(outputs[0], skip_special_tokens=True)

    def correct_text(self, extracted_text, caption, image):
        """
        Uses BLIP-2 to refine extracted text based on image caption context.
        """
        inputs = self.blip_processor(images=image, text=f"Correct this text: {extracted_text} in context: {caption}", return_tensors="pt").to(self.device)
        pixel_values = inputs["pixel_values"]  # Extract pixel values
        with torch.no_grad():
            outputs = self.blip_model.generate(pixel_values=pixel_values)  # ✅ Pass explicitly
        return self.blip_processor.decode(outputs[0], skip_special_tokens=True)


    def create_mask(self, image_size, coordinates):
        """
        Creates a binary mask for the text regions.
        """
        mask = Image.new('L', image_size, 0)
        draw = ImageDraw.Draw(mask)
        draw.polygon([tuple(point) for point in coordinates], outline=255, fill=255)
        return mask

    def graph_based_text_alignment(self, detected_boxes):
        """
        Uses a graph-based Hungarian Matching algorithm to align detected text positions.
        """
        num_boxes = len(detected_boxes)
        cost_matrix = np.zeros((num_boxes, num_boxes))

        for i in range(num_boxes):
            for j in range(num_boxes):
                if i != j:
                    # Distance-based cost function
                    x1, y1 = np.mean(detected_boxes[i]['coordinates'], axis=0)
                    x2, y2 = np.mean(detected_boxes[j]['coordinates'], axis=0)
                    cost_matrix[i, j] = np.linalg.norm(np.array([x1, y1]) - np.array([x2, y2]))

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        aligned_boxes = [detected_boxes[i] for i in row_ind]
        return aligned_boxes

    def inpaint_text(self, image, mask, corrected_text):
        """
        Inpaints the corrected text using TextDiffuser.
        """
    

        return self.model(prompt=f"Generate text '{corrected_text}' in a matching style", image=image, mask_image=mask, num_inference_steps=50, guidance_scale=7.5).images[0]

    def run_pipeline(self, image):
        """
        Runs the complete text correction pipeline.
        """
        text_boxes = self.detect_text_boxes(image)
        caption = self.generate_caption(image)
        aligned_boxes = self.graph_based_text_alignment(text_boxes)

        corrected_image = image.copy()

        for box in aligned_boxes:
            original_text = box["text"]
            corrected_text = self.correct_text(original_text, caption, image)

            if corrected_text.strip() == original_text.strip():
                continue  # Skip if no correction needed
            
            # mask = self.create_mask(image.size, box["coordinates"])
            
            mask = self.create_mask(image.size, box["coordinates"])
            mask = mask.resize(image.size)  # Ensure mask is the same size as image
            inpainted_region = self.inpaint_text(corrected_image, mask, corrected_text)

            # Blend the corrected text back into the image
            full_mask = Image.new('L', corrected_image.size, 0)
            full_mask.paste(mask, (0, 0))
            # corrected_image.paste(inpainted_region, (0, 0), full_mask)
            # Ensure the mask is resized to match the original image
            mask = mask.resize(image.size, Image.LANCZOS)

            # Ensure the inpainted region is the same size as the mask
            inpainted_region = inpainted_region.resize(image.size, Image.LANCZOS)

            # Debugging: Print sizes before pasting
            print("Original image size:", image.size)
            print("Mask size:", mask.size)
            print("Inpainted region size:", inpainted_region.size)

            # Paste the inpainted region back into the corrected image
            corrected_image.paste(inpainted_region, (0, 0), mask)

        return corrected_image

# Example usage
if __name__ == "__main__":
    corrector = AITextCorrector()

    input_image = Image.open("Incorrect_Images/Incorrect_SOTP_sign.jpg")  # Replace with your test image
    output_image = corrector.run_pipeline(input_image)
    print("Corrected Image Generated")
    output_image.save("test.png")

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 1024,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 16,
  "num_channels": 3,
  "num_hidden_layers": 24,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.48.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 1024,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decod

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [None]:
# Free unused memory on MPS
import torch
import gc

# Run garbage collection to free memory
gc.collect()

# Free unused memory on MPS
torch.mps.empty_cache()

In [None]:
# For Levenshtein distance
!pip install python-Levenshtein


In [15]:
import torch
import numpy as np
from scipy.optimize import minimize
from scipy.spatial.distance import cdist
from PIL import Image, ImageDraw, ImageFont
from transformers import Blip2Processor, Blip2ForConditionalGeneration
from diffusers import StableDiffusionInpaintPipeline
import easyocr

class EnhancedTextCorrector:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        # Initialize models with better memory management
        self.easyocr_reader = easyocr.Reader(['en'])
        self.blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
        self.blip_model = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-opt-2.7b", 
            torch_dtype=torch.float16 if 'cuda' in str(self.device) else torch.float32
        ).to(self.device)
        
        self.inpaint_pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-inpainting",
            torch_dtype=torch.float16 if 'cuda' in str(self.device) else torch.float32
        ).to(self.device)

    def detect_text_regions(self, image):
        """Improved text detection with error handling"""
        results = self.easyocr_reader.readtext(np.array(image), paragraph=True)
        regions = []
        
        for item in results:
            try:
                # Handle varying result formats from EasyOCR
                coords = np.array(item[0])
                text = str(item[1])
                confidence = float(item[2]) if len(item) > 2 else 0.0
            except (IndexError, TypeError) as e:
                print(f"Skipping invalid OCR result: {item} | Error: {e}")
                continue
                
            if confidence > 0.4:
                regions.append({
                    'coordinates': coords,
                    'text': text,
                    'confidence': confidence
                })
        
        return regions

    def generate_contextual_caption(self, image):
        """Enhanced context understanding with BLIP-2"""
        inputs = self.blip_processor(images=image, return_tensors="pt").to(self.device, torch.float16)
        generated_ids = self.blip_model.generate(**inputs, max_new_tokens=50)
        return self.blip_processor.decode(generated_ids[0], skip_special_tokens=True)

    def correct_text_with_context(self, text, caption, image):
        """Context-aware text correction"""
        prompt = f"Correct this text in image context: '{text}'. Image shows: {caption}. Correction:"
        inputs = self.blip_processor(images=image, text=prompt, return_tensors="pt").to(self.device, torch.float16)
        generated_ids = self.blip_model.generate(**inputs, max_new_tokens=50)
        return self.blip_processor.decode(generated_ids[0], skip_special_tokens=True).strip()

    def calculate_energy(self, params, original_layout, lambda1=0.5, lambda2=0.3, mu=0.1, nu=0.1, d=30):
        """Energy function for geometric optimization"""
        N = len(original_layout)
        positions = params[:2*N].reshape(N, 2)
        heights = params[2*N:3*N]
        
        # Fidelity terms
        fidelity_pos = lambda1 * np.sum((positions - original_layout[:, :2])**2)
        fidelity_h = lambda2 * np.sum((heights - original_layout[:, 2])**2)
        
        # Spacing and uniformity terms
        spacing = 0
        uniformity = 0
        for i in range(1, N):
            spacing += mu * ((positions[i, 0] - positions[i-1, 0] - d)**2)
            uniformity += nu * ((heights[i] - heights[i-1])**2)
            
        return fidelity_pos + fidelity_h + spacing + uniformity

    def optimize_layout(self, text_regions):
        """Robust geometric optimization with bounds checking"""
        if not text_regions:
            return text_regions

        # Build layout matrix with error handling
        layout_data = []
        for region in text_regions:
            try:
                coords = region['coordinates']
                x_center = np.mean(coords[:, 0])
                y_center = np.mean(coords[:, 1])
                height = coords[3][1] - coords[0][1]
                layout_data.append([x_center, y_center, height])
            except (IndexError, KeyError) as e:
                print(f"Skipping invalid region: {e}")
                continue

        if not layout_data:
            return text_regions

        original_layout = np.array(layout_data)
        
        # Handle single region case
        if original_layout.ndim == 1:
            original_layout = original_layout.reshape(1, -1)

        # Create bounds for L-BFGS-B optimization
        bounds = [
            (max(0, x-50), x+50) for x in original_layout[:, 0].flatten()
        ] + [
            (max(0, y-50), y+50) for y in original_layout[:, 1].flatten()
        ] + [
            (max(1, h*0.5), h*2) for h in original_layout[:, 2].flatten()
        ]

        # Run optimization
        initial_guess = np.concatenate([
            original_layout[:, :2].flatten(),
            original_layout[:, 2].flatten()
        ])
        
        result = minimize(self.calculate_energy, initial_guess, 
                        args=(original_layout), method='L-BFGS-B',
                        bounds=bounds, options={'maxiter': 100})

        # Apply optimized layout
        optimized = result.x
        positions = optimized[:2*len(text_regions)].reshape(-1, 2)
        heights = optimized[2*len(text_regions):]

        for i, region in enumerate(text_regions):
            try:
                region['optimized_coords'] = self.create_rectangular_bbox(
                    positions[i], heights[i], region['coordinates'])
            except IndexError:
                continue  # Skip failed optimizations

        return text_regions

    def create_rectangular_bbox(self, center, height, original_bbox):
        """Create new bbox based on optimized parameters"""
        width = original_bbox[1][0] - original_bbox[0][0]
        return np.array([
            [center[0] - width/2, center[1] - height/2],
            [center[0] + width/2, center[1] - height/2],
            [center[0] + width/2, center[1] + height/2],
            [center[0] - width/2, center[1] + height/2]
        ])

    # def semantic_alignment(self, original_texts, corrected_texts):
    #     """Hungarian algorithm with edit distance"""
    #     cost_matrix = cdist(
    #         [self.text_to_vector(t) for t in original_texts],
    #         [self.text_to_vector(t) for t in corrected_texts],
    #         metric='cosine'
    #     )
    #     row_ind, col_ind = linear_sum_assignment(cost_matrix)
    #     return [corrected_texts[i] for i in col_ind]

    # def text_to_vector(self, text):
    #     """Simple text vectorization for alignment"""
    #     return np.array([len(text)] + [ord(c) for c in text[:10]])
    
    def text_to_vector(self, text):
        """Fixed-length text vectorization with padding"""
        max_chars = 10
        text = str(text)
        vec = [len(text)]  # Start with text length
        
        # Add character ordinals for first 10 chars
        for c in text[:max_chars]:
            vec.append(ord(c))
        
        # Pad with zeros if shorter than max_chars
        while len(vec) < max_chars + 1:
            vec.append(0)
            
        return np.array(vec).reshape(1, -1)  # Ensure 2D shape

    def semantic_alignment(self, original_texts, corrected_texts):
        """Robust semantic alignment with error handling"""
        if not original_texts or not corrected_texts:
            return corrected_texts
        
        try:
            # Create proper 2D arrays
            orig_vecs = np.vstack([self.text_to_vector(t) for t in original_texts])
            corr_vecs = np.vstack([self.text_to_vector(t) for t in corrected_texts])
            
            cost_matrix = cdist(orig_vecs, corr_vecs, metric='cosine')
            _, col_ind = linear_sum_assignment(cost_matrix)
            return [corrected_texts[i] for i in col_ind]
        except ValueError as e:
            print(f"Alignment error: {e}. Returning original order.")
            return corrected_texts

    def inpaint_text_region(self, image, mask, corrected_text):
        """Improved inpainting with style preservation"""
        return self.inpaint_pipe(
            prompt=f"High quality text: '{corrected_text}', matching original style",
            negative_prompt="blurry, distorted, inconsistent style",
            image=image,
            mask_image=mask,
            num_inference_steps=75,
            guidance_scale=9.0,
            strength=0.95
        ).images[0]

    def process_image(self, image_path):
        """Complete processing pipeline"""
        image = Image.open(image_path).convert("RGB")
        regions = self.detect_text_regions(image)
        caption = self.generate_contextual_caption(image)
        
        # Stage 1: Text correction
        corrected_texts = [self.correct_text_with_context(r['text'], caption, image) for r in regions]
        
        # Stage 2: Semantic alignment
        aligned_texts = self.semantic_alignment([r['text'] for r in regions], corrected_texts)
        
        # Stage 3: Geometric optimization
        optimized_regions = self.optimize_layout(regions)
        
        # Stage 4: Inpainting
        result_image = image.copy()
        for region, new_text in zip(optimized_regions, aligned_texts):
            if new_text.lower() == region['text'].lower():
                continue
                
            mask = self.create_mask(image.size, region['optimized_coords'])
            inpainted = self.inpaint_text_region(image, mask, new_text)
            result_image.paste(inpainted, (0, 0), mask)
        
        return result_image

    def create_mask(self, image_size, coordinates):
        """Precision mask creation"""
        mask = Image.new('L', image_size, 0)
        draw = ImageDraw.Draw(mask)
        draw.polygon([tuple(p) for p in coordinates], fill=255)
        return mask

# Usage
corrector = EnhancedTextCorrector()
result = corrector.process_image("Incorrect_Images/Incorrect_SOTP_sign.jpg")
result.save("corrected_image.jpg")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]