In [1]:
#! pip install pandas

In [2]:
import os
os.environ['OPENAI_API_KEY'] = 'sk-...'

In [3]:
from typing import List, TypedDict
from baml_client import b
from baml_client.tracing import trace
import openai
import pandas as pd

In [4]:
def load_categories() -> List[str]:
    categories = []
    with open("./data/category.txt", "r") as f:
        for line in f:
            categories.append(line.strip())
    return categories

categories = load_categories()

In [5]:
df = pd.DataFrame(categories, columns=["Categories"])
df["Parent"] = df["Categories"].apply(lambda x: f'/{x.split("/")[1]}')
df

Unnamed: 0,Categories,Parent
0,/Appliances,/Appliances
1,/Appliances/Refrigerators,/Appliances
2,/Appliances/Refrigerators/French Door Refriger...,/Appliances
3,/Appliances/Dishwashers,/Appliances
4,/Appliances/Dishwashers/Built-In Dishwashers,/Appliances
...,...,...
1405,/Safety Equipment/Traffic Safety Supplies/Traf...,/Safety Equipment
1406,/Safety Equipment/Traffic Safety Supplies/Traf...,/Safety Equipment
1407,/Safety Equipment/Traffic Safety Supplies/Traf...,/Safety Equipment
1408,/Safety Equipment/Knee Pads,/Safety Equipment


In [6]:
# Group categories by parent and assemble as a single string per parent
df_grouped = df.groupby("Parent")["Categories"].apply(lambda x: "\n".join(x)).reset_index()
df_grouped

Unnamed: 0,Parent,Categories
0,/Appliances,/Appliances\n/Appliances/Refrigerators\n/Appli...
1,/Automotive,/Automotive\n/Automotive/Battery Charging Syst...
2,/Bath,/Bath\n/Bath/Bathroom Storage\n/Bath/Bathroom ...
3,/Building Materials,/Building Materials\n/Building Materials/Ladde...
4,/Cleaning,/Cleaning\n/Cleaning/Cleaning Supplies\n/Clean...
5,/Doors & Windows,/Doors & Windows\n/Doors & Windows/Windows\n/D...
6,/Electrical,"/Electrical\n/Electrical/Electrical Boxes, Con..."
7,/Flooring,/Flooring\n/Flooring/Flooring Supplies\n/Floor...
8,/Furniture,/Furniture\n/Furniture/Home Office Furniture\n...
9,/Hardware,/Hardware\n/Hardware/Cabinet Hardware\n/Hardwa...


In [7]:
# Create an embedding for category
def create_embedding(text: str) -> List[float]:
    response = openai.embeddings.create(input=text, model='text-embedding-3-large')
    return response.data[0].embedding
df_grouped["Embedding"] = df_grouped["Categories"].apply(lambda x: create_embedding(x))

In [8]:
df_grouped['count'] = df_grouped['Categories'].apply(lambda x: len(x.split("\n")))
df_grouped

Unnamed: 0,Parent,Categories,Embedding,count
0,/Appliances,/Appliances\n/Appliances/Refrigerators\n/Appli...,"[-0.033564064651727676, -0.007931031286716461,...",95
1,/Automotive,/Automotive\n/Automotive/Battery Charging Syst...,"[-0.002323379274457693, 0.0446450300514698, -0...",66
2,/Bath,/Bath\n/Bath/Bathroom Storage\n/Bath/Bathroom ...,"[-0.016305895522236824, 0.00472629489377141, -...",52
3,/Building Materials,/Building Materials\n/Building Materials/Ladde...,"[0.005039564799517393, -0.001713049947284162, ...",104
4,/Cleaning,/Cleaning\n/Cleaning/Cleaning Supplies\n/Clean...,"[-0.009269515983760357, 0.0029607058968394995,...",39
5,/Doors & Windows,/Doors & Windows\n/Doors & Windows/Windows\n/D...,"[-0.02355542965233326, 0.03024088405072689, -0...",29
6,/Electrical,"/Electrical\n/Electrical/Electrical Boxes, Con...","[0.01351709384471178, 0.024028806015849113, -0...",129
7,/Flooring,/Flooring\n/Flooring/Flooring Supplies\n/Floor...,"[-0.003810738679021597, 0.012377684004604816, ...",32
8,/Furniture,/Furniture\n/Furniture/Home Office Furniture\n...,"[0.022286467254161835, -0.004916132427752018, ...",12
9,/Hardware,/Hardware\n/Hardware/Cabinet Hardware\n/Hardwa...,"[-0.0026445782277733088, 0.03329707682132721, ...",91


