In [1]:
import json
import os
from glob import glob
import re
import numpy as np

In [2]:
import spacy
from scispacy.umls_linking import UmlsEntityLinker
NER_tagger = spacy.load('en_ner_jnlpba_md')
linker = UmlsEntityLinker(resolve_abbreviations=True)
NER_tagger.add_pipe(linker)



In [3]:
from transformers import *
# load bert tokenizer
bert_dir = 'allenai/scibert_scivocab_uncased'
tokenizer = BertTokenizer.from_pretrained(bert_dir)

In [4]:
import pandas as pd

# df that map CUI to STY
with open('umls_data/MRSTY.RRF','r') as f:
    lines = [line.split('|')[:-1] for line in f.readlines()]

mapping_df = pd.DataFrame(lines,columns=['CUI','TUI','STN','STY','ATUI','CVF'])
mapping_df.head()

Unnamed: 0,CUI,TUI,STN,STY,ATUI,CVF
0,C0000005,T116,A1.4.1.2.1.7,"Amino Acid, Peptide, or Protein",AT17648347,256
1,C0000005,T121,A1.4.1.1.1,Pharmacologic Substance,AT17575038,256
2,C0000005,T130,A1.4.1.1.4,"Indicator, Reagent, or Diagnostic Aid",AT17634323,256
3,C0000039,T109,A1.4.1.2.1,Organic Chemical,AT45562015,256
4,C0000039,T121,A1.4.1.1.1,Pharmacologic Substance,AT17567371,256


In [5]:
#create mapping from CUI to STY
CUI2STY_mapping = mapping_df.groupby('CUI').STY.apply(set).to_dict()

In [6]:
max_length=230
dummy_span = None

In [7]:
def filter_protein_entities(entities):
    '''
    Feed a list of entities, and return the span of those whose type is proein
    '''
    # store the start and end of each mapped entity
    document_entities = []
    for entity in entities:
        if len(entity._.umls_ents ) > 0:
            # umls mapping
            best_matched_CUI, score = entity._.umls_ents[0]

            # only keep this mapping if the score is greater than a threshold 
            # AND if this is a protein
            if score > 0.9 and 'Amino Acid, Peptide, or Protein' in CUI2STY_mapping.get(best_matched_CUI,[]):
                document_entities.append(entity)
    return document_entities

In [8]:
from collections import defaultdict


def process_document(document, doc_id, output_dir, corpus_proteinOrigIdBySpan):
    '''
    Run Spacy to extract proteins and split single document into sentences
    '''
    document = NER_tagger(document)
    document_entities = filter_protein_entities(document.ents)
    
    # create mapping from starting character position to entity id
    ent_char2_id_map = {ent.start_char:f'T{idx+1}' for idx, ent in enumerate(document_entities)}
    
    with open(f'{output_dir}/{doc_id}.a1','w')as f:
        [f.write(f'{ent_char2_id_map[ent.start_char]}\tProtein {ent.start_char} {ent.end_char}\t{ent}\n') for ent in document_entities]
    doc_result_dict = []
    for sentence in document.sents:
        doc_result_dict.append(generate_sentence_instance(sentence, doc_id, ent_char2_id_map, corpus_proteinOrigIdBySpan))
    
    return doc_result_dict
    

In [9]:
import pickle
def process_documents(file_path, output_dir):
    '''
    process each document under pmc_json
    '''
    os.makedirs(output_dir, exist_ok=True)
    file_name = file_path.split('/')[-1].split('.')[0]
    with open(file_path,'r') as f:
        document = json.loads(f.read())

    for doc_id, doc in enumerate(document['body_text']):
        doc_id = f'{file_name}.d{doc_id}'
        # ignore these sections
        if doc['section'] not in ['Introduction','Concluding remark','Conclusion']:
            with open(f'{output_dir}/{doc_id}.txt', 'w') as f:
                f.write(doc['text'])
            try:
                # process documents
                for sent_result_dict in process_document(doc['text'], doc_id, output_dir, corpus_proteinOrigIdBySpan):
                    if sent_result_dict is not None:
                        # store key, value pairs into the all result dict
                        for key, value in sent_result_dict.items():
                            all_result_dict[key].append(value)

                corpus_docId2OrigId[doc_id] = doc_id
            except Exception as e:
                print(f'{e} : {doc_id}')

