In [169]:
import pytesseract
pytesseract.pytesseract.tesseract_cmd = 'C:\\Program Files\\Tesseract-OCR\\tesseract.exe'
import json
import os
from PIL import Image
import cv2
import Levenshtein
from rank_bm25 import BM25L
import pickle
import torch
from pytorch_beam_search import seq2seq
from post_ocr_correction import correction
from pprint import pprint
import re

In [157]:
def load_dataset_json(json_path):
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data

def perform_ocr(image_path):
    img = cv2.imread(image_path)
    extracted_text = pytesseract.image_to_string(img)
    return extracted_text

In [158]:
def save_ocr_json(image_id, extracted_text, save_dir):
    # Split text by "\n\n" to create passages
    passages = extracted_text.split("\n\n")
    
    # Prepare OCR JSON content with passage IDs
    ocr_json = {
        "image_id": image_id,
        "passages": [{"passage_id": i + 1, "text": passage.strip()} for i, passage in enumerate(passages) if passage.strip()]
    }
    
    # Save OCR JSON file
    ocr_file_name = f"{image_id}_ocr.json"
    ocr_file_path = os.path.join(save_dir, ocr_file_name)
    with open(ocr_file_path, 'w') as ocr_file:
        json.dump(ocr_json, ocr_file, indent=4)
    
    return ocr_file_path

def update_dataset_with_ocr_path(data, entry_index, ocr_file_path):
    data[entry_index]['ocr_json_path'] = ocr_file_path

def save_updated_dataset_json(data, json_path):
    with open(json_path, 'w') as f:
        json.dump(data, f, indent=4)

In [159]:
with open("OCRcorrection_en.arch", "rb") as file:
    architecture = pickle.load(file)
source = list(architecture["in_vocabulary"].keys())
target = list(architecture["out_vocabulary"].values())
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)

# Remove old API keys from architecture
for k in ["in_vocabulary", "out_vocabulary", "model", "parameters"]:
    if k in architecture:
        architecture.pop(k)
model = seq2seq.Transformer(source_index, target_index, **architecture)

# Load the model state dictionary
state_dict = torch.load("OCRcorrection_en.pt", map_location=torch.device("cpu"))
state_dict["source_embeddings.weight"] = state_dict.pop("in_embeddings.weight")
state_dict["target_embeddings.weight"] = state_dict.pop("out_embeddings.weight")
model.eval()
model.load_state_dict(state_dict)

# Function to correct OCR text using beam search and post-ocr correction
def correct_ocr_text(ocr_text, reference_text):
    
    # Prepare the test data (OCR text) and reference text from gt file
    test = ocr_text
    reference = reference_text

    # # Convert test string to tensor
    # new_source = [list(test)]
    # X_new = source_index.text2tensor(new_source)

    # # Perform plain beam search
    # predictions, log_probabilities = seq2seq.beam_search(model, X_new, progress_bar=0)
    # just_beam = target_index.tensor2text(predictions[:, 0, :])[0]
    # just_beam = re.sub(r"<START>|<PAD>|<UNK>|<END>.*", "", just_beam)

    # Disjoint beam search correction
    # disjoint_beam = correction.disjoint(test, model, source_index, target_index, 5, "beam_search")

    # N-grams beam search correction
    votes, n_grams_beam = correction.n_grams(test, model, source_index, target_index, 5, "beam_search", "triangle")

    # Full evaluation (optional)
    # evaluation = correction.full_evaluation([test], [reference], model, source_index, target_index)
    # print(test)
    # print(reference)
    # print(n_grams_beam)
    # print(evaluation)

    return n_grams_beam

Model: Seq2Seq Transformer
Source index: <Seq2Seq Index with 164 items>
Target index: <Seq2Seq Index with 164 items>
Max sequence length: 110
Embedding dimension: 256
Feedforward dimension: 1024
Encoder layers: 2
Decoder layers: 2
Attention heads: 8
Activation: relu
Dropout: 0.5
Trainable parameters: 3,841,700



  state_dict = torch.load("OCRcorrection_en.pt", map_location=torch.device("cpu"))


In [160]:
# Function to calculate the length of a string in bits (based on UTF-8 encoding)
def calculate_bits(text):
    return len(text.encode('utf-8')) * 8  # Length in bits

