In [1]:
import pickle
import re
from collections import defaultdict, Counter
from tqdm.notebook import tqdm
import pandas as pd
import random
from IPython.display import display

In [2]:
with open('data_processed.pkl', 'rb') as f:
    data = pickle.load(f)

In [3]:
path_has_book = set()
path_has_quote = set()
for cat, files in data.items():
    for path, lines in files.items():
        text = "".join(lines)
        quote_cnt = text.count('“') + text.count('”')
        book_cnt = text.count('《') + text.count('》')
        if quote_cnt > 10:
            path_has_quote.add(path)
        if book_cnt > 10:
            path_has_book.add(path)

path_has_quote_book = path_has_book & path_has_quote

In [4]:
random.seed('1')
dataset_train = defaultdict(dict)
dataset_dev = defaultdict(dict)
dataset_test = defaultdict(dict)
# dev_test_path = set()
for cat, files in data.items():
    paths = list(files.keys() & path_has_quote_book)
    random.shuffle(paths)
    size = len(paths)
    dev_size = test_size = int(len(paths) * 0.1)
    dev_path = paths[:dev_size]
    test_path = paths[dev_size:dev_size+test_size]
    train_path = paths[dev_size+test_size:]
#     dev_test_path |= set(dev_path) | set(test_path)
    dataset_dev[cat] = {path: files[path] for path in dev_path}
    dataset_test[cat] = {path: files[path] for path in test_path}
    dataset_train[cat] = {path: files[path] for path in train_path}

# for cat, files in data.items():
#     dataset_train[cat] = {path: files[path] for path in files if path not in dev_test_path}


In [5]:
def sampling(lines):
    return random.sample(lines, int(len(lines) * 0.01))

for cat in data:
    dataset_dev[cat] = {path: sampling(dataset_dev[cat][path]) for path in dataset_dev[cat]}
    dataset_test[cat] = {path: sampling(dataset_test[cat][path]) for path in dataset_test[cat]}
    dataset_train[cat] = {path: sampling(dataset_train[cat][path]) for path in dataset_train[cat]}


In [6]:
# Data statistics
def data_stat(data):
    df = pd.DataFrame([[cat,
               len(files),                                                             # 文章数
               sum([len(lines) for _, lines in files.items()]),                        # 段落数
               sum([sum([len(line) for line in lines])for _, lines in files.items()]), # 字数
              ] for cat, files in data.items()])
    display(df)
    display(df.sum())
    
data_stat(dataset_train)
data_stat(dataset_dev)
data_stat(dataset_test)

Unnamed: 0,0,1,2,3
0,易藏,41,573,43088
1,医藏,137,2075,176176
2,艺藏,96,109,12582
3,史藏,410,7092,785104
4,佛藏,44,486,47014
5,集藏,391,7316,953770
6,诗藏,227,4629,232277
7,子藏,340,2869,318214
8,儒藏,96,1303,127277
9,道藏,65,483,50724


0    易藏医藏艺藏史藏佛藏集藏诗藏子藏儒藏道藏
1                    1847
2                   26935
3                 2746226
dtype: object

Unnamed: 0,0,1,2,3
0,易藏,5,12,1773
1,医藏,16,235,22496
2,艺藏,11,88,9482
3,史藏,51,1380,151351
4,佛藏,5,48,4798
5,集藏,48,934,95966
6,诗藏,28,901,71522
7,子藏,42,225,29186
8,儒藏,12,224,22439
9,道藏,8,28,2424


0    易藏医藏艺藏史藏佛藏集藏诗藏子藏儒藏道藏
1                     226
2                    4075
3                  411437
dtype: object

Unnamed: 0,0,1,2,3
0,易藏,5,36,1652
1,医藏,16,261,21223
2,艺藏,11,20,1654
3,史藏,51,1195,136890
4,佛藏,5,30,2835
5,集藏,48,1679,243599
6,诗藏,28,253,26448
7,子藏,42,412,57896
8,儒藏,12,80,8246
9,道藏,8,26,2752


0    易藏医藏艺藏史藏佛藏集藏诗藏子藏儒藏道藏
1                     226
2                    3992
3                  503195
dtype: object

In [7]:
puncs_map = {
    '，': 'B-,',
    '。': 'B-.',
    '？': 'B-?',
    '！': 'B-!',
    '、': 'B-\\',
    '：': 'B-:',
    '；': 'B-;',
}

quotes_map = {
    '“': 'B',
    '”': 'E',
    '‘': 'B',
    '’': 'E',
}

books_map = {
    '《': 'B',
    '》': 'E',
    '〈': 'B',
    '〉': 'E',
}

