In [None]:
import os
import spacy
import json
import torch
import numpy as np
from tokenizations import get_alignments, get_original_spans

In [None]:
def load_labels(file_name):
    with open(file_name, 'r') as f:
        labels = json.load(f)
    return labels


def load_conll_2003_data(file_name):
    with open(file_name, 'r') as f:
        lines = f.readlines()

    app_sentence = list()
    all_sentence = list()
    sentence = list()
    all_labels = list()
    labels = list()
    for line in lines:
        try:
            token, _, _, ner_label = line.strip().split()
            sentence.append(token)
            labels.append(ner_label)
        except ValueError:
            app_sentence.append(sentence + ["@SB@"])
            all_sentence.append(sentence)
            all_labels.append(labels)
            sentence = list()
            labels = list()

    for sentence, labels in zip(all_sentence, all_labels):
        assert len(sentence) == len(labels)

    return app_sentence, all_sentence, all_labels

In [None]:
data_dir = r'data/conll2003/'
data_name = 'dev.txt'
app_tokens, all_tokens, all_labels = load_conll_2003_data(os.path.join(data_dir, data_name))
idx2label = load_labels(os.path.join('data', 'CoNLL2003-labels.json'))
label2idx = {v: k for k, v in enumerate(idx2label)}

for labels in all_labels:
    for lb in labels:
        if lb not in idx2label:
            raise ValueError

lb_indices = [[0] + [label2idx[lb] for lb in lbs] for lbs in all_labels]

In [None]:
sentences = list()
for tokens in all_tokens:
    sentences.append(' '.join(tokens))
app_sentences = list()
for tokens in app_tokens:
    app_sentences.append(' '.join(tokens))

In [None]:
sents = ' '.join(sentences[:10])
app_sents = ' '.join(app_sentences[:10])

In [None]:
nlp = spacy.load('en_core_web_sm')
app_doc = nlp(app_sents)

In [None]:
sep_indices = list()
start_indices = list()
for i, token in enumerate(doc):
    if token.text == '@SB@':
        sep_indices.append(i)
        start_indices.append(i+1 - len(sep_indices))

In [None]:
def set_custom_boundaries(doc):
#     for token in doc[:-1]:
#         if token.text == '@SB@':
#             doc[token.i+1].is_sent_start = True
    for i in range(len(doc)):
        doc[i].is_sent_start=False
    doc[0].is_sent_start=True
    for i in start_indices[:-1]:
        doc[i].is_sent_start=True
    return doc

nlp = spacy.load("en_core_web_sm")
nlp.add_pipe(set_custom_boundaries, before="parser")
doc = nlp(sents)
for sent in doc.sents:
    for t in sent:
        print(t, end=' ')
    print()

---

In [None]:
data_dir = r'data/conll2003/'
data_name_ori = 'dev.txt'
data_name_new = 'eng.testa'

with open(os.path.join(data_dir, data_name_ori), 'r') as f:
    corpus1 = f.readlines()
with open(os.path.join(data_dir, data_name_new), 'r') as f:
    corpus2 = f.readlines()

In [None]:
i = 0
j = 0
corpus_sent = list()
sep_corpus = list()
sep_sentences = list()
sentences = list()
sentence = list()
corpus_label = list()
label_seqs = list()
label_seq = list()
while i<len(corpus1) and j<len(corpus2):
    try:
        token2, _, _, ner_label2 = corpus2[j].strip().split()
        if token2 == '-DOCSTART-':
            if sentences:
                corpus_sent.append(sentences)
                sep_corpus.append(sep_sentences)
                corpus_label.append(label_seqs)
                sentences = list()
                sep_sentences = list()
                label_seqs = list()
            j += 2  # skip a new line
        else:
            _, _, _, ner_label1 = corpus1[i].strip().split()
            sentence.append(token2)
            label_seq.append(ner_label1)
            i += 1
            j += 1
        
    except ValueError:
        if sentence:
            sep_sentences.append(sentence + ["@SB@"])
            sentences.append(sentence)
            label_seqs.append(label_seq)
            sentence = list()
            label_seq = list()
    
        i += 1
        j += 1

corpus_sent.append(sentences)
sep_corpus.append(sep_sentences)
corpus_label.append(label_seqs)

In [None]:
data = {
    "documents": corpus_sent,
    "labels": corpus_label
}
torch.save(data, "CoNLL03-dev.pt")

---

In [None]:
def set_custom_boundaries(doc, start_indices):
    for i in range(len(doc)):
        doc[i].is_sent_start=False
    doc[0].is_sent_start=True
    for i in start_indices[:-1]:
        doc[i].is_sent_start=True
    return doc

In [None]:
sep_nlp = spacy.load('en_core_web_sm')

In [None]:
docs = []
for cp, sep_cp in zip(corpus_sent, sep_corpus):
    sentences = list()
    for tokens in cp:
        sentences.append(' '.join(tokens))
    sep_sentences = list()
    for tokens in sep_cp:
        sep_sentences.append(' '.join(tokens))

    sents = ' '.join(sentences)
    sep_sents = ' '.join(sep_sentences)
    
    sep_doc = sep_nlp(sep_sents)
    n_sep = 0
    start_indices = list()
    for i, token in enumerate(sep_doc):
        if token.text == '@SB@':
            n_sep += 1
            start_indices.append(i+1 - n_sep)

    nlp = spacy.load("en_core_web_sm")
    nlp.add_pipe(lambda doc: set_custom_boundaries(doc, start_indices=start_indices), before="parser")
    doc = nlp(sents)
    docs.append(doc)

In [None]:
torch.save(docs, 'something.pt')

In [None]:
len(docs)