In [10]:
def generate_sentence_instance(sentence, doc_id, ent_char2_id_map, corpus_proteinOrigIdBySpan):
    '''
    Generate one instance per sentence.
    '''
    
    # only mark the head tokens as label

    entities = sentence.ents
    
    # start char offset of this sentence    
    sent_start = sentence.start_char
    entities = filter_protein_entities(entities)
    # don't consider sentence with no entity 
    if len(entities) == 0: return None
    
    # obtain the text of this sentence
    text = str(sentence)
    
    # head entity char position -> bert head token position
    # TODO: create origin to token map
    char_to_token_map = {}
    token_to_char_map = {}
    
    char_to_token_tail_map = {}
    new_tokens = ["[CLS]"]
            
    # cumulative the cumulative sum of length in origin_tokens except for [CLS]
    cur_tokens_length = 0
    

    # split sentence into sub-sentences
    segment_numbers = [0]+\
    np.hstack([[entity.start_char - sent_start, entity.end_char - sent_start]  for entity in entities ]).tolist()\
    + [len(text)+1]   
    
    sub_sentences = [text[start:end] for start, end in zip(segment_numbers[:-1], segment_numbers[1:])]

    # absolute span of each entity
    abs_spans = [dummy_span]

    for sub_sentence in sub_sentences:
        # create a mapping from character space to bert token space

        skip_patterns = r'([^a-zA-Z0-9])'

        for original_texts in re.split(skip_patterns, sub_sentence):
        
            if len(original_texts) == 0: continue

            # a list of bert tokens
            bert_tokens = tokenizer.tokenize(original_texts)
            
            head_char_position = cur_tokens_length 

            # head token position
            char_to_token_map[head_char_position] = len(new_tokens)
            token_to_char_map[len(new_tokens)] = head_char_position

            cur_tokens_length += len(original_texts)
            new_tokens.extend(bert_tokens)
            
            tail_char_position = cur_tokens_length -1 

            # if the tail_char_position is a list
            while text[tail_char_position] in [' '] and tail_char_position >= head_char_position:
                tail_char_position -= 1
            
            # this is invalid
            if tail_char_position < head_char_position or re.search(skip_patterns, original_texts) is not None: #original_texts=='\n':# 

                abs_spans.extend([dummy_span] * (len(bert_tokens) ))                
                continue
            
            tail_position = len(new_tokens) 
            
            # trim the tail token position if the token is not [UNK]
            if new_tokens[tail_position-1] not in ['[UNK]']:
                while text[tail_char_position].lower() != new_tokens[tail_position-1][-1]:
                    # tail token position
                    tail_position -= 1


            char_to_token_tail_map[tail_char_position] = len(new_tokens) 
            abs_span = '-'.join([str(head_char_position+sent_start), str(tail_char_position+1+sent_start)])                    

            abs_spans.append(abs_span)  

            # append dummy spans
            abs_spans.extend([dummy_span] * (len(bert_tokens) -1))
    
    entity_label = np.array(['None'] * (len(new_tokens) + 1) ,dtype=object)  # plus 1 for [SEP]

    # iteratev over each offset
    for entity in entities:

        head_char_idx = entity.start_char - sent_start
        tail_char_idx = entity.end_char - sent_start

        # only add if there is a None
        assert entity_label[char_to_token_map[head_char_idx]] == 'None'
        
        # assign protein label
        entity_label[char_to_token_map[head_char_idx]] = 'Protein'



        # doc_id -> span -> protein origid
        corpus_proteinOrigIdBySpan[doc_id][abs_spans[char_to_token_map[head_char_idx]]] = f'{doc_id}.{ent_char2_id_map[entity.start_char]}'



        # make sure at least the first character of the mapping is correct
        assert new_tokens[char_to_token_map[head_char_idx]][0] == text[head_char_idx].lower(), (new_tokens[char_to_token_map[head_char_idx]][0] , text[head_char_idx], entity)


    # convert entity label back to list for zero padding
    entity_label = entity_label.tolist()


    new_tokens.append('[SEP]')
    abs_spans.append(dummy_span)            

    assert len(new_tokens) == len(abs_spans), (len(new_tokens) , len(abs_spans))


    if len(new_tokens) > max_length:
        print(f"{doc_id}-{sentence.ent_id}: Exceed max length")
        return None
    tokenized_ids = tokenizer.convert_tokens_to_ids(new_tokens)

    # mask ids
    mask_ids = [1] * len(tokenized_ids)

    # segment ids
    segment_ids = [0] * len(tokenized_ids)

    if len(tokenized_ids) < max_length:
        # Zero-pad up to the sequence length
        padding = [0] * (max_length - len(tokenized_ids))
        tokenized_ids += padding
        entity_label += ['None'] * len(padding)
        mask_ids += padding
        segment_ids += padding

    assert len(tokenized_ids) == max_length == len(mask_ids) == len(segment_ids) == len(entity_label ) ,\
     (len(tokenized_ids) ,max_length ,len(mask_ids) ,len(segment_ids), len(entity_label) )


    result_dict = {
    'tokenized_ids': tokenized_ids,
    'entity_labels': entity_label,
    'mask_ids': mask_ids,
    'segment_ids': segment_ids,
    'sent_ids':sentence.ent_id,
    'doc_ids': doc_id,
    'token_to_char_map': token_to_char_map,
    'char_to_token_map': char_to_token_map,
    'abs_spans':abs_spans
    }
    return result_dict