def process_line(line):
    tokens = []
    seg_labels = []
    punc_labels = []
    quote_labels = []
    book_labels = []
    quote_begin = -1
    book_begin = -1
    for c in line:
        if c in puncs_map:
            if len(tokens) == 0:
                continue
            if punc_labels[-1] != 'O':
                continue
            punc_labels[-1] = puncs_map[c]
            seg_labels[-1] = 'B'
            continue

        if c in quotes_map:
            if len(tokens) == 0:
                continue
            if quotes_map[c] == 'B':
                if quote_begin != -1:
                    quote_labels[quote_begin:] = ['O'] * (len(quote_labels) - quote_begin)
                quote_begin = len(tokens)
            if quotes_map[c] == 'E':
                quote_begin = -1
            continue

        if c in books_map:
            if books_map[c] == 'B':
                if book_begin != -1:
                    book_labels[book_begin:] = ['O'] * (len(book_labels) - book_begin)
                book_begin = len(tokens)
            if books_map[c] == 'E':
                book_begin = -1
            continue
            
        punc_labels.append('O')
        seg_labels.append('O')

        if quote_begin == -1:
            quote_labels.append('O')
        elif quote_begin == len(tokens):
            quote_labels.append('B')
        else:
            quote_labels.append('I')
            
        if book_begin == -1:
            book_labels.append('O')
        elif book_begin == len(tokens):
            book_labels.append('B')
        else:
            book_labels.append('I')
        
        tokens.append(c)
    zipped = zip(tokens, seg_labels, punc_labels, quote_labels, book_labels)
    return [' '.join(l) for l in zipped]
    
test_str = '《谷梁传》曰：“壅河，三日不流，晋君率群臣素服而哭之，河乃流。”《左传》曰：“晋侯以传召伯宗，伯宗辟重曰：‘辟传。’重人曰：‘待我，不如捷之速也。’问其所，曰：‘绛人也。’问绛事焉，曰：‘梁山崩，将召伯宗谋之。’问将若之何，曰：‘山有朽坏而崩，可若何？国主山川，故山崩川竭，君为之不举，降服、乘缦、撤乐、出次、祝币，史辞以礼焉，其如此而已。虽伯宗若之何？”'
process_line(test_str)

['谷 O O O B',
 '梁 O O O I',
 '传 O O O I',
 '曰 B B-: O O',
 '壅 O O B O',
 '河 B B-, I O',
 '三 O O I O',
 '日 O O I O',
 '不 O O I O',
 '流 B B-, I O',
 '晋 O O I O',
 '君 O O I O',
 '率 O O I O',
 '群 O O I O',
 '臣 O O I O',
 '素 O O I O',
 '服 O O I O',
 '而 O O I O',
 '哭 O O I O',
 '之 B B-, I O',
 '河 O O I O',
 '乃 O O I O',
 '流 B B-. I O',
 '左 O O O B',
 '传 O O O I',
 '曰 B B-: O O',
 '晋 O O O O',
 '侯 O O O O',
 '以 O O O O',
 '传 O O O O',
 '召 O O O O',
 '伯 O O O O',
 '宗 B B-, O O',
 '伯 O O O O',
 '宗 O O O O',
 '辟 O O O O',
 '重 O O O O',
 '曰 B B-: O O',
 '辟 O O B O',
 '传 B B-. I O',
 '重 O O O O',
 '人 O O O O',
 '曰 B B-: O O',
 '待 O O B O',
 '我 B B-, I O',
 '不 O O I O',
 '如 O O I O',
 '捷 O O I O',
 '之 O O I O',
 '速 O O I O',
 '也 B B-. I O',
 '问 O O O O',
 '其 O O O O',
 '所 B B-, O O',
 '曰 B B-: O O',
 '绛 O O B O',
 '人 O O I O',
 '也 B B-. I O',
 '问 O O O O',
 '绛 O O O O',
 '事 O O O O',
 '焉 B B-, O O',
 '曰 B B-: O O',
 '梁 O O B O',
 '山 O O I O',
 '崩 B B-, I O',
 '将 O O I O',
 '召 O O I O',
 '伯 O O I O'

In [9]:
# 构建训练集 dataset_train
result = []
for cat, files in tqdm(dataset_train.items()):
    for path, lines in tqdm(files.items(), leave=False):
        result.append('-DOCSTART- -SEG- -PUNC- '
                     + ('-QUOTE- ' if path in path_has_quote else '-X- ')
                     + ('-BOOK-' if path in path_has_book else '-X-' ))
        result.append('')
        
        for line in lines:
            result.extend(process_line(line))
            result.append('')

with open('punctuation/train.txt', 'wt') as f:
    f.write('\n'.join(result))

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/137 [00:00<?, ?it/s]

  0%|          | 0/96 [00:00<?, ?it/s]

  0%|          | 0/410 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/391 [00:00<?, ?it/s]

  0%|          | 0/227 [00:00<?, ?it/s]

  0%|          | 0/340 [00:00<?, ?it/s]

  0%|          | 0/96 [00:00<?, ?it/s]

  0%|          | 0/65 [00:00<?, ?it/s]

In [54]:
!(cd punctuation && tar -czf train.txt.tar.gz train.txt && rm train.txt)

In [10]:
# 构建开发集，测试集
def build_dev_test(dataset, dataset_type):
    result_all = []
    for cat, files in tqdm(dataset.items()):
        result = []
        for lines in files.values():
            result.append('-DOCSTART- -SEG- -PUNC- -QUOTE- -BOOK-')
            result.append('')
            for line in lines:
                result.extend(process_line(line))
                result.append('')
        result_all.extend(result)
#         with open(f'punctuation/{cat}.{dataset_type}.txt', 'wt') as f:
#             f.write('\n'.join(result))
    with open(f'punctuation/{dataset_type}.txt', 'wt') as f:
        f.write('\n'.join(result_all))

build_dev_test(dataset_dev, 'dev')
build_dev_test(dataset_test, 'test')

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]