# Object Bank Generation for Synthetic Data

Object bank generation by extracting objects from the TACO dataset (`datasets/taco_official`). The object bank will be saved to `datasets/object_bank_for_balancing` and can be used by `roboflow_augmentation.ipynb` for synthetic data generation and dataset balancing.

In [None]:
import os
import json
import cv2
import numpy as np
import yaml
from PIL import Image, ImageDraw
import random
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from collections import defaultdict, Counter
from pathlib import Path

print("Libraries imported successfully!")

## 1. Configuration

In [None]:
from pathlib import Path

# --- Paths ---
BASE_DIR = Path('')
TACO_DIR = BASE_DIR / 'datasets/taco_official'
TACO_ANNOTATIONS_PATH = TACO_DIR / 'annotations.json'
OBJECT_BANK_DIR = BASE_DIR / 'datasets/object_bank_for_balancing'

# --- Object Bank Parameters ---
OBJECTS_PER_CLASS_IN_BANK = 300  # Number of objects to extract per class
MIN_OBJECT_PIXEL_AREA = 40 * 40  # Minimum area for objects (in pixels)

# --- Target Classes (standard waste detection classes) ---
TARGET_CLASSES = ['glass', 'metal', 'organic', 'paper', 'plastic']

# --- TACO to Target Class Mapping ---
# This mapping connects the TACO dataset categories to our target classes.
TACO_TO_TARGET_MAPPING = {
    # Glass (class 0)
    'Glass bottle': 0, 'Glass cup': 0, 'Glass jar': 0, 'Broken glass': 0,
    
    # Metal (class 1) 
    'Aluminium foil': 1, 'Aluminium blister pack': 1, 'Drink can': 1, 
    'Food Can': 1, 'Pop tab': 1, 'Scrap metal': 1, 'Aerosol': 1, 'Metal lid': 1,
    
    # Organic (class 2)
    'Food waste': 2,
    
    # Paper (class 3)
    'Paper': 3, 'Paper cup': 3, 'Drink carton': 3, 'Normal paper': 3, 
    'Tissues': 3, 'Wrapping paper': 3, 'Magazine': 3, 'Carded blister pack': 3, 
    'Other carton': 3, 'Meal carton': 3, 'Pizza box': 3, 'Paper bag': 3,
    
    # Plastic (class 4)
    'Clear plastic bottle': 4, 'Other plastic bottle': 4, 'Plastic bottle cap': 4, 
    'Other plastic cup': 4, 'Plastic lid': 4, 'Shopping bag': 4, 'Plastic straw': 4, 
    'Other plastic wrapper': 4, 'Other plastic': 4, 'Styrofoam piece': 4, 
    'Plastic film': 4, 'Squeezable tube': 4, 'Plastic utensils': 4, 
    'Tupperware': 4, 'Plastic glooves': 4, 'Lighter': 4,
    
    # Ignored categories (mapped to -1)
    'Unlabeled litter': -1, 'Cigarette': -1, 'Shoe': -1, 'Battery': -1, 
    'Rope & strings': -1, 'Medical waste': -1,
}

# Create object bank directory
OBJECT_BANK_DIR.mkdir(exist_ok=True)

print("🚀 Object Bank Generation Setup")
print("=" * 50)
print(f"TACO Dataset Directory: {TACO_DIR}")
print(f"TACO Annotations: {TACO_ANNOTATIONS_PATH}")
print(f"Object Bank Output: {OBJECT_BANK_DIR}")
print(f"Target Classes: {TARGET_CLASSES}")
print(f"Objects per class: {OBJECTS_PER_CLASS_IN_BANK}")
print(f"Minimum object area: {MIN_OBJECT_PIXEL_AREA} pixels")

# Verify TACO dataset exists
if not TACO_DIR.exists():
    print(f"❌ ERROR: TACO directory not found at {TACO_DIR}")
elif not TACO_ANNOTATIONS_PATH.exists():
    print(f"❌ ERROR: TACO annotations not found at {TACO_ANNOTATIONS_PATH}")
