In [2]:
from bs4 import BeautifulSoup
import os
import re
from IPython.core.debugger import set_trace
from transformers import BertTokenizerFast
import copy
from tqdm import tqdm
import html
from pprint import pprint
import glob
import json
import time

# Data Preprocessing

In [3]:
project_root = "/home/wangyucheng/workplace/notebook/research/nested_ner"
ori_data_dir = os.path.join(project_root, "ori_data")
preprocessed_data_dir = os.path.join(project_root, "preprocessed_data")
exp_name = "genia"
genia_path = os.path.join(ori_data_dir, "GENIA_term_3.02", "GENIAcorpus3.02.xml")

In [4]:
soup = BeautifulSoup(open(genia_path, "r", encoding = "utf-8"), "lxml")

In [5]:
article_list = soup.select("set > article")

In [6]:
len(article_list)

2000

In [7]:
model_path = "/home/wangyucheng/opt/transformers_models_h5/bert-base-cased"
tokenizer = BertTokenizerFast.from_pretrained(model_path, add_special_tokens = False, do_lower_case = False)

In [8]:
def get_char_ind2tok_ind(tok2char_span):
    char_num = None
    for tok_ind in range(len(tok2char_span) - 1, -1, -1):
        if tok2char_span[tok_ind][1] != 0:
            char_num = tok2char_span[tok_ind][1]
            break
    char_ind2tok_ind = [0 for _ in range(char_num)] # 除了空格，其他字符均有对应token
    for tok_ind, sp in enumerate(tok2char_span):
        for char_ind in range(sp[0], sp[1]):
            char_ind2tok_ind[char_ind] = tok_ind
    return char_ind2tok_ind

In [9]:
def convert_to_dict(article):
    '''
    article: article tag
    return: 
        article_dict: {
            "id": medline_id,
            "text": article_text,
            "entity_list": [(lex, sem, span), ]
        }
    '''
    article_cp = copy.copy(article) # extract tag fr a copy, avoid removing it from the dom tree
    medline_id = article_cp.select_one("articleinfo").extract().select_one("bibliomisc").get_text()
    art_text = article_cp.get_text()
    article_dict = {
        "id": medline_id,
        "text": art_text,
    }
    
    segs = re.sub("(<[^>]+>)", r"⺀\1⺀", str(article_cp)).split("⺀")
    # s中某些符号会被转义，在这里要转义回来，如> &lg;
    # 如果不转义回来，char pos的计算会错误，如把>算作4个字符（&lg;）
    # 因为get_text()会自动转义回去
    segs = [html.unescape(s) for s in segs if s != ""]
    
    # count tokens' position
    str_w_pos = ""
    all_char_num = 0
    for s in segs:
        if re.match("<[^>]+>", s):
            str_w_pos += s
            continue
        char_num = len(s)
        char_pos = [str(all_char_num + i) for i in range(char_num)]
        if len(char_pos) > 0:
            str_w_pos += " " + " ".join(char_pos) + " "
        all_char_num += char_num
#     print(str_w_pos)
#     set_trace()
    # parse terms' spans
    soup = BeautifulSoup(str_w_pos, "lxml")
    cons_w_pos_list = soup.select("cons")
    ori_cons_list = article_cp.select("cons")
    assert len(cons_w_pos_list) == len(ori_cons_list) # 检查是否影响了原来的标注

    term_list = []
    offset_map = tokenizer.encode_plus(art_text, 
                                       return_offsets_mapping = True, 
                                       add_special_tokens = False)["offset_mapping"]
    char_ind2tok_ind = get_char_ind2tok_ind(offset_map)
    for ind, cons in enumerate(cons_w_pos_list):
        sem_text = "[UNK]" if "sem" not in cons.attrs else cons["sem"] 
        # subtype
        subtype = re.search("G#[^\s()]+", sem_text)
        if subtype is not None:
            subtype = subtype.group().split("#")[1]
        
        lex = "[UNK]" if "lex" not in cons.attrs else re.sub("_", " ", cons["lex"])  
        
        # position
        pos_num = cons.get_text().strip().split(" ")
        span = (int(pos_num[0]), int(pos_num[-1]) + 1)
        
        cons_text = ori_cons_list[ind].get_text()
        term = {
            "text": cons_text,
            "lex": lex,
            "sem": sem_text,
            "subtype": subtype,
            "char_span": span,
            "tok_span": (char_ind2tok_ind[span[0]], char_ind2tok_ind[span[1] - 1] + 1),
        }
        term_list.append(term)
    article_dict["entity_list"] = term_list
    return article_dict

