In [12]:
import json
import os
from glob import glob
import re
import numpy as np

In [13]:
import spacy
# from scispacy.umls_linking import UmlsEntityLinker
NER_tagger = spacy.load('en_ner_jnlpba_md')
# linker = UmlsEntityLinker(resolve_abbreviations=True)
# NER_tagger.add_pipe(linker)

In [14]:
from transformers import *
# load bert tokenizer
bert_dir = 'allenai/scibert_scivocab_uncased'
tokenizer = BertTokenizer.from_pretrained(bert_dir)

In [6]:
max_length=230
dummy_span = None

In [16]:
a = NER_tagger('In mature human B cells, BMP-6 inhibited cell growth, and rapidly induced phosphorylation of Smad1/5/8 followed by an upregulation of Id1.')
a

In mature human B cells, BMP-6 inhibited cell growth, and rapidly induced phosphorylation of Smad1/5/8 followed by an upregulation of Id1.

In [18]:
[ent.start for ent in a.ents]

[1, 6, 16, 22]

In [25]:
entity_types = [token.ent_type_ for token in a]
entity_types[22]

'PROTEIN'

In [26]:
a[22].ent_type_

'PROTEIN'

In [30]:
def filter_protein_entities(entities, document):
    '''
    Feed a list of entities, and return the span of those whose type is proein
    '''
    # store the start and end of each mapped entity
    document_entities = []

    for entity in entities:
        
        # if the entity is a protein
        if document[entity.start].ent_type_ == 'PROTEIN':
            document_entities.append(entity)
        
    return document_entities

In [32]:
from collections import defaultdict


def process_document(document, doc_id, output_dir, corpus_proteinOrigIdBySpan):
    '''
    Run Spacy to extract proteins and split single document into sentences
    '''
    document = NER_tagger(document)
    document_entities = filter_protein_entities(document.ents, document)
    
    # create mapping from starting character position to entity id
    ent_char2_id_map = {ent.start_char:f'T{idx+1}' for idx, ent in enumerate(document_entities)}
    
    with open(f'{output_dir}/{doc_id}.a1','w')as f:
        [f.write(f'{ent_char2_id_map[ent.start_char]}\tProtein {ent.start_char} {ent.end_char}\t{ent}\n') for ent in document_entities]
    doc_result_dict = []
    for sentence in document.sents:
        doc_result_dict.append(generate_sentence_instance(document, sentence, doc_id, ent_char2_id_map, corpus_proteinOrigIdBySpan))
    
    return doc_result_dict
    

In [33]:
import pickle
def process_documents(file_path, output_dir):
    '''
    process each document under pmc_json
    '''
    os.makedirs(output_dir, exist_ok=True)
    file_name = file_path.split('/')[-1].split('.')[0]
    with open(file_path,'r') as f:
        document = json.loads(f.read())

    for doc_id, doc in enumerate(document['body_text']):
        doc_id = f'{file_name}.d{doc_id}'
        # ignore these sections
        if doc['section'] not in ['Introduction','Concluding remark','Conclusion']:
            with open(f'{output_dir}/{doc_id}.txt', 'w') as f:
                f.write(doc['text'])
            try:
                # process documents
                for sent_result_dict in process_document(doc['text'], doc_id, output_dir, corpus_proteinOrigIdBySpan):
                    if sent_result_dict is not None:
                        # store key, value pairs into the all result dict
                        for key, value in sent_result_dict.items():
                            all_result_dict[key].append(value)

                corpus_docId2OrigId[doc_id] = doc_id
            except Exception as e:
                print(f'{e} : {doc_id}')

