In [None]:
%%time

from IPython.display import clear_output

!pip install transformers==4.45.0
!pip install bitsandbytes==0.44.1 accelerate
! pip install einops flash_attn # florence 2

clear_output()

CPU times: user 780 ms, sys: 169 ms, total: 949 ms
Wall time: 2min 17s


In [7]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=92ce8165ca0fef50091385fdc7cd09d10e8f9d36905041a17da2b41113c82532
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [17]:
import pandas as pd
import json
import os
import numpy as np
from rouge_score import rouge_scorer
import re
import glob
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
import time
import gc

# Install rouge-score if needed
# !pip install rouge-score

# Configuration
class CFG:
    florece_model = "microsoft/Florence-2-large"
    image_dir = '/content/drive/MyDrive/training_data/images'
    annotation_dir= '//content/drive/MyDrive/training_data/annotations'  # Update this path
    num_images = 50
    output_csv = 'florence_ocr_results.csv'

# Load model only once
def build_model():
    print('Loading Florence model...')
    processor = AutoProcessor.from_pretrained(
        CFG.florece_model,
        trust_remote_code=True
    )

    # Determine if CUDA is available and set device accordingly
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # If using CPU, use float32 instead of float16
    dtype = torch.float16 if torch.cuda.is_available() else torch.float32

    model = AutoModelForCausalLM.from_pretrained(
        CFG.florece_model,
        trust_remote_code=True,
        torch_dtype=dtype,
    ).to(device).eval()

    return processor, model

# Process single image
def process_image(image_path, processor, model):
    try:
        # Load image
        image = Image.open(image_path)

        # Convert grayscale (1-channel) or any non-RGB images to RGB (3-channel)
        if image.mode != 'RGB':
            print(f"Converting {image.mode} image to RGB")
            image = image.convert('RGB')

        # Start timing
        start_time = time.time()
        torch.cuda.reset_peak_memory_stats()

        # Run inference
        inputs = processor(
            text="<OCR>",
            images=image,
            return_tensors="pt"
        )

        # Convert inputs to the same dtype as model and move to device
        inputs = {k: v.to(device=model.device, dtype=torch.float16 if k == "pixel_values" else v.dtype)
                 for k, v in inputs.items()}

        generated_ids = model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=1024,
            num_beams=3
        )

        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
        result = processor.post_process_generation(
            generated_text,
            task="<OCR>",
            image_size=(image.width, image.height)
        )

        # Get metrics
        inference_time = time.time() - start_time
        max_memory = torch.cuda.max_memory_allocated() / (1024 ** 2)  # MB

        # Clean up
        del inputs, generated_ids
        torch.cuda.empty_cache()
        gc.collect()

        return result, inference_time, max_memory

    except Exception as e:
        print(f"Error processing {image_path}: {str(e)}")
        return None, None, None

# Main processing function for OCR
def process_ocr():
    # Get images
    all_images = sorted([
        os.path.join(CFG.image_dir, f)
        for f in os.listdir(CFG.image_dir)
        if f.lower().endswith(('.png', '.jpg', '.jpeg'))
    ])[:CFG.num_images]

    print(f"Found {len(all_images)} images to process")

    # Load model once
    processor, model = build_model()

    results = []

    for idx, img_path in enumerate(all_images, 1):
        img_filename = os.path.basename(img_path)
        print(f"Processing image {idx}/{len(all_images)}: {img_filename}")

        result, inf_time, mem_usage = process_image(img_path, processor, model)

        results.append({
            'image_id': img_filename,
            'ocr_text': str(result) if result else None,
            'inference_time_sec': inf_time,
            'gpu_memory_usage_mb': mem_usage
        })

    # Save results
    df = pd.DataFrame(results)
    df.to_csv(CFG.output_csv, index=False)
    print(f"Results saved to {CFG.output_csv}")

    return df

# Function to extract all text from annotation file
def extract_text_from_annotation(annotation_file):
    try:
        with open(annotation_file, 'r') as f:
            data = json.load(f)

        # Your annotation format has a list of text entries
        all_texts = []

        # Extract text from each item in the list
        if isinstance(data, list):
            for item in data:
                if 'text' in item:
                    all_texts.append(item['text'])
        # If the data is a dictionary with a list under a key like 'annotations'
        elif isinstance(data, dict):
            for key in data:
                if isinstance(data[key], list):
                    for item in data[key]:
                        if isinstance(item, dict) and 'text' in item:
                            all_texts.append(item['text'])

        # Join all the text pieces
        return " ".join(all_texts)
    except Exception as e:
        print(f"Error extracting text from {annotation_file}: {e}")
        return ""

# Function to clean text for ROUGE comparison
def clean_text(text):
    if text is None or text == "None":
        return ""
    # Remove extra whitespace, newlines and normalize
    return re.sub(r'\s+', ' ', str(text)).strip()

