In [19]:
entity_unit_map = {
    'width': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'depth': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'height': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'item_weight': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'maximum_weight_recommendation': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'voltage': {'kilovolt', 'millivolt', 'volt'},
    'wattage': {'kilowatt', 'watt'},
    'item_volume': {'centilitre',
        'cubic foot',
        'cubic inch',
        'cup',
        'decilitre',
        'fluid ounce',
        'gallon',
        'imperial gallon',
        'litre',
        'microlitre',
        'millilitre',
        'pint',
        'quart'}
}

allowed_units = {unit for entity in entity_unit_map for unit in entity_unit_map[entity]}

In [20]:
import re
#from constants import constants # Remove this line
import os
import pandas as pd
import multiprocessing
import time
from time import time as timer
import tqdm
import numpy as np
from pathlib import Path
from functools import partial
import urllib
from PIL import Image

# Add this line to define the constants module
constants = type('constants', (), {'allowed_units': allowed_units})

def common_mistake(unit):
    if unit in constants.allowed_units:
        return unit
    if unit.replace('ter', 'tre') in constants.allowed_units:
        return unit.replace('ter', 'tre')
    if unit.replace('feet', 'foot') in constants.allowed_units:
        return unit.replace('feet', 'foot')
    return unit

def parse_string(s):
    s_stripped = "" if s==None or str(s)=='nan' else s.strip()
    if s_stripped == "":
        return None, None
    pattern = re.compile(r'^-?\d+(\.\d+)?\s+[a-zA-Z\s]+$') # Add $ to match the end of the string
    if not pattern.match(s_stripped):
        raise ValueError("Invalid format in {}".format(s))
    parts = s_stripped.split(maxsplit=1)
    number = float(parts[0])
    unit = common_mistake(parts[1])
    if unit not in constants.allowed_units:
        raise ValueError("Invalid unit [{}] found in {}. Allowed units: {}".format(
            unit, s, constants.allowed_units))
    return number, unit


def create_placeholder_image(image_save_path):
    try:
        placeholder_image = Image.new('RGB', (100, 100), color='black')
        placeholder_image.save(image_save_path)
    except Exception as e:
        return

def download_image(image_link, save_folder, retries=3, delay=3):
    if not isinstance(image_link, str):
        return

    filename = Path(image_link).name
    image_save_path = os.path.join(save_folder, filename)

    if os.path.exists(image_save_path):
        return

    for _ in range(retries):
        try:
            urllib.request.urlretrieve(image_link, image_save_path)
            return
        except:
            time.sleep(delay)

    create_placeholder_image(image_save_path) #Create a black placeholder image for invalid links/images

def download_images(image_links, download_folder, allow_multiprocessing=True):
    if not os.path.exists(download_folder):
        os.makedirs(download_folder)

    if allow_multiprocessing:
        download_image_partial = partial(
            download_image, save_folder=download_folder, retries=3, delay=3)

        with multiprocessing.Pool(64) as pool:
            list(tqdm(pool.imap(download_image_partial, image_links), total=len(image_links)))
            pool.close()
            pool.join()
    else:
        for image_link in tqdm(image_links, total=len(image_links)):
            download_image(image_link, save_folder=download_folder, retries=3, delay=3)

In [21]:
# Load train data
train_df = pd.read_csv('data/dataset/train.csv')
# Load test data
test_df = pd.read_csv('data/dataset/test.csv')

In [22]:
import re

# Define a function to clean and extract values based on patterns
def extract_entity_value(text, entity_name, allowed_units):
    # Regular expression for finding numbers and units
    pattern = r'(\d+(\.\d+)?)\s?(' + '|'.join(allowed_units) + ')'
    match = re.search(pattern, text.lower())
    if match:
        return f"{match.group(1)} {match.group(3)}"
    return ""

# Example usage on the train set
for entity_name in train_df['entity_name'].unique():
    allowed_units = entity_unit_map[entity_name]  # Defined in constants.py
    train_df['predicted_value'] = train_df.apply(
        lambda row: extract_entity_value(row['extracted_text'], row['entity_name'], allowed_units), axis=1)

KeyError: 'extracted_text'

In [23]:
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor

# Assume X is image features (could be embeddings from a CNN) and y is the numeric part of the entity_value
X_train, X_val, y_train, y_val = train_test_split(train_df['extracted_text'], train_df['entity_value'], test_size=0.2)

# Placeholder: You can use a simple regression model here
model = RandomForestRegressor()
model.fit(X_train, y_train)

# Make predictions
train_df['predicted_value'] = model.predict(train_df['extracted_text'])

KeyError: 'extracted_text'

In [24]:
# Ensure the predicted values are in the correct format
def format_prediction(value, allowed_units):
    try:
        number, unit = value.split()
        if unit in allowed_units:
            return f"{float(number):.2f} {unit}"
    except:
        return ""
    return ""

# Apply formatting on test predictions
for entity_name in test_df['entity_name'].unique():
    allowed_units = entity_unit_map[entity_name]
    test_df['prediction'] = test_df.apply(
        lambda row: format_prediction(row['predicted_value'], allowed_units), axis=1)

# Output the final predictions to a CSV file
test_df[['index', 'prediction']].to_csv('test_out.csv', index=False)

KeyError: 'predicted_value'

In [8]:
from src.sanity import sanity_check

# Run sanity check on the output file
sanity_check('test_out.csv')

ModuleNotFoundError: No module named 'src.sanity'