# Algorithm Alchemists

### Basic library imports

In [None]:
import pandas as pd
import re
import os
from tqdm import tqdm
import constants
import spacy
from typing import Dict, Set, Tuple, Optional, List, Any

### Read Dataset

In [None]:
DATASET_FOLDER = '../dataset/'
train = pd.read_csv(os.path.join(DATASET_FOLDER, 'train.csv'))
test = pd.read_csv(os.path.join(DATASET_FOLDER, 'test.csv'))
sample_test = pd.read_csv(os.path.join(DATASET_FOLDER, 'sample_test.csv'))
sample_test_out = pd.read_csv(os.path.join(DATASET_FOLDER, 'sample_test_out.csv'))

In [None]:
sample_test.entity_name.head(),sample_test_out.prediction.head()

### Download images

In [None]:
from utils import download_images
download_images(sample_test['image_link'], '../images')

In [None]:
assert len(os.listdir('../images_test'))>0

In [None]:
rm -rf ./images

#### Extracting Text from images using Paddle OCR

In [None]:
# importing necessary libraries
import cv2
import csv
from paddleocr import PaddleOCR

In [None]:
# Initialize the OCR model
ocr = PaddleOCR(use_angle_cls=True, lang='en') 

In [None]:
# Define the directory containing images and the output CSV file path
image_directory = '../images_test'
output_csv = '../dataset/ocr_results_test.csv'

In [None]:
# Create or open the CSV file to save the results
with open(output_csv, mode='a', newline='', encoding='utf-8') as file:
    writer = csv.writer(file)
    writer.writerow(["image_link", "extracted_text"])  # Header row

    # Loop through each image in the directory
    for image_name in os.listdir(image_directory):
        if image_name.endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tiff')):  # Filter for image files
            image_path = os.path.join(image_directory, image_name)

            # Read the image
            img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)

            # Perform OCR on the image
            try:
                res = ocr.ocr(img, cls=True)
                if res[0] == None :
                    writer.writerow([image_path,None])
                    continue
                # Extract all text from the image
                extracted_text = ""
                for line in res:
                    for word_info in line:
                        extracted_text += word_info[-1][0] + " "  # Concatenate all the words

                # Write the image path and extracted text to the CSV file
                writer.writerow([image_path, extracted_text.strip()])
            except:
                print(image_path)
                continue
print("OCR completed and results saved to", output_csv)

In [None]:
# loading extracted text and test.csv

df = pd.read_csv("../dataset/test.csv")
ou = pd.read_csv("../dataset/ocr_results_test.csv")

In [None]:
# Preprocessing and mapping the extracted data to dataframe

df['image_link'] = df['image_link'].apply(lambda x: x.split('/')[-1])

df.drop(columns=['group_id'],inplace=True)

ou['image_link'] = ou['image_link'].apply(lambda x: x.split('/')[-1])

ou['image_link'] = ou['image_link'].apply(lambda x: x.split('\\')[-1])

ou = ou.merge(df, on=['image_link'], how='right')

ou.drop(columns=['image_link'],inplace=True)

In [None]:
# A Named Entity Recognition (NER) model is built for each entity name using the spaCy framework, therefore creating total of 8 models
# For training script refer multiprocess_df.py
# Loading all NER models

width = spacy.load("../models/output__width/model-best/")

depth = spacy.load("../models/output__depth/model-best/")

height = spacy.load("../models/output__height/model-best/")

max_weight = spacy.load("../models/output__maximum_weight_recommendation/model-best/")

wattage = spacy.load("../models/output__wattage/model-best/")

voltage = spacy.load("../models/output__voltage/model-best/")

item_volume = spacy.load("../models/output__item_volume/model-best/")

item_weight = spacy.load("../models/output__item_weight/model-best/")

In [None]:
# Backup method for entity value recognition is Ruled-Based Recognition
# The EntityValueExtractor class contains methods for RBR

