# Cleaning with LLMs 

https://huggingface.co/unsloth/Phi-3-medium-4k-instruct?utm_source=chatgpt.com

In [None]:
# Load the model and necessary libraries
import torch
import pandas as pd
import numpy as np
import time
import re
from transformers import pipeline

# Initialize the text generation pipeline with Phi-3
pipe = pipeline("text-generation", model="unsloth/Phi-3-medium-4k-instruct", torch_dtype=torch.bfloat16, device_map="auto")

  from .autonotebook import tqdm as notebook_tqdm
Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

## OCR Text Cleaning Pipeline

We'll create a specialized OCR correction pipeline using the Phi-3 language model with a carefully crafted prompt designed to:

1. Fix OCR artifacts and errors while preserving the original text's style and meaning
2. Avoid hallucinations or unnecessary rewrites
3. Flag ambiguous text rather than guessing
4. Preserve original formatting, punctuation, and capitalization
5. Process text in a pandas DataFrame efficiently

In [None]:
# Define our specialized OCR correction prompt
def create_ocr_correction_prompt(text):
    prompt = """<|system|>
You are a meticulous text restorer. Your job is to correct OCR artifacts in legal or historical documents.

RULES:
1. Do NOT rewrite or modernize wording — preserve original grammar, style, and vocabulary, even if unusual.
2. Only fix:
   - Obvious misspellings caused by OCR (e.g., "oi" → "of", "xkxi" → remove or replace with intended text if 100% clear from context).
   - Broken words due to incorrect character recognition.
   - Spacing issues.
3. If a word is unclear or ambiguous, do NOT guess. Keep it as-is but flag it in [square brackets].
4. Preserve punctuation, capitalization, and formatting exactly as in the original unless it is clearly an OCR error.
5. Never add or remove whole sentences or change meaning.
6. Output only the corrected text — no explanations.
<|user|>
TEXT TO CORRECT:
"""
{}
"""
<|assistant|>
""".format(text)
    
    return prompt

In [None]:
# Function to correct OCR text using the LLM
def correct_ocr_text(text, max_length=4000):
    """
    Clean OCR text using our specialized LLM prompt
    
    Parameters:
    -----------
    text : str
        The OCR text to clean
    max_length : int
        Maximum length for text generation
        
    Returns:
    --------
    str: Corrected text
    """
    if not isinstance(text, str) or not text.strip():
        return text
        
    # Prepare prompt with the text
    prompt = create_ocr_correction_prompt(text)
    
    # Generate corrected text
    result = pipe(
        prompt,
        max_length=max_length,
        do_sample=False,  # Deterministic generation for OCR correction
        temperature=0.1,  # Very low temperature for minimal creativity
        return_full_text=False
    )[0]["generated_text"]
    
    # Clean up the output - remove any leading/trailing quotes and whitespace
    cleaned_result = result.strip().strip('"""').strip()
    
    return cleaned_result

In [None]:
# Create a DataFrame processor function that applies OCR correction to specified columns
def process_dataframe_ocr(df, text_columns, batch_size=10, show_progress=True):
    """
    Process a DataFrame by correcting OCR errors in specified text columns
    
    Parameters:
    -----------
    df : pandas DataFrame
        DataFrame containing the text to process
    text_columns : list or str
        Column name(s) containing OCR text to correct
    batch_size : int
        Number of rows to process in each batch for progress tracking
    show_progress : bool
        Whether to show progress information
        
    Returns:
    --------
    pandas.DataFrame: DataFrame with corrected text columns
    """
    # Make a copy to avoid modifying the original
    result_df = df.copy()
    
    # Convert text_columns to list if it's a string
    if isinstance(text_columns, str):
        text_columns = [text_columns]
    
    # Total number of rows to process
    total_rows = len(df)
    
    # Process data in batches
    for col in text_columns:
        if show_progress:
            print(f"Processing column: {col}")
        
        # Ensure the column exists
        if col not in df.columns:
            print(f"Warning: Column '{col}' not found in DataFrame. Skipping.")
            continue
        
        # Create a new column for the cleaned text
        cleaned_col = f"{col}_cleaned"
        
        # Process in batches
        for i in range(0, total_rows, batch_size):
            batch_end = min(i + batch_size, total_rows)
            
            if show_progress:
                print(f"  Processing rows {i+1} to {batch_end} of {total_rows}...")
            
            # Process each row in the current batch
            start_time = time.time()
            
            for j in range(i, batch_end):
                # Get the text for this row
                text = df.iloc[j][col]
                
                # Only process if it's a valid string
                if isinstance(text, str) and text.strip():
                    # Apply OCR correction
                    result_df.loc[df.index[j], cleaned_col] = correct_ocr_text(text)
                else:
                    result_df.loc[df.index[j], cleaned_col] = text
            
            elapsed = time.time() - start_time
            if show_progress:
                print(f"  Batch completed in {elapsed:.2f} seconds")
    
    return result_df