# Function to split text into smaller chunks while also calculating their length in bits
def split_text_into_chunks(text, max_length):
    # Split the text into words
    words = text.split()
    chunks = []
    chunk = []
    chunk_bits = 0  # Track the size of the current chunk in bits
    
    for word in words:
        # Calculate the number of bits for the current word
        word_bits = calculate_bits(word)  # in bits
        # Check if adding the word will exceed max_length in bits
        if chunk_bits + word_bits + 8 <= max_length:  # +8 accounts for space between words
            chunk.append(word)
            chunk_bits += word_bits + 8  # Add the word's bits and the space between words
        else:
            # If the chunk is too large, save the current chunk and start a new one
            chunks.append(" ".join(chunk))
            chunk = [word]
            chunk_bits = word_bits + 8  # Reset chunk and start with the new word
    
    # Add the last chunk
    if chunk:
        chunks.append(" ".join(chunk))
    
    return chunks

# Function to correct OCR text by processing chunks
def correct_ocr_text_in_chunks(extracted_text, reference_text, max_sequence_length):
    # Split the extracted text and reference text into chunks
    extracted_text_chunks = split_text_into_chunks(extracted_text, max_sequence_length)
    reference_text_chunks = split_text_into_chunks(reference_text, max_sequence_length)

    # Ensure that both extracted text and reference text have the same number of chunks
    # assert len(extracted_text_chunks) == len(reference_text_chunks), "Text and reference chunks do not match in size"

    corrected_chunks = []

    # Process each chunk
    for extracted_chunk, reference_chunk in zip(extracted_text_chunks, reference_text_chunks):
        # Perform OCR correction for each chunk
        corrected_chunk = correct_ocr_text(extracted_chunk, reference_chunk)
        corrected_chunks.append(corrected_chunk)
    
    # Combine the corrected chunks into a single text
    corrected_text = " ".join(corrected_chunks)
    return corrected_text

In [161]:
# Load the original dataset JSON
dataset_path = os.path.join('dataset', 'dataset.json')
data = load_dataset_json(dataset_path)

# Directory to save OCR text files
ocr_texts_dir = os.path.join('dataset', 'ocr_texts')
os.makedirs(ocr_texts_dir, exist_ok=True)

# Process each image in the dataset
for idx, entry in enumerate(data):
    image_id = entry['image_id']
    image_path = entry['image_path']
    gt_path = entry['gt_path']

    # Perform OCR on the image
    extracted_text = perform_ocr(image_path)

    # Save OCR JSON
    ocr_json_path = save_ocr_json(image_id, extracted_text, ocr_texts_dir)

    # Update the main dataset JSON entry with the OCR path
    update_dataset_with_ocr_path(data, idx, ocr_json_path)

    print(f"Processed and saved OCR data for image {image_id}")
    # break

# Save the updated dataset JSON with new OCR correction paths
with open(dataset_path, 'w') as f:
    json.dump(data, f, indent=4)

print("All images processed with OCR and final dataset updated.")

Processed and saved OCR data for image 3200797029
Processed and saved OCR data for image 3200797032
Processed and saved OCR data for image 3200797034
Processed and saved OCR data for image 3200797037
Processed and saved OCR data for image 3200801612
Processed and saved OCR data for image 3200801613
Processed and saved OCR data for image 3200801615
Processed and saved OCR data for image 3200801619
Processed and saved OCR data for image 3200801622
Processed and saved OCR data for image 3200801629
Processed and saved OCR data for image 3200801630
Processed and saved OCR data for image 3200801632
Processed and saved OCR data for image 3200801633
Processed and saved OCR data for image 3200801634
Processed and saved OCR data for image 3200803382
Processed and saved OCR data for image 3200803389
Processed and saved OCR data for image 3200803401
Processed and saved OCR data for image 3200803403
Processed and saved OCR data for image 3200807879
Processed and saved OCR data for image 3200807881


In [162]:
# Directory to save OCR corrected files
ocr_corrected_dir = os.path.join('dataset', 'ocr_corrected')
os.makedirs(ocr_corrected_dir, exist_ok=True)

# Load the original dataset JSON
dataset_path = os.path.join('dataset', 'dataset.json')
data = load_dataset_json(dataset_path)

