In [None]:
import os
import pickle
import numpy as np
from PIL import Image
from tqdm import tqdm

# CIFAR-10 class names
CIFAR10_CLASSES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

def unpickle(file):
    with open(file, 'rb') as fo:
        return pickle.load(fo, encoding='bytes')

def convert_to_caltech_style(cifar_root, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # Load training data
    for batch_id in range(1, 6):
        batch = unpickle(os.path.join(cifar_root, f'data_batch_{batch_id}'))
        save_images(batch, output_dir, is_train=True, batch_id=batch_id)

    # Load test data
    test_batch = unpickle(os.path.join(cifar_root, 'test_batch'))
    save_images(test_batch, output_dir, is_train=False, batch_id=0)

def save_images(batch, output_dir, is_train, batch_id):
    data = batch[b'data']  # [10000, 3072]
    labels = batch[b'labels']  # [10000]
    filenames = batch[b'filenames']

    for i in tqdm(range(len(data)), desc=f"{'Train' if is_train else 'Test'} batch {batch_id}"):
        img = data[i].reshape(3, 32, 32).transpose(1, 2, 0)  # Convert to HWC format
        label = labels[i]
        cls_name = CIFAR10_CLASSES[label]

        class_dir = os.path.join(output_dir, cls_name)
        os.makedirs(class_dir, exist_ok=True)

        # Image save path
        fname = filenames[i].decode('utf-8')
        img_path = os.path.join(class_dir, fname)
        Image.fromarray(img).save(img_path)

# Example usage
cifar10_py_folder = "./cifar-10-batches-py"  # Path to the extracted CIFAR-10 dataset
output_folder = "./cifar10_caltech_style"

convert_to_caltech_style(cifar10_py_folder, output_folder)


Train batch 1: 100%|██████████| 10000/10000 [00:04<00:00, 2176.05it/s]
Train batch 2: 100%|██████████| 10000/10000 [00:03<00:00, 3227.87it/s]
Train batch 3: 100%|██████████| 10000/10000 [00:03<00:00, 3283.38it/s]
Train batch 4: 100%|██████████| 10000/10000 [00:03<00:00, 3190.60it/s]
Train batch 5: 100%|██████████| 10000/10000 [00:03<00:00, 3183.91it/s]
Test batch 0: 100%|██████████| 10000/10000 [00:03<00:00, 2860.47it/s]


In [None]:
import os
import csv
from pathlib import Path

def listdir_nohidden(path):
    """List all non-hidden files and folders in the directory."""
    return [f for f in os.listdir(path) if not f.startswith('.')]

def generate_csv(image_dir, save_path, ignored_categories=None, new_cnames=None):
    """
    Generate an annotation file for the Caltech101 dataset.
    
    Args:
        image_dir (str): Path to the top-level directory of the dataset.
        save_path (str): Path to save the generated CSV file.
        ignored_categories (list, optional): List of categories to ignore. Defaults to None.
        new_cnames (dict, optional): Mapping dictionary for category names. Defaults to None.
    """
    if ignored_categories is None:
        ignored_categories = []
    
    # Ensure the save path's directory exists
    save_dir = os.path.dirname(save_path)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # Get the list of categories
    categories = listdir_nohidden(image_dir)
    categories = [c for c in categories if c not in ignored_categories]
    categories.sort()
    
    # Prepare data
    data = []
    for label, category in enumerate(categories):
        # Get the list of images in the category directory
        category_dir = os.path.join(image_dir, category)
        images = listdir_nohidden(category_dir)
        images = [os.path.join(category_dir, im) for im in images]
        
        # Update category name (if mapping exists)
        if new_cnames is not None and category in new_cnames:
            category = new_cnames[category]
        
        # Add to the data list
        for image_path in images:
            data.append({
                'id': len(data),
                'image_path': image_path,
                'label': category
            })
    
    # Write to the CSV file
    with open(save_path, mode='w') as file:
        writer = csv.DictWriter(file, fieldnames=['id', 'image_path', 'label'])
        writer.writeheader()
        writer.writerows(data)

# Example usage
if __name__ == "__main__":
    # Dataset path and save path
    image_dir = '/root/autodl-tmp/cifar-10/cifar10'  # Path to the Caltech101 dataset
    save_path = '/root/autodl-tmp/cifar-10/cifar10.csv'  # Save path
    
    # Ignored categories and category name mapping (if any)
    ignored_categories = []  # Adjust as needed
    new_cnames = None  # Define category name mapping here if needed
    
    # Generate the CSV file
    generate_csv(image_dir, save_path, ignored_categories, new_cnames)
    print(f"Annotation file has been generated and saved to: {save_path}")

In [4]:
!pwd

/root/autodl-tmp/cifar-10