In [None]:
# Function to analyze differences between original and corrected text
def analyze_corrections(original, corrected):
    """
    Analyze the differences between original and corrected text
    
    Parameters:
    -----------
    original : str
        Original OCR text
    corrected : str
        Corrected text from the LLM
        
    Returns:
    --------
    dict: Dictionary with metrics about the corrections
    """
    if not isinstance(original, str) or not isinstance(corrected, str):
        return {
            'original_words': 0,
            'corrected_words': 0,
            'different_words': 0,
            'percent_changed': 0,
            'flagged_words': 0
        }
    
    # Tokenize into words
    original_words = re.findall(r'\b\w+\b', original.lower())
    corrected_words = re.findall(r'\b\w+\b', corrected.lower())
    
    # Count words in square brackets (flagged as ambiguous)
    flagged_count = len(re.findall(r'\[.*?\]', corrected))
    
    # Count different words
    orig_set = set(original_words)
    corr_set = set(corrected_words)
    
    # Words in corrected that aren't in original
    different = len(corr_set - orig_set)
    
    # Calculate percentage changed
    if len(orig_set) > 0:
        pct_changed = (different / len(orig_set)) * 100
    else:
        pct_changed = 0
        
    return {
        'original_words': len(orig_set),
        'corrected_words': len(corr_set),
        'different_words': different,
        'percent_changed': pct_changed,
        'flagged_words': flagged_count
    }

In [None]:
# Function to generate correction metrics for the entire DataFrame
def generate_correction_metrics(df, original_cols, corrected_cols):
    """
    Generate metrics about OCR corrections across a DataFrame
    
    Parameters:
    -----------
    df : pandas DataFrame
        DataFrame containing original and corrected text
    original_cols : list or str
        Column name(s) containing original OCR text
    corrected_cols : list or str
        Column name(s) containing corrected OCR text
        
    Returns:
    --------
    pandas.DataFrame: DataFrame with correction metrics
    """
    # Convert to lists if strings
    if isinstance(original_cols, str):
        original_cols = [original_cols]
    if isinstance(corrected_cols, str):
        corrected_cols = [corrected_cols]
        
    # Check lengths match
    if len(original_cols) != len(corrected_cols):
        raise ValueError("Number of original columns must match number of corrected columns")
    
    # Create a DataFrame to store metrics
    metrics_df = pd.DataFrame()
    
    # Process each column pair
    for orig_col, corr_col in zip(original_cols, corrected_cols):
        # Check columns exist
        if orig_col not in df.columns or corr_col not in df.columns:
            print(f"Warning: Columns '{orig_col}' and/or '{corr_col}' not found. Skipping.")
            continue
            
        # Apply analysis to each row
        metrics = []
        for _, row in df.iterrows():
            metric = analyze_corrections(row[orig_col], row[corr_col])
            metrics.append(metric)
            
        # Convert to DataFrame
        col_metrics = pd.DataFrame(metrics)
        
        # Add column identifiers
        col_metrics['original_column'] = orig_col
        col_metrics['corrected_column'] = corr_col
        
        # Append to main metrics DataFrame
        metrics_df = pd.concat([metrics_df, col_metrics], ignore_index=True)
    
    return metrics_df