# Iterate through the OCR files and perform corrections
for idx, entry in enumerate(data):
    ocr_file_path = entry['ocr_json_path']

    # Load the ground truth reference text with explicit encoding
    with open(gt_path, 'r', encoding='utf-8') as gt_file:
        reference_text = gt_file.read().strip()
    
    # Create a list to store corrected passages
    corrected_passages = []

    # Load OCR JSON file
    with open(ocr_file_path, 'r', encoding='utf-8') as f:
        ocr_data = json.load(f)
    
    # Extract image info (id and path from the OCR JSON file or main dataset)
    image_id = ocr_data.get('image_id')
    
    # Extract passages from the OCR JSON
    for passage in ocr_data['passages']:
        reference_text_passage = reference_text[:len(passage['text'])]
        reference_text = reference_text[len(passage['text']):]
        
        # Perform OCR correction using the model
        correction_results = correct_ocr_text_in_chunks(passage['text'], reference_text, max_sequence_length=512*2)
        
        # Assuming the model returns a corrected text or other useful result
        corrected_passages.append(correction_results)
    
    # After processing all passages, save the corrected OCR data
    corrected_ocr_data = {
        "image_id": image_id,
        "passages": [{"passage_id": i + 1, "text": corrected_passage} for i, corrected_passage in enumerate(corrected_passages)]
    }

    corrected_file_path = os.path.join(ocr_corrected_dir, f'{image_id}_corrected.json')
    with open(corrected_file_path, 'w') as corrected_file:
        json.dump(corrected_ocr_data, corrected_file, indent=4)
    
    # Update the main dataset JSON entry with the OCR path
    update_dataset_with_ocr_path(data, idx, corrected_file_path)

    print(f"Processed and saved OCR data for image {image_id}")
    # break

# Save the updated dataset JSON with new OCR correction paths
with open(dataset_path, 'w') as f:
    json.dump(data, f, indent=4)

print("All images processed with OCR and final dataset updated.")


Processed and saved OCR data for image 3200797029
Processed and saved OCR data for image 3200797032
Processed and saved OCR data for image 3200797034
Processed and saved OCR data for image 3200797037
Processed and saved OCR data for image 3200801612
Processed and saved OCR data for image 3200801613
Processed and saved OCR data for image 3200801615
Processed and saved OCR data for image 3200801619
Processed and saved OCR data for image 3200801622
Processed and saved OCR data for image 3200801629
Processed and saved OCR data for image 3200801630
Processed and saved OCR data for image 3200801632
Processed and saved OCR data for image 3200801633
Processed and saved OCR data for image 3200801634
Processed and saved OCR data for image 3200803382
Processed and saved OCR data for image 3200803389
Processed and saved OCR data for image 3200803401
Processed and saved OCR data for image 3200803403
Processed and saved OCR data for image 3200807879
Processed and saved OCR data for image 3200807881


In [173]:


# Paths to the directories
ocr_texts_dir = os.path.join('dataset', 'ocr_texts')
ocr_corrected_dir = os.path.join('dataset', 'ocr_corrected')
gt_dir = os.path.join('dataset', 'gt')

def calculate_cer(original_text, corrected_text):
    """
    Calculates the Character Error Rate (CER) between two texts.
    
    Args:
    - original_text (str): The original OCR text.
    - corrected_text (str): The reference or corrected text.

    Returns:
    - float: The CER value as a percentage.
    """
    distance = Levenshtein.distance(original_text, corrected_text)
    cer = distance / len(corrected_text) if corrected_text else float('inf')
    return cer * 100 

def calculate_average_cer(ocr_dir, gt_dir):
    """
    Calculates the average CER across all OCR files compared to their ground truths.

    Args:
    - ocr_dir (str): Directory path to the OCR JSON files.
    - gt_dir (str): Directory path to the ground truth text files.

    Returns:
    - float: The average CER across all files.
    """
    total_cer = 0
    file_count = 0

    for ocr_file in os.listdir(ocr_dir):
        # Construct file paths
        ocr_path = os.path.join(ocr_dir, ocr_file)
        if ocr_dir == ocr_corrected_dir: file_count += 9
        gt_path = os.path.join(gt_dir, f"{ocr_file.split('_')[0]}.txt")
        
        # Load OCR and reference data
        if not os.path.exists(gt_path):
            print(f"Ground truth file for {ocr_file} not found.")
            continue
        
        with open(ocr_path, 'r', encoding='utf-8') as f_ocr, open(gt_path, 'r', encoding='utf-8') as f_gt:
            ocr_data = json.load(f_ocr)
            reference_text = f_gt.read().strip()
        
        # Concatenate all passages into a single text
        combined_ocr_text = " ".join(passage.get('text', "") for passage in ocr_data.get('passages', []))

        # Calculate CER for the combined text
        cer = calculate_cer(combined_ocr_text, reference_text)
        total_cer += cer
        file_count += 1
        # print(f"File: {ocr_file}, CER: {cer:.2f}%")

    # Average CER across all files
    average_cer = total_cer / file_count if file_count else 0
    # print(f"Average CER across all files: {average_cer:.2f}%")
    return average_cer