In [9]:
class TestCase(TypedDict):
    text: str
    description: str
    category: str


tests: List[TestCase] = [
    {
        "text": "53 in Fiberglass Handle Mortar Hoe",
        "description": "This 53 in. Fiberglass Handle Mortar Hoe is designed for spreading and mixing mortar, cement, and concrete. It features a high-strength, forged metal head with a multi-step hammer-tone finish and integrated holes for improved flow of materials. The forged steel head is attached with dual stainless steel rivets, and the tool includes an integrated hanging hole for easy storage. The fiberglass handle is resistant to environmental damage and has an over-molded end-grip and mid-grip for better leverage and two-handed use.",
        "category": "/Building Materials/Concrete, Cement & Masonry/Concrete Tools",
    },
    {
        "text": "20V MAX* 14 in. Folding String Trimmer",
        "description": "This 20V MAX* 14 in. String Trimmer includes a folding hinge mechanism to reduce its length by 40%** for easier storage and portability. The QuickLoad™ Spool facilitates straightforward line replacement for the DEWALT 0.080 in. line. The 14 in. cutting swath covers a broad area in one pass. The variable speed trigger provides power control, while the Hi/Lo speed control switch assists with performance and runtime management.",
        "category": "/Outdoors/Outdoor Power Equipment",
    },
    {
        "text": "8 ft Fiberglass Twin Front Step Ladder",
        "description": "This DEWALT 8 ft. Fiberglass Twin Front Step Ladder features a 300lb load capacity per side, allowing two people to use it simultaneously, one on each side. It includes a top with magnet and tool slots, heavy duty boots, an impact absorption system, and a 25% larger step surface.",
        "category": "/Building Materials/Ladders/Step Ladders",
    },
    {
        "text": "Rotex™ Protective Eyewear",
        "description": "The DEWALT DPG103 ROTEX™ Safety Glass features a lightweight frame and flexible temples with soft, rubber grips for a comfortable and secure fit.",
        "category": "/Safety Equipment/Protective Eyewear/Safety Glasses",
    },
    {
        "text": "3/8 in Drive Deep Metric Sockets 12 pt",
        "description": "These Hand Sockets feature DirectTorque™ technology designed to help prevent rounding of fasteners and provide a secure grip. The chrome finish offers protection against chipping and flaking.",
        "category": "/Tools/Hand Tools/Ratchets & Sockets/Sockets",
    },
    {
        "text": "20V MAX* XR® Brushless Cordless 1/2 in. 3-Speed Hammer Drill",
        "description": "The 20V MAX XR Brushless 1/2 in. 3-Speed Hammer Drill is designed to handle various applications. It features the ANTI-ROTATION System which deactivates the tool if excessive rotational motion is detected, and can achieve up to 275 holes per charge. The drill has an all-metal transmission construction for durability and includes an adjustable 3-position LED to illuminate dim work areas. Battery and charger are sold separately.",
        "category": "/Tools/Power Tools/Drills/Hammer Drills",
    },
    {
        "text": "6:1 Chalk Reel 30m/100ft",
        "description": "This 6:1 Chalk Reel is suitable for various tasks including framing, wall layout, flooring, and cabinet work. It features a planetary gear system designed to prevent jams, enhancing its durability. The reel also comes with an extra-large flip cap to facilitate easy refilling and minimize chalk spillage.",
        "category": "/Tools/Hand Tools/Measuring Tools",
    },
    {
        "text": "Cold Water Residential Electric Pressure Washer (2000 PSI at 3.0 GPM)",
        "description": "This pressure washer is designed to meet the needs of cleaning professionals and individuals. Suitable for deck cleaning, wood restoration, paint preparation, graffiti removal, and various other cleaning tasks.",
        "category": "/Outdoors/Outdoor Power Equipment/Pressure Washers/Electric Pressure Washers",
    },
]

