In [1]:
hierarchy_dict = {
    'food_items': {
        "Fruits": [],
        "Vegetables": [],
        "Beverages": [],
        "Baking Ingredients": [],
        "Snacks": [],
        "Condiments & Sauces": [],
        "Meats": [],
        "Seafood": [],
        "Dairy": [],
        "Grains & Pasta": [],
        "Herbs & Spices": [],
        "Sweets": [],
        "Other Food Items": []
    },
    'household_items': {
        "Cleaning Supplies": [],
        "Kitchenware": [],
        "Decor": [],
        "Furniture": [],
        "Appliances": [],
        "Other Household Items": []
    },
    'tools_hardware': {
        "Hand Tools": [],
        "Power Tools": [],
        "Gardening Tools": [],
        "Hardware": [],
        "Other Tools & Hardware": []
    },
    'electronics': {
        "Computing": [],
        "Entertainment": [],
        "Communications": [],
        "Other Electronics": []
    },
    'personal_care_items': {
        "Hygiene": [],
        "Beauty": [],
        "Healthcare": [],
        "Other Personal Care Items": []
    },
    'clothing_accessories': {
        "Clothing": [],
        "Footwear": [],
        "Accessories": [],
        "Other Clothing & Accessories": []
    },
    'office_supplies': {
        "Writing Instruments": [],
        "Paper Products": [],
        "Organizational Supplies": [],
        "Other Office Supplies": []
    },
    'other': {
        "Miscellaneous": []
    }
}

In [2]:
from bddl.knowledge_base import *

Loading BDDL knowledge base... This may take a few seconds.
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\cgokmen\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [30]:
hierarchy = ""
for high, lows in hierarchy_dict.items():
    hierarchy += f"{high.replace('_', ' ').title()}:\n"
    for low in lows:
        hierarchy += f"  {low}\n"
    hierarchy += "\n"
print(hierarchy)

Food Items:
  Fruits
  Vegetables
  Beverages
  Baking Ingredients
  Snacks
  Condiments & Sauces
  Meats
  Seafood
  Dairy
  Grains & Pasta
  Herbs & Spices
  Sweets
  Other Food Items

Household Items:
  Cleaning Supplies
  Kitchenware
  Decor
  Furniture
  Appliances
  Other Household Items

Tools Hardware:
  Hand Tools
  Power Tools
  Gardening Tools
  Hardware
  Other Tools & Hardware

Electronics:
  Computing
  Entertainment
  Communications
  Other Electronics

Personal Care Items:
  Hygiene
  Beauty
  Healthcare
  Other Personal Care Items

Clothing Accessories:
  Clothing
  Footwear
  Accessories
  Other Clothing & Accessories

Office Supplies:
  Writing Instruments
  Paper Products
  Organizational Supplies
  Other Office Supplies

Other:
  Miscellaneous




In [18]:
len(Object.all_objects())

8842

In [4]:
all_categories = [c for c in Category.all_objects() if len(c.objects) > 0]
def category_str(cats):
    categories = ""
    for c in cats:
        categories += f"{c.name.replace('_', ' ').title()}: {len(c.objects)}\n"
    return categories

In [5]:
task_msg = lambda cx: f"""
I have a dataset of objects that I am releasing for research purposes. I am giving you a list of categories of objects that
are included in the dataset below as well as the number of objects in each category. I need to make a plot out of these that
shows the diversity of the dataset by highlighting the different kinds of objects we have in our dataset. I think a great way
of doing this could be a sankey plot with multiple levels. To do that, I need to create a hierarchy of these categories. I can
group some of the important stuff at a high level ("food", "furniture", "structure", etc.) and then go into more detail at the
lower levels. We can group the less frequently seen objects into "other" groups. I have made a hierarchy that you can see below:

{hierarchy}

Now I need to assign each of the object categories below to one of the hierarchy nodes. Here are the object categories and their
object counts:

{cx}
"""

In [6]:
def query(high, low):
    return f"Now, tell me what categories should go in the {high.replace('_', ' ').title()} -> {low} node of the hierarchy as a JSON list. Return the JSON wrapped in three ticks (```). The list elements you return must match the elements from the above category list exactly. Don't say anything else."

In [39]:
from openai import OpenAI
import getpass
apikey = getpass.getpass()
client = OpenAI(api_key=apikey)

········


In [38]:
def assignment_fixer(x):
    return x.split("/")[-1].split(":")[-1].split(".")[-1]

In [None]:
assignments = {}

In [49]:
import json
all_category_names = [c.name for c in all_categories]
available_lows = list(low for lows in hierarchy_dict.values() for low in lows)
if True:
    with open("assignments.json", "r") as f:
        assignments = {k: assignment_fixer(v) for k, v in json.load(f).items()}
