In [8]:
import os
import base64
from datasets import load_dataset
import google.generativeai as genai
from dotenv import load_dotenv
from typing import List, Dict, Tuple, Any
import pandas as pd
import faiss
import numpy as np
from sentence_transformers import SentenceTransformer
from sqlalchemy import create_engine, Column, Integer, String, MetaData, Table
from IPython.display import display
from PIL import Image
from io import BytesIO
from pdf2image import convert_from_bytes
import json
import re

In [9]:
load_dotenv()
api_key = os.getenv("GEMINI_API_KEY")

# 1. Function Definitions
def setup_database_from_dfs(dataframes: Dict[str, pd.DataFrame], db_path: str = 'sqlite:///hs_database.db'):
    """Sets up the SQLite database from a dictionary of pandas DataFrames."""
    engine = create_engine(db_path)
    metadata = MetaData()

    table_definitions = {
        'df_harmonized_system': Table(
            'df_harmonized_system', metadata,
            Column('hscode', String, primary_key=True),
            Column('section', String),
            Column('description', String),
            Column('parent', String),
            Column('level', Integer)
        ),
        'df_sections': Table(
            'df_sections', metadata,
            Column('section', String, primary_key=True),
            Column('name', String)
        ),
        # Add more table definitions here as needed...
    }

    metadata.create_all(engine)

    for table_name, df in dataframes.items():
        if table_name in table_definitions:
            df.to_sql(table_name, engine, if_exists='replace', index=False)
    return engine

def search_hs_code(description: str, k: int = 2) -> List[Dict[str, str]]:
    """
    Searches for Harmonized System codes based on a product description using FAISS.

    Args:
        description (str): The product description to search for.
        k (int, optional): The number of nearest neighbors to retrieve. Defaults to 2.

    Returns:
        List[Dict[str, str]]: A list of dictionaries, where each dictionary contains
                              the 'hscode', 'description', and 'section' of the matched HS code.
    """
    query_vector = embedding_model.encode([description], convert_to_numpy=True)
    distances, indices = index.search(query_vector, k)
    results = []
    for idx in indices[0]:
        if idx != -1:
            row = df_map.iloc[idx]
            results.append({"hscode": row['hscode'], "description": row['description'], "section": row['section']})
    return results

def encode_image(data: Any) -> str:
    """
    Encodes an image into a base64 string, handling different input types.

    Args:
        data (Any):  PIL Image object, bytes, or a path to an image file.

    Returns:
        str: The base64 encoded image string, or None if encoding fails.
    """
    if isinstance(data, Image.Image):
        buffered = BytesIO()
        data.save(buffered, format="JPEG")
        encoded_string = base64.b64encode(buffered.getvalue()).decode('utf-8')
        return encoded_string
    elif isinstance(data, bytes):
        try:
            img = Image.open(BytesIO(data))
            return encode_image(img)
        except:
            try:
                images = convert_from_bytes(data)
                if images:
                    return encode_image(images[0])
                else:
                    return None
            except:
                return None
    else:
        return None
    
