In [1]:
from doccano_transformer.datasets import NERDataset
from doccano_transformer.utils import read_jsonl
from tqdm import tqdm

In [3]:
dataset = read_jsonl(filepath='all.jsonl', dataset=NERDataset, encoding='utf-8')
conll = dataset.to_conll2003(tokenizer=str.split)
with open('dataset.conll', 'w') as f:
    for item in tqdm(conll):
        f.write(item['data'])

1093it [00:00, 11077.89it/s]


In [4]:
import re

chars = [',', '.', '?', '!', '"', ':', '(', ')', '/']
with open('dataset_pure_punct.conll', 'w') as output_file:
    with open('dataset.conll', 'r') as input_file:
        first = True
        for line in input_file:
            if line.startswith('-DOCSTART-'):
                if not first:
                    output_file.write('\n')
                else:
                    first = False
            elif line == '\n':
                pass
            else:
                contents = line.split()
                word = contents[0]
                tag = contents[3]
                texts = list(filter(None, re.split('([,|.|?|!|"|:|(|)|/])', word)))
                end_index = len(texts)
                for text in texts[::-1]:
                    if text in chars:
                        end_index -= 1
                    else:
                        break
                start_idx = 0
                for idx, text in enumerate(texts):
                    temp_tag = tag
                    if idx > start_idx and tag.startswith('B-'):
                        temp_tag = tag.replace('B-', 'I-')
                    # exists trailing problem (eg. Done. ; add, edit and delete a book)
                    if idx >= end_index:
                        temp_tag = 'O'
                    if idx == start_idx and text in '("':
                        start_idx += 1
                        output_file.write("{} {}\n".format(text, 'O'))
                    else:
                        output_file.write("{} {}\n".format(text, temp_tag))


In [5]:
# Fix the trailing character problem
with open('dataset_pure_punct_notrail.conll', 'w') as output_file:
    with open('dataset_pure_punct.conll', 'r') as input_file:
        queue = list()
        occur = False
        delimiter = ',.?!":()/'
        for line in input_file:
            if line == '\n':
                for q in queue:
                    output_file.write("{} {}\n".format(q[0], q[1]))
                queue.clear()
                output_file.write('\n')
                occur = False
                continue
            contents = line.split()
            word = contents[0]
            tag = contents[1]
            if word in delimiter:  # need push to queue
                occur = True
            elif len(queue) != 0:  # check if occur delimiter
                if not occur:
                    for q in queue:
                        output_file.write("{} {}\n".format(q[0], q[1]))
                else:
                    if queue[0][1] != 'O' and tag.startswith('I-') and queue[0][1][2:] == tag[2:]:
                        for q in queue:
                            if q[0] in delimiter:
                                output_file.write("{} {}\n".format(q[0], tag))
                            else:
                                output_file.write("{} {}\n".format(q[0], q[1]))
                    else:
                        for q in queue:
                            output_file.write("{} {}\n".format(q[0], q[1]))
                    occur = False
                queue.clear()
            queue.append((word, tag))


In [7]:
import random

proportion = 80

with open('dataset_pure_punct_notrail.conll', 'r') as input_file:
    with open('train.txt', 'w') as train_file:
        with open('dev.txt', 'w') as dev_file:
            cache = []
            for line in input_file:
                if line != '\n':
                    cache.append(line)
                else:
                    dice = random.randint(1, 100)
                    if dice > proportion:
                        for item in cache:
                            dev_file.write(item)
                        dev_file.write('\n')
                    else:
                        for item in cache:
                            train_file.write(item)
                        train_file.write('\n')
                    cache.clear()
            if len(cache) > 0:
                dice = random.randint(1, 100)
                if dice > proportion:
                    for item in cache:
                        dev_file.write(item)
                    dev_file.write('\n')
                else:
                    for item in cache:
                        train_file.write(item)
                    train_file.write('\n')

In [9]:
import random

total = 5
curIdx = 0
fileHandler = list()
for i in range(total):
    fileHandler.append(open(f'{i}.conll', 'w'))

with open('dataset_pure_punct_notrail.conll', 'r') as input_file:
    cache = []
    for line in input_file:
        if line != '\n':
            cache.append(line)
        else:
            handler = fileHandler[curIdx]
            for item in cache:
                handler.write(item)
            handler.write('\n')
            cache.clear()
            curIdx = (curIdx + 1) % total
    if len(cache) > 0:
        handler = fileHandler[curIdx]
        for item in cache:
            handler.write(item)
        handler.write('\n')
        cache.clear()
        curIdx = (curIdx + 1) % total

In [10]:
for i in range(5):
    with open(f'k-fold/combine/{i}/train.txt', 'w') as train_file:
        for j in range(5):
            if i == j:
                continue
            with open(f'k-fold/{j}.conll', 'r') as input_file:
                for line in input_file:
                    train_file.write(line)

    with open(f'k-fold/combine/{i}/dev.txt', 'w') as dev_file:
        with open(f'k-fold/{i}.conll', 'r') as input_file:
            for line in input_file:
                    dev_file.write(line)
