In [1]:
import os
import json
import torch
from datasets import load_dataset
from PIL import Image
import pandas as pd
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
import torch
from datasets import load_dataset
from collections import Counter
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import json
from tqdm import tqdm

class FashionpediaPrep:
    def __init__(self, cache_dir=None):
        self.dataset = load_dataset("detection-datasets/fashionpedia", cache_dir=cache_dir)
        print(f"Dataset loaded with splits: {self.dataset.keys()}")
        
    def analyze_labels(self):
        """Analyze and display label statistics"""
        # Collect all labels
        train_labels = []
        val_labels = []
        
        # Process training set
        for item in self.dataset['train']:
            train_labels.extend(item['labels'])
            
        # Process validation set
        for item in self.dataset['validation']:
            val_labels.extend(item['labels'])
            
        # Count frequencies
        train_counter = Counter(train_labels)
        val_counter = Counter(val_labels)
        
        # Create DataFrame for better visualization
        df = pd.DataFrame({
            'Label': list(set(train_counter.keys()) | set(val_counter.keys())),
            'Train Count': [train_counter.get(label, 0) for label in set(train_counter.keys()) | set(val_counter.keys())],
            'Val Count': [val_counter.get(label, 0) for label in set(train_counter.keys()) | set(val_counter.keys())]
        })
        
        df['Total'] = df['Train Count'] + df['Val Count']
        df = df.sort_values('Total', ascending=False)
        
        print("\nLabel Statistics:")
        print(df)
        
        # Plot top 20 labels
        plt.figure(figsize=(15, 8))
        top_20 = df.head(20)
        
        x = range(len(top_20))
        plt.bar(x, top_20['Train Count'], label='Train', alpha=0.8)
        plt.bar(x, top_20['Val Count'], bottom=top_20['Train Count'], label='Val', alpha=0.8)
        
        plt.xticks(x, top_20['Label'], rotation=45, ha='right')
        plt.title('Top 20 Labels Distribution')
        plt.legend()
        plt.tight_layout()
        plt.show()
        
        return df
    
    def create_subset(self, selected_labels, n_train, m_val, output_dir="fashion_subset"):
        """Create subset with selected labels and specified sizes"""
        output_dir = Path(output_dir)
        output_dir.mkdir(exist_ok=True)
        
        def filter_and_select(split, n_samples, selected_labels):
            filtered_data = []
            
            for item in tqdm(self.dataset[split], desc=f"Processing {split}"):
                # Check if any selected label is in this image
                if any(label in selected_labels for label in item['labels']):
                    # Filter boxes and labels to only include selected labels
                    selected_indices = [i for i, label in enumerate(item['labels']) 
                                     if label in selected_labels]
                    
                    if selected_indices:  # If we found any selected labels
                        filtered_item = {
                            'image': item['image'],
                            'boxes': [item['boxes'][i] for i in selected_indices],
                            'labels': [item['labels'][i] for i in selected_indices],
                            'image_id': item['image_id']
                        }
                        filtered_data.append(filtered_item)
                        
                if len(filtered_data) >= n_samples:
                    break
                    
            return filtered_data

        # Create train and val subsets
        train_subset = filter_and_select('train', n_train, selected_labels)
        val_subset = filter_and_select('validation', m_val, selected_labels)
        
        # Save images and annotations
        for split, data in [('train', train_subset), ('val', val_subset)]:
            split_dir = output_dir / split
            split_dir.mkdir(exist_ok=True)
            
            # Save images and create COCO format annotations
            coco_annotations = {
                "images": [],
                "annotations": [],
                "categories": [{"id": i, "name": label} 
                             for i, label in enumerate(selected_labels, 1)]
            }
            
            ann_id = 0
            for idx, item in enumerate(tqdm(data, desc=f"Saving {split}")):
                # Save image
                image_path = split_dir / f"{item['image_id']}.jpg"
                item['image'].save(image_path)
                
                # Add image info
                coco_annotations["images"].append({
                    "id": idx,
                    "file_name": image_path.name,
                    "width": item['image'].width,
                    "height": item['image'].height
                })
                
                # Add annotations
                for box, label in zip(item['boxes'], item['labels']):
                    coco_annotations["annotations"].append({
                        "id": ann_id,
                        "image_id": idx,
                        "category_id": selected_labels.index(label) + 1,
                        "bbox": box,
                        "area": box[2] * box[3],
                        "iscrowd": 0
                    })
                    ann_id += 1
            
            # Save annotations
            with open(output_dir / f"{split}_annotations.json", 'w') as f:
                json.dump(coco_annotations, f)
        
        print(f"\nDataset created at {output_dir}")
        print(f"Train: {len(train_subset)} images")
        print(f"Val: {len(val_subset)} images")
        
        return output_dir

In [10]:
prep = FashionpediaPrep()

# Analyze labels
label_stats = prep.analyze_labels()
    

DatasetNotFoundError: Dataset 'valentinafeve/yolos-fashionpedia' doesn't exist on the Hub or cannot be accessed.