else:
    print("✅ TACO dataset found")

print(f"\nTACO category mappings:")
for taco_cat, target_id in sorted(TACO_TO_TARGET_MAPPING.items()):
    if target_id != -1:
        target_name = TARGET_CLASSES[target_id]
        print(f"  {taco_cat} → {target_name} (ID: {target_id})")
    else:
        print(f"  {taco_cat} → IGNORED")

## 2. TACO Dataset Analysis

In [None]:
def analyze_taco_dataset():
    """Analyze the TACO dataset to understand available categories and annotations."""
    print("🔍 ANALYZING TACO DATASET")
    print("=" * 50)
    
    if not TACO_ANNOTATIONS_PATH.exists():
        print(f"❌ ERROR: Annotations file not found at {TACO_ANNOTATIONS_PATH}")
        return None, None
    
    print("Loading TACO annotations...")
    with open(TACO_ANNOTATIONS_PATH, 'r') as f:
        taco_data = json.load(f)
    
    print(f"✅ Loaded TACO dataset:")
    print(f"  - Images: {len(taco_data['images']):,}")
    print(f"  - Annotations: {len(taco_data['annotations']):,}")
    print(f"  - Categories: {len(taco_data['categories'])}")
    
    # Create category lookup
    categories_info = {cat['id']: cat for cat in taco_data['categories']}
    
    print(f"\n📊 TACO Categories:")
    category_stats = {}
    for cat_id, cat_info in categories_info.items():
        supercategory = cat_info['supercategory']
        name = cat_info['name']
        target_class_id = TACO_TO_TARGET_MAPPING.get(supercategory, -1)
        target_name = TARGET_CLASSES[target_class_id] if target_class_id != -1 else "IGNORED"
        
        print(f"  {name} ({supercategory}) → {target_name}")
        
        if target_class_id != -1:
            if target_class_id not in category_stats:
                category_stats[target_class_id] = []
            category_stats[target_class_id].append({
                'id': cat_id,
                'name': name,
                'supercategory': supercategory
            })
    
    # Count annotations per target class
    print(f"\n📈 Annotations per target class:")
    target_class_counts = defaultdict(int)
    valid_annotations = []
    
    for ann in taco_data['annotations']:
        cat_id = ann.get('category_id')
        if cat_id in categories_info:
            supercategory = categories_info[cat_id]['supercategory']
            target_class_id = TACO_TO_TARGET_MAPPING.get(supercategory, -1)
            
            if target_class_id != -1 and ann.get('area', 0) >= MIN_OBJECT_PIXEL_AREA:
                target_class_counts[target_class_id] += 1
                valid_annotations.append(ann)
    
    for class_id, class_name in enumerate(TARGET_CLASSES):
        count = target_class_counts.get(class_id, 0)
        print(f"  {class_name} (ID {class_id}): {count:,} valid annotations")
    
    total_valid = len(valid_annotations)
    total_annotations = len(taco_data['annotations'])
    print(f"\nValid annotations (area >= {MIN_OBJECT_PIXEL_AREA} pixels): {total_valid:,} / {total_annotations:,} ({total_valid/total_annotations*100:.1f}%)")
    
    # Create visualization
    if target_class_counts:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Bar chart of annotations per class
        classes = [TARGET_CLASSES[i] for i in sorted(target_class_counts.keys())]
        counts = [target_class_counts[i] for i in sorted(target_class_counts.keys())]
        
        colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7'][:len(classes)]
        bars = ax1.bar(classes, counts, color=colors, alpha=0.8)
        ax1.set_title('Valid TACO Annotations per Target Class')
        ax1.set_xlabel('Target Class')
        ax1.set_ylabel('Number of Annotations')
        ax1.tick_params(axis='x', rotation=45)
        
        # Add value labels on bars
        for bar, count in zip(bars, counts):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(counts)*0.01, 
                    f'{count:,}', ha='center', va='bottom', fontweight='bold')
        
        # Pie chart of class distribution
        ax2.pie(counts, labels=classes, colors=colors, autopct='%1.1f%%', startangle=90)
        ax2.set_title('Distribution of Valid Annotations')
        
        plt.tight_layout()
        plt.show()
    
    return taco_data, categories_info

