In [None]:
import os
import json

origin_data_path = 'path/to/qasper'


def process(doc):
    seg_point_count = [0]*10
    seg_point_count[0] = 1
    
    doc_text = doc.get('title', '')
    doc_text = doc_text + '\n' + doc.get('abstract', '')
    for sec in doc['full_text']:
        if sec.get('section_name', None):
            titles = [t.strip() for t in sec['section_name'].split(':::')]
            paragraphs = sec['paragraphs']
            level = len(titles)
            title_text = f'{"#"*level} ' + titles[-1]
            section_text = '\n'.join(paragraphs)
            doc_text = doc_text + '\n' + title_text + '\n' + section_text
            seg_point_count[level] += 1
        else:
            paragraphs = sec['paragraphs']
            section_text = '\n'.join(paragraphs)
            doc_text = doc_text + '\n' + section_text
        
    
    return doc_text, seg_point_count

    
splits = ['train', 'dev', 'test']
for split in splits:
    seg_point_count = [0]*10
    save_path = f'corpus/qasper/{split}_file'
    os.makedirs(save_path, exist_ok=True)
    
    input_path = f'{origin_data_path}/qasper-{split}-v0.3.json'
    doc_dict = json.load(open(input_path))
    for _id in doc_dict:
        doc_text, spc = process(doc_dict[_id])
        with open(f'{save_path}/{_id}.txt', 'w') as f:
            f.write(doc_text)
        for i in range(10):
            seg_point_count[i] += spc[i]
            
    print(split, seg_point_count)


In [None]:
import json
import os
from tqdm import tqdm

origin_data_path = 'path/to/gov-report'


def process(section, level, spc):
    if section['section_title'].strip()=="" and len(section['paragraphs']) == 0:
        return '', False
    res = '#'*level + ' ' + (f"{section['section_title']}\n" if ('section_title' in section and section['section_title'].strip() != '') else '') + '\n'.join(section['paragraphs']) + '\n'
    spc[level] += 1
    for subs in section['subsections']:
        _, flag = process(subs, level+1, spc)
        if not flag:
            return '', False
        else:
            res += _

    return res, True

    
splits = ['train', 'valid', 'test']
save_splits_map = {
    'train': 'train',
    'valid': 'dev',
    'test': 'test'
}
for split in splits:
    seg_point_count = [0]*10
    save_path = f'corpus/gov-report_5w/{save_splits_map[split]}_file'
    os.makedirs(save_path, exist_ok=True)
    
    for n in tqdm(open(f'{origin_data_path}/split_ids/gao_{split}.ids', 'r').readlines()):
        n = n.strip()
        d = json.load(open(f'{origin_data_path}/gao/{n}.json', 'r'))
        doc = ''
        spc = [0]*10
        spc[0] += 1
        for subs in d['report']:
            _, flag = process(subs, 1, spc)
            if not flag:
                doc = ''
                break
            else:
                doc += _
        if doc != '' and len(doc) >= 50000:
            for i in range(10):
                seg_point_count[i] += spc[i]
            with open(f'{save_path}/gao_{n}.txt', 'w') as f:
                f.write(doc)
    
    
    for n in tqdm(open(f'{origin_data_path}/split_ids/crs_{split}.ids', 'r').readlines()):
        n = n.strip()
        spc = [0]*10
        spc[0] += 1
        d = json.load(open(f'{origin_data_path}/crs/{n}.json', 'r'))
        doc, flag = process(d['reports'], 1, spc)
        if doc != '' and len(doc) >= 50000:
            for i in range(10):
                seg_point_count[i] += spc[i]
            with open(f'{save_path}/crs_{n}.txt', 'w') as f:
                f.write(doc)
    
    print(split, seg_point_count)


In [None]:
import os
from tqdm import tqdm


origin_data_path = 'path/to/wiki_727'
splits = ['train', 'dev', 'test']
for split in splits:
    seg_point_count = [0]*10
    save_path = f'corpus/wiki_727_5w/{split}_file'
    os.makedirs(save_path, exist_ok=True)
    
    for root, dirnames, filenames in tqdm(os.walk(f'{origin_data_path}/{split}/')):
        for fn in filenames:
            with open(os.path.join(root, fn), 'r', encoding='utf-8') as f:
                data = f.read()
            if len(data) < 50000:
                continue
                
            spc = [0]*10
            save_name = root.lstrip(f'{origin_data_path}/{split}/').replace('/', '_') + '_' + fn
            with open(os.path.join(root, fn), 'r', encoding='utf-8') as f:
                lines = f.readlines()
            spc[0] += 1
            for i, l in enumerate(lines):
                if l.startswith('========,'):
                    level = int(l.split(',')[1].strip())
                    spc[level] += 1
    
            if spc[1] <= 1:
                for i, l in enumerate(lines):
                    if l.startswith('========,'):
                        level = int(l.split(',')[1].strip())
                        lines[i] = f"{'#' * (level if level == 1 else level - 1)} {l.split(',')[2]}"
                spc.pop(1)
                spc[1] += 1
                spc.append(0)
            else:
                for i, l in enumerate(lines):
                    if l.startswith('========,'):
                        level = int(l.split(',')[1].strip())
                        lines[i] = f"{'#' * level} {l.split(',')[2]}"
            
            for i in range(10):
                seg_point_count[i] += spc[i]
            with open(f"{save_path}/{save_name}", 'w', encoding='utf-8') as f:
                f.writelines(lines)
    
    print(split, seg_point_count)
