In [13]:
import io
import json
import re
import os
import sys

import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize

import nltk

from collections import namedtuple, defaultdict
from statistics import mean 
from typing import Dict, List, Set, Tuple

# Custom module imports (assuming they are necessary for your project)
sys.path.append("../../")
from models.bounding_box import FeatureType, Point, BoundingBox, DSU
from models.dish_segmenter import Dish
from models.word_unit import WordUnit
from utils.cv_preprocess import *
from utils.file_utils import *
from utils.nlp_preprocess import *

# IPython specific import for inline display
from IPython.display import display


In [14]:

def process_menu(filein):
    base_name = os.path.basename(filein)
    file_name_without_extension = os.path.splitext(base_name)[0]

    raw_ocr_directory = '../../dataset/menu_photo_ocr_raw/'
    raw_ocr_filename = file_name_without_extension + "_raw_annotation.json"
    raw_ocr_path = os.path.join(raw_ocr_directory, raw_ocr_filename)

    document = load_json(raw_ocr_path)
    image = Image.open(filein)
    
    

    bounds = process_bounds_in_words(document)
    

    price_bounds = extract_price_bounds(bounds)
    filtered_bounds, chinese_bbox, english_bbox = filter_and_classify_bounds(bounds)

    # locator_bounds = {"price_bounds": price_bounds}
    locator_bounds = {"chinese_bbox": chinese_bbox}

    overlap_threshold_list = [0.3, 0.4, 0.5]

    container_width, container_height = image.size
    
    # loop through all possible extend directions
    extend_directions_pair = {"price_bounds":[[ExtendDirection.BOTTOM, ExtendDirection.LEFT], [ExtendDirection.BOTTOM, ExtendDirection.RIGHT]],
                                "chinese_bbox": [[ExtendDirection.BOTTOM, ExtendDirection.LEFT], [ExtendDirection.BOTTOM, ExtendDirection.RIGHT], [ExtendDirection.TOP, ExtendDirection.LEFT], [ExtendDirection.TOP, ExtendDirection.RIGHT]]
                              }
    
    # find the extend direction that have the highest semantic correlation in chinese and english
    max_correlation = 0
    max_correlation_pair = None
    max_locator = None
    max_overlap = 0
    max_avg_correlation = 0

    max_grouped_list, max_grouped_box = None, None  
    for overlap_threshold in overlap_threshold_list:
        for locator_name, locator_bound in locator_bounds.items():

            for extend_directions in extend_directions_pair[locator_name]:
                # print()
                # print("Processing", extend_directions)
                extended_boxes = extend_bounding_boxes(locator_bound, container_width, container_height, extend_directions=extend_directions)
                sorted_bounding_boxes = sorted(filtered_bounds, key=lambda bbox: (bbox.y_min, bbox.x_min))
                grouped_list, grouped_box = group_extended_boxes(extended_boxes, sorted_bounding_boxes, overlap_threshold=overlap_threshold)
                
                total_correlation = 0
                correlation_count = 0
                for string_list in grouped_list:
                    # flatten the list of strings by joining them
                    string_list = [" ".join(string) for string in string_list if string != ""]
                    
                    chinese_text, english_text = split_chinese_english(string_list)
                    chinese_text = "".join(chinese_text)
                    english_text = " ".join(english_text)
                    
                    correlation = calculate_semantic_correlation(chinese_text, english_text)
                    # print(chinese_text, english_text, correlation)
                    total_correlation += correlation
                    correlation_count += 1
                
                # Calculate average correlation if there are valid correlations
                if correlation_count > 0:
                    # avg_correlation = total_correlation / correlation_count
                    if total_correlation > max_avg_correlation:
                        max_avg_correlation = total_correlation
                        max_correlation_pair = extend_directions
                        max_grouped_list = grouped_list
                        max_grouped_box = grouped_box
                        max_locator = locator_bound
                        max_overlap = overlap_threshold
    
    return max_grouped_list, max_grouped_box




In [15]:
class MenuProcessor:
    def __init__(self, dir_path):
        self.dir_path = dir_path
        self.processed_files = []
        self.saving_progress = True
        self.progress_file_path = 'progress.json'
        print("MenuProcessor Initialized.")
        self.setup_filepath()

    def setup_filepath(self):
        print("Setting up file paths...")

        try:
            with open(self.progress_file_path, 'r') as f:
                self.processed_files = json.load(f)
            print("Progress file loaded.")
        except FileNotFoundError:
            print("Progress file not found, starting fresh.")
            self.processed_files = []

    def process_files(self):
        all_files = [f for f in os.listdir(self.dir_path) if os.path.isfile(os.path.join(self.dir_path, f))]
        sorted_files = sort_filenames(all_files)

        for file_name in tqdm(sorted_files, desc='Processing files'):
            if file_name not in self.processed_files:
                self.prepare_file_paths(file_name)
                self.process_menu_segmentation()
                self.save_progress()

        print("All files have been processed.")

    def prepare_file_paths(self, file_name):
        self.file_name = file_name
        self.image_path = os.path.join(self.dir_path, self.file_name)
        file_name_without_extension = os.path.splitext(self.file_name)[0]

        self.raw_ocr_directory = '../../dataset/menu_photo_ocr_raw/'
        raw_ocr_filename = file_name_without_extension + "_raw_annotation.json"
        self.raw_ocr_path = os.path.join(self.raw_ocr_directory, raw_ocr_filename)

        self.preprocessed_ocr_directory = '../../dataset/unverified_menu_text/'
        preprocessed_ocr_filename = file_name_without_extension + "_prep_ocr.json"
        self.preprocessed_ocr_path = os.path.join(self.preprocessed_ocr_directory, preprocessed_ocr_filename)
        # print(f"Processing file: {self.file_name}")

    def process_menu_segmentation(self):
        if not os.path.exists(self.raw_ocr_path):
            print(f"Warning: OCR file {self.raw_ocr_path} does not exist. Skipping this file.")
            self.processed_files.append(self.file_name)  # Add the file to processed_files to skip it in future
            return

        max_grouped_list, max_grouped_box = process_menu(self.image_path)

        # print("Saving segmented menu...")
        
        self.save_segmented_menu(max_grouped_list)

    def save_segmented_menu(self, grouped_list):
        dish_instance_list = []
        if grouped_list is None:
            return
        
        for string_list in grouped_list:
            dish = segment_dish_text_list(string_list)
            dish_instance_list.append(dish)

        results = [obj.to_dict() for obj in dish_instance_list]
        self.processed_files.append(self.file_name)
        save_json(results, self.preprocessed_ocr_path, verbose=False)
        # print(f"Segmented menu saved to {self.preprocessed_ocr_path}")

    def save_progress(self):
        if self.saving_progress:
            # print("Saving progress...")
            with open(self.progress_file_path, 'w') as f:
                json.dump(self.processed_files, f)
            # print("Progress saved.")

In [16]:

dir_path = '../../dataset/menu_photo/segment_by_price'
processor = MenuProcessor(dir_path)
processor.saving_progress = True
    
processor.setup_filepath()
processor.process_files()

MenuProcessor Initialized.
Setting up file paths...
Progress file loaded.
Setting up file paths...
Progress file loaded.


Processing files: 100%|██████████| 114/114 [05:43<00:00,  3.01s/it]

All files have been processed.