# Run TACO analysis
taco_data, categories_info = analyze_taco_dataset()

## 3. Object Extraction and Bank Creation

In [None]:
def find_image_path(image_info, taco_dir):
    """Finds the full path of a TACO image, searching in batch folders."""
    img_filename = image_info['file_name']
    
    # First, check if the filename is already a relative path that exists
    path1 = taco_dir / img_filename
    if path1.exists():
        return path1

    # If not, search inside batch_* subdirectories, which is the standard TACO structure
    if '/' not in str(img_filename) and '\\' not in str(img_filename):
        for d in os.listdir(taco_dir):
            if d.startswith('batch_') and os.path.isdir(taco_dir / d):
                path2 = taco_dir / d / img_filename
                if path2.exists():
                    return path2
    return None

def extract_object(image, segmentation):
    """Extracts an object from an image using its segmentation mask, returning an RGBA image."""
    img_h, img_w = image.shape[:2]
    
    mask = np.zeros((img_h, img_w), dtype=np.uint8)
    for poly in segmentation:
        if len(poly) < 6: continue
        pts = np.array(poly, np.int32).reshape((-1, 2))
        cv2.fillPoly(mask, [pts], 1)

    if np.sum(mask) == 0:
        return None

    # Find bounding box of the mask to crop
    y_indices, x_indices = np.where(mask)
    if len(y_indices) == 0 or len(x_indices) == 0:
        return None
        
    y_min, y_max = y_indices.min(), y_indices.max()
    x_min, x_max = x_indices.min(), x_indices.max()

    # Ensure we have a valid bounding box
    if y_max <= y_min or x_max <= x_min:
        return None

    # Crop image and mask
    cropped_img_bgr = image[y_min:y_max+1, x_min:x_max+1]
    cropped_mask = mask[y_min:y_max+1, x_min:x_max+1]

    # Create 4-channel RGBA image
    rgba_object = cv2.cvtColor(cropped_img_bgr, cv2.COLOR_BGR2RGBA)
    rgba_object[:, :, 3] = cropped_mask * 255  # Apply mask to alpha channel
    
    return Image.fromarray(rgba_object)