# Main function to calculate ROUGE scores
def calculate_rouge():
    # Load OCR results
    ocr_results_df = pd.read_csv(CFG.output_csv)
    print(f"Loaded {len(ocr_results_df)} OCR results")

    # Initialize ROUGE scorer
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

    rouge_scores = []

    # Process each image that has OCR results
    for _, row in ocr_results_df.iterrows():
        image_filename = row['image_id']
        ocr_text = clean_text(row['ocr_text'])

        # Skip if OCR failed
        if not ocr_text:
            print(f"Skipping {image_filename} - No OCR text available")
            continue

        # Find corresponding annotation file
        base_name = os.path.splitext(image_filename)[0]
        annotation_path = os.path.join(CFG.annotation_dir, f"{base_name}.json")

        if not os.path.exists(annotation_path):
            print(f"No annotation found for {image_filename}")
            continue

        # Extract text from annotation
        ground_truth = extract_text_from_annotation(annotation_path)
        ground_truth = clean_text(ground_truth)

        if not ground_truth:
            print(f"Empty ground truth for {image_filename}")
            continue

        # Calculate ROUGE scores
        scores = scorer.score(ground_truth, ocr_text)

        rouge_scores.append({
            'image_file': image_filename,
            'rouge1_precision': scores['rouge1'].precision,
            'rouge1_recall': scores['rouge1'].recall,
            'rouge1_f1': scores['rouge1'].fmeasure,
            'rouge2_precision': scores['rouge2'].precision,
            'rouge2_recall': scores['rouge2'].recall,
            'rouge2_f1': scores['rouge2'].fmeasure,
            'rougeL_precision': scores['rougeL'].precision,
            'rougeL_recall': scores['rougeL'].recall,
            'rougeL_f1': scores['rougeL'].fmeasure,
            'ground_truth_length': len(ground_truth),
            'ocr_text_length': len(ocr_text)
        })

        print(f"Calculated ROUGE for {image_filename}")

    # Save ROUGE scores to JSON
    with open('rouge_scores.json', 'w') as f:
        json.dump(rouge_scores, f, indent=4)

    # Calculate and print average scores
    if rouge_scores:
        avg_scores = {
            'avg_rouge1_precision': np.mean([s['rouge1_precision'] for s in rouge_scores]),
            'avg_rouge1_recall': np.mean([s['rouge1_recall'] for s in rouge_scores]),
            'avg_rouge1_f1': np.mean([s['rouge1_f1'] for s in rouge_scores]),
            'avg_rouge2_precision': np.mean([s['rouge2_precision'] for s in rouge_scores]),
            'avg_rouge2_recall': np.mean([s['rouge2_recall'] for s in rouge_scores]),
            'avg_rouge2_f1': np.mean([s['rouge2_f1'] for s in rouge_scores]),
            'avg_rougeL_precision': np.mean([s['rougeL_precision'] for s in rouge_scores]),
            'avg_rougeL_recall': np.mean([s['rougeL_recall'] for s in rouge_scores]),
            'avg_rougeL_f1': np.mean([s['rougeL_f1'] for s in rouge_scores]),
        }

        print("\nAverage ROUGE Scores:")
        for metric, value in avg_scores.items():
            print(f"{metric}: {value:.4f}")

        # Save summary
        with open('rouge_scores_summary.json', 'w') as f:
            json.dump(avg_scores, f, indent=4)

        # Create sorted lists for best/worst performing images
        sorted_by_f1 = sorted(rouge_scores, key=lambda x: x['rougeL_f1'], reverse=True)

        print("\nTop 5 images by ROUGE-L F1 score:")
        for i, score in enumerate(sorted_by_f1[:5]):
            print(f"{i+1}. {score['image_file']} - ROUGE-L F1: {score['rougeL_f1']:.4f}")

        print("\nBottom 5 images by ROUGE-L F1 score:")
        for i, score in enumerate(sorted_by_f1[-5:]):
            print(f"{i+1}. {score['image_file']} - ROUGE-L F1: {score['rougeL_f1']:.4f}")

        return avg_scores
    else:
        print("No ROUGE scores calculated. Check your OCR results and annotation files.")
        return None

# Main execution
if __name__ == "__main__":
    # First, process images with OCR
    process_ocr()

    # Then calculate ROUGE scores
    calculate_rouge()

Found 50 images to process
Loading Florence model...




Processing image 1/50: 0000971160.png
Converting L image to RGB
Processing image 2/50: 0000989556.png
Converting L image to RGB
Processing image 3/50: 0000990274.png
Converting L image to RGB
Processing image 4/50: 0000999294.png
Converting L image to RGB
Processing image 5/50: 0001118259.png
Converting L image to RGB
Processing image 6/50: 0001123541.png
Converting L image to RGB
Processing image 7/50: 0001129658.png
Converting L image to RGB
Processing image 8/50: 0001209043.png
Converting L image to RGB
Processing image 9/50: 0001239897.png
Converting L image to RGB
Processing image 10/50: 0001438955.png
Converting L image to RGB
Processing image 11/50: 0001456787.png
Converting L image to RGB
Processing image 12/50: 0001463282.png
Converting L image to RGB
Processing image 13/50: 0001463448.png
Converting L image to RGB
Processing image 14/50: 0001476912.png
Converting L image to RGB
Processing image 15/50: 0001477983.png
Converting L image to RGB
Processing image 16/50: 0001485288