In [10]:
def get_tok2char_span_map(text):
    tok2char_span = tokenizer.encode_plus(text, 
                                           return_offsets_mapping = True, 
                                           add_special_tokens = False)["offset_mapping"]
    return tok2char_span

In [11]:
# check spans
for art in tqdm(article_list):
    art_dict = convert_to_dict(art)
    art_text = art_dict["text"]
    tok2char_span = get_tok2char_span_map(art_text)
    for term in art_dict["entity_list"]:
#         # check char span
#         char_span = term["char_span"]
#         pred_text = art_text[char_span[0]:char_span[1]]
#         assert pred_text == term["text"]
        
        # check tok span
        # # voc 里必须加两个token：hypo, mineralo
        tok_span = term["tok_span"]
        char_span_list = tok2char_span[tok_span[0]:tok_span[1]]
        pred_text = art_text[char_span_list[0][0]:char_span_list[-1][1]]
        assert pred_text == term["text"]

100%|██████████| 2000/2000 [00:44<00:00, 45.24it/s]


In [12]:
def collapse(article_dict):
    '''
    only keep 5 types: RNA, DNA, protein, cell_type, cell_line
    '''
    new_term_list = []
    save_types = {"RNA", "DNA", "protein", "cell_line", "cell_type"}
    
    for term in article_dict["entity_list"]:
        subtype = term["subtype"]
        if subtype is None:
            continue
        type_ = subtype.split("_")[0] if subtype not in {"cell_type", "cell_line"} else subtype
        if type_ in save_types:
            
            term["type"] = type_
            new_term_list.append(term)
    
    article_dict["entity_list"] = new_term_list

In [13]:
# for art in tqdm(article_list):
#     art_dict = convert_to_dict(art)
#     pprint(art_dict["term_list"])
#     print()
#     collapse(art_dict)
#     for term in art_dict["term_list"]:
#         if "type" not in term:
#             set_trace()
#     pprint(art_dict["term_list"])
# #     print("------------------"))

In [14]:
# convert to dict
article_dict_list = []
for art in tqdm(article_list):
    art_dict = convert_to_dict(art)
    collapse(art_dict)
    article_dict_list.append(art_dict)

100%|██████████| 2000/2000 [00:41<00:00, 48.39it/s]


In [16]:
# split into train and eval set
train_num = int(len(article_dict_list) * 0.9)
train_data, eval_data = article_dict_list[:train_num], article_dict_list[train_num:]
print(len(train_data), len(eval_data))

1800 200


# Output

In [17]:
exp_path = os.path.join(preprocessed_data_dir, exp_name)
if not os.path.exists(exp_path):
    os.mkdir(exp_path)
train_save_path = os.path.join(preprocessed_data_dir, exp_name, "train_data.json")
eval_save_path = os.path.join(preprocessed_data_dir, exp_name, "eval_data.json")
json.dump(train_data, open(train_save_path, "w", encoding = "utf-8"), ensure_ascii = False)
json.dump(eval_data, open(eval_save_path, "w", encoding = "utf-8"), ensure_ascii = False)

In [18]:
tags = ["RNA", "DNA", "protein", "cell_line", "cell_type"]
tag_path = os.path.join(preprocessed_data_dir, exp_name, "tags.json")
json.dump(tags, open(tag_path, "w", encoding = "utf-8"), ensure_ascii = False)