class EntityValueExtractor:
    def __init__(self, entity_unit_map: Dict[str, Set[str]], default_units: Dict[str, str], unit_mapping: Dict[str, str]):
        self.entity_unit_map = entity_unit_map
        self.default_units = default_units
        self.unit_mapping = unit_mapping
        self.allowed_units = {unit for units in entity_unit_map.values() for unit in units}
        self.reverse_unit_mapping = {v: k for k, v in unit_mapping.items()}

    def preprocess_text(self, text: str) -> str:
        if not isinstance(text, str):
            return ""
        text = text.replace('"', ' inch ')
        # text = text.replace('\'', ' foot ')
        text = re.sub(r'(\d+),(?=0+[^\d])', r'\1', text)
        text = re.sub(r'(\d+),(\d+)', r'\1.\2', text)
        text = text.lower()
        text = re.sub(r'(\d+)\s*0z', r'\1 oz', text)
        # Handle prefix
        text = re.sub(r'wt\.(\d+)([a-zA-Z]+)', r'\1 \2', text)
        text = re.sub(r'max\.(\d+)([a-zA-Z]+)', r'\1 \2', text)
        text = re.sub(r'qty\.(\d+)([a-zA-Z]+)', r'\1 \2', text)
        text = re.sub(r'qt\.(\d+)([a-zA-Z]+)', r'\1 \2', text)
        # Handle hyphenated number-unit pairs
        text = re.sub(r'(\d+)-([a-zA-Z]+)', r'\1 \2', text)
        text = re.sub(r'([a-zA-Z]+)-(\d+)', r'\1 \2', text)
        # Separate numbers and units when directly attached
        text = re.sub(r'(\d+)([a-zA-Z]+)', r'\1 \2', text)
        return text

    def extract_values_from_text(self, text: str) -> List[Tuple[float, str]]:
        text = self.preprocess_text(text)
        # Updated pattern to handle 'WT.' cases, hyphenated pairs, and general number-unit pairs
        pattern = r'(?:wt\.)?([-+]?\d*\.?\d+)\s*-?\s*([a-zA-Z]+(?:-[a-zA-Z]+)?)'
        matches = re.findall(pattern, text)
        return [(float(value), self.normalize_unit(unit)) for value, unit in matches]

    def normalize_unit(self, unit: str) -> str:
        unit = unit.lower().rstrip('s')
        return self.unit_mapping.get(unit, unit)

    def find_unit_in_text(self, text: str, entity: str) -> Optional[str]:
        text = self.preprocess_text(text)
        for unit in self.entity_unit_map.get(entity, set()):
            if unit.lower() in text or any(abbr.lower() in text for abbr in self.reverse_unit_mapping if self.reverse_unit_mapping[abbr] == unit):
                return unit
        return None

    def map_value_to_entity(self, value: float, unit: str, entity: str) -> Optional[Tuple[float, str]]:
        normalized_unit = self.normalize_unit(unit)
        if normalized_unit in self.entity_unit_map.get(entity, set()):
            return (value, normalized_unit)
        return None

    def format_float(self, value: float) -> str:
        if value == 0:
            return "0.0"
        abs_value = abs(value)
        if abs_value < 0.01 or abs_value >= 1e7:
            return f"{abs_value:.6f}".rstrip('0').rstrip('.')
        else:
            return f"{abs_value:.2f}".rstrip('0').rstrip('.')

    def extract(self, text: Any, entity: str) -> Optional[str]:
        if not isinstance(text, str):
            text = str(text)
        extracted_values = self.extract_values_from_text(text)
        matching_values = []

        # Original logic for other entities
        for value, unit in extracted_values:
            result = self.map_value_to_entity(value, unit, entity)
            if result:
                matching_values.append(result)

        if matching_values:
            largest_value = max(matching_values, key=lambda x: x[0])
            return f"{self.format_float(largest_value[0])} {largest_value[1]}"

        found_unit = self.find_unit_in_text(text, entity)
        if found_unit:
            values = [value for value, _ in extracted_values]
            if values:
                return f"{self.format_float(max(values))} {found_unit}"

        default_unit = self.default_units.get(entity)
        if default_unit and extracted_values:
            values = [value for value, _ in extracted_values]
            return f"{self.format_float(max(values))} {default_unit}"

        return None

In [None]:
# Configuration

default_units = {
    'height': 'centimetre',
    'width': 'centimetre',
    'depth': 'centimetre',
    'length': 'centimetre',
    'item_weight': 'gram',
    'maximum_weight_recommendation': 'gram',
    'voltage': 'volt',
    'wattage': 'watt',
    'item_volume': 'millilitre'
}