def process_invoice(image_data: Any, model: Any) -> List[Dict[str, Any]]:
    """
    Processes an invoice image, extracts information, and adds harmonized codes.

    Args:
        image_data: The image data of the invoice.
        model: The Gemini model.

    Returns:
        List[Dict[str, Any]]: A list of dictionaries, where each dictionary
        represents the extracted information from the invoice, including
        harmonized codes for the items.  Returns an empty list on error.
    """
    encoded_image = encode_image(image_data)
    vision_prompt = """
    Analyze this invoice and extract the following information in JSON format:

    {
        "seller_name": "Seller Name",
        "seller_address": "Seller Address",
        "receiver_name": "Receiver Name",
        "receiver_address": "Receiver Address",
        "items": [
            {
                "name": "Item Name",
                "quantity": 1,
                "price": 10.00,
                "harmonized_code": "HS Code",
                "description": "HS Description"
            },
            {
                "name": "Another Item Name",
                "quantity": 2,
                "price": 20.00,
                "harmonized_code": "HS Code",
                "description": "HS Description"
            }
        ]
    }

    If any information is not found, use "None" as the value.  For example, if the seller name is not on the invoice, use "seller_name": "None".
    Return ONLY the JSON object.
    """

    if model:
        vision_response = model.generate_content([vision_prompt, {"mime_type": "image/jpeg", "data": encoded_image}])
        response_text = vision_response.text.strip()

        try:
            json_match = re.search(r'```(?:json)?\s*({.*?})\s*```', response_text, re.DOTALL)
            if json_match:
                json_string = json_match.group(1)
            else:
                json_string = response_text
            vision_json = json.loads(json_string)
            items = vision_json.get("items", [])

            for item in items:
                hs_codes = search_hs_code(item.get('name', ''))
                if hs_codes:
                    item['harmonized_code'] = hs_codes[0]['hscode']
                    item['description'] = hs_codes[0]['description']
                else:
                    item['harmonized_code'] = "None"
                    item['description'] = "None"
            return [vision_json]  # Wrap the result in a list
        except json.JSONDecodeError as e:
            print(f"Error: Vision response was not valid JSON: {e}")
            print("Raw response:", response_text)
            return []
        except Exception as e:
            print(f"Error: Unexpected error: {e}")
            print("Raw response:", response_text)
            return []
    else:
        print("Gemini API not configured.")
        return []

In [10]:
# 2. Main Code Execution
dataframes = {
    'df_harmonized_system': pd.read_csv('./data/harmonized-system.csv'),
    'df_sections': pd.read_csv('./data/sections.csv'),
}

for df_name, df in dataframes.items():
    print(f"\nSample from {df_name}:")
    display(df.head())

engine = setup_database_from_dfs(dataframes)

query = "SELECT hscode, description, section FROM df_harmonized_system"
df = pd.read_sql(query, engine)

embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = embedding_model.encode(df['description'].tolist(), convert_to_numpy=True)

index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

faiss.write_index(index, "hs_codes.index")
df[['hscode', 'description', 'section']].to_csv("hs_code_mapping.csv", index=False)

index = faiss.read_index("hs_codes.index")
df_map = pd.read_csv("hs_code_mapping.csv")


Sample from df_harmonized_system:


Unnamed: 0,section,hscode,description,parent,level
0,I,1,Animals; live,TOTAL,2
1,I,101,"Horses, asses, mules and hinnies; live",01,4
2,I,10121,"Horses; live, pure-bred breeding animals",0101,6
3,I,10129,"Horses; live, other than pure-bred breeding an...",0101,6
4,I,10130,Asses; live,0101,6



Sample from df_sections:


Unnamed: 0,section,name
0,I,live animals; animal products
1,II,Vegetable products
2,III,Animal or vegetable fats and oils and their cl...
3,IV,"Prepared foodstuffs; beverages, spirits and vi..."
4,V,Mineral products


In [11]:
description = "orange"
results = search_hs_code(description)

print("Search Results:")
for result in results:
    print(f"HS Code: {result['hscode']}, Description: {result['description']}, Section: {result['section']}")


Search Results:
HS Code: 080510, Description: Fruit, edible; oranges, fresh or dried, Section: II
HS Code: 0805, Description: Citrus fruit; fresh or dried, Section: II


In [12]:
genai_model = None # Initialize outside the if block
if api_key:
    genai.configure(api_key=api_key)
    try:
        genai_model = genai.GenerativeModel('gemini-1.5-flash')
    except Exception as e:
        print(f"Error loading model: {e}")
else:
    print("GEMINI_API_KEY not found in environment variables.")


In [13]:
# Load only 2 samples from the train split.  Changed from 10 to 2.
dataset = load_dataset("mychen76/invoices-and-receipts_ocr_v1", split="train[1:3]") # Changed from train[1:2]

all_results = []
for data_point in dataset:
    image_data = data_point['image']
    results = process_invoice(image_data, genai_model) # Pass the model
    all_results.extend(results) # Use extend to add all results from process_invoice

# Print the results
print("\nRaw Results (Before Review):")
for result in all_results:
    print(json.dumps(result, indent=4))

Generating train split:   0%|          | 0/2043 [00:00<?, ? examples/s]

: 