In [1]:
import nltk
from nltk.tokenize.punkt import PunktSentenceTokenizer, PunktParameters, PunktLanguageVars
from nltk.tokenize import WhitespaceTokenizer

In [2]:
nltk.download('punkt')

[nltk_data] Downloading package punkt to C:\Users\Chuang Feng
[nltk_data]     Chia\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [3]:
from transformers import AutoTokenizer
import pandas as pd
from datasets import Dataset

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from datasets import load_dataset
dataset = load_dataset("adsabs/WIESP2022-NER")

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['bibcode', 'label_studio_id', 'ner_ids', 'ner_tags', 'section', 'tokens', 'unique_id'],
        num_rows: 1753
    })
    validation: Dataset({
        features: ['bibcode', 'label_studio_id', 'ner_ids', 'ner_tags', 'section', 'tokens', 'unique_id'],
        num_rows: 1366
    })
    test: Dataset({
        features: ['bibcode', 'label_studio_id', 'ner_ids', 'ner_tags', 'section', 'tokens', 'unique_id'],
        num_rows: 2505
    })
})

In [6]:
tk = WhitespaceTokenizer()
punkt_param = PunktParameters()
abbreviation = ['al', 'fig', 'tab', 'i.e', 'no', 'etal', ]
punkt_param.abbrev_types = set(abbreviation)
class SpacedLangVars(PunktLanguageVars):
    _period_context_fmt = r"""
        %(SentEndChars)s             # a potential sentence ending
        (?=(?P<after_tok>
            (((%(NonWord)s)+\s+)            # either other punctuation
            |
            \s+)(?P<next_tok>\S+)     # or whitespace and some other token
        ))"""
pt = PunktSentenceTokenizer(lang_vars = SpacedLangVars(), train_text = punkt_param)

In [7]:
import json
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

In [8]:
ner_tags = {"B-Archive": 0, "B-CelestialObject": 1, "B-CelestialObjectRegion": 2, "B-CelestialRegion": 3, "B-Citation": 4, "B-Collaboration": 5, "B-ComputingFacility": 6, "B-Database": 7, "B-Dataset": 8, "B-EntityOfFutureInterest": 9, "B-Event": 10, "B-Fellowship": 11, "B-Formula": 12, "B-Grant": 13, "B-Identifier": 14, "B-Instrument": 15, "B-Location": 16, "B-Mission": 17, "B-Model": 18, "B-ObservationalTechniques": 19, "B-Observatory": 20, "B-Organization": 21, "B-Person": 22, "B-Proposal": 23, "B-Software": 24, "B-Survey": 25, "B-Tag": 26, "B-Telescope": 27, "B-TextGarbage": 28, "B-URL": 29, "B-Wavelength": 30, "I-Archive": 31, "I-CelestialObject": 32, "I-CelestialObjectRegion": 33, "I-CelestialRegion": 34, "I-Citation": 35, "I-Collaboration": 36, "I-ComputingFacility": 37, "I-Database": 38, "I-Dataset": 39, "I-EntityOfFutureInterest": 40, "I-Event": 41, "I-Fellowship": 42, "I-Formula": 43, "I-Grant": 44, "I-Identifier": 45, "I-Instrument": 46, "I-Location": 47, "I-Mission": 48, "I-Model": 49, "I-ObservationalTechniques": 50, "I-Observatory": 51, "I-Organization": 52, "I-Person": 53, "I-Proposal": 54, "I-Software": 55, "I-Survey": 56, "I-Tag": 57, "I-Telescope": 58, "I-TextGarbage": 59, "I-URL": 60, "I-Wavelength": 61, "O": 62}
ner_tags_swap = {v: k for k, v in ner_tags.items()}

In [9]:
def get_label_idx(label, idx):
    if idx == None:
        return len(ner_tags)-1
    else:
        return label[idx]

In [10]:
def convert_to_features(example_batch, indices=None):
        texts = example_batch["tokens"]

        features = tokenizer.batch_encode_plus(
            texts,
            truncation=True,
            is_split_into_words=True
        )
        features["word_ids"] = [list(map(lambda x: -1 if x is None else x, features.word_ids(idx))) for idx in range(len(example_batch["unique_id"]))]
        if "ner_ids" in example_batch:
            features["labels"] = [[get_label_idx(label,i) for i in features.word_ids(idx)] for idx, label in enumerate(example_batch["ner_ids"])]
        return features

In [11]:
def reconstruct_dataset(ds_type):
    data = []
    id = 0
    for item in dataset[ds_type]:
        sentences = " ".join(item["tokens"])
        sentence_list = pt.tokenize(sentences)
        counter = 0
        deduct = 0
        for idx, sentence in enumerate(sentence_list):
            prev_counter = counter
            counter += (len(sentence.strip().split(" ")))
            entry = {}
            entry["tokens"] = item["tokens"][prev_counter: counter]
            entry["unique_id"] = item["unique_id"]
            entry["part"] = idx
            entry["id"] = id
            id += 1
            if "ner_ids" in item and "ner_tags" in item:
                entry["ner_ids"] = item["ner_ids"][prev_counter: counter]
                entry["ner_tags"] = item["ner_tags"][prev_counter: counter]
            data.append(entry)
        if counter != len(item['tokens']):
            assert counter == len(item['tokens'])
    data = Dataset.from_list(data)
    data = data.map(convert_to_features,
                batched=True,
                batch_size=-1
            )
    return data

In [12]:
data = reconstruct_dataset("test")

Map: 100%|██████████| 35657/35657 [00:03<00:00, 9663.84 examples/s]


In [13]:
data = list(data)

In [14]:
for i in range(len(data)):
    data[i].pop("tokens")
    data[i].pop("unique_id")
    data[i].pop("part")
    data[i].pop("token_type_ids")
    data[i].pop("attention_mask")
    #data[i].pop("word_ids")
    data[i].pop("ner_ids")
    data[i].pop("ner_tags")
    data[i]["text_labels"] = [ner_tags_swap[key] for key in data[i]["labels"]]

In [15]:
with open("test_processed_new.jsonl", "w") as f:
    for item in data:
        f.write(json.dumps(item) + "\n")

In [16]:
#print(len(data[0]["tokens"]))

In [17]:
print(data[0]["text_labels"])

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-Person', 'I-Person', 'I-Person', 'I-Person', 'I-Person', 'I-Person', 'B-Person', 'I-Person', 'I-Person', 'I-Person', 'B-Person', 'B-Person', 'I-Person', 'I-Person', 'B-Person', 'I-Person', 'I-Person', 'I-Person', 'B-Person', 'I-Person', 'I-Person', 'O', 'B-Person', 'I-Person', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [18]:
print(data[0]["input_ids"])

[101, 1109, 5752, 1156, 1176, 1106, 6243, 3379, 139, 15243, 11192, 1200, 117, 15760, 5308, 1200, 117, 26835, 4838, 7665, 117, 2639, 140, 17760, 117, 1847, 8411, 117, 1105, 5590, 8859, 1111, 5616, 10508, 1113, 6757, 8519, 2344, 117, 2233, 3252, 117, 1105, 1672, 2233, 118, 2235, 7577, 8015, 119, 102]


In [19]:
print(data[0]["word_ids"])

[-1, 0, 1, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 8, 9, 9, 9, 10, 10, 11, 11, 12, 13, 13, 13, 14, 15, 15, 16, 17, 18, 19, 20, 21, 22, 23, 23, 24, 24, 25, 26, 26, 27, 28, 29, 29, 29, 30, 31, 31, -1]