In [34]:
def generate_sentence_instance(document, sentence, doc_id, ent_char2_id_map, corpus_proteinOrigIdBySpan):
    '''
    Generate one instance per sentence.
    '''
    
    # only mark the head tokens as label

    entities = sentence.ents
    
    # start char offset of this sentence    
    sent_start = sentence.start_char
    entities = filter_protein_entities(entities, document)
    # don't consider sentence with no entity 
    if len(entities) == 0: return None
    
    # obtain the text of this sentence
    text = str(sentence)
    
    # head entity char position -> bert head token position
    # TODO: create origin to token map
    char_to_token_map = {}
    token_to_char_map = {}
    
    char_to_token_tail_map = {}
    new_tokens = ["[CLS]"]
            
    # cumulative the cumulative sum of length in origin_tokens except for [CLS]
    cur_tokens_length = 0
    

    # split sentence into sub-sentences
    segment_numbers = [0]+\
    np.hstack([[entity.start_char - sent_start, entity.end_char - sent_start]  for entity in entities ]).tolist()\
    + [len(text)+1]   
    
    sub_sentences = [text[start:end] for start, end in zip(segment_numbers[:-1], segment_numbers[1:])]

    # absolute span of each entity
    abs_spans = [dummy_span]

    for sub_sentence in sub_sentences:
        # create a mapping from character space to bert token space

        skip_patterns = r'([^a-zA-Z0-9])'

        for original_texts in re.split(skip_patterns, sub_sentence):
        
            if len(original_texts) == 0: continue

            # a list of bert tokens
            bert_tokens = tokenizer.tokenize(original_texts)
            
            head_char_position = cur_tokens_length 

            # head token position
            char_to_token_map[head_char_position] = len(new_tokens)
            token_to_char_map[len(new_tokens)] = head_char_position

            cur_tokens_length += len(original_texts)
            new_tokens.extend(bert_tokens)
            
            tail_char_position = cur_tokens_length -1 

            # if the tail_char_position is a list
            while text[tail_char_position] in [' '] and tail_char_position >= head_char_position:
                tail_char_position -= 1
            
            # this is invalid
            if tail_char_position < head_char_position or re.search(skip_patterns, original_texts) is not None: #original_texts=='\n':# 

                abs_spans.extend([dummy_span] * (len(bert_tokens) ))                
                continue
            
            tail_position = len(new_tokens) 
            
            # trim the tail token position if the token is not [UNK]
            if new_tokens[tail_position-1] not in ['[UNK]']:
                while text[tail_char_position].lower() != new_tokens[tail_position-1][-1]:
                    # tail token position
                    tail_position -= 1


            char_to_token_tail_map[tail_char_position] = len(new_tokens) 
            abs_span = '-'.join([str(head_char_position+sent_start), str(tail_char_position+1+sent_start)])                    

            abs_spans.append(abs_span)  

            # append dummy spans
            abs_spans.extend([dummy_span] * (len(bert_tokens) -1))
    
    entity_label = np.array(['None'] * (len(new_tokens) + 1) ,dtype=object)  # plus 1 for [SEP]

    # iteratev over each offset
    for entity in entities:

        head_char_idx = entity.start_char - sent_start
        tail_char_idx = entity.end_char - sent_start

        # only add if there is a None
        assert entity_label[char_to_token_map[head_char_idx]] == 'None'
        
        # assign protein label
        entity_label[char_to_token_map[head_char_idx]] = 'Protein'



        # doc_id -> span -> protein origid
        corpus_proteinOrigIdBySpan[doc_id][abs_spans[char_to_token_map[head_char_idx]]] = f'{doc_id}.{ent_char2_id_map[entity.start_char]}'



        # make sure at least the first character of the mapping is correct
        assert new_tokens[char_to_token_map[head_char_idx]][0] == text[head_char_idx].lower(), (new_tokens[char_to_token_map[head_char_idx]][0] , text[head_char_idx], entity)


    # convert entity label back to list for zero padding
    entity_label = entity_label.tolist()


    new_tokens.append('[SEP]')
    abs_spans.append(dummy_span)            

    assert len(new_tokens) == len(abs_spans), (len(new_tokens) , len(abs_spans))


    if len(new_tokens) > max_length:
        print(f"{doc_id}-{sentence.ent_id}: Exceed max length")
        return None
    tokenized_ids = tokenizer.convert_tokens_to_ids(new_tokens)

    # mask ids
    mask_ids = [1] * len(tokenized_ids)

    # segment ids
    segment_ids = [0] * len(tokenized_ids)

    if len(tokenized_ids) < max_length:
        # Zero-pad up to the sequence length
        padding = [0] * (max_length - len(tokenized_ids))
        tokenized_ids += padding
        entity_label += ['None'] * len(padding)
        mask_ids += padding
        segment_ids += padding

    assert len(tokenized_ids) == max_length == len(mask_ids) == len(segment_ids) == len(entity_label ) ,\
     (len(tokenized_ids) ,max_length ,len(mask_ids) ,len(segment_ids), len(entity_label) )


    result_dict = {
    'tokenized_ids': tokenized_ids,
    'entity_labels': entity_label,
    'mask_ids': mask_ids,
    'segment_ids': segment_ids,
    'sent_ids':sentence.ent_id,
    'doc_ids': doc_id,
    'token_to_char_map': token_to_char_map,
    'char_to_token_map': char_to_token_map,
    'abs_spans':abs_spans
    }
    return result_dict

