In [3]:
import os
import pickle
import numpy as np
from PIL import Image
from tqdm import tqdm

# CIFAR-10 类别名
CIFAR10_CLASSES = [
    'airplane', 'automobile', 'bird', 'cat', 'deer',
    'dog', 'frog', 'horse', 'ship', 'truck'
]

def unpickle(file):
    with open(file, 'rb') as fo:
        return pickle.load(fo, encoding='bytes')

def convert_to_caltech_style(cifar_root, output_dir):
    os.makedirs(output_dir, exist_ok=True)

    # 加载训练数据
    for batch_id in range(1, 6):
        batch = unpickle(os.path.join(cifar_root, f'data_batch_{batch_id}'))
        save_images(batch, output_dir, is_train=True, batch_id=batch_id)

    # 加载测试数据
    test_batch = unpickle(os.path.join(cifar_root, 'test_batch'))
    save_images(test_batch, output_dir, is_train=False, batch_id=0)

def save_images(batch, output_dir, is_train, batch_id):
    data = batch[b'data']  # [10000, 3072]
    labels = batch[b'labels']  # [10000]
    filenames = batch[b'filenames']

    for i in tqdm(range(len(data)), desc=f"{'Train' if is_train else 'Test'} batch {batch_id}"):
        img = data[i].reshape(3, 32, 32).transpose(1, 2, 0)  # 转换为 HWC 格式
        label = labels[i]
        cls_name = CIFAR10_CLASSES[label]

        class_dir = os.path.join(output_dir, cls_name)
        os.makedirs(class_dir, exist_ok=True)

        # 图像保存路径
        fname = filenames[i].decode('utf-8')
        img_path = os.path.join(class_dir, fname)
        Image.fromarray(img).save(img_path)

# 示例调用
cifar10_py_folder = "./cifar-10-batches-py"  # 解压后的路径
output_folder = "./cifar10_caltech_style"

convert_to_caltech_style(cifar10_py_folder, output_folder)


Train batch 1: 100%|██████████| 10000/10000 [00:04<00:00, 2176.05it/s]
Train batch 2: 100%|██████████| 10000/10000 [00:03<00:00, 3227.87it/s]
Train batch 3: 100%|██████████| 10000/10000 [00:03<00:00, 3283.38it/s]
Train batch 4: 100%|██████████| 10000/10000 [00:03<00:00, 3190.60it/s]
Train batch 5: 100%|██████████| 10000/10000 [00:03<00:00, 3183.91it/s]
Test batch 0: 100%|██████████| 10000/10000 [00:03<00:00, 2860.47it/s]


In [5]:
import os
import csv
from pathlib import Path

def listdir_nohidden(path):
    """列出目录下所有非隐藏文件和文件夹"""
    return [f for f in os.listdir(path) if not f.startswith('.')]

def generate_csv(image_dir, save_path, ignored_categories=None, new_cnames=None):
    """
    生成Caltech101数据集的标注文件。
    
    Args:
        image_dir (str): 数据集顶层目录路径。
        save_path (str): 保存生成的csv文件的路径。
        ignored_categories (list, optional): 要忽略的类别列表。默认为None。
        new_cnames (dict, optional): 类别名称映射字典。默认为None。
    """
    if ignored_categories is None:
        ignored_categories = []
    
    # 确保保存路径的目录存在
    save_dir = os.path.dirname(save_path)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    # 获取类别列表
    categories = listdir_nohidden(image_dir)
    categories = [c for c in categories if c not in ignored_categories]
    categories.sort()
    
    # 准备数据
    data = []
    for label, category in enumerate(categories):
        # 获取类别目录下的图像列表
        category_dir = os.path.join(image_dir, category)
        images = listdir_nohidden(category_dir)
        images = [os.path.join(category_dir, im) for im in images]
        
        # 更新类别名称（如果存在映射）
        if new_cnames is not None and category in new_cnames:
            category = new_cnames[category]
        
        # 添加到数据列表
        for image_path in images:
            data.append({
                'id': len(data),
                'image_path': image_path,
                'label': category
            })
    
    # 写入csv文件
    with open(save_path, mode='w') as file:
        writer = csv.DictWriter(file, fieldnames=['id', 'image_path', 'label'])
        writer.writeheader()
        writer.writerows(data)

# 示例用法
if __name__ == "__main__":
    # 数据集路径和保存路径
    image_dir = '/root/autodl-tmp/cifar-10/cifar10'  # Caltech101数据集路径
    save_path = '/root/autodl-tmp/cifar-10/cifar10.csv'  # 保存路径
    
    # 忽略的类别和类别名称映射（如果有）
    ignored_categories = []  # 根据实际需要调整
    new_cnames = None  # 如果需要映射类别名称，可以在这里定义
    
    # 生成csv文件
    generate_csv(image_dir, save_path, ignored_categories, new_cnames)
    print(f"标注文件已生成并保存到: {save_path}")

标注文件已生成并保存到: /root/autodl-tmp/cifar-10/cifar10.csv


In [4]:
!pwd

/root/autodl-tmp/cifar-10
