In [29]:
from sentence_transformers import SentenceTransformer
import numpy as np

# Load the SentenceBERT model
model = SentenceTransformer('all-MiniLM-L6-v2')
# model = SentenceTransformer('all-mpnet-base-v2')

# Descriptions for categories and subcategories
category_descriptions = {
    'books_movies_and_music': "Items related to books, movies, music, and other media.",
    'clothing_shoes_and_accessories': "Men's and women's clothing, shoes, and accessories.",
    'ebay_motors': "Vehicles, vehicle parts, and accessories.",
    'electronics': "Devices and gadgets like cameras, headphones, laptops, and video games.",
    'home_and_garden': "Furniture, decor, and other home and garden items.",
    'jewelry_and_watches': "Jewelry and watches for men and women.",
    'pet_supplies': "Supplies for pets, including dogs, cats, and other animals.",
    'sporting_goods': "Sports equipment and outdoor gear."
}

subcategory_descriptions = {
    'books_movies_and_music': {
        'Books': "Books across various genres and formats.",
        'DVDs': "DVDs of movies, shows, and educational content.",
        'Guitars-Basses': "Guitars, basses, and related musical instruments.",
        'Pianos-Keyboards-Organs': "Pianos, keyboards, and organs for music enthusiasts."
    },
    'clothing_shoes_and_accessories': {
        'Mens-Clothing': "Clothing for men, including shirts, pants, jackets, and more.",
        'Mens-Shoes': "Shoes for men, including casual, formal, and sports shoes.",
        'Travel-Luggage': "Luggage and bags for travel.",
        'Womens-Clothing': "Clothing for women, including dresses, tops, and pants.",
        'Womens-Shoes': "Shoes for women, including heels, flats, and sports shoes."
    },
    'ebay_motors': {
        'ATVs': "All-terrain vehicles for off-road adventures.",
        'Boats': "Boats for leisure, fishing, and water sports.",
        'Cadillac': "Cadillac vehicles and accessories.",
        'Ford': "Ford vehicles and related parts.",
        'Jeep': "Jeep vehicles and accessories.",
        'Mercedes-Benz': "Mercedes-Benz cars and parts.",
        'Scooters-Mopeds': "Scooters and mopeds for urban transportation.",
        'Toyota': "Toyota cars and accessories.",
        'Toyota-Supra-Cars': "Toyota Supra cars and related accessories.",
        'Yamaha': "Yamaha vehicles and musical instruments."
    },
    'electronics': {
        'Digital-Cameras': "Digital cameras for photography.",
        'Headphones': "Headphones for personal audio experiences.",
        'Laptops-Netbooks': "Laptops and netbooks for personal and professional use.",
        'Video-Games': "Video games and gaming consoles."
    },
    'home_and_garden': {
        'Beds-Headboards': "Beds and headboards for comfortable sleeping.",
        'Chairs': "Chairs for seating in various settings.",
        'Chandeliers-Ceiling-Fixtures': "Lighting fixtures including chandeliers.",
        'Tables': "Tables for dining, working, and other uses."
    },
    'jewelry_and_watches': {
        'Engagement-Rings': "Engagement rings in various designs.",
        'Fine-Earrings': "Fine earrings for various occasions.",
        'Watches': "Watches for men and women."
    },
    'pet_supplies': {
        'Dog-Supplies': "Supplies for dogs including food, toys, and accessories.",
        'Fish-Aquariums': "Fish and aquarium supplies."
    },
    'sporting_goods': {
        'Archery-Equipment': "Equipment for archery enthusiasts.",
        'Basketball-Equipment': "Basketball gear and equipment.",
        'Boxing-MMA-Equipment': "Boxing and MMA equipment for training.",
        'Golf-Accessories': "Accessories for golf players."
    }
}

# Generate embeddings for each category and subcategory
category_embeddings = {category: model.encode(description) for category, description in category_descriptions.items()}
subcategory_embeddings = {
    category: {subcategory: model.encode(description) for subcategory, description in subcategories.items()}
    for category, subcategories in subcategory_descriptions.items()
}

# Check if the embeddings were generated correctly
print("Category Embeddings:")
for category, embedding in category_embeddings.items():
    print(f"{category}: {embedding[:5]}...")  # print a preview of each embedding