In [None]:
from tqdm.notebook import tqdm
import multiprocessing 
from joblib import Parallel, delayed
import shutil

import os

corpus_docId2OrigId = {}
corpus_proteinOrigIdBySpan = defaultdict(dict)
all_result_dict = defaultdict(list)

genia_input_folder='genia_cord_19'
output_dir='preprocessed_data'
f_name = 'CORD_19_PMC'    


# clear out previous results
shutil.rmtree(genia_input_folder, ignore_errors=True)
shutil.rmtree(output_dir, ignore_errors=True)


parallel = Parallel(multiprocessing.cpu_count(), backend="threading", verbose=0)
parallel(delayed(process_documents)(file_path, genia_input_folder) for file_path in tqdm(glob('custom_license/custom_license/pmc_json/*')))
all_result_dict['sample_ids'] = np.arange(len(all_result_dict['tokenized_ids'])).tolist()


# create directory
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    
with open(f'{output_dir}/{f_name}.pkl','wb') as f:
    pickle.dump(all_result_dict, f, protocol = 4)

with open(f'{output_dir}/{f_name}protIdBySpan.pkl','wb') as f:
    pickle.dump(corpus_proteinOrigIdBySpan, f, protocol = 4)

with open(f'{output_dir}/{f_name}origIdById.pkl' ,'wb') as f:
    pickle.dump(corpus_docId2OrigId, f, protocol = 4)

HBox(children=(FloatProgress(value=0.0, max=7802.0), HTML(value='')))

('[', 'µ', µM) : PMC1142193.d3
PMC1474053.d9-0: Exceed max length
('[', 'µ', µM) : PMC1475929.d14
PMC1824729.d5-0: Exceed max length
('[', '…', …) : PMC1876871.d3
PMC1852721.d9-0: Exceed max length
PMC1859983.d18-0: Exceed max length
('[', '▶', ▶  ) : PMC1850917.d31
('[', '…', …) : PMC2516372.d36
('a', 'Å', Å2 buried31) : PMC2584968.d10
PMC2630752.d2-0: Exceed max length
('[', 'µ', µL) : PMC2662633.d1
('[', '…', …) : PMC2681116.d1
('a', 'Å', Å) : PMC2658058.d22
('[', 'µ', µl protein G-Sepharose) : PMC2655771.d25
('[', 'µ', µL) : PMC2687000.d2
('a', 'Å', Å) : PMC2692245.d6
PMC2712425.d6-0: Exceed max length
PMC2723995.d11-0: Exceed max length
PMC2725397.d9-0: Exceed max length
('[', 'µ', µL) : PMC2725899.d10
PMC2751450.d2-0: Exceed max length
PMC2736877.d15-0: Exceed max length
('[', '…', …) : PMC2751450.d14
PMC2749164.d20-0: Exceed max length
PMC2736877.d30-0: Exceed max length
('[', 'ɛ', ɛ (epsilon) portion) : PMC2769243.d10
PMC2782110.d3-0: Exceed max length
PMC2782110.d3-0: Exceed m