# Run the CER calculation
cer_text = calculate_average_cer(ocr_texts_dir, gt_dir)
cer_corr = calculate_average_cer(ocr_corrected_dir, gt_dir)
print(f"CER before error correction: {cer_text:.2f}%")
print(f"CER after error correction: {cer_corr:.2f}%")

CER before error correction: 8.54%
CER after error correction: 6.63%


In [163]:
# Path to the directory containing OCR JSON files
ocr_corr_dir = os.path.join('dataset', 'ocr_corrected')
ocr_img_dir = os.path.join('dataset', 'img')

# List to store tokenized passages along with image information
tokenized_extracted_text_list = []
image_info_list = []  # List to store image info (id, file path)

# Iterate through all OCR JSON files in the directory
for ocr_file in os.listdir(ocr_corr_dir):
    ocr_file_path = os.path.join(ocr_corr_dir, ocr_file)
    
    # Load OCR JSON file
    with open(ocr_file_path, 'r', encoding='utf-8') as f:
        ocr_data = json.load(f)
    
    # Extract image info (id and path from the OCR JSON file or main dataset)
    image_id = ocr_data.get('image_id')
    # You can modify the image path extraction logic based on your actual dataset
    image_path = os.path.join(ocr_img_dir, f"{image_id}.jpg")  # Assuming the image file name is based on the image_id

    # Extract passages from the OCR JSON
    for passage in ocr_data['passages']:
        # Tokenize passage text (remove punctuation and non-word characters)
        tokenized_passage = re.findall(r'\b\w+\b', passage['text'])
        if(len(tokenized_passage) < 10):
            continue
        tokenized_extracted_text_list.append(tokenized_passage)
        passage_id = passage['passage_id']
        image_info_list.append((image_id, passage_id ,image_path))  # Store image id and path for each passage

bm25 = BM25L(tokenized_extracted_text_list)
# Print tokenized list
# print(tokenized_extracted_text_list)

In [164]:
query = "A PROMINENT speaker at a free trader’s meeting at Madrid"
tokenized_query = query.split(" ")

scores = bm25.get_scores(tokenized_query)

top_n = 3
top_passages_idx = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]

print(f"Top {top_n} Passages (based on BM25 scores):")
for idx in top_passages_idx:
    passage = ' '.join(tokenized_extracted_text_list[idx])  # Join tokens back into a string
    image_id, passage_id ,image_path = image_info_list[idx]  # Get the image id and path for the passage
    print(f"Passage: {passage}")
    print(f'Passage Score: {scores[idx]}')
    print(f"Source Image ID: {image_id}")
    print(f"Image Path: {image_path}\n")

Top 3 Passages (based on BM25 scores):
Passage: A PROMINENT speaker at a free traders meeting at Madrid on Monday stated that there are in Spain more than a million anda half male adults occupied in di ﬀerent industries and liberal profesffons and theso with the exception of some 59 of Catalans are eager for free trade while of these Catalans 4 of are from Barcelona
Passage Score: 41.08977378126222
Source Image ID: 3200801632
Image Path: dataset\img\3200801632.jpg

Passage: a comedian and vocalist and was also an agent and the action was brought agains him to recover the sum of lls Si beivs commis on due npon engagements precured for him by the plainti ﬀ The first item was comwmissicu at ﬁve per cent up 2 Myr Macdermoit s engsgenw nt at tho Londen Pavilion at per week on whica engagement there were twenty three weol commis on due up to July 20 h Then ihere was c mmis on at Sve per cent upow Mr Macdermott s engagements at the Cambridge
Passage Score: 19.64269643280513
Source Image ID: 3