def create_object_bank():
    """Creates the object bank by extracting objects from the TACO dataset."""
    print("\n🌟 STARTING OBJECT BANK CREATION")
    print("=" * 60)
    
    if taco_data is None or categories_info is None:
        print("❌ ERROR: TACO data not loaded. Please run the analysis first.")
        return False
    
    print(f"📊 Processing {len(taco_data['annotations'])} total annotations...")
    
    images_info = {img['id']: img for img in taco_data['images']}
    
    # Group annotations by our target classes
    class_annotations = defaultdict(list)
    print("🔍 Filtering annotations based on mapping and area...")
    
    for ann in tqdm(taco_data['annotations'], desc="Filtering annotations"):
        cat_id = ann.get('category_id')
        if cat_id in categories_info:
            supercategory = categories_info[cat_id]['supercategory']
            target_class_id = TACO_TO_TARGET_MAPPING.get(supercategory, -1)
            
            if target_class_id != -1 and ann.get('area', 0) >= MIN_OBJECT_PIXEL_AREA:
                class_annotations[target_class_id].append(ann)
    
    print(f"✅ Found matching annotations:")
    for class_id in range(len(TARGET_CLASSES)):
        count = len(class_annotations.get(class_id, []))
        print(f"  {TARGET_CLASSES[class_id]}: {count:,} annotations")
    
    if not class_annotations:
        print("❌ ERROR: No annotations matched the filter criteria.")
        return False

    # Create class directories in object bank
    print(f"\n📁 Creating class directories...")
    for class_name in TARGET_CLASSES:
        class_dir = OBJECT_BANK_DIR / class_name
        class_dir.mkdir(exist_ok=True)
        print(f"  Created: {class_dir}")

    # Extract and save objects
    print(f"\n🎨 EXTRACTING OBJECTS")
    print("-" * 40)
    
    total_extracted_count = 0
    extraction_stats = {}
    
    for class_id, class_name in enumerate(TARGET_CLASSES):
        print(f"\n🔄 Processing class: {class_name}")
        
        annotations_for_class = class_annotations.get(class_id, [])
        if not annotations_for_class:
            print(f"  ⚠️ No annotations found for {class_name}")
            extraction_stats[class_name] = {
                'extracted': 0,
                'attempted': 0,
                'img_not_found': 0,
                'extraction_failed': 0
            }
            continue
        
        # Shuffle annotations for better variety
        random.shuffle(annotations_for_class)
        
        object_count = 0
        img_not_found_count = 0
        extraction_failed_count = 0
        attempted_count = 0
        
        target_count = min(len(annotations_for_class), OBJECTS_PER_CLASS_IN_BANK)
        pbar = tqdm(total=target_count, desc=f"Extracting {class_name}")
        
        for ann in annotations_for_class:
            if object_count >= OBJECTS_PER_CLASS_IN_BANK:
                break
                
            attempted_count += 1
            img_info = images_info.get(ann['image_id'])
            if not img_info: 
                extraction_failed_count += 1
                continue
            
            img_path = find_image_path(img_info, TACO_DIR)
            if not img_path:
                img_not_found_count += 1
                continue

            try:
                image = cv2.imread(str(img_path))
                if image is None:
                    extraction_failed_count += 1
                    continue
            except Exception:
                extraction_failed_count += 1
                continue

            extracted_obj = extract_object(image, ann['segmentation'])
            
            if extracted_obj is not None:
                save_path = OBJECT_BANK_DIR / class_name / f"{object_count:04d}.png"
                try:
                    extracted_obj.save(save_path)
                    object_count += 1
                    total_extracted_count += 1
                    pbar.update(1)
                except Exception as e:
                    extraction_failed_count += 1
                    print(f"    ⚠️ Failed to save {save_path}: {e}")
            else:
                extraction_failed_count += 1

        pbar.close()
        
        # Store statistics
        extraction_stats[class_name] = {
            'extracted': object_count,
            'attempted': attempted_count,
            'img_not_found': img_not_found_count,
            'extraction_failed': extraction_failed_count
        }
        
        print(f"  ✅ Extracted: {object_count}/{target_count} objects")
        if img_not_found_count > 0:
            print(f"  ⚠️ Images not found: {img_not_found_count}")
        if extraction_failed_count > 0:
            print(f"  ⚠️ Extraction failures: {extraction_failed_count}")

    # Final summary
    print(f"\n🎉 OBJECT BANK CREATION COMPLETE")
    print("=" * 50)
    print(f"Total objects extracted: {total_extracted_count:,}")
    print(f"Object bank location: {OBJECT_BANK_DIR}")
    
    # Create summary visualization
    if extraction_stats:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Extracted objects per class
        classes = list(extraction_stats.keys())
        extracted_counts = [stats['extracted'] for stats in extraction_stats.values()]
        target_counts = [OBJECTS_PER_CLASS_IN_BANK] * len(classes)
        
        x = np.arange(len(classes))
        width = 0.35
        
        bars1 = ax1.bar(x - width/2, extracted_counts, width, label='Extracted', alpha=0.8, color='green')
        bars2 = ax1.bar(x + width/2, target_counts, width, label='Target', alpha=0.6, color='lightgray')
        
        ax1.set_title('Objects Extracted per Class')
        ax1.set_xlabel('Class')
        ax1.set_ylabel('Number of Objects')
        ax1.set_xticks(x)
        ax1.set_xticklabels(classes, rotation=45)
        ax1.legend()
        
        # Add value labels
        for bar, count in zip(bars1, extracted_counts):
            ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(extracted_counts)*0.01, 
                    f'{count}', ha='center', va='bottom', fontweight='bold')
        
        # Success rate per class
        success_rates = []
        for stats in extraction_stats.values():
            attempted = stats['attempted']
            extracted = stats['extracted']
            rate = (extracted / attempted * 100) if attempted > 0 else 0
            success_rates.append(rate)
        
        bars3 = ax2.bar(classes, success_rates, alpha=0.8, color='orange')
        ax2.set_title('Extraction Success Rate')
        ax2.set_xlabel('Class')
        ax2.set_ylabel('Success Rate (%)')
        ax2.tick_params(axis='x', rotation=45)
        ax2.set_ylim(0, 100)
        
        # Add value labels
        for bar, rate in zip(bars3, success_rates):
            ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2, 
                    f'{rate:.1f}%', ha='center', va='bottom', fontweight='bold')
        
        plt.tight_layout()
        plt.show()
    
    if total_extracted_count == 0:
        print("❌ ERROR: No objects were extracted. Please check the configuration and TACO dataset.")
        return False
    
    print("✅ Object bank creation successful!")
    return True

