In [1]:
# import packages
import os
import random
import json

In [2]:
# set global random seed
random.seed(123)

In [60]:
# load data dirs,get class name

# get root path
data_root = './birdsDatasets/'

# Based on the current root directory, traverse all objects in the directory and filter out all directories
# 根据当前根目录，遍历目录下的所有对象，过滤出所有的目录
first_path = [path for path in [os.path.join(data_root,objs) for objs in os.listdir(data_root)] if os.path.isdir(path)]

# Continue to traverse all objects under the directory based on the first level subdirectory and filter out all second level subdirectories
# 继续根据一级子目录，遍历目录下所有对象，过滤出所有二级子目录
second_path = []
for objs in first_path:
    
    # get second path
    second_path.append([path for path in [os.path.join(objs,class_path) for class_path in os.listdir(objs)] if os.path.isdir(path)])

In [59]:
# At this point, the second level subdirectory name is the category of all the samples provided by the dataset. Now find a way to obtain these category names
# 此时的二级子目录名称就是数据集提供的所有样本的类别，现在想办法获取这些类别名称

# prepare dict to save class name by dataset names
sample_class_dicts = {}

# iter path array,index iter index,sets iter items
for index,sets in enumerate(second_path):
    
    # all dataset classes name append to array
    classes_arr = []
    
    for classes in sets:
        
        # append
        classes_arr.append(classes.split('/')[-1])
    
    # get dict,key is dataset name,value is array
    sample_class_dicts[second_path[index][0].split('/')[-2]] = classes_arr

In [29]:
# 判断训练集、验证集、测试集中给定的所有样本类别是否全部一致
set(sample_class_dicts['valid']) == set(sample_class_dicts['train']) == set(sample_class_dicts['test'])

False

In [58]:
# 既然经过判断，各个子集下样本的类别不一致，则将所有类别全部拼接在一起，并去重，即得数据集给定的所有样本类别
all_classes = list(set(sample_class_dicts['valid'] + sample_class_dicts['train'] + sample_class_dicts['test']))

In [57]:
# 开始构造类别映射对象：类别 -- 对象 and 对象 -- 类别
cls_mappers = {
    
    # class to id
    'cls2id':{},
    
    # id to class
    'id2cls':{}
}

# iter index and value
for index,value in enumerate(all_classes):
    cls_mappers['cls2id'][value] = index
    cls_mappers['id2cls'][index] = value

In [36]:
# save json
json.dump(cls_mappers,open('./cls_mapper.json','w'))

In [37]:
# 准备函数：递归遍历指定根目录，遍历其中所有的文件，构建每一个文件的名称
def recursive_fetching(root, suffix=['jpg', 'png']):
    all_file_path = []

    def get_all_files(path):
        all_file_list = os.listdir(path)
        # 遍历该文件夹下的所有目录或者文件
        for file in all_file_list:
            filepath = os.path.join(path, file)
            # 如果是文件夹，递归调用函数
            if os.path.isdir(filepath):
                get_all_files(filepath)
            # 如果不是文件夹，保存文件路径及文件名
            elif os.path.isfile(filepath):
                all_file_path.append(filepath)

    get_all_files(root)

    file_paths = [it for it in all_file_path if os.path.split(it)[-1].split('.')[-1].lower() in suffix]

    return file_paths

In [56]:
dataset_items = recursive_fetching(data_root)
len(dataset_items)

89885

In [42]:
# 获取所有样本路径后，将样本打乱，随机重排，增加随机性
random.shuffle(dataset_items)

In [55]:
"""
# # 数据集的每一个类别及对应的数据list
dataset_dict = {
    0: ['./birdsDatasets/valid/CREAM COLORED WOODPECKER/4.jpg','./birdsDatasets/valid/CREAM COLORED WOODPECKER/5.jpg'...]
    1: ['./birdsDatasets/valid/CREAM COLORED WOODPECKER/4.jpg','./birdsDatasets/valid/CREAM COLORED WOODPECKER/5.jpg'...]
    ...
    xx:[xx,xx]
}
"""

dataset_dict = {}
for it in dataset_items:
    
    # get sample classes name
    cls_name = os.path.split(it)[0].split('/')[-1]
    
    # cls to id
    cls_id = cls_mappers['cls2id'][cls_name]
    
    # if cls_id not in data
    if cls_id not in dataset_dict:
        
        # init
        dataset_dict[cls_id] = [it]
    else:
        
        # append
        dataset_dict[cls_id].append(it)

In [49]:
# 每个类别按照比例分到train/eval/test
train_ratio, eval_ratio, test_ratio = 0.8, 0.1, 0.1

# prepare array
train_set, eval_set, test_set = [], [], []

# iter
for _, set_list in dataset_dict.items():
    
    # get data length:how many sample in dataset
    length = len(set_list)
    
    # calculate every set of sample numbers
    train_num, eval_num = int(length*train_ratio), int(length*eval_ratio)
    test_num = length - train_num - eval_num
    
    # shuffle
    random.shuffle(set_list)
    
    # generate finall set
    train_set.extend(set_list[:train_num])
    eval_set.extend(set_list[train_num:train_num+eval_num])
    test_set.extend(set_list[train_num+eval_num:])

In [52]:
# 再次随机打乱
random.shuffle(train_set)
random.shuffle(eval_set)
random.shuffle(test_set)

print(f'train set samples number:{len(train_set)};\ntest set samples number:{len(test_set)};\neval set samples number:{len(eval_set)};')

train set samples number:71719;
test set samples number:9388;
eval set samples number:8778;


In [53]:
# save
def save_meata_data(meta_path,datasets):
    
    # file write
    with open(meta_path,'w') as f:
        for path in datasets:
            
            # get class name
            cls_name = os.path.split(path)[0].split('/')[-1]
            
            # class name to id
            cls_id = cls_mappers['cls2id'][cls_name]
            
            # write
            f.write(f'{cls_id}|{path}\n')

In [54]:
# meta_path get path ,datasets get data
for meta_path,datasets in zip(['./train_meta_data.txt','./test_meta_data.txt','./eval_meta_data.txt'],[train_set,test_set,eval_set]):
    save_meata_data(meta_path,datasets)