In [None]:
entity_unit_map = {
    'width': [
        ['centimetre', 'centimetres', 'cm', 'cms'],
        ['foot', 'feet', 'ft'],
        ['inch', 'inches', 'in', '"'],
        ['metre', 'meter', 'meters', 'm'],
        ['millimetre', 'millimeter', 'millimeters', 'mm'],
        ['yard', 'yards', 'yd', 'yds']
    ],
    'depth': [
        ['centimetre', 'centimetres', 'cm', 'cms'],
        ['foot', 'feet', 'ft'],
        ['inch', 'inches', 'in'],
        ['metre', 'meter', 'meters', 'm'],
        ['millimetre', 'millimeter', 'millimeters', 'mm'],
        ['yard', 'yards', 'yd', 'yds']
    ],
    'height': [
        ['centimetre', 'centimetres', 'cm', 'cms'],
        ['foot', 'feet', 'ft'],
        ['inch', 'inches', 'in'],
        ['metre', 'meter', 'meters', 'm'],
        ['millimetre', 'millimeter', 'millimeters', 'mm'],
        ['yard', 'yards', 'yd', 'yds']
    ],
    'item_weight': [
        ['gram', 'grams', 'g'],
        ['kilogram', 'kilograms', 'kg', 'kgs'],
        ['microgram', 'micrograms', 'μg', 'mcg'],
        ['milligram', 'milligrams', 'mg'],
        ['ounce', 'ounces', 'oz'],
        ['pound', 'pounds', 'lb', 'lbs'],
        ['ton', 'tons', 't']
    ],
    'maximum_weight_recommendation': [
        ['gram', 'grams', 'g'],
        ['kilogram', 'kilograms', 'kg', 'kgs'],
        ['microgram', 'micrograms', 'μg', 'mcg'],
        ['milligram', 'milligrams', 'mg'],
        ['ounce', 'ounces', 'oz'],
        ['pound', 'pounds', 'lb', 'lbs'],
        ['ton', 'tons', 't']
    ],
    'voltage': [
        ['kilovolt', 'kilovolts', 'kv'],
        ['millivolt', 'millivolts', 'mv'],
        ['volt', 'volts', 'v']
    ],
    'wattage': [
        ['kilowatt', 'kilowatts', 'kw'],
        ['watt', 'watts', 'w']
    ],
    'item_volume': [
        ['centilitre', 'centiliter', 'centiliters', 'cl'],
        ['cubic foot', 'cubic feet', 'cu ft', 'ft³'],
        ['cubic inch', 'cubic inches', 'cu in', 'in³'],
        ['cup', 'cups'],
        ['decilitre', 'deciliter', 'deciliters', 'dl'],
        ['fluid ounce', 'fluid ounces', 'fl oz'],
        ['gallon', 'gallons', 'gal'],
        ['imperial gallon', 'imperial gallons', 'imp gal'],
        ['litre', 'liter', 'liters', 'litres', 'l'],
        ['microlitre', 'microliter', 'microliters', 'microlitres', 'μl', 'mcl'],
        ['millilitre', 'milliliter', 'milliliters', 'millilitres', 'ml'],
        ['pint', 'pints', 'pt'],
        ['quart', 'quarts', 'qt']
    ]
}


In [None]:
import os
import pandas as pd
import re
from urllib.parse import urlparse
from paddleocr import PaddleOCR

# Initialize the PaddleOCR model with the angle classifier
ocr = PaddleOCR(use_angle_cls=True)

# Define relative paths
BASE_PATH = os.getcwd()
IMAGES_FOLDER = os.path.join(BASE_PATH, 'test_images')
DATASET_FOLDER = os.path.join(BASE_PATH, 'dataset')

def run_OCR(image_name):
    """
    Perform OCR using the pre-trained text detection model.
    
    Parameters:
    - image_name: The name of the image file to be processed.
    
    Returns:
    - result: OCR results for the image.
    """
    image_path = os.path.join(IMAGES_FOLDER, image_name)
    result = ocr.ocr(image_path)
    return result

def name_parser(image_link):
    """
    Parse the URL to get the image name from the path.
    
    Parameters:
    - image_link: The URL of the image.
    
    Returns:
    - image_name: The name of the image extracted from the URL.
    """
    parsed_url = urlparse(image_link)
    path = parsed_url.path
    image_name = path.split('/')[-1]
    return image_name

def split_alphanumeric_with_special_chars(string):
    """
    Split the alphanumeric string into numbers with units and standalone numbers.
    
    Parameters:
    - string: The string containing numbers and units.
    
    Returns:
    - numbers_with_units: List of tuples containing numbers and their units.
    - standalone_numbers: List of standalone numbers.
    """
    matches = re.findall(r'(\d+\.?\,?\d*)([a-zA-Z]+)?', string)
    numbers_with_units = []
    standalone_numbers = []
    
    for match in matches:
        number, unit = match
        number = number.replace(",", ".")
        if unit:
            numbers_with_units.append((float(number), unit))
        else:
            standalone_numbers.append(float(number))
    
    return numbers_with_units, standalone_numbers

def map_to_standard_unit(unit, entity_name):
    """
    Map a given unit to the standard (0th index) unit in the entity_unit_map.
    
    Parameters:
    - unit: The unit to be mapped.
    - entity_name: The name of the entity for which the unit is being mapped.
    
    Returns:
    - The standard unit if found, otherwise None.
    """
    for unit_list in entity_unit_map[entity_name]:
        if unit.lower() in [u.lower() for u in unit_list]:
            return unit_list[0]  # Return the first entry as the standard unit
    return None