## Example Usage

Here's an example of how to use the OCR correction pipeline on a sample dataset:

In [None]:
# Example: Create a sample DataFrame with OCR errors
sample_data = {
    'id': [1, 2, 3],
    'title': ['Document A', 'Document B', 'Document C'],
    'transcript': [
        "Tke Board oi Directors met on Aprii 15, 2021 to discusa tke annual oudget. Mr. Smitk presented tke financial reports whick showed an increased prolit of 15% compared to last yéar.",
        "Purshant to Section 8.2 oi the agreement, the partles hereby agtee to exfend the Term for an additlonal period oi five (5) gears, commencing on January 1, 2022.",
        "Withm 30 deys of recerpt, all invoiccs must be processéd by the Acconnts department and forwarded t0 the Finance Director f0r approval."
    ]
}

# Create DataFrame
sample_df = pd.DataFrame(sample_data)

# Display the original DataFrame
print("Original Sample Data:")
sample_df

In [None]:
# Apply OCR correction to the 'transcript' column
# Note: This will take some time depending on your hardware and model size
# Using a small batch size for this example
corrected_df = process_dataframe_ocr(sample_df, 'transcript', batch_size=1)

# Display the corrected DataFrame
print("\nCorrected Sample Data:")
corrected_df

In [None]:
# Analyze corrections made by the model
corrections_analysis = analyze_corrections(sample_df['transcript'], corrected_df['transcript'])

# Print the analysis results
print("Correction Analysis:")
print(f"Total corrections made: {corrections_analysis['total_corrections']}")
print(f"Average corrections per text: {corrections_analysis['avg_corrections_per_text']:.2f}")
print(f"Character change rate: {corrections_analysis['char_change_rate']:.2%}")
print(f"Longest correction: {corrections_analysis['longest_correction']}")
print("\nCommon corrections:")
for correction, count in corrections_analysis['common_corrections'].items():
    print(f"'{correction[0]}' → '{correction[1]}': {count} times")

In [None]:
# Create a side-by-side comparison DataFrame for visualization
comparison_df = pd.DataFrame({
    'Original Text': sample_df['Transcript'],
    'Corrected Text': corrected_df['Transcript']
})

# Function to highlight differences between original and corrected text
def highlight_differences(s):
    original = s['Original Text']
    corrected = s['Corrected Text']
    
    # Basic diff visualization (for more complex highlighting, consider libraries like difflib)
    if original != corrected:
        return ['background-color: #ffcccc', 'background-color: #ccffcc']
    else:
        return ['', '']

# Display the comparison with highlighting
print("\nSide-by-Side Comparison (Red: Original, Green: Corrected):")
comparison_df.style.apply(highlight_differences, axis=1)

In [None]:
# Uncomment and modify the path as needed
output_path = '../data/mccray/changed_data/sample_corrected.xlsx'

df.to_excel(output_path, index=False)

In [None]:
# Function to load real data from Excel files
def load_mccray_data(file_path):
    """
    Load data from McCray Excel files.
    
    Parameters:
    -----------
    file_path : str
        Path to the Excel file
    
    Returns:
    --------
    pandas DataFrame
        Loaded data
    """
    try:
        df = pd.read_excel(file_path)
        print(f"Loaded {len(df)} rows from {file_path}")
        return df
    except Exception as e:
        print(f"Error loading data: {e}")
        return None

# Example: Process a small subset of real data
# Uncomment and modify as needed to process real data

