In [1]:
import re
from collections import Counter

## 读取原始数据

In [2]:
names = ['train', 'val', 'test']
real_all = []
synt_all = []
for name in names:
    real_file = f'./dataset/{name}_real.txt'
    with open(real_file, 'r') as f:
        real = []
        for line in f:
            if line.strip().endswith('_back'):
                continue
            real.append(line.strip())
    real_all += real
     
    synt_file = f'./dataset/{name}_synt.txt'
    with open(synt_file, 'r') as f:
        synt = []
        for line in f:
            if line.strip().endswith('_back'):
                continue
            line = line.strip().replace('_front', '')
            synt.append(line)
    synt_all += synt

In [3]:
print(len(real_all))
print(len(set(real_all)))
print(len(synt_all))
print(len(set(synt_all)))

1043
1043
12550
12403


## 原始分布

In [4]:
reals = [line.split('_')[0] for line in set(real_all)]
Counter(reals)

Counter({'Hem': 134,
         'Move1': 114,
         'Miss': 65,
         'Mesh': 122,
         'Full': 16,
         'Cable2': 71,
         'Links2': 105,
         'Cable1': 69,
         'Tuck': 167,
         'Move2': 138,
         'Links1': 42})

In [5]:
synts = [line.split('_')[0] for line in set(synt_all)]
Counter(synts)

Counter({'Cable1': 1312,
         'Hem': 2686,
         'Move2': 2858,
         'Links2': 981,
         'Miss': 1419,
         'Cable2': 829,
         'Move1': 1522,
         'Mesh': 229,
         'Links1': 261,
         'Tuck': 303,
         'Full': 3})

In [6]:
synt_match = []
for line in set(synt_all):
    if line in set(real_all) or line+'_front' in set(real_all):
        synt_match.append(line)

In [7]:
len(synt_match)

988

In [8]:
synt_matchs = [line.split('_')[0] for line in set(synt_match)]
Counter(synt_matchs)

Counter({'Cable1': 69,
         'Links2': 105,
         'Tuck': 167,
         'Miss': 65,
         'Mesh': 122,
         'Hem': 134,
         'Cable2': 71,
         'Move1': 98,
         'Move2': 114,
         'Links1': 42,
         'Full': 1})

## 增加非real

In [9]:
residuals = dict()
for k, v in Counter(synt_matchs).items():
    if k == 'Full':
        continue
    resi = min(300 - v, Counter(synts)[k] - v)
    residuals[k] = resi
residuals

{'Cable1': 231,
 'Links2': 195,
 'Tuck': 133,
 'Miss': 235,
 'Mesh': 107,
 'Hem': 166,
 'Cable2': 229,
 'Move1': 202,
 'Move2': 186,
 'Links1': 219}

In [10]:
synt_tmp = [line for line in synt_match if not line.startswith('Full')]
for cat in residuals.keys():
    total = residuals[cat]
    i = 0
    for line in set(synt_all):
        if line not in synt_match and line.startswith(cat) and i < total:
            synt_tmp.append(line)
            i += 1

In [11]:
len(synt_tmp)

2890

In [12]:
synt_tmps = [line.split('_')[0] for line in set(synt_tmp)]
Counter(synt_tmps)

Counter({'Cable1': 300,
         'Hem': 300,
         'Move2': 300,
         'Tuck': 300,
         'Links2': 300,
         'Miss': 300,
         'Cable2': 300,
         'Links1': 261,
         'Move1': 300,
         'Mesh': 229})

## 凑到3000个

In [13]:
print(3000 - len(synt_tmp))

110


In [14]:
res2 = {
    'Cable1': 15,
    'Hem': 16,
    'Miss': 15,
    'Move2': 16,
    'Links2': 15,
    'Move1': 15,
    'Mesh': 0,
    'Cable2': 15,
    'Links1': 0,
    'Tuck': 3
}

In [15]:
synt_final = synt_tmp.copy()
for cat in res2.keys():
    total = res2[cat]
    i = 0
    for line in set(synt_all):
        if line not in synt_tmp and line.startswith(cat) and i < total:
            synt_final.append(line)
            i += 1
len(synt_final)

3000

In [16]:
real_final = []
for line in real_all:
    if not line.startswith('Full') and (line in synt_match or line.replace('_front', '') in synt_match):
        real_final.append(line)
len(real_final)

987

In [17]:
inst_tmp = [line.replace('_front', '') for line in real_final]
resi3 = [line for line in synt_final if not line in inst_tmp]
inst_final = real_final + resi3
len(inst_final)

3000

## train val split

In [25]:
val_synt = []
val_synt_cnt = {
    'Cable1': 0,
    'Hem': 0,
    'Miss': 0,
    'Move2': 0,
    'Links2': 0,
    'Move1': 0,
    'Mesh': 0,
    'Cable2': 0,
    'Links1': 0,
    'Tuck': 0
}

for line in synt_final:
    cat = line.split('_')[0]
    if val_synt_cnt[cat] < 30:
        val_synt.append(line)
        val_synt_cnt[cat] += 1
        
train_synt = [line for line in synt_final if line not in val_synt]
print(len(train_synt))
print(len(val_synt))

2700
300


In [29]:
with open('./dataset/train_synt_test.txt', 'w') as f:
    for line in train_synt:
        f.write(line+'\n')
        
with open('./dataset/val_synt_test.txt', 'w') as f:
    for line in val_synt:
        f.write(line+'\n')
        
with open('./dataset/train_unsup_test.txt', 'w') as f:
    pass

with open('./dataset/train_real_test.txt', 'w') as f:
    pass

with open('./dataset/val_real_test.txt', 'w') as f:
    pass

In [30]:
# with open('./dataset/synt_test.txt', 'w') as f:
#     for line in synt_final:
#         f.write(line+'\n')