In [1]:
import random
from collections import defaultdict

def read_labels(filename):
    """读取标签文件，返回一个字典，键为文件路径，值为标签"""
    labels = defaultdict(list)
    with open(filename, 'r') as file:
        for line in file:
            parts = line.strip().split('\t')
            if len(parts) != 2:
                continue  # 跳过格式不正确的行
            filepath, label = parts
            labels[int(label)].append((filepath, label))
    return labels

def split_data(labels, train_ratio=0.9):
    """将数据按比例随机划分为训练集和测试集"""
    train_data = []
    test_data = []
    for label, items in labels.items():
        random.shuffle(items)  # 打乱同一标签下的数据
        split_index = int(len(items) * train_ratio)
        train_data.extend(items[:split_index])
        test_data.extend(items[split_index:])
    return train_data, test_data

def write_labels(data, filename):
    """将数据写入文件"""
    with open(filename, 'w') as file:
        for filepath, label in data:
            file.write(f"{filepath}\t{label}\n")

In [2]:
# 读取原始标签文件
labels = read_labels('./data/label/label.txt')

In [6]:
labels

defaultdict(list,
            {8: [('0458_c017_00003270_0.jpg', '8'),
              ('0227_c012_00032330_0.jpg', '8'),
              ('0075_c004_00047465_0.jpg', '8'),
              ('0455_c015_00021635_1.jpg', '8'),
              ('0538_c014_00055660_0.jpg', '8'),
              ('0433_c015_00050935_0.jpg', '8'),
              ('0171_c014_00021090_0.jpg', '8'),
              ('0072_c009_00087225_0.jpg', '8'),
              ('0483_c018_00066620_0.jpg', '8'),
              ('0535_c003_00067595_0.jpg', '8'),
              ('0645_c017_00054320_0.jpg', '8'),
              ('0432_c016_00061025_0.jpg', '8'),
              ('0171_c012_00045645_0.jpg', '8'),
              ('0491_c012_00056880_1.jpg', '8'),
              ('0562_c017_00031615_0.jpg', '8'),
              ('0415_c019_00030610_0.jpg', '8'),
              ('0434_c016_00057455_0.jpg', '8'),
              ('0505_c003_00087465_1.jpg', '8'),
              ('0419_c010_00017840_0.jpg', '8'),
              ('0530_c016_00071160_0.jpg', '8'),

In [8]:
total = 0
for label, items in labels.items():
    print(str(label)+" "+str(len(items)))
    total += len(items)
print(total)

8 8194
5 8316
6 5415
3 3893
2 2889
1 2898
0 788
4 1578
33971


In [4]:
# 按照 9:1 比例分割数据
train_data, test_data = split_data(labels, train_ratio=0.9)

In [5]:
# 写入训练集文件
write_labels(train_data, './data/label/label_train.txt')

# 写入测试集文件
write_labels(test_data, './data/label/label_test.txt')