In [None]:
import json
import os
import shutil
from collections import defaultdict
from tqdm import tqdm  # For progress bars
import time

# Paths
coco_json = "wcs_20220205_bboxes_with_classes.json"
cropped_dir = "yolo_subset/cropped_animals"
output_dir = "yolo_subset/resnet_dataset_corrected"

print(f"\n{'='*50}")
print("Starting Dataset Organization")
print(f"{'='*50}\n")
print(f"Current time: {time.strftime('%Y-%m-%d %H:%M:%S')}")
print(f"COCO JSON: {coco_json}")
print(f"Cropped images: {cropped_dir}")
print(f"Output directory: {output_dir}")

# Load COCO data
print("\n[1/4] Loading COCO annotations...")
start_time = time.time()
with open(coco_json) as f:
    coco_data = json.load(f)
load_time = time.time() - start_time
print(f"✓ Loaded {len(coco_data['images'])} images and {len(coco_data['annotations'])} annotations in {load_time:.2f}s")

# Create mapping from original paths to image data
print("\n[2/4] Creating path mappings...")
path_to_data = {}
path_conversions = 0

for img in tqdm(coco_data["images"], desc="Processing images"):
    original_path = img["file_name"]
    clean_path = original_path.replace("humans/", "animals/")
    if clean_path != original_path:
        path_conversions += 1
    path_to_data[clean_path] = img

print(f"  • Created {len(path_to_data)} path mappings")
print(f"  • Performed {path_conversions} path conversions (humans→animals)")

category_id_to_name = {cat["id"]: cat["name"] for cat in coco_data["categories"]}
print(f"\n  • Found {len(category_id_to_name)} categories:")
for i, (cat_id, cat_name) in enumerate(list(category_id_to_name.items())[:5]):
    print(f"    {cat_id}: {cat_name}" + ("..." if i == 4 else ""))

# Create class folders
print("\n[3/4] Preparing directory structure...")
os.makedirs(f"{output_dir}/train", exist_ok=True)
os.makedirs(f"{output_dir}/val", exist_ok=True)
print(f"✓ Created base directories at {output_dir}")

def get_original_path(crop_path):
    """Convert animals_0003_1406_crop0.jpg → animals/0003/1406.jpg"""
    parts = crop_path.split('_')
    if len(parts) >= 3 and parts[0] == "animals":
        return f"animals/{parts[1]}/{parts[2].split('.')[0]}.jpg"
    return None

print("\n[4/4] Organizing cropped images...")
total_stats = {"train": {"matched": 0, "unmatched": 0}, "val": {"matched": 0, "unmatched": 0}}

for split in ["train", "val"]:
    print(f"\n{'='*30}")
    print(f"Processing {split} set")
    print(f"{'='*30}")
    
    split_files = [f for f in os.listdir(f"{cropped_dir}/{split}") if f.lower().endswith((".jpg", ".jpeg", ".png"))]
    print(f"Found {len(split_files)} images to process")
    
    start_time = time.time()
    stats = {"matched": 0, "unmatched": 0}
    class_distribution = defaultdict(int)
    
    for crop_path in tqdm(split_files, desc=f"Processing {split}"):
        # Find original path
        original_path = get_original_path(crop_path)
        img_data = path_to_data.get(original_path) if original_path else None
        
        if not img_data:
            stats["unmatched"] += 1
            continue
            
        # Get classes
        classes = set()
        for ann in coco_data["annotations"]:
            if ann["image_id"] == img_data["id"]:
                classes.add(category_id_to_name[ann["category_id"]])
        
        if not classes:
            stats["unmatched"] += 1
            continue
            
        class_name = list(classes)[0].replace(" ", "_")
        os.makedirs(f"{output_dir}/{split}/{class_name}", exist_ok=True)
        
        shutil.copy2(
            f"{cropped_dir}/{split}/{crop_path}",
            f"{output_dir}/{split}/{class_name}/{crop_path}"
        )
        stats["matched"] += 1
        class_distribution[class_name] += 1
    
    total_stats[split] = stats
    elapsed = time.time() - start_time
    
    print(f"\n{split.upper()} Results:")
    print(f"- Processing time: {elapsed:.2f}s ({elapsed/max(1,len(split_files)):.3f}s per image)")
    print(f"- Matched: {stats['matched']} ({stats['matched']/max(1,len(split_files))*100:.1f}%)")
    print(f"- Unmatched: {stats['unmatched']}")
    
    print("\nClass distribution (top 5):")
    for cls, count in sorted(class_distribution.items(), key=lambda x: -x[1])[:5]:
        print(f"  {cls}: {count} images")

print(f"\n{'='*50}")
print("Organization Complete!")
print(f"{'='*50}\n")

total_matched = total_stats["train"]["matched"] + total_stats["val"]["matched"]
total_unmatched = total_stats["train"]["unmatched"] + total_stats["val"]["unmatched"]
total_images = total_matched + total_unmatched

print("Final Statistics:")
print(f"- Total matched images: {total_matched} ({total_matched/max(1,total_images)*100:.1f}%)")
print(f"- Total unmatched images: {total_unmatched}")
print(f"- Output directory: {output_dir}")

print("\nNext steps:")
print("1. Verify the directory structure")
print("2. Check unmatched images if percentage is high")
print("3. Proceed with training using the organized dataset")

# Save stats to file
with open(f"{output_dir}/organization_stats.txt", "w") as f:
    f.write("Dataset Organization Report\n")
    f.write(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n")
    f.write(f"Total matched images: {total_matched}\n")
    f.write(f"Total unmatched images: {total_unmatched}\n\n")
    f.write("Details per split:\n")
    for split in ["train", "val"]:
        f.write(f"{split.upper()}: {total_stats[split]['matched']} matched, {total_stats[split]['unmatched']} unmatched\n")