# 统计类别分布

In [1]:
from tqdm import tqdm
from glob import glob
import numpy as np
from sklearn.model_selection import train_test_split
import shutil
import os
from collections import Counter, defaultdict
from pprint import pprint

In [2]:
labels = [x.split('/')[-2] for x in glob('/home/ypw/data/xuelang/xuelang_round1_train*/*/*.jpg')]

In [3]:
counter = Counter(labels)
counter = counter.most_common()
# pprint(counter)

classes = [x[0] for x in counter if x[1] > 30]
classes.append('其他')
pprint(classes)

['正常', '吊经', '擦洞', '跳花', '毛洞', '织稀', '扎洞', '缺经', '毛斑', '其他']


# 创建符号链接

In [4]:
def make_symlinks(fnames, path):
    with tqdm(fnames) as pbar:
        for fname in pbar:
            img_fname = fname.split('/')[-1]
            label = fname.split('/')[-2]
            if label not in classes:
                label = '其他'
            os.symlink(fname, f'{path}/{label}/{img_fname}')

In [5]:
def get_class_fname_dict(fnames):
    class_fname_dict = defaultdict(list)
    for fname in fnames:
        label = fname.split('/')[-2]
        if label not in classes:
            label = '其他'
        class_fname_dict[label].append(fname)
    return class_fname_dict

In [6]:
def balance_make_symlinks(fnames, path):
    class_fname_dict = get_class_fname_dict(fnames)
    target_num = max([len(class_fname_dict[x]) for x in class_fname_dict])
    for label, fnames in class_fname_dict.items():
        n = len(fnames)
        for i in range(target_num):
            fname = fnames[i % n]
            img_fname = fname.split('/')[-1]
            os.symlink(fname, f'{path}/{label}/{i}_{img_fname}')

In [7]:
!rm -rf train valid

for c in classes:
    os.makedirs(f'train/{c}')
    os.makedirs(f'valid/{c}')

In [8]:
fnames = glob('/home/ypw/data/xuelang/xuelang_round1_train*/*/*.jpg')
train, valid = train_test_split(fnames, test_size=0.1)
balance_make_symlinks(train, 'train')
make_symlinks(valid, 'valid')

100%|██████████| 203/203 [00:00<00:00, 48645.59it/s]


# 查看生成的文件

In [9]:
labels = [x.split('/')[-2] for x in glob('train/*/*.jpg')]
counter = Counter(labels)
counter = counter.most_common()
counter

[('毛洞', 1191),
 ('正常', 1191),
 ('织稀', 1191),
 ('其他', 1191),
 ('毛斑', 1191),
 ('跳花', 1191),
 ('扎洞', 1191),
 ('擦洞', 1191),
 ('吊经', 1191),
 ('缺经', 1191)]

In [10]:
labels = [x.split('/')[-2] for x in glob('valid/*/*.jpg')]
counter = Counter(labels)
counter = counter.most_common()
counter

[('正常', 125),
 ('吊经', 21),
 ('其他', 11),
 ('擦洞', 10),
 ('跳花', 8),
 ('毛洞', 7),
 ('扎洞', 7),
 ('缺经', 7),
 ('织稀', 6),
 ('毛斑', 1)]