In [10]:
import numpy as np

def cosine_similarity(a: List[float], b: List[float]) -> float:
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

@trace
def get_best_categories(tool: str, description: str) -> List[str]:
    # Get the top 3 sections by similarity
    mapping = openai.embeddings.create(input=f'{tool}\n{description}', model='text-embedding-3-large')

    # Cos-similarity between the test and the categories
    cos_sim = []
    for i, row in df_grouped.iterrows():
        cos_sim.append((row["Parent"], cosine_similarity(row["Embedding"], mapping.data[0].embedding)))
    
    # Get the top 3 categories
    top_3 = sorted(cos_sim, key=lambda x: x[1], reverse=True)[:3]
    return [x[0] for x in top_3]

In [11]:
from baml_client.type_builder import TypeBuilder
from typing import Tuple
from baml_client.types import Classification

@trace
async def classify(tool: str, description: str) -> Classification:
    root_categories = get_best_categories(tool, description)
    # Filter for all categories which are children of the root categories
    tb = TypeBuilder()
    for _, row in df.iterrows():
        if row["Parent"] in root_categories:
            tb.Tools.add_value(row['Categories'])
    selected = await b.Classify(tool, description, count=1, baml_options={ "tb": tb })
    if len(selected) == 0:
        return None
    return selected[0]


In [12]:
from baml_client.types import Classification

def score(value: str, options: Classification) -> float:
    if value == options.category:
        return 1
    if options.category.startswith(value) or value.startswith(options.category):
        return 0.5
    return 0

@trace
async def validate(test: TestCase):
    selected = await classify(test["text"], test["description"])
    return score(test['category'], selected), selected

@trace
async def main():
    pass_count = 0
    for test in tests:
        res, selected = await validate(test)
        messg = 'failed'
        if res == 1:
            messg = 'passed'
        if res == 0.5:
            messg = 'Parent Match'

        pass_count += 1 if res else 0
        print(f"Test - {messg}: {test['text']}")
        print(f"Expected: {test['category']}")
        print(f"Selected: {selected.category}")
    print(f"Pass rate: {pass_count}/{len(tests)}")
    return pass_count


await main()

Test - passed: 53 in Fiberglass Handle Mortar Hoe
Expected: /Building Materials/Concrete, Cement & Masonry/Concrete Tools
Selected: /Building Materials/Concrete, Cement & Masonry/Concrete Tools
Test - Parent Match: 20V MAX* 14 in. Folding String Trimmer
Expected: /Outdoors/Outdoor Power Equipment
Selected: /Outdoors
Test - passed: 8 ft Fiberglass Twin Front Step Ladder
Expected: /Building Materials/Ladders/Step Ladders
Selected: /Building Materials/Ladders/Step Ladders
Test - Parent Match: Rotex™ Protective Eyewear
Expected: /Safety Equipment/Protective Eyewear/Safety Glasses
Selected: /Safety Equipment/Protective Eyewear
Test - passed: 3/8 in Drive Deep Metric Sockets 12 pt
Expected: /Tools/Hand Tools/Ratchets & Sockets/Sockets
Selected: /Tools/Hand Tools/Ratchets & Sockets/Sockets
Test - Parent Match: 20V MAX* XR® Brushless Cordless 1/2 in. 3-Speed Hammer Drill
Expected: /Tools/Power Tools/Drills/Hammer Drills
Selected: /Tools/Power Tools/Drills
Test - passed: 6:1 Chalk Reel 30m/100f

8