In [12]:
# Installation cell - run this first
!pip install transformers
!pip install torch
!pip install ipywidgets

import torch
from PIL import Image
import io
import base64
from IPython.display import display, HTML
from transformers import AutoProcessor, AutoModelForCausalLM
import ipywidgets as widgets
import matplotlib.pyplot as plt
from google.colab import output
import numpy as np

# Initialize model and processor
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

print("Loading Florence-2 model...")
model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-large",
    torch_dtype=torch_dtype,
    trust_remote_code=True
).to(device)

processor = AutoProcessor.from_pretrained(
    "microsoft/Florence-2-large",
    trust_remote_code=True
)

# Create drawing canvas
def create_drawing_canvas():
    canvas_html = """
    <canvas id="canvas" width="400" height="200" style="border:1px solid black"></canvas>
    <br>
    <button onclick="clearCanvas()">Clear</button>
    <script>
        var canvas = document.getElementById('canvas');
        var ctx = canvas.getContext('2d');
        var isDrawing = false;
        var lastX = 0;
        var lastY = 0;

        ctx.fillStyle = 'white';
        ctx.fillRect(0, 0, canvas.width, canvas.height);
        ctx.fillStyle = 'black';

        canvas.addEventListener('mousedown', startDrawing);
        canvas.addEventListener('mousemove', draw);
        canvas.addEventListener('mouseup', stopDrawing);
        canvas.addEventListener('mouseout', stopDrawing);

        function startDrawing(e) {
            isDrawing = true;
            [lastX, lastY] = [e.offsetX, e.offsetY];
        }

        function draw(e) {
            if (!isDrawing) return;
            ctx.beginPath();
            ctx.moveTo(lastX, lastY);
            ctx.lineTo(e.offsetX, e.offsetY);
            ctx.strokeStyle = 'black';
            ctx.lineWidth = 3;
            ctx.stroke();
            [lastX, lastY] = [e.offsetX, e.offsetY];
        }

        function stopDrawing() {
            isDrawing = false;
        }

        function clearCanvas() {
            ctx.fillStyle = 'white';
            ctx.fillRect(0, 0, canvas.width, canvas.height);
        }

        // Function to get image data
        function getImageData() {
            return canvas.toDataURL('image/png');
        }
    </script>
    """
    display(HTML(canvas_html))

# Function to get image from canvas
def get_drawn_image():
    # Get the image data from JavaScript
    image_data = output.eval_js('getImageData()')
    # Convert base64 to PIL Image
    image_bytes = base64.b64decode(image_data.split(',')[1])
    image = Image.open(io.BytesIO(image_bytes))

    # Convert to RGB mode if not already
    if image.mode != 'RGB':
        image = image.convert('RGB')

    # Ensure image has correct dimensions and format
    # Resize to a standard size while maintaining aspect ratio
    target_size = (400, 200)
    image.thumbnail(target_size, Image.Resampling.LANCZOS)

    # Create a white background image
    background = Image.new('RGB', target_size, (255, 255, 255))

    # Paste the image onto the white background
    offset = ((target_size[0] - image.size[0]) // 2,
              (target_size[1] - image.size[1]) // 2)
    background.paste(image, offset)

    return background

# Function to process image and get OCR result
def process_handwriting(image):
    # Prepare prompt for OCR task
    prompt = "<OCR>"  # Must be exactly this token, no additional text

    # Process image with Florence-2
    inputs = processor(
        text=prompt,
        images=image,
        return_tensors="pt"
    ).to(device, torch_dtype)

    # Generate text
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=50,  # Reduced for faster processing
        num_beams=2,
        do_sample=False
    )

    # Decode and post-process the generated text
    generated_text = processor.batch_decode(
        generated_ids,
        skip_special_tokens=True
    )[0]

    return generated_text.strip()

# Create button for processing
process_button = widgets.Button(description="Recognize Text")
output_text = widgets.Output()

def on_button_click(b):
    with output_text:
        output_text.clear_output()
        try:
            print("Getting image from canvas...")
            image = get_drawn_image()
            print("Image successfully loaded")
            print("Processing with Florence-2...")
            result = process_handwriting(image)
            print(f"Recognized text: {result}")
        except Exception as e:
            print(f"Error occurred: {str(e)}")
            print(f"Error type: {type(e)}")

process_button.on_click(on_button_click)

# Display everything
print("Draw your text on the canvas below:")
create_drawing_canvas()
display(process_button)
display(output_text)

Loading Florence-2 model...
Draw your text on the canvas below:


Button(description='Recognize Text', style=ButtonStyle())

Output()