"""
# Path to a decade subset or other McCray data file
real_data_path = '../../data/mccray/changed_data/decade_subsets/McCray (1940s, 100 rows).xlsx'

# Load the data
real_df = load_mccray_data(real_data_path)

if real_df is not None and 'transcript' in real_df.columns:
    # Check how many rows have transcripts
    has_transcript = real_df['transcript'].notna()
    print(f"Rows with transcripts: {has_transcript.sum()} of {len(real_df)}")
    
    # Process only rows that have transcripts (limit to first 10 for testing)
    df_to_process = real_df[has_transcript].head(10).copy()
    
    if len(df_to_process) > 0:
        # Process the data with the LLM
        print("Processing real data with the OCR correction pipeline...")
        corrected_real_df = process_dataframe_ocr(df_to_process, 'transcript', batch_size=2)
        
        # Analyze the corrections
        real_corrections = analyze_corrections(df_to_process['transcript'], corrected_real_df['transcript'])
        
        print("\nReal Data Correction Analysis:")
        print(f"Total corrections made: {real_corrections['total_corrections']}")
        print(f"Average corrections per text: {real_corrections['avg_corrections_per_text']:.2f}")
        
        # Save the corrected data
        output_path = '../../data/mccray/changed_data/llm_corrected/mccray_1940s_sample_corrected.xlsx'
        save_corrected_data(corrected_real_df, output_path)
"""

In [None]:
# Performance optimization and large dataset processing
"""
# Tips for optimizing the OCR correction pipeline:

1. Process Large Datasets in Chunks:

# Function to process a large dataset in batches
def process_large_dataset(input_path, output_path, text_column='transcript', chunk_size=100, batch_size=5):
    """
    Process a large dataset in chunks to avoid memory issues.
    
    Parameters:
    -----------
    input_path : str
        Path to the input Excel file
    output_path : str
        Path to save the output Excel file
    text_column : str
        Column name containing the text to correct
    chunk_size : int
        Number of rows to process in each chunk
    batch_size : int
        Batch size for the LLM processing within each chunk
    """
    # Load the dataset in chunks
    reader = pd.read_excel(input_path, chunksize=chunk_size)
    
    # Process each chunk
    all_processed_chunks = []
    chunk_counter = 0
    
    for chunk in reader:
        chunk_counter += 1
        print(f"Processing chunk {chunk_counter} ({len(chunk)} rows)")
        
        # Filter rows with text
        has_text = chunk[text_column].notna()
        to_process = chunk[has_text].copy()
        no_text = chunk[~has_text].copy()
        
        if len(to_process) > 0:
            # Process the chunk
            processed_chunk = process_dataframe_ocr(to_process, text_column, batch_size=batch_size)
            # Combine with rows that didn't have text
            result_chunk = pd.concat([processed_chunk, no_text])
        else:
            result_chunk = chunk
            
        all_processed_chunks.append(result_chunk)
        print(f"Completed chunk {chunk_counter}")
    
    # Combine all processed chunks
    final_df = pd.concat(all_processed_chunks)
    
    # Save the final result
    save_corrected_data(final_df, output_path)
    print(f"Complete dataset processed and saved to {output_path}")

2. Use GPU acceleration:
   - If available, ensure the model is loaded on GPU using device='cuda'
   - For multi-GPU setups, consider using device_map='auto'

3. Optimize memory usage:
   - Set torch_dtype=torch.float16 for half precision
   - Use 8-bit quantization: load_in_8bit=True
   - Consider using Unsloth for optimized models: https://github.com/unslothai/unsloth

4. For extremely large datasets:
   - Consider saving intermediate results after each chunk
   - Implement checkpointing to resume processing if interrupted
   - Use a distributed processing approach with multiple machines

5. For batch processing multiple files:
   - Create a list of files and process them sequentially
   - Implement parallel processing using multiprocessing or joblib

# Example usage for large dataset processing:
# process_large_dataset(
#     input_path='../../data/mccray/changed_data/McCray+.xlsx',
#     output_path='../../data/mccray/changed_data/llm_corrected/McCray+_corrected.xlsx',
#     chunk_size=100,
#     batch_size=5
# )
"""