unit_mapping = {
    'in': 'inch', '"': 'inch', 'inch': 'inch', 'inches': 'inch', 'foot': 'foot', 'ft': 'foot',
    'cm': 'centimetre', 'm': 'metre', 'mm': 'millimetre', 
    'yard': 'yard', 'yd': 'yard',
    'g': 'gram', 'kg': 'kilogram', 
    'mg': 'milligram', 'lb': 'pound', 'lbs': 'pound', 'oz': 'ounce', 
    'l': 'litre', 'ml': 'millilitre', 'cl': 'centilitre',
    'v': 'volt', 'kv': 'kilovolt', 'mv': 'millivolt',
    'w': 'watt', 'kw': 'kilowatt',
    'fl oz': 'fluid ounce', 'gal': 'gallon', 'qt': 'quart', 'pt': 'pint',
    'cu ft': 'cubic foot', 'cu in': 'cubic inch',
}

In [None]:
# Creating EntityValueExtractor class instance
extractor = EntityValueExtractor(constants.entity_unit_map, default_units, unit_mapping)

In [None]:
# Preprocessing of OCR extracted text before feeding it to the model

def preprocess_text(text):
    unit_replacements = {
        # Length and dimension units
        r'(\d+(\.\d+)?)\s*\'' : r' \1 foot ',  # Single quote for feet
        r'(\d+(\.\d+)?)\s*\"' : r' \1 inch ',  # Double quote for inches
        r'(\d+(\.\d+)?)\s*(in|In|"|\'|inch|Inch|inchs|inches|Inches)\b': r' \1 inch ',
        r'(\d+(\.\d+)?)\s*(ft|FT|feet|Feet|foot|Foot)\b': r'\1 foot ',
        r'(\d+(\.\d+)?)\s*(cm|CM|centimeters|Centimeters|centimetre|Centimetre)\b': r' \1 centimetre ',
        r'(\d+(\.\d+)?)\s*(m|M|metre|Metre|meters|Meters)\b': r' \1 metre ',
        r'(\d+(\.\d+)?)\s*(mm|MM|millimeters|Millimeters|millimetre|Millimetre)\b': r' \1 millimetre ',
        r'(\d+(\.\d+)?)\s*(yard|Yard|yards|Yards)\b': r'\1 yard ',
        # Weight units
        r'(\d+(\.\d+)?)\s*(g|gr|G|grams|Grams|gram|Gram)\b': r' \1 gram ',
        r'(\d+(\.\d+)?)\s*(kg|KG|kilograms|Kilograms|kilogram|Kilogram)\b': r' \1 kilogram ',
        r'(\d+(\.\d+)?)\s*(mg|MG|milligrams|Milligrams|milligram|Milligram)\b': r' \1 milligram ',
        r'(\d+(\.\d+)?)\s*(lb|1b|1bs|LB|lbs|LBS|pounds|Pounds|pound|Pound)\b': r' \1 pound ',
        r'(\d+(\.\d+)?)\s*(oz|OZ|0z|ounces|Ounces|ounce|Ounce)\b': r' \1 ounce ',
        r'(\d+(\.\d+)?)\s*(ton|Ton|tons|Tons)\b': r' \1 ton ',
        # Volume units
        r'(\d+(\.\d+)?)\s*(l|L|liters|Liters|litres|Litres|litre|Litre)\b': r' \1 litre ',
        r'(\d+(\.\d+)?)\s*(ml|ML|milliliters|Milliliters|millilitres|Millilitres|millilitre|Millilitre)\b': r' \1 millilitre ',
        r'(\d+(\.\d+)?)\s*(cl|CL|centiliters|Centiliters|centilitre|Centilitre)\b': r' \1 centilitre ',
        r'(\d+(\.\d+)?)\s*(dl|DL|deciliters|Deciliters|decilitre|Decilitre)\b': r' \1 decilitre ',
        r'(\d+(\.\d+)?)\s*(microlitre|Microlitre|microliters|Microliters|µL|uL)\b': r' \1 microlitre ',
        r'(\d+(\.\d+)?)\s*(pint|Pint|pints|Pints)\b': r' \1 pint ',
        r'(\d+(\.\d+)?)\s*(quart|Quart|quarts|Quarts)\b': r' \1 quart ',
        r'(\d+(\.\d+)?)\s*(cup|Cup|cups|Cups)\b': r' \1 cup ',
        r'(\d+(\.\d+)?)\s*(gallon|Gallon|gallons|Gallons)\b': r' \1 gallon ',
        r'(\d+(\.\d+)?)\s*(imperial gallon|Imperial Gallon|imperial gallons|Imperial Gallons)\b': r' \1 imperial gallon ',
        r'(\d+(\.\d+)?)\s*(cubic inch|Cubic Inch|cubic inches|Cubic Inches)\b': r' \1 cubic inch ',
        r'(\d+(\.\d+)?)\s*(cubic foot|Cubic Foot|cubic feet|Cubic Feet)\b': r' \1 cubic foot ',
        r'(\d+(\.\d+)?)\s*(fl oz|floz|FL OZ|fluid ounce|Fluid Ounce|fluid ounces|Fluid Ounces)\b': r' \1 fluid ounce ',
        # Voltage units
        r'(\d+(\.\d+)?)\s*(volt|Volt|volts|Volts|v|V)\b': r' \1 volt ',
        r'(\d+(\.\d+)?)\s*(kilovolt|Kilovolt|kV|KV)\b': r' \1 kilovolt ',
        r'(\d+(\.\d+)?)\s*(millivolt|Millivolt|mV|MV)\b': r' \1 millivolt ',
        # Power units
        r'(\d+(\.\d+)?)\s*(watt|Watt|watts|Watts|w|W)\b': r' \1 watt ',
        r'(\d+(\.\d+)?)\s*(kilowatt|Kilowatt|kW|KW)\b': r' \1 kilowatt '
    }
    # Replace unit misspellings and handle units attached to numbers
    for pattern, replacement in unit_replacements.items():
        text = re.sub(pattern, replacement, text, flags=re.IGNORECASE)
    return text