PMC3983375.d82-0: Exceed max length
PMC4095937.d20-0: Exceed max length
('[', '≳', ≳6
) : PMC4051247.d41
PMC4148155.d4-0: Exceed max length
PMC4095937.d64-0: Exceed max length
('a', 'Å', Å) : PMC4165820.d7
PMC4095937.d78-0: Exceed max length
('[', '™', ™ Fluorescent Protein) : PMC4167438.d14
PMC4167443.d9-0: Exceed max length
PMC4095937.d86-0: Exceed max length
PMC4095937.d94-0: Exceed max length
('[', '¼', ¼) : PMC4169708.d21
('[', '⩾', ⩾80) : PMC4180291.d7
('[', '…', …) : PMC4193163.d1
('[', '…', …) : PMC4193163.d2
('[', '…', …) : PMC4193163.d10
PMC4095937.d118-0: Exceed max length
PMC4095937.d122-0: Exceed max length
PMC4095937.d123-0: Exceed max length
('[', 'µ', µL) : PMC4179962.d31
('[', 'ϕ', ϕ6) : PMC4207942.d5
PMC4095937.d137-0: Exceed max length
('[', '…', …I) : PMC3654146.d886
('[', '…', …) : PMC4205922.d29
('[', '…', …) : PMC3654146.d889
('[', 'ϕ', ϕ6) : PMC4207942.d28
('[', 'ϕ', ϕ6 polymerases) : PMC4207942.d30
('[', 'ϕ', ϕ6 RdRp) : PMC4207942.d36
('a', 'Å', Å) : PMC4268874

('=', '≠', ≠) : PMC6044682.d15
PMC6052913.d35-0: Exceed max length
PMC6052913.d36-0: Exceed max length
PMC6052913.d41-0: Exceed max length
('[', 'µ', µL) : PMC6111962.d8
PMC6111954.d12-0: Exceed max length
('[', '…', …) : PMC6106442.d8
PMC6066311.d38-0: Exceed max length
('[', 'µ', µ-calpain) : PMC6140930.d8
PMC6145396.d12-0: Exceed max length
('[', 'µ', µ-calpain) : PMC6140930.d23
PMC6145396.d17-0: Exceed max length
PMC6145396.d18-0: Exceed max length
PMC6145396.d20-0: Exceed max length
PMC6178105.d4-0: Exceed max length
PMC6145396.d24-0: Exceed max length
PMC6145396.d25-0: Exceed max length
('[', 'µ', µ-calpain) : PMC6140930.d32
PMC6145396.d30-0: Exceed max length
PMC6145396.d31-0: Exceed max length
PMC6145396.d32-0: Exceed max length
PMC6145396.d33-0: Exceed max length
PMC6145396.d35-0: Exceed max length
PMC6145396.d36-0: Exceed max length
PMC6145396.d36-0: Exceed max length
PMC6207709.d7-0: Exceed max length
PMC6145396.d37-0: Exceed max length
PMC6145396.d37-0: Exceed max length
PM

PMC7080032.d23-0: Exceed max length
PMC7080032.d23-0: Exceed max length
PMC7080032.d23-0: Exceed max length
PMC7080032.d23-0: Exceed max length
PMC7080032.d23-0: Exceed max length
PMC7080047.d21-0: Exceed max length
PMC7080047.d22-0: Exceed max length
PMC7080047.d23-0: Exceed max length
PMC7080032.d24-0: Exceed max length
PMC7080032.d24-0: Exceed max length
PMC7080032.d24-0: Exceed max length
PMC7080032.d24-0: Exceed max length
PMC7080032.d24-0: Exceed max length
PMC7080032.d24-0: Exceed max length
PMC7080032.d24-0: Exceed max length
PMC7080032.d25-0: Exceed max length
PMC7080032.d26-0: Exceed max length
PMC7080032.d26-0: Exceed max length
PMC7080032.d28-0: Exceed max length
('o', 'Ö', Öls bestimmt werden kann) : PMC7080060.d7
PMC7080047.d32-0: Exceed max length
PMC7080032.d30-0: Exceed max length
PMC7080032.d30-0: Exceed max length
('o', 'Ö', Öle ist) : PMC7080060.d11
('o', 'Ö', Öls) : PMC7080060.d12
('o', 'Ö', Ölen zählen) : PMC7080060.d18
PMC7080093.d8-0: Exceed max length
PMC708009