if True:  # set to True to regenerate assignments
    batch_size = 100
    while True:
        # Clear out any assignments that are invalid
        assignments = {k: v for k, v in assignments.items() if v in available_lows}
        remaining = set(all_category_names) - set(assignments.keys())
        if not remaining:
            break
        print("Remaining:", len(remaining))
        batch_cats = sorted(remaining)[:batch_size]
        batch_cats_str = "\n".join(batch_cats)
        batch_task_msg = task_msg(batch_cats_str)
        response = client.chat.completions.create(
          model="gpt-4-1106-preview",
          response_format={ "type": "json_object" },
          messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": batch_task_msg},
            {"role": "user", "content": (
                 "Now, in JSON format, tell me the hierarchical node assignment for each of the categories. "
                 "For example, 'chicken_breast': 'Meats'. You need to give an assignment for each of the categories "
                 "listed above - do NOT skip over any or give summarized answers like 'so on', etc. and only output "
                 "the lower assignment node (e.g. 'Meats', not 'Food Items' or 'Food Items / Meats'). Make sure each "
                 "category you output is one of the items in the above list of subcategories - do NOT make up new "
                 "hierarchy nodes. If you think an object does not fall into any particular node, put it to one of the "
                 "Other XYZ nodes or in the Miscellaneous node. For example, put generic office supplies into Other "
                 "Office Supplies, or put baked goods into Other Food Items, or potted plants into Decor."
            )},
          ]
        )
        assignments.update({k: assignment_fixer(v) for k, v in json.loads(response.choices[0].message.content).items() if k in all_category_names})
    # Dump the assignments
    with open("assignments.json", "w") as f:
        json.dump(assignments, f)

Remaining: 139
Remaining: 73
Remaining: 15
Remaining: 3
Remaining: 3


In [50]:
{k: assignment_fixer(v) for k, v in json.loads(response.choices[0].message.content).items() if k in all_category_names and k not in assignments}

{}

In [51]:
found_assignments = set(assignments.keys())
print("Cats missing assignment", found_assignments - set(all_category_names))
print("Assignments that are invalid", set(all_category_names) - found_assignments)

Cats missing assignment set()
Assignments that are invalid set()


In [52]:
assert len(assignments) == len(all_category_names)
print(len(assignments))

1734


In [53]:
sum(len(Category.get(c).objects) for c in assignments)

8842

In [54]:
import collections

found_lows = set(assignment_fixer(x) for x in assignments.values())
available_lows = list(low for lows in hierarchy_dict.values() for low in lows)
repeated_lows = {x for x, v in collections.Counter(available_lows).items() if v > 1}
available_lows_set = set(available_lows)
assert not repeated_lows, repeated_lows

unknown_lows = found_lows - available_lows_set
assert not unknown_lows, unknown_lows

['Fruits', 'Vegetables', 'Beverages', 'Baking Ingredients', 'Snacks', 'Condiments & Sauces', 'Meats', 'Seafood', 'Dairy', 'Grains & Pasta', 'Herbs & Spices', 'Sweets', 'Other Food Items', 'Cleaning Supplies', 'Kitchenware', 'Decor', 'Furniture', 'Appliances', 'Other Household Items', 'Hand Tools', 'Power Tools', 'Gardening Tools', 'Hardware', 'Other Tools & Hardware', 'Computing', 'Entertainment', 'Communications', 'Other Electronics', 'Hygiene', 'Beauty', 'Healthcare', 'Other Personal Care Items', 'Clothing', 'Footwear', 'Accessories', 'Other Clothing & Accessories', 'Writing Instruments', 'Paper Products', 'Organizational Supplies', 'Other Office Supplies', 'Miscellaneous']


In [61]:
IGNORE_CATEGORIES = {"walls", "floors", "ceilings", "driveway", "lawn", "roof"}
hierarchylow = []
for high, lows in hierarchy_dict.items():
    for low in lows:
        objs = sum(len(c.objects) for c in all_categories if assignments[c.name] == low and c.name not in IGNORE_CATEGORIES)
        hierarchylow.append((high, low, objs))

In [62]:
hierarchylow