count = 1
def predictor(image_link, category_id, entity_name):
    """
    Generate predictions for an image by extracting numbers and units using OCR.
    
    Parameters:
    - image_link: The URL of the image to be processed.
    - category_id: The category ID of the image (not used in the current implementation).
    - entity_name: The name of the entity for unit mapping.
    
    Returns:
    - The predicted value in the format "{number} {unit}" or the first standalone number if no unit is found.
    """
    global index, count
    print()
    print(f"Count: {count}, Index: {index}")
    print(image_link)
    count += 1

    OCR_result = run_OCR(name_parser(image_link))        
    result_modified = []
    contains_number = False
    
    # Splitting number + entity into two different results
    for line in OCR_result:
        if not line:
            return ""
        for word_info in line:
            numbers_with_units, standalone_numbers = split_alphanumeric_with_special_chars(word_info[1][0])
            
            # Add the number with unit to the result_modified
            for number, unit in numbers_with_units:
                standard_unit = map_to_standard_unit(unit, entity_name)
                if standard_unit:  # Only add if unit is found in entity_unit_map
                    result_modified.append([number, standard_unit, word_info[0]])
                    contains_number = True
            
            # Add standalone numbers (without a unit)
            for number in standalone_numbers:
                result_modified.append([number, None, word_info[0]])
                contains_number = True
            
            # Check for standalone units and map them to standard units
            alphabets = re.findall(r'[a-zA-Z]+', word_info[1][0])
            for word in alphabets:
                for unit_list in entity_unit_map[entity_name]:
                    if word.lower() in unit_list:
                        result_modified.append([unit_list[0], word_info[0]])  # Use the first entry in the unit list
                        break
    
    # If there are no numbers or units, return an empty string
    if result_modified == [] or not contains_number:
        return ""
    
    # Filter results based on the entity type
    filtered_results = []
    for item in result_modified:
        if len(item) == 3:  # number with unit
            number, unit, coordinates = item
            if unit:
                filtered_results.append((number, unit, coordinates))
        elif len(item) == 2 and isinstance(item[0], str):  # standalone unit
            filtered_results.append(item)
    
    if not filtered_results:
        return ""
    
    # If we have a number with the correct unit, return it in the format "{number} {unit}"
    for item in filtered_results:
        if len(item) == 3:
            number, unit, _ = item
            return f"{number} {unit}"
    
    # If no number with unit, check the closest number to a standalone unit
    min_distance = float('inf')
    closest_number = None
    closest_unit = None
    
    for i in range(len(filtered_results)):
        if len(filtered_results[i]) == 3:  # number with coordinates
            number, _, coordinates = filtered_results[i]
            xi_center = sum(coord[0] for coord in coordinates) / 4
            yi_center = sum(coord[1] for coord in coordinates) / 4
            
            for j in range(len(filtered_results)):
                if len(filtered_results[j]) == 2 and isinstance(filtered_results[j][0], str):  # standalone unit
                    unit, unit_coordinates = filtered_results[j]
                    xj_center = sum(coord[0] for coord in unit_coordinates) / 4
                    yj_center = sum(coord[1] for coord in unit_coordinates) / 4
                    
                    distance = ((xi_center - xj_center) * 2 + (yi_center - yj_center) * 2) ** 0.5
                    if distance < min_distance:
                        min_distance = distance
                        closest_number = number
                        closest_unit = unit
    
    if closest_number is not None and closest_unit is not None:
        return f"{closest_number} {closest_unit}"
    
    # If we still haven't found a match, return the first number with space
    return f"{filtered_results[0][0]}" if len(filtered_results[0]) == 3 else ""

if __name__ == "__main__":
    # Paths to input and output files
    input_file = os.path.join(DATASET_FOLDER, 'test.csv')  # Change to the full CSV file
    output_file = os.path.join(DATASET_FOLDER, 'test_out.csv')  # Output file
    
    # Read the test data
    test = pd.read_csv(input_file)
    
    # Initialize the starting index from the "index" column
    start_index = 0
    
    # If the output file exists, get the last processed index from the "index" column
    if os.path.exists(output_file):
        try:
            last_line = pd.read_csv(output_file).tail(1)
            start_index = last_line['index'].values[0] + 1
        except Exception as e:
            print(f"Error reading the output file: {e}")
            start_index = 0
    
    # Allow custom starting points, overriding the automatic start index if needed
    custom_start_index = 0  # Process from the beginning
    start_index = max(start_index, custom_start_index)
    
    # Empty list to store temporary results
    results_list = []
    
    # Iterate through the rows of the test data starting from start_index in the "index" column
    for _, row in test[test['index'] >= start_index].iterrows():
        index = row['index']
        
        # Check if we need to stop after processing 200 rows
        # Comment out this condition to process all rows
        # if (index - start_index) % 500 == 0 and index > start_index:
        #     print(f"Stopping at index {index}.")
        #     break
        
        prediction = predictor(row['image_link'], row['group_id'], row['entity_name'])
        
        # Append the result to the list
        results_list.append({'index': index, 'prediction': prediction})
        
        # Every 500 iterations, append to CSV and clear list
        if (index + 1) % 500 == 0:
            results_df = pd.DataFrame(results_list)
            results_df.to_csv(output_file, mode='a', header=not os.path.exists(output_file), index=False)
            results_list.clear()
    
    # After the loop, write any remaining rows in the list
    if results_list:
        results_df = pd.DataFrame(results_list)
        results_df.to_csv(output_file, mode='a', header=not os.path.exists(output_file), index=False)
    


In [None]:
#print done when code completes
print("Done")