In [25]:
import boto3
import csv
import os
from constants import entity_unit_map

In [26]:
textract_client = boto3.client('textract')
s3_client = boto3.client('s3')

In [27]:
def create_custom_query(entity_name):
    units = ', '.join(entity_unit_map.get(entity_name, []))  # Get the allowed units for the entity, e.g., "cm, foot, inch"
    
    custom_queries = [
        f"Extract the {entity_name} in {units}",
        f"Extract the {entity_name} from the given object in the image",
        f"Extract the {entity_name} from the image"
    ]
    
    return custom_queries

In [28]:
# Function to extract text from images using Textract with a dynamic query
def extract_entity_value_from_image(image_bytes, entity_name):
    queries = create_custom_query(entity_name)
    
    response = textract_client.analyze_document(
        Document={'Bytes': image_bytes},
        FeatureTypes=["QUERIES"],
        QueriesConfig={
            "Queries": [{"Text": custom_queries, "Alias": None} for query in queries]
        }
    )

    # Extract text from Textract response for the specific entity
    extracted_text = ''
    for block in response['Blocks']:
        if block['BlockType'] == 'QUERY_RESULT' and block['Query']['Alias'] == entity_name:
            extracted_text = block.get('Text', '')

    # Format and validate the extracted value to match the "x unit" format
    return format_extracted_value(extracted_text, entity_name)

In [29]:
# Function to format extracted text to "x unit" where x is a float and unit is valid
def format_extracted_value(extracted_text, entity_name):
    if not extracted_text:
        return ""

    # Extract the first valid float number and unit from the text
    allowed_units = entity_unit_map.get(entity_name, [])
    pattern = r'([-+]?\d*\.\d+|\d+)\s*(' + '|'.join(allowed_units) + r')'
    match = re.search(pattern, extracted_text)
    
    if match:
        value = match.group(1)
        unit = match.group(2)
        return f"{float(value)} {unit}"  # Ensure standard formatting with float
    else:
        return ""  # Return empty string if no valid value is found

In [30]:
def get_image_from_s3(bucket_name, s3_key):
    s3_object = s3_client.get_object(Bucket=bucket_name, Key=s3_key)
    return s3_object['Body'].read()

In [31]:
def process_images_and_save_to_csv(bucket_name, image_folder_prefix, output_csv):
    with open(output_csv, mode='w', newline='') as csvfile:
        fieldnames = ['index', 'group_id', 'entity_name', 'extracted_value']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()

        index = 0  # Initialize index for CSV
        # List the directories and files in the S3 bucket
        result = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=image_folder_prefix)
        for content in result.get('Contents', []):
            s3_key = content['Key']
            
            # Extract group_id and entity_name from the S3 key path
            parts = s3_key.split('/')
            if len(parts) < 3:
                continue  # Skip files that don't match the expected folder structure

            group_id = parts[-3]  # Assuming folder structure: <prefix>/group_id/entity_name/image_file
            entity_name = parts[-2]

            # Skip files that are not images or folders
            if not entity_name in entity_unit_map or s3_key.endswith('/'):
                continue
            
            # Read image bytes directly from S3
            image_bytes = get_image_from_s3(bucket_name, s3_key)
            
            # Extract value using Textract
            extracted_value = extract_entity_value_from_image(image_bytes, entity_name)

            # Write to CSV with index
            writer.writerow({
                'index': index,
                'group_id': group_id,
                'entity_name': entity_name,
                'extracted_value': extracted_value
            })
            index += 1

In [32]:
if __name__ == "__main__":
    bucket_name = 'ocrtextextraction'  
    image_folder_prefix = 'preprocessed_images/train/'
    output_csv = 'output_file.csv'  # Path to the CSV output file

    process_images_and_save_to_csv(bucket_name, image_folder_prefix, output_csv)

NameError: name 'custom_queries' is not defined