[('food_items', 'Fruits', 138),
 ('food_items', 'Vegetables', 183),
 ('food_items', 'Beverages', 281),
 ('food_items', 'Baking Ingredients', 95),
 ('food_items', 'Snacks', 95),
 ('food_items', 'Condiments & Sauces', 81),
 ('food_items', 'Meats', 79),
 ('food_items', 'Seafood', 22),
 ('food_items', 'Dairy', 66),
 ('food_items', 'Grains & Pasta', 78),
 ('food_items', 'Herbs & Spices', 120),
 ('food_items', 'Sweets', 100),
 ('food_items', 'Other Food Items', 172),
 ('household_items', 'Cleaning Supplies', 76),
 ('household_items', 'Kitchenware', 559),
 ('household_items', 'Decor', 701),
 ('household_items', 'Furniture', 1026),
 ('household_items', 'Appliances', 182),
 ('household_items', 'Other Household Items', 373),
 ('tools_hardware', 'Hand Tools', 19),
 ('tools_hardware', 'Power Tools', 6),
 ('tools_hardware', 'Gardening Tools', 33),
 ('tools_hardware', 'Hardware', 30),
 ('tools_hardware', 'Other Tools & Hardware', 93),
 ('electronics', 'Computing', 52),
 ('electronics', 'Entertainmen

In [63]:
sum(x[2] for x in hierarchylow)

6685

In [64]:
hierarchyhigh = [("objects", high, sum(v[2] for v in hierarchylow if v[0] == high)) for high in hierarchy_dict.keys()]
print(hierarchyhigh)

[('objects', 'food_items', 1510), ('objects', 'household_items', 2917), ('objects', 'tools_hardware', 181), ('objects', 'electronics', 271), ('objects', 'personal_care_items', 153), ('objects', 'clothing_accessories', 202), ('objects', 'office_supplies', 487), ('objects', 'other', 964)]


In [65]:
fullhierarchy = hierarchylow + hierarchyhigh

In [75]:
hr_hier = {
    'food_items': "Food Items",
    'household_items': "Household Items",
    'tools_hardware': "Tools & Hardware",
    'electronics': "Electronics",
    'personal_care_items': "Personal Care Items",
    'clothing_accessories': "Clothing & Accessories",
    'office_supplies': "Office Supplies",
    'other': "Other"
}

In [91]:
print("\n".join(hr_hier[x[1]] + "," + str(x[2]) for x in hierarchyhigh))

Food Items,1510
Household Items,2917
Tools & Hardware,181
Electronics,271
Personal Care Items,153
Clothing & Accessories,202
Office Supplies,487
Other,964


In [67]:
import plotly.graph_objects as go

isother = lambda v: "other" in v.lower().split() or "miscellaneous" in v.lower().split()
otherhigh = [
    ("", high, sum(
        v[2] for v in hierarchylow
        if v[0] == high and isother(v[1])
    ))
    for high in hierarchy_dict.keys()
]
nonotherlow = [x for x in hierarchylow if not isother(x[1])]
custhier = nonotherlow + otherhigh
custhier = [(high.replace("_", " ").title(), low.replace("_", " ").title(), val) for high, low, val in custhier]
parents, labels, values = zip(*custhier)
# labels = [f"{label} ({count})" for label, count in zip(labels, values)]
fig = go.Figure(go.Sunburst(
    labels=labels,
    parents=parents,
    values=values,
))
fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), autosize=False, width=800, height=1000)

fig.show()
# fig.write_image("fig1.png")

In [87]:
import pandas as pd
import plotly.express as px
df = pd.DataFrame.from_records([{"category": hr_hier[x[1]], "count": x[2]} for x in hierarchyhigh])
fig = px.pie(df, values='count', names='category')
fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), autosize=False, width=600, height=600)
fig.show()

In [85]:
import numpy as np
totalhigh = sum(x[2] for x in hierarchyhigh)
print(",\n".join(f"{x[2]}/{hr_hier[x[1]]}" for x in hierarchyhigh))

1510/Food Items,
2917/Household Items,
181/Tools & Hardware,
271/Electronics,
153/Personal Care Items,
202/Clothing & Accessories,
487/Office Supplies,
964/Other


In [102]:
import kaleido

In [109]:
for source, target, value in fullhierarchy:
    # source_cnt = scene_count_per_type[source]
    # source_str = f"{source.title()} ({source_cnt})"
    # target_cnt = total_value_per_room_type[target]
    # target_str = f"{target.title()} ({target_cnt})"
    source_str = source.replace("_", " ").title()
    target_str = target.replace("_", " ").title()
    print(f'{source_str} [{value}] {target_str}')

Food Items [108] Fruits
Food Items [160] Vegetables
Food Items [277] Beverages
Food Items [91] Baking Ingredients
Food Items [81] Snacks
Food Items [73] Condiments & Sauces
Food Items [63] Meats
Food Items [19] Seafood
Food Items [62] Dairy
Food Items [66] Grains & Pasta
Food Items [115] Herbs & Spices
Food Items [85] Sweets
Food Items [91] Other Food Items
Household Items [61] Cleaning Supplies
Household Items [433] Kitchenware
Household Items [340] Decor
Household Items [833] Furniture
Household Items [110] Appliances
Household Items [457] Other Household Items
Tools Hardware [16] Hand Tools
Tools Hardware [5] Power Tools
Tools Hardware [29] Gardening Tools
Tools Hardware [21] Hardware
Tools Hardware [43] Other Tools & Hardware
Electronics [47] Computing
Electronics [45] Entertainment
Electronics [9] Communications
Electronics [49] Other Electronics
Personal Care Items [52] Hygiene
Personal Care Items [14] Beauty
Personal Care Items [48] Healthcare
Personal Care Items [22] Other Pers