In [11]:
from tqdm.notebook import tqdm
import multiprocessing 
from joblib import Parallel, delayed
import shutil

import os

corpus_docId2OrigId = {}
corpus_proteinOrigIdBySpan = defaultdict(dict)
all_result_dict = defaultdict(list)

genia_input_folder='genia_cord_19'
output_dir='preprocessed_data'
f_name = 'CORD_19_PMC'    


# clear out previous results
shutil.rmtree(genia_input_folder, ignore_errors=True)
shutil.rmtree(output_dir, ignore_errors=True)


parallel = Parallel(multiprocessing.cpu_count(), backend="threading", verbose=0)
parallel(delayed(process_documents)(file_path, genia_input_folder) for file_path in tqdm(glob('custom_license/custom_license/pmc_json/*')))
all_result_dict['sample_ids'] = np.arange(len(all_result_dict['tokenized_ids'])).tolist()


# create directory
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
with open(f'{output_dir}/{f_name}.pkl','wb') as f:
    pickle.dump(all_result_dict, f, protocol = 4)

with open(f'{output_dir}/{f_name}protIdBySpan.pkl','wb') as f:
    pickle.dump(corpus_proteinOrigIdBySpan, f, protocol = 4)

with open(f'{output_dir}/{f_name}origIdById.pkl' ,'wb') as f:
    pickle.dump(corpus_docId2OrigId, f, protocol = 4)

HBox(children=(FloatProgress(value=0.0, max=7802.0), HTML(value='')))

PMC1852721.d9-0: Exceed max length
PMC2725397.d9-0: Exceed max length
PMC2736877.d15-0: Exceed max length
PMC2970923.d5-0: Exceed max length
PMC2970804.d14-0: Exceed max length
PMC3127101.d5-0: Exceed max length
PMC3127101.d13-0: Exceed max length
PMC3230553.d20-0: Exceed max length
PMC3304337.d0-0: Exceed max length
PMC3304337.d10-0: Exceed max length
('[', 'µ', µL proteinase K) : PMC3311214.d2
PMC3522772.d23-0: Exceed max length
PMC3858665.d0-0: Exceed max length
('g', '\xad', ­glutamine) : PMC3867001.d13
PMC4000470.d1-0: Exceed max length
PMC4095937.d20-0: Exceed max length
PMC4148155.d4-0: Exceed max length
PMC4095937.d64-0: Exceed max length
PMC4095937.d86-0: Exceed max length
PMC4095937.d123-0: Exceed max length
PMC4095937.d137-0: Exceed max length
PMC4426353.d0-0: Exceed max length
PMC4514392.d12-0: Exceed max length
PMC4578197.d1-0: Exceed max length
('[', 'µ', µM CT) : PMC4580978.d15
('[', '…', … I) : PMC3654146.d1416
PMC4682876.d15-0: Exceed max length
PMC4711683.d13-0: Excee

PMC7122307.d18-0: Exceed max length
PMC7122375.d16-0: Exceed max length
PMC7122433.d4-0: Exceed max length
PMC7122471.d4-0: Exceed max length
PMC7122603.d401-0: Exceed max length
PMC7122853.d21-0: Exceed max length
PMC7122853.d34-0: Exceed max length
('[', '˜', ˜13–17 amino acids) : PMC7122908.d8
PMC7122908.d35-0: Exceed max length
PMC7123060.d10-0: Exceed max length
PMC7123080.d19-0: Exceed max length
PMC7122603.d1161-0: Exceed max length
PMC7123154.d18-0: Exceed max length
PMC7123232.d7-0: Exceed max length
PMC7123232.d10-0: Exceed max length
PMC7123262.d6-0: Exceed max length
PMC7123311.d5-0: Exceed max length
PMC7123318.d67-0: Exceed max length
('c', '\xad', ­collagen) : PMC7123455.d22
PMC7123919.d10-0: Exceed max length
PMC7123921.d9-0: Exceed max length
PMC7123921.d19-0: Exceed max length
PMC7123921.d27-0: Exceed max length
PMC7123921.d29-0: Exceed max length
PMC7123921.d30-0: Exceed max length
PMC7123938.d22-0: Exceed max length
PMC7123921.d39-0: Exceed max length
PMC7123921.d52