In [19]:
import os
import json
import torch
from datasets import load_dataset
from PIL import Image
import pandas as pd
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm

In [20]:
def analyze_categories():
    # Load dataset
    dataset = load_dataset("detection-datasets/fashionpedia")
    print("Dataset loaded")
    
    # Initialize statistics
    stats = defaultdict(lambda: {
        'count': 0,
        'total_area': 0,
        'avg_area': 0,
        'areas': []
    })
    
    # Process training set
    print("Analyzing object statistics...")
    for item in tqdm(dataset['train']):
        for cat, area in zip(item['objects']['category'], item['objects']['area']):
            cat_name = dataset['train'].features['objects']['category'].names[cat]
            stats[cat_name]['count'] += 1
            stats[cat_name]['total_area'] += area
            stats[cat_name]['areas'].append(area)
    
    # Calculate averages and create DataFrame
    df = pd.DataFrame([{
        'category': cat,
        'count': data['count'],
        'avg_area': data['total_area'] / data['count'],
        'area_std': np.std(data['areas']),
        'is_part': cat in ['collar', 'pocket', 'sleeve', 'zipper', 'button', 'buckle']
    } for cat, data in stats.items()])
    
    # Sort by count and area
    df['score'] = df['count'] * df['avg_area']  # Combined metric
    df = df.sort_values('score', ascending=False)
    
    # Filter for main garments (not parts)
    main_garments = df[~df['is_part']].head(10)
    
    print("\nTop 10 Main Garments by frequency and size:")
    print(main_garments[['category', 'count', 'avg_area']].to_string())
    
    # Visualize top categories
    plt.figure(figsize=(15, 6))
    plt.bar(main_garments['category'], main_garments['avg_area'])
    plt.xticks(rotation=45, ha='right')
    plt.title('Average Area by Category')
    plt.ylabel('Average Area (pixels²)')
    plt.tight_layout()
    plt.show()
    
    return df

In [21]:
stats_df = analyze_categories()
    
# Recommended easy labels based on:
# 1. Large average area (easier to detect)
# 2. High count (more training examples)
# 3. Clear, distinct appearance
recommended_labels = [
    "coat",           # Large outer garment, distinct shape
    "dress",          # Full-body garment, large area
    "pants",          # Clear lower-body garment
    "shirt, blouse",  # Common upper-body garment
    "bag, wallet"     # Distinct accessory
]
    

SSLError: (MaxRetryError("HTTPSConnectionPool(host='cdn-lfs.hf.co', port=443): Max retries exceeded with url: /repos/76/2b/762b4c4b8680ad7f14a7ea37aa912882e6345f3c9722627115a6eb005356d61a/a575e7d8303d52197cd7d8a5e327321a602d3eeafe39e47a0e5d2c03041549c2?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27train-00000-of-00007-fe108070118553c3.parquet%3B+filename%3D%22train-00000-of-00007-fe108070118553c3.parquet%22%3B&Expires=1731092917&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczMTA5MjkxN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5oZi5jby9yZXBvcy83Ni8yYi83NjJiNGM0Yjg2ODBhZDdmMTRhN2VhMzdhYTkxMjg4MmU2MzQ1ZjNjOTcyMjYyNzExNWE2ZWIwMDUzNTZkNjFhL2E1NzVlN2Q4MzAzZDUyMTk3Y2Q3ZDhhNWUzMjczMjFhNjAyZDNlZWFmZTM5ZTQ3YTBlNWQyYzAzMDQxNTQ5YzI~cmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=EEVMsgpQvm6DXCVk3dHPnAMx2WPk0vcVjYnIz7ZnrtVv1Nxb3fpsjge7GFZ-lPDKgnRrrxpwQl42wR0UIPdVqKc48q-d8yKLSvMDHQTrgPcAed8Y9UH3VDBnjrTL9KxR4GRKE3SqSqK4J29AqQXsspeCHmxMPiXPT~isB~1qXtiS0qbjJw~bGnRjOV9BNoP29ftrTXtblkqdYPHAhVTB53k3ZQOmCoEnnrlmp03GpD5A3N7sCfVhM-Es8A5gx1dptTVpCU0Enf8Y6IFniLL1tK7h5oc0aWRGn4Jn2mKz-yEEMB8UT91zgvUGiqo50jUkjNHc729KaQ2hX2iaJCS31Q__&Key-Pair-Id=K3RPWS32NSSJCE (Caused by SSLError(SSLCertVerificationError(1, '[SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1131)')))"), '(Request ID: 1e6f8f67-77dd-4a9e-84d0-93fadf994074)')