# Create the object bank
if taco_data is not None:
    success = create_object_bank()
else:
    print("❌ Cannot create object bank: TACO data not loaded.")

## 4. Object Bank Verification

In [None]:
def verify_object_bank():
    """Verifies the created object bank and shows statistics."""
    print("🔍 VERIFYING OBJECT BANK")
    print("=" * 50)
    
    if not OBJECT_BANK_DIR.exists():
        print(f"❌ Object bank directory not found: {OBJECT_BANK_DIR}")
        return False
    
    bank_stats = {}
    total_objects = 0
    
    print("📊 Object bank contents:")
    for class_name in TARGET_CLASSES:
        class_dir = OBJECT_BANK_DIR / class_name
        if class_dir.exists():
            objects = list(class_dir.glob('*.png'))
            count = len(objects)
            bank_stats[class_name] = {
                'count': count,
                'objects': objects[:10]  # Store first 10 for visualization
            }
            total_objects += count
            status = "✅" if count > 0 else "⚠️"
            print(f"  {status} {class_name}: {count:,} objects")
        else:
            bank_stats[class_name] = {'count': 0, 'objects': []}
            print(f"  ❌ {class_name}: directory not found")
    
    print(f"\n📈 Total objects in bank: {total_objects:,}")
    
    if total_objects == 0:
        print("❌ Object bank is empty!")
        return False
    
    # Calculate statistics
    counts = [stats['count'] for stats in bank_stats.values()]
    avg_per_class = np.mean(counts)
    min_count = min(counts)
    max_count = max(counts)
    
    print(f"📊 Statistics:")
    print(f"  Average per class: {avg_per_class:.1f}")
    print(f"  Minimum per class: {min_count}")
    print(f"  Maximum per class: {max_count}")
    print(f"  Balance ratio: {min_count/max_count:.2f}" if max_count > 0 else "  Balance ratio: 0.00")
    
    # Create visualization
    if any(counts):
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        fig.suptitle('Object Bank Verification', fontsize=16, fontweight='bold')
        
        # 1. Object counts bar chart
        colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7']
        bars = axes[0,0].bar(TARGET_CLASSES, counts, color=colors, alpha=0.8)
        axes[0,0].set_title('Objects per Class')
        axes[0,0].set_ylabel('Number of Objects')
        axes[0,0].tick_params(axis='x', rotation=45)
        
        # Add value labels
        for bar, count in zip(bars, counts):
            axes[0,0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(counts)*0.01, 
                          f'{count}', ha='center', va='bottom', fontweight='bold')
        
        # Add target line
        axes[0,0].axhline(y=OBJECTS_PER_CLASS_IN_BANK, color='red', linestyle='--', 
                         label=f'Target: {OBJECTS_PER_CLASS_IN_BANK}')
        axes[0,0].legend()
        
        # 2. Pie chart of distribution
        axes[0,1].pie(counts, labels=TARGET_CLASSES, colors=colors, autopct='%1.1f%%', 
                     startangle=90)
        axes[0,1].set_title('Class Distribution')
        
        # 3. Balance analysis
        balance_scores = [count/max_count if max_count > 0 else 0 for count in counts]
        bars2 = axes[0,2].bar(TARGET_CLASSES, balance_scores, color=colors, alpha=0.8)
        axes[0,2].set_title('Class Balance Score')
        axes[0,2].set_ylabel('Balance Score (0-1)')
        axes[0,2].tick_params(axis='x', rotation=45)
        axes[0,2].set_ylim(0, 1)
        
        # Add value labels
        for bar, score in zip(bars2, balance_scores):
            axes[0,2].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.02, 
                          f'{score:.2f}', ha='center', va='bottom', fontweight='bold')
        
        # 4-6. Sample objects from each class (show first 3 classes)
        for idx, class_name in enumerate(TARGET_CLASSES[:3]):
            ax = axes[1, idx]
            objects = bank_stats[class_name]['objects']
            
            if objects:
                # Create a grid of sample objects
                sample_objects = objects[:9]  # Show up to 9 objects
                grid_size = int(np.ceil(np.sqrt(len(sample_objects))))
                
                # Create composite image
                if sample_objects:
                    # Load and resize sample objects
                    sample_imgs = []
                    for obj_path in sample_objects:
                        try:
                            img = Image.open(obj_path)
                            img.thumbnail((64, 64), Image.Resampling.LANCZOS)
                            sample_imgs.append(np.array(img))
                        except:
                            continue
                    
                    if sample_imgs:
                        # Create grid
                        rows = []
                        for i in range(0, len(sample_imgs), grid_size):
                            row_imgs = sample_imgs[i:i+grid_size]
                            # Pad row if needed
                            while len(row_imgs) < grid_size:
                                row_imgs.append(np.zeros_like(sample_imgs[0]) if sample_imgs else np.zeros((64,64,4), dtype=np.uint8))
                            
                            # Concatenate horizontally
                            row = np.concatenate(row_imgs, axis=1)
                            rows.append(row)
                        
                        if rows:
                            # Concatenate vertically
                            grid = np.concatenate(rows, axis=0)
                            ax.imshow(grid)
                            ax.set_title(f'{class_name.title()} Samples\n({bank_stats[class_name]["count"]} objects)')
                        else:
                            ax.text(0.5, 0.5, f'No valid\n{class_name} objects', 
                                   ha='center', va='center', transform=ax.transAxes)
                            ax.set_title(f'{class_name.title()}\n(0 objects)')
                    else:
                        ax.text(0.5, 0.5, f'Cannot load\n{class_name} objects', 
                               ha='center', va='center', transform=ax.transAxes)
                        ax.set_title(f'{class_name.title()}\n({bank_stats[class_name]["count"]} objects)')
                else:
                    ax.text(0.5, 0.5, f'No {class_name}\nobjects found', 
                           ha='center', va='center', transform=ax.transAxes)
                    ax.set_title(f'{class_name.title()}\n(0 objects)')
            else:
                ax.text(0.5, 0.5, f'No {class_name}\nobjects', 
                       ha='center', va='center', transform=ax.transAxes)
                ax.set_title(f'{class_name.title()}\n(0 objects)')
            
            ax.axis('off')
        
        plt.tight_layout()
        plt.show()
    
    # Quality assessment
    if min_count >= OBJECTS_PER_CLASS_IN_BANK * 0.8:
        print("✅ EXCELLENT: Object bank is well-populated across all classes!")
    elif min_count >= OBJECTS_PER_CLASS_IN_BANK * 0.5:
        print("⚠️  GOOD: Object bank has reasonable coverage, but some classes are under-represented.")
    elif total_objects > 0:
        print("⚠️  FAIR: Object bank exists but has significant imbalances.")
    else:
        print("❌ POOR: Object bank is empty or severely lacking.")
    
    return total_objects > 0

# Verify the object bank
verification_success = verify_object_bank()