print("\nSubcategory Embeddings:")
for category, subcategories in subcategory_embeddings.items():
    for subcategory, embedding in subcategories.items():
        print(f"{subcategory}: {embedding[:5]}...")  # print a preview of each embedding


Category Embeddings:
books_movies_and_music: [ 0.01548985 -0.03924794 -0.04360467 -0.00504606  0.0211649 ]...
clothing_shoes_and_accessories: [0.02551067 0.03625545 0.02715615 0.02555301 0.03353601]...
ebay_motors: [ 0.00602218  0.03761044  0.07130247 -0.02300078  0.00485703]...
electronics: [-0.01214534  0.020263    0.06607424 -0.0968603   0.06716678]...
home_and_garden: [0.0327583  0.04490215 0.07843655 0.00485398 0.01965619]...
jewelry_and_watches: [-0.00795137  0.04848052  0.01808016  0.02092567 -0.11314842]...
pet_supplies: [ 0.00161075  0.03867755  0.0821714   0.03885077 -0.05908728]...
sporting_goods: [ 0.00379407  0.07182351  0.06321409 -0.01055412  0.00652471]...

Subcategory Embeddings:
Books: [ 7.96836539e-05 -8.34209919e-02 -4.56190221e-02  1.16966395e-02
 -6.95318505e-02]...
DVDs: [ 0.01236952 -0.06343742 -0.01189125 -0.01657241 -0.04884796]...
Guitars-Basses: [ 0.01787486 -0.04763495  0.00674222 -0.02637216 -0.13242719]...
Pianos-Keyboards-Organs: [ 0.04348233 -0.0436656 

In [30]:
# Function to compute cosine similarity
def cosine_similarity(v1, v2):
    return 1 - cosine(v1, v2)

# Step 1: Generate embedding for the description
def get_top_subcategories(description, model, category_embeddings, subcategory_embeddings, top_n_categories=3, top_n_subcategories=5):
    # Generate embedding for the input description
    description_embedding = model.encode(description)

    # Step 2: Find the top N categories based on cosine similarity
    category_scores = {
        category: cosine_similarity(description_embedding, embedding)
        for category, embedding in category_embeddings.items()
    }
    top_categories = sorted(category_scores, key=category_scores.get, reverse=True)[:top_n_categories]
    print(top_categories)

    # Step 3: Within each of the top categories, find the top subcategories
    top_subcategories = []
    for category in top_categories:
        subcategory_scores = {
            subcategory: cosine_similarity(description_embedding, sub_embedding)
            for subcategory, sub_embedding in subcategory_embeddings[category].items()
        }
        sorted_subcategories = sorted(subcategory_scores, key=subcategory_scores.get, reverse=True)[:top_n_subcategories]
        top_subcategories.extend([(subcategory, subcategory_scores[subcategory]) for subcategory in sorted_subcategories])

    # Sort the collected top subcategories by similarity score across all categories
    top_subcategories = sorted(top_subcategories, key=lambda x: x[1], reverse=True)[:top_n_subcategories]
    return top_subcategories

# Example Usage
description = "Loose-fitting, casual dress with long sleeves and a relaxed, oversized silhouette. The dress should have a scoop neckline with raw-edge detailing, a slightly dropped waist, and a rounded hemline that falls above the knee. Made from a soft, lightweight fabric like cotton or jersey for a flowy, comfortable look. Looking specifically for a solid red color, preferably in shades like burgundy, crimson, or wine."
top_subcategories = get_top_subcategories(description, model, category_embeddings, subcategory_embeddings)

['jewelry_and_watches', 'clothing_shoes_and_accessories', 'sporting_goods']


In [31]:
# Output the top 5 subcategories and their similarity scores
for subcategory, score in top_subcategories:
    print(f"Subcategory: {subcategory}, Similarity Score: {score:.4f}")

Subcategory: Womens-Clothing, Similarity Score: 0.5436
Subcategory: Mens-Clothing, Similarity Score: 0.4503
Subcategory: Mens-Shoes, Similarity Score: 0.3145
Subcategory: Fine-Earrings, Similarity Score: 0.2811
Subcategory: Womens-Shoes, Similarity Score: 0.2676