In [None]:
# Processing the model generated outputs and Rule-Based Recognition outputs
import utils
def parse_string(s):
    s_stripped = "" if s==None or str(s)=='nan' else s.strip()
    if s_stripped == "":
        return ""
    pattern = re.compile(r'^-?\d+(\.\d+)?\s+[a-zA-Z\s]+$')
    if not pattern.match(s_stripped):
        return ""
    parts = s_stripped.split(maxsplit=1)
    number = float(parts[0])
    unit = utils.common_mistake(parts[1])
    if unit not in constants.allowed_units:
        return ""
    return s

def predict_entity(text, entity_type, counter):
    # Dictionary to map entity types to their respective models
    model_map = {
        "width": width,
        "height": height,
        "maximum_weight_recommendation": max_weight,
        "wattage": wattage,
        "voltage": voltage,
        "item_volume": item_volume,
        "item_weight": item_weight,
        "depth": depth
    }
    
    if type(text) != type("s"):
        return ""
    # Preprocess the text
    preprocessed_text = preprocess_text(text)
    
    if entity_type == 'wattage':
        counter[0]+=1
        return extractor.extract(text,entity_type)
    # Get the appropriate model and make the prediction
    model = model_map[entity_type]
    doc = model(preprocessed_text)
    
    # sanity Check
    if len(doc.ents)!=0:
        s = str(doc.ents[0])
        if parse_string(s)!="":
            return s
    return extractor.extract(text,entity_type)

In [54]:
# Generating entity values for test dataset

tqdm.pandas()

def safe_predict_entity(row):
    try:
        return predict_entity(row['extracted_text'], row['entity_name'],counter)
    except Exception as e:
        print(f"Error processing row: {row}")
        print(f"Error message: {str(e)}")
        return "Error"

# Use progress_apply instead of apply
counter = []
counter.append(0)
ou['prediction'] = ou.progress_apply(safe_predict_entity, axis=1)
print(counter[0])

100%|██████████| 131187/131187 [15:09<00:00, 144.24it/s]

5384





In [55]:
ou.head()

Unnamed: 0,extracted_text,index,entity_name,prediction
0,2.63in 6.68cm 91.44cm - 199.39cm 36in - 78in,0,height,2.63 inch
1,"Size Width Length One Size 42cm/16.54"" 200cm/7...",1,width,200 centimetre
2,"Size Width Length One Size 42cm/16.54"" 200cm/7...",2,height,200 centimetre
3,"Size Width Length One Size 42cm/16.54"" 200cm/7...",3,depth,200 centimetre
4,"Size Width Length One Size 10.50cm/4.13"" 90cm/...",4,depth,90 centimetre


In [56]:
# Preprocessing the output csv
ou.drop(columns=['extracted_text','entity_name'],inplace=True)

In [57]:
# Storing the output to sample_output.csv
ou.to_csv('output.csv',index=False)

### Run Sanity check using src/sanity.py

In [58]:
!python sanity.py --test_filename ../dataset/test.csv --output_filename "output.csv"

Parsing successfull for file: test_output_5_6.csv
