## Setup

First, we need to install packages beyond those pre-installed in Colab in order to run Florence-2.

In [None]:
%%capture
!pip install timm flash_attn einops;

In [None]:
%%bash
git clone https://github.com/AssemblyAI-Community/florence-2
mv florence-2/** .
rm -rf ./florence-2/

In [None]:
!python -m pip install --upgrade transformers==4.53.1



In [None]:
import copy

from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import requests

import utils

%matplotlib inline

In [None]:
!huggingface-cli login --token "add_your_huggingface_token_here"

In [None]:
!ngrok config add-authtoken "add_your_ngrok_token_here"


In [None]:
model_id = 'microsoft/Florence-2-large'
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

In [None]:
utils.set_model_info(model, processor)

In [None]:
!pkill ngrok || true

In [None]:
!pip install groq

In [None]:
!pip install groq pillow fal-client requests flask pyngrok flask-cors


In [None]:
!pkill ngrok || true
!ngrok config add-authtoken 33bfkQRER7AYPl7DUpPLVG4ahtd_3afEGFNzNnGxKfoAuTGZk


In [None]:
import os
import json
import glob
import requests
import threading
from PIL import Image, ImageDraw
from groq import Groq
import fal_client
from flask import Flask, request, jsonify, render_template_string, send_file
from flask_cors import CORS
from pyngrok import ngrok
from werkzeug.utils import secure_filename
from IPython.display import display, HTML

# --- CELL 3: Configuration ---
# API Keys

GROQ_API_KEY = "paste_Your_GROQ_API_KEY_Here"

# Paths
UPLOAD_FOLDER = "/content/uploads"
MASKS_FOLDER = "/content/correction_masks"
OUTPUT_DIR = "/content/fal_corrections"
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}

os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(MASKS_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize Groq client
client = Groq(api_key=GROQ_API_KEY)

# --- CELL 4: Flask App Setup ---
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024  # 16MB max file size
CORS(app)

def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

# --- CELL 5: OCR Validation Function ---
def validate_and_correct_ocr(results_task, image_context_description, client):
    """
    Analyze OCR results and identify incorrect text with suggested corrections.
    """
    labels = results_task['labels']
    boxes = results_task['quad_boxes']
    print(image_context_description)
    prompt = f"""You are an expert OCR validator analyzing text from an image.

IMAGE CONTEXT:
{image_context_description}

EXTRACTED OCR TEXT (in order):
{json.dumps(labels, indent=2)}

TASK:
Identify which texts are incorrect or misspelled, and suggest corrections.
Consider the context provided by the user, if the misspelled word does not match with any of the words given by user, then replace it with a sensible and related word of your own.

Return ONLY a JSON array with this exact structure:
[
  {{
    "incorrect_text": "exact text from OCR",
    "suggested_correction": "what it should be",
    "confidence": "high/medium/low",
    "reason": "brief explanation"
  }}
]

Rules:
- Only include texts that are clearly wrong or misspelled
- Use context clues from the description
- Be specific with corrections
- Skip texts that are correct
"""

    try:
        completion = client.chat.completions.create(
            model="llama-3.3-70b-versatile",
            messages=[{"role": "user", "content": prompt}],
            temperature=0.2,
            max_tokens=2048,
        )

        result_text = completion.choices[0].message.content.strip()

        # Clean JSON from markdown code blocks
        if "```json" in result_text:
            result_text = result_text.split("```json")[1].split("```")[0].strip()
        elif "```" in result_text:
            result_text = result_text.split("```")[1].split("```")[0].strip()

        parsed = json.loads(result_text)

    except Exception as e:
        print(f"⚠️ Error parsing LLM response: {e}")
        print(f"Raw response: {result_text if 'result_text' in locals() else 'No response'}")
        return []

    # Map incorrect labels to their quad boxes
    validation_results = []
    for item in parsed:
        incorrect_text = item['incorrect_text']
        if incorrect_text in labels:
            idx = labels.index(incorrect_text)
            validation_results.append({
                "incorrect_text": incorrect_text,
                "suggested_correction": item['suggested_correction'],
                "confidence": item['confidence'],
                "reason": item['reason'],
                "quad_box": boxes[idx],
                "box_index": idx
            })

    return validation_results

# --- CELL 6: Generate Masks for Incorrect Boxes Only ---
def create_correction_masks(image_path, validation_results, output_dir=MASKS_FOLDER):
    """
    Create individual mask images for each incorrect region (white on black).
    Returns list of tuples: (mask_path, correction_text, incorrect_text)
    """
    os.makedirs(output_dir, exist_ok=True)

    # Clear existing masks
    for old_mask in glob.glob(os.path.join(output_dir, "mask_*.png")):
        os.remove(old_mask)

    base_img = Image.open(image_path).convert("L")
    w, h = base_img.size
    mask_info = []

    for i, item in enumerate(validation_results, 1):
        quad = item["quad_box"]
        mask = Image.new("L", (w, h), 0)  # Black background
        draw = ImageDraw.Draw(mask)
        polygon = [(quad[j], quad[j+1]) for j in range(0, len(quad), 2)]
        draw.polygon(polygon, outline=255, fill=255)  # White box

        # Create safe filename from suggested correction
        safe_text = "".join(c if c.isalnum() else "_" for c in item["suggested_correction"][:30])
        mask_path = os.path.join(output_dir, f"mask_{i:02d}_{safe_text}.png")
        mask.save(mask_path)

        mask_info.append({
            "mask_path": mask_path,
            "correction_text": item["suggested_correction"],
            "incorrect_text": item["incorrect_text"],
            "confidence": item["confidence"],
            "reason": item["reason"]
        })

        print(f"✅ Created mask {i}: '{item['incorrect_text']}' → '{item['suggested_correction']}'")

    return mask_info

# --- CELL 7: No separate upload needed ---
# The /sequential endpoint handles everything, so we don't need a separate upload function

# --- CELL 8: Sequential Correction with ngrok Calligrapher ---
def apply_corrections_with_ngrok(source_image_path, mask_info_list, ngrok_url, output_dir=OUTPUT_DIR):
    """
    Sequentially apply corrections using ngrok Calligrapher /sequential endpoint.
    The server handles all steps internally using the previous output as the next input.
    """
    # Get original dimensions
    with Image.open(source_image_path) as img:
        orig_width, orig_height = img.size
    print(f"📏 Original image size: {orig_width}×{orig_height}")

    print(f"\n🖋️ Starting sequential corrections for {len(mask_info_list)} boxes...\n")

    try:
        # Prepare mask_info_list for the API
        # The server expects mask_url, so we need to ensure masks are accessible
        # For now, we'll read mask files and include them in the request

        # Prepare the JSON payload
        api_mask_info_list = []
        for idx, mask_info in enumerate(mask_info_list, start=1):
            correction_text = mask_info["correction_text"]
            incorrect_text = mask_info["incorrect_text"]
            mask_path = mask_info["mask_path"]

            print(f"🔧 Step {idx}/{len(mask_info_list)}: '{incorrect_text}' → '{correction_text}'")
            print(f"   Confidence: {mask_info['confidence']}, Reason: {mask_info['reason']}")

            # Note: The /sequential endpoint expects mask_url
            # You'll need to either:
            # 1. Host masks temporarily and provide URLs, OR
            # 2. Modify the server to accept mask files directly

            # For this implementation, we'll use individual /generate calls
            # since /sequential expects mask_url which we don't have

        print("\n⚠️ Using individual /generate calls instead of /sequential")
        print("   (because /sequential expects mask_url, not mask files)")

        current_image_path = source_image_path

        for idx, mask_info in enumerate(mask_info_list, start=1):
            correction_text = mask_info["correction_text"]
            incorrect_text = mask_info["incorrect_text"]
            mask_path = mask_info["mask_path"]

            print(f"\n{'='*70}")
            print(f"🔧 Step {idx}/{len(mask_info_list)}: Correcting '{incorrect_text}' → '{correction_text}'")
            print(f"   Confidence: {mask_info['confidence']}")
            print(f"   Reason: {mask_info['reason']}")
            print(f"{'='*70}")

            # Create prompt with the CORRECT text
            prompt = f"The text is '{correction_text}'"
            print(f"💬 Prompt: {prompt}")

            # Prepare files and data for ngrok API
            with open(current_image_path, 'rb') as img_file, open(mask_path, 'rb') as mask_file:
                files = {
                    'image_file': img_file,
                    'mask_file': mask_file
                }
                data = {
                    'prompt': prompt,
                    'num_inference_steps': 50,
                    'guidance_scale': 1.0,
                    'use_context': True
                }

                # Call ngrok /generate endpoint
                print(f"📡 Calling {ngrok_url}/generate...")
                response = requests.post(f"{ngrok_url}/generate", files=files, data=data, timeout=600)
                response.raise_for_status()
                result = response.json()

            # Extract base64 image and save
            image_base64 = result["image_base64"]
            print(f"✅ Generated image received")

            # Decode and save new image
            new_image_path = os.path.join(
                output_dir,
                f"step_{idx:02d}_{correction_text.replace(' ', '_')[:30]}.png"
            )
            img_data = base64.b64decode(image_base64)
            with open(new_image_path, "wb") as f:
                f.write(img_data)

            # Verify dimensions
            with Image.open(new_image_path) as new_img:
                w, h = new_img.size
                print(f"🖼️ Output size: {w}×{h} (expected: {orig_width}×{orig_height})")

            # Update source for next iteration
            current_image_path = new_image_path

        print(f"\n{'='*70}")
        print(f"✅ COMPLETED! All {len(mask_info_list)} corrections applied.")
        print(f"📁 Final corrected image: {current_image_path}")
        print(f"{'='*70}")

        return current_image_path

    except Exception as e:
        print(f"❌ Error during sequential correction: {e}")
        import traceback
        traceback.print_exc()
        return None
# --- CELL 9: Main Pipeline Function ---
def run_ocr_correction_pipeline(image_path, image_context_description, results_task):
    """
    Complete pipeline with user-provided context.
    """
    print("\n" + "="*70)
    print("🚀 STARTING OCR CORRECTION PIPELINE")
    print("="*70)

    # STEP 1: Validate OCR and find incorrect texts
    print("\n📋 STEP 1: Validating OCR results...")
    validation_results = validate_and_correct_ocr(
        results_task,
        image_context_description,
        client
    )

    if not validation_results:
        print("\n✅ No incorrect OCR text found! Image is perfect.")
        return None, []

    print(f"\n🔍 Found {len(validation_results)} incorrect text(s):")
    for i, item in enumerate(validation_results, 1):
        print(f"  {i}. '{item['incorrect_text']}' → '{item['suggested_correction']}' ({item['confidence']})")

    # STEP 2: Generate masks for incorrect boxes only
    print(f"\n🎭 STEP 2: Generating {len(validation_results)} correction mask(s)...")
    mask_info_list = create_correction_masks(image_path, validation_results)

    # STEP 3: Apply corrections sequentially using Fal API
    print("\n🔧 STEP 3: Applying corrections with Fal Calligrapher...")

    ngrok_url = "https://your-ngrok-url.ngrok.io"  # Your ngrok public URL
    final_image_path = apply_corrections_with_ngrok(source_image_path, mask_info_list, ngrok_url)
    print("\n✨ PIPELINE COMPLETE! ✨")
    return final_image_path, validation_results

# --- CELL 10: HTML Template ---
HTML_TEMPLATE = """
<!DOCTYPE html>
<html>
<head>
    <title>OCR Correction Service</title>
    <style>
        * { margin: 0; padding: 0; box-sizing: border-box; }
        body {
            font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            min-height: 100vh;
            padding: 20px;
        }
        .container {
            max-width: 800px;
            margin: 0 auto;
            background: white;
            border-radius: 20px;
            box-shadow: 0 20px 60px rgba(0,0,0,0.3);
            padding: 40px;
        }
        h1 {
            color: #667eea;
            margin-bottom: 10px;
            font-size: 2.5em;
        }
        .subtitle {
            color: #666;
            margin-bottom: 30px;
            font-size: 1.1em;
        }
        .form-group {
            margin-bottom: 25px;
        }
        label {
            display: block;
            margin-bottom: 8px;
            font-weight: 600;
            color: #333;
        }
        input[type="file"] {
            width: 100%;
            padding: 12px;
            border: 2px dashed #667eea;
            border-radius: 10px;
            cursor: pointer;
            transition: all 0.3s;
        }
        input[type="file"]:hover {
            border-color: #764ba2;
            background: #f8f9ff;
        }
        textarea {
            width: 100%;
            padding: 15px;
            border: 2px solid #e0e0e0;
            border-radius: 10px;
            font-size: 14px;
            font-family: inherit;
            resize: vertical;
            min-height: 150px;
            transition: border-color 0.3s;
        }
        textarea:focus {
            outline: none;
            border-color: #667eea;
        }
        button {
            width: 100%;
            padding: 15px;
            background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
            color: white;
            border: none;
            border-radius: 10px;
            font-size: 16px;
            font-weight: 600;
            cursor: pointer;
            transition: transform 0.2s;
        }
        button:hover {
            transform: translateY(-2px);
        }
        button:disabled {
            opacity: 0.6;
            cursor: not-allowed;
        }
        #status {
            margin-top: 20px;
            padding: 15px;
            border-radius: 10px;
            display: none;
        }
        .status-processing {
            background: #fff3cd;
            color: #856404;
            border: 2px solid #ffeeba;
        }
        .status-success {
            background: #d4edda;
            color: #155724;
            border: 2px solid #c3e6cb;
        }
        .status-error {
            background: #f8d7da;
            color: #721c24;
            border: 2px solid #f5c6cb;
        }
        #result {
            margin-top: 30px;
            display: none;
        }
        .result-image {
            width: 100%;
            border-radius: 10px;
            margin: 20px 0;
            box-shadow: 0 5px 15px rgba(0,0,0,0.2);
        }
        .corrections-list {
            background: #f8f9ff;
            padding: 20px;
            border-radius: 10px;
            margin-top: 20px;
        }
        .correction-item {
            padding: 12px;
            margin: 10px 0;
            background: white;
            border-left: 4px solid #667eea;
            border-radius: 5px;
        }
        .download-btn {
            display: inline-block;
            padding: 12px 30px;
            background: #28a745;
            color: white;
            text-decoration: none;
            border-radius: 8px;
            margin-top: 15px;
            transition: background 0.3s;
        }
        .download-btn:hover {
            background: #218838;
        }
    </style>
</head>
<body>
    <div class="container">
        <h1>🔤 OCR Correction Service</h1>
        <p class="subtitle">Upload an image and describe its context for intelligent OCR correction</p>

        <form id="uploadForm" enctype="multipart/form-data">
            <div class="form-group">
                <label for="image">📷 Select Image (PNG, JPG)</label>
                <input type="file" id="image" name="image" accept=".png,.jpg,.jpeg" required>
            </div>

            <div class="form-group">
                <label for="context">📝 Image Context Description</label>
                <textarea id="context" name="context" placeholder="Describe the image context. For example:&#10;&#10;This is a fashion magazine cover featuring Tom Cruise.&#10;- Main magazine title: 'FASHION' (large, at the top)&#10;- Primary headline: 'CRUISE CONTROL'&#10;- Subheadings include fashion trends and watch guides" required></textarea>
            </div>

            <button type="submit" id="submitBtn">
                🚀 Process Image
            </button>
        </form>

        <div id="status"></div>
        <div id="result"></div>
    </div>

    <script>
        document.getElementById('uploadForm').addEventListener('submit', async (e) => {
            e.preventDefault();

            const formData = new FormData();
            const imageFile = document.getElementById('image').files[0];
            const context = document.getElementById('context').value;

            formData.append('image', imageFile);
            formData.append('context', context);

            const statusDiv = document.getElementById('status');
            const resultDiv = document.getElementById('result');
            const submitBtn = document.getElementById('submitBtn');

            // Show processing status
            statusDiv.style.display = 'block';
            statusDiv.className = 'status-processing';
            statusDiv.innerHTML = '⏳ Processing your image... This may take a few minutes.';
            submitBtn.disabled = true;
            resultDiv.style.display = 'none';

            try {
                const response = await fetch('/api/correct', {
                    method: 'POST',
                    body: formData
                });

                const data = await response.json();

                if (data.success) {
                    statusDiv.className = 'status-success';
                    statusDiv.innerHTML = '✅ Processing complete!';

                    let resultHTML = '<h2>Results</h2>';

                    if (data.corrections && data.corrections.length > 0) {
                        resultHTML += `<p><strong>Found ${data.corrections.length} correction(s):</strong></p>`;
                        resultHTML += '<div class="corrections-list">';
                        data.corrections.forEach((corr, idx) => {
                            resultHTML += `
                                <div class="correction-item">
                                    <strong>${idx + 1}.</strong> "${corr.incorrect_text}" → "${corr.suggested_correction}"<br>
                                    <small><em>${corr.reason} (${corr.confidence} confidence)</em></small>
                                </div>
                            `;
                        });
                        resultHTML += '</div>';

                        if (data.final_image_url) {
                            resultHTML += `
                                <img src="${data.final_image_url}" class="result-image" alt="Corrected Image">
                                <a href="${data.final_image_url}" class="download-btn" download>📥 Download Corrected Image</a>
                            `;
                        }
                    } else {
                        resultHTML += '<p>No corrections needed - the image is already perfect! ✨</p>';
                    }

                    resultDiv.innerHTML = resultHTML;
                    resultDiv.style.display = 'block';

                } else {
                    statusDiv.className = 'status-error';
                    statusDiv.innerHTML = `❌ Error: ${data.error || 'Unknown error occurred'}`;
                }

            } catch (error) {
                statusDiv.className = 'status-error';
                statusDiv.innerHTML = `❌ Error: ${error.message}`;
            } finally {
                submitBtn.disabled = false;
            }
        });
    </script>
</body>
</html>
"""

# --- CELL 11: Flask Routes ---
@app.route('/')
def index():
    return render_template_string(HTML_TEMPLATE)

@app.route('/api/correct', methods=['POST'])
def correct_image():
    """
    API endpoint to process image with user-provided context.
    """
    try:
        # Validate request
        if 'image' not in request.files:
            return jsonify({'success': False, 'error': 'No image file provided'}), 400

        if 'context' not in request.form:
            return jsonify({'success': False, 'error': 'No context description provided'}), 400

        file = request.files['image']
        context_description = request.form['context']

        if file.filename == '':
            return jsonify({'success': False, 'error': 'No file selected'}), 400

        if not allowed_file(file.filename):
            return jsonify({'success': False, 'error': 'Invalid file type. Use PNG, JPG, or JPEG'}), 400

        # Save uploaded file
        filename = secure_filename(file.filename)
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)

        # Run OCR (you'll need to integrate your actual OCR tool here)
        # For now, this is a placeholder - replace with your actual OCR implementation
        print("🔍 Running OCR on uploaded image...")

        # YOU NEED TO ADD YOUR OCR IMPLEMENTATION HERE
        # This should use your utils.run_example or similar function
        # Example placeholder:
        # from your_ocr_module import utils
        image = Image.open(filepath).convert("RGB")
        task = utils.TaskType.OCR_WITH_REGION
        results = utils.run_example(task, image)
        results_task = results[task]

        # PLACEHOLDER - Replace with actual OCR results


        # Run correction pipeline
        final_image_path, corrections = run_ocr_correction_pipeline(
            filepath,incc
            context_description,
            results_task
        )

        # Prepare response
        if final_image_path:
            return jsonify({
                'success': True,
                'final_image_url': f'/output/{os.path.basename(final_image_path)}',
                'corrections': corrections,
                'total_corrections': len(corrections)
            })
        else:
            return jsonify({
                'success': True,
                'corrections': [],
                'total_corrections': 0,
                'message': 'No corrections needed'
            })

    except Exception as e:
        import traceback
        traceback.print_exc()
        return jsonify({'success': False, 'error': str(e)}), 500



In [None]:
@app.route('/output/<filename>')
def serve_output(filename):
    """Serve corrected images"""
    return send_file(os.path.join(OUTPUT_DIR, filename))

# --- CELL 12: Start Server with Ngrok ---
def start_server():
    """Start Flask server with ngrok tunnel"""
    port = 5000

    # Start ngrok tunnel
    public_url = ngrok.connect(port)
    print('\n' + '='*70)
    print('🌐 OCR CORRECTION SERVICE STARTED')
    print('='*70)
    print(f'🔗 Public URL: {public_url}')
    print(f'🏠 Local URL: http://localhost:{port}')
    print('='*70)
    print('\n📝 Usage:')
    print('1. Open the public URL in your browser')
    print('2. Upload an image')
    print('3. Provide context description')
    print('4. Click "Process Image"')
    print('\n⏹️  Press Ctrl+C to stop the server\n')

    # Display clickable link in Colab
    display(HTML(f'<h2><a href="{public_url}" target="_blank">🔗 Click here to open the service</a></h2>'))

    # Run Flask app
    app.run(port=port, debug=False, use_reloader=False)

# --- CELL 13: Execute ---
if __name__ == "__main__":
    # Note: Make sure to add your actual OCR implementation in the /api/correct route
    print("⚠️  IMPORTANT: You need to integrate your OCR implementation in the /api/correct route")
    print("    Look for the 'YOU NEED TO ADD YOUR OCR IMPLEMENTATION HERE' comment\n")

    # Start the server
    start_server()