In [1]:
import json
import os
import numpy as np
from utils import *
import copy
import pprint
import glob
from data_load import *

No NCBI API key provided; throttling to 3 requests/second; see https://ncbiinsights.ncbi.nlm.nih.gov/2017/11/02/new-api-keys-for-the-e-utilities/


In [2]:
train_withGNP = True
ignoreLongDocument= False
os.environ['BC6PM_dir'] = '../BC6PM'
os.environ['pretrain_dir'] = '../BioBERT-Models'

In [3]:
if train_withGNP:
    train_data_dir = os.path.join(os.environ.get('BC6PM_dir'), 'GNormPlus', 'withAnn-Result', 'PMtask_Relations_TrainingSet_r.json')
else:
    train_data_dir = os.path.join(os.environ.get('BC6PM_dir'), 'json', 'PMtask_Relations_TrainingSet.json')
with open(train_data_dir) as f:
    train_data = json.load(f)
documents_train = train_data['documents']
with open(os.path.join(os.environ.get('BC6PM_dir'), 'GNormPlus', 'result', 'PMtask_Relations_TestSet_r.json')) as f:
    eval_data = json.load(f)
documents_eval = eval_data['documents']
with open(os.path.join(os.environ.get('BC6PM_dir'), 'json', 'PMtask_Relations_TestSet.json')) as f:
    eval_data_ground_truth = json.load(f)
documents_eval_ground_truth = eval_data_ground_truth['documents']
if ignoreLongDocument: # Ignore too long text
    with open('pmid2tokenlen.json') as f:
        pmid2tokenlen = json.load(f)
    def filterLongDocu(documents):
        documents_ = []
        for document in documents:
            if pmid2tokenlen.get(document['id'], 1) < 512: # some documents may not have RC instance
                documents_.append(document)
        return documents_
    documents_train = filterLongDocu(documents_train)
    documents_eval = filterLongDocu(documents_eval)
    documents_eval_ground_truth = filterLongDocu(documents_eval_ground_truth)

### Ground Truth Relation Distribution

In [4]:
relation_num = []
for document in documents_eval_ground_truth:
    relation_num.append(len(document['relations']))
relation_num = np.array(relation_num)
print(relation_num.mean(), [(relation_num==i).sum() for i in range(0,10)])

1.368503937007874 [0, 477, 111, 26, 15, 4, 2, 0, 0, 0]


### Eval Performance

In [5]:
def relationFilter(documents, confidence_threshold = 0.5, at_least = 0, at_most = 4):
    documents_rtn = copy.deepcopy(documents)
    for i, document in enumerate(documents):
        documents_rtn[i]['relations'] = []
        relations =  sorted(document['relations'], key=lambda relation: relation['infons']['confidence'], reverse = True)
        for j, relation in enumerate(relations):
            if relation['infons']['confidence']  > th_conf:
                 documents_rtn[i]['relations'].append(relation)
        if len(documents_rtn[i]['relations']) < at_least and len(relations)>0:
            documents_rtn[i]['relations'] = [relations[:at_least]]
        if len(documents_rtn[i]['relations']) > at_most and at_most > 0:
            documents_rtn[i]['relations'] = relations[:at_most]
    return documents_rtn
def countDocumentNums(documents_pred_all_relation, gold_standard_relations, confidence_threshold = 0.5):
    """
    if `documents` is `documents_filtered` :
    return (pred_positive_num, true_positive_num, false_positive_num, relation2PredResult) 
    if `documents` is `documents_all_relations`:
    return (pred_positive_num, label_1_num, label_0_num, relation2GroundTruthLabel)
    """
    pred_pos = 0
    label_1_num_or_tp_num = 0
    label_0_num_or_fp_num = 0
    relation2Result = {}
    for document in documents_pred_all_relation:
        for relation in document['relations']:
            if  relation['infons']['confidence'] >= confidence_threshold:
                pred_pos += 1

            geneids = [relation['infons']['Gene1'], relation['infons']['Gene2']]
            geneids.sort()
            relation_str = 'PMID' + document['id'] + '_' + '_'.join(geneids)
            if relation_str in gold_standard_relations:
                label_1_num_or_tp_num += 1
                relation2Result[relation_str] = 1
            else:
                label_0_num_or_fp_num += 1
                relation2Result[relation_str] = 0
        
    return pred_pos, label_1_num_or_tp_num, label_0_num_or_fp_num, relation2Result

#### Post process of outputjson data

In [6]:
gold_standard_ids, gold_standard_relations = JSON_Collection_Relation({'documents': documents_eval_ground_truth})
th_conf = 0.5 # threshold of confidence
at_least = 0  # assume the minimal/maximal number of PPIm of one single document in predicited relations at least/most  
at_most = 4
fileList = glob.glob('outputjson/*Test*.json')
fileList = sorted(fileList, key = lambda f : (f.split(".json")[0].split("_")[0], int(f.split(".json")[0].split("_")[1])))
print('\t'.join(['Setting'+' '*42, 'Epoch', 'P-Ext', 'R-Ext', 'F1-Ext', 'P-Homo', 'R-Homo', 'F1-Homo']))#'R-GNP' ,
for fileName in fileList:
    name, idx = fileName.split(".json")[0].split("_")
    if int(idx) < 11:
        print(name.split("/"))
#         print("{:50}".format(name.split("/")[1]), idx,sep = '\t', end = '\t')
    else:
        continue    
    with open(fileName) as f:
        documents_all_relations = json.load(f)['documents']
    documents_filtered = relationFilter(documents_all_relations, confidence_threshold= th_conf, at_least = at_least, at_most = at_most)

    pred_positive_num, label_1_num, label_0_num, r2truth = countDocumentNums(documents_all_relations, gold_standard_relations)
    pred_positive_num, true_positive_num, false_positive_num, r2pred = countDocumentNums(documents_filtered, gold_standard_relations)

    # R-GNP: recall of all gene pair generated by GNP, which means ignoring recognized genes.
#     print("{:.2f}".format(100*true_positive_num/label_1_num), end='\t')

    # Exact match
    micro_precision, micro_recall, micro_f1, macro_precision, macro_recall, macro_f1 = Classification_Performance_Relation({'documents': documents_filtered}, gold_standard_ids, gold_standard_relations, homolo = False)
    print("{:.2f}\t{:.2f}\t{:.2f}".format(100*micro_precision, 100*micro_recall, 100*micro_f1), end = '\t')
    # Homolo match
    micro_precision, micro_recall, micro_f1, macro_precision, macro_recall, macro_f1 = Classification_Performance_Relation({'documents': documents_filtered}, gold_standard_ids, gold_standard_relations, homolo = True)
    print("{:.2f}\t{:.2f}\t{:.2f}".format(100*micro_precision, 100*micro_recall, 100*micro_f1))
#     break

Setting                                          	Epoch	P-Ext	R-Ext	F1-Ext	P-Homo	R-Homo	F1-Homo
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
27.56	42.46	33.42	29.79	45.48	36.00
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
32.13	41.08	36.06	34.42	44.08	38.66
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
26.38	45.11	33.29	28.51	48.49	35.91
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
34.62	38.20	36.32	36.84	40.60	38.63
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
26.87	44.99	33.65	28.95	48.03	36.13
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
32.62	40.62	36.19	35.11	43.50	38.86
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
26.87	45.45	33.78	29.26	48.84	36.59
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
30.63	39.13	34.36	32.85	41.88	36.82
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
37.52	35.10	36.27	39.83	37.47	38.61
['outputjson\\Joint-RC-NER-SaveModel-Test-fineTune']
33.85	37.51	35.59	36.16	40.02	38.00
['outputjson\

### Case Study

#### collect result

In [8]:
fileList = glob.glob('outputjson/*Test*.json')
fileList = sorted(fileList, key = lambda f : (f.split(".json")[0].split("_")[0], int(f.split(".json")[0].split("_")[1])))
targetFileName = {
               'rc_triage': 'Joint-RC-triage-Test-fineTune-GNP-weightLabel_10',
                'rc':  'RC-Only-Test-fineTune-GNP-weightLabel_10',
                'rc_ner':  'Joint-RC-NER-SaveModel-Test-fineTune-GNP_10'
}
settings = sorted(targetFileName.keys())
r2pred_dict = {}
setting2documents_all_relations = {}
for setting in settings:
    path = os.path.join('outputjson', targetFileName[setting]+'.json')
    print(setting.upper())
    with open(path) as f:
        setting2documents_all_relations[setting] = json.load(f)['documents']
    documents_filtered = relationFilter(setting2documents_all_relations[setting], confidence_threshold= th_conf, at_least = at_least, at_most = at_most)

    pred_positive_num, label_1_num, label_0_num, r2truth = countDocumentNums(setting2documents_all_relations[setting], gold_standard_relations)
    print('pred_positive', pred_positive_num,'label_1', label_1_num,'label_0', label_0_num, sep = '\t')
    pred_positive_num, true_positive_num, false_positive_num, r2pred_dict[setting] = countDocumentNums(documents_filtered, gold_standard_relations)
    print('pred_positive', pred_positive_num, 'tp', true_positive_num,'fp',  false_positive_num, sep = '\t')

NameError: name 'glob' is not defined

In [9]:
pred_result2relation = {}
print('\t'.join(['pmid', 'g1', 'g2', "label"]+settings))
for truth_label_target in [0, 1]:
    notFound = 'false_neg' if truth_label_target == 1 else 'true_neg'
    predResult ={1:'true', 0: 'false', 'true_neg': 'true', 'false_neg': 'false'}

    for relation_str, label in r2truth.items():
        if label == truth_label_target:
            pred_result = [label] + [predResult[r2pred_dict[setting].get(relation_str, notFound)] for setting in settings]
            pred_result = tuple(pred_result)
            print('\t'.join(relation_str[4:].split('_')), '\t'.join([str(x) for x in pred_result]), sep ='\t')
            if pred_result2relation.get(pred_result) is None:
                    pred_result2relation[pred_result] = []
            pred_result2relation[pred_result].append(relation_str)

NameError: name 'settings' is not defined

In [10]:
print('\t'.join(['label'] +settings+ [ 'cnt']) )
pred_result_cnt = { k: len(v) for k, v in pred_result2relation.items()}
for k, v in pred_result_cnt.items():
    print( '\t'.join([str(x) for x in k]), v, sep='\t')

NameError: name 'settings' is not defined

#### Print case

In [11]:
def getDocumentByRelation(documents, relation_str):
    pmid, g1,g2 = relation_str[4:].split('_')
    for document in documents:
        if document['id'] != pmid:
            continue
        for relation in document['relations']:
            geneids = [relation['infons']['Gene1'], relation['infons']['Gene2']]
            geneids.sort()
            if geneids == [g1, g2]:
                return document, relation
def printGeneMentionInRelation(document, relation_str):
    pmid, g1,g2 = relation_str[4:].split('_')
    geneid2text = {}
    geneid2text[g1] = []
    geneid2text[g2] = []
    for ann in (document['passages'][0]['annotations'] + document['passages'][1]['annotations']):
        if Annotation.gettype(ann) == 'Gene':
            geneid = Annotation.getNCBIID(ann)
            if geneid in [g1, g2]:
                geneid2text[geneid].append(ann['text'])
    for id_, texts in geneid2text.items():
        if id_ == g1:
            print("GeneA", end='\t')
        else:
            print("GeneB", end = '\t')
        print(f"ID: {id_}", '; Mentions:' if len(set(texts)) > 1 else 'Mention:', ','.join(set(texts)))

def pprintline(text, width = 50):
    dashNum = width - len(text)
    offset = dashNum //2
    print('-'*(dashNum-offset), text, '-'*(dashNum+offset))


In [11]:
# result = (0, 'false', 'true', 'true')
result = (1, 'true', 'true', 'true')
pprintline("Result")
print('label','\t'.join(settings), sep = '\t')
print(result[0],'\t'.join(result[1:]), sep='\t')
cnt = 0
for relation_str in pred_result2relation[result]:
    pprintline('PMID&GeneID')

    setting2relation = {}
    print( '\t'.join(relation_str[4:].split('_')))
    pprintline('Confidence')
    for setting, documents_all_relations in setting2documents_all_relations.items():
        document, relation = getDocumentByRelation( documents_all_relations, relation_str)
        setting2relation[setting] = relation
        print(f"{setting:10}", f"{relation['infons']['confidence']:.2%}", sep = '\t')
    pprintline('Gene Mention')
    printGeneMentionInRelation(document, relation_str)
    pprintline('Original Text')
    print(document['passages'][0]['text'], document['passages'][1]['text'])
    dataset = RCDataSet([document],os.environ.get('pretrain_dir') )
    for example in dataset.examples:
        if example.guid == relation_str[4:]:
            pprintline('Pre-processed Text')
            preprocessedText = example.text_a
            for word in preprocessedText.split():
                if word == "Gene_A":
                    formated = f"\033[0;37;41m{word}\033[0m"
                elif word == "Gene_B":
                    formated = f"\033[0;37;40m{word}\033[0m"
                elif word == "Gene_S":
                    formated = f"\033[0;37;40m{word}\033[0m"
                elif word == "Gene_N":
                    formated = f"\033[0;30;43m{word}\033[0m"
                else:
                    formated = word
                print(formated, end =' ')
            print()
#             print(preprocessedText)
            ids = dataset.collate_fn([example])[0]['input_ids'][0]
            pprintline('To BERT')
            tokens = dataset.tokenizer.convert_ids_to_tokens(ids)
            print(dataset.tokenizer.decode(ids))
            pprintline('Pre-processed Text For Latex')
            preprocessedText = example.text_a

            for word in preprocessedText.split():
                if word == "Gene_A":
                    formated = "\colorbox{blue-wy}{Gene\_A}"
                elif word == "Gene_B":
                    formated = "\colorbox{mossgreen}{Gene\_B}"
                elif word == "Gene_S":
                    formated = "\colorbox{plum(web)}{Gene\_S}"
                elif word == "Gene_N":
                    formated = "\colorbox{peach}{Gene\_N}"
                else:
                    formated = word
                print(formated, end =' ')
            print()
    cnt += 1
    if cnt > 3:
        break
#     break

---------------------- Result ------------------------------------------------------------------
label	rc	rc_ner	rc_triage
1	true	true	true


KeyError: (1, 'true', 'true', 'true')

### (Unfinished) Gene Distribution in Sentences 在句子中分布情况 

In [12]:
def train_document_process(document):
    pmid = document['id']
    relations = set()
    genes = set()
    text_li = []
    passages = document['passages']
    split_word = re.compile(r"\w+|\S")
    for r in document['relations']:
        relations.add((r['infons']['Gene1'], r['infons']['Gene2']))

    for passage in passages:
        anns = passage['annotations']
        text = passage['text']
        offset_p = passage['offset']
        index = 0
        if len(anns) == 0:
            text_li.extend(split_word.findall(text))
        else:
            anns = Annotation.sortAnns(anns)
            for ann in anns:
                if Annotation.gettype(ann) == 'Gene':
                    for infon_type in ann['infons']:
                        if infon_type.lower() == 'ncbi gene':
                            genes.add(ann['infons'][infon_type])
                else:
                    continue
                for i, location in enumerate(ann['locations']):
                    offset = location['offset']
                    length = location['length']
                    text_li.extend(split_word.findall(
                        text[index:offset-offset_p]))
                    if i == len(ann['locations']) - 1:
                        ncbi_gene_id = Annotation.getNCBIID(ann)
                        # for infon_type in ann['infons']:
                            # if infon_type.lower() in ['ncbi gene', 'identifier']:
                                # ncbi_gene_id = ann['infons'][infon_type]
                        text_li.append("Gene_{}".format(ncbi_gene_id))
                    index = max(offset - offset_p + length, index)
            text_li.extend(split_word.findall(text[index:]))
    return pmid, text_li, genes, relations

In [13]:
with open(os.path.join(os.environ.get('BC6PM_dir'), 'GNormPlus', 'withAnn-Result', 'PMtask_Relations_TrainingSet_r.json')) as f:
    collection = json.load(f)
    documents = collection['documents']
pmid, text_li, genes, relations = train_document_process(documents[0])
for sent in (' '.join(text_li)).split("."):
    sent_li = sent.split(' ')
    print(sent_li)
    for gene in genes:
        if f"Gene_{gene}" in sent_li:
            print(gene, 'exist')
print(genes)

['Crystal', 'structure', 'of', 'Gene_817592', 'from', 'Arabidopsis', 'thaliana', 'and', 'its', 'Gene_827837', 'binding', 'characterization', '']
817592 exist
827837 exist
['', 'Microtubules', 'are', 'composed', 'of', 'polymerized', 'alpha', '/', 'Gene_827837', 'heterodimers', '']
827837 exist
['', 'Biogenesis', 'of', 'assembly', '-', 'competent', 'tubulin', 'dimers', 'is', 'a', 'complex', 'multistep', 'process', 'that', 'requires', 'sequential', 'actions', 'of', 'distinct', 'molecular', 'chaperones', 'and', 'cofactors', '']
['', 'Gene_817592', '(', 'Gene_817592', ')', ',', 'which', 'captures', 'Gene_827837', 'during', 'the', 'folding', 'pathway', ',', 'has', 'been', 'identified', 'in', 'many', 'organisms', '']
817592 exist
827837 exist
['', 'Here', ',', 'we', 'report', 'the', 'crystal', 'structure', 'of', 'Arabidopsis', 'thaliana', 'Gene_817592', '(', 'Gene_817592', ',', 'Gene_817592', ')', ',', 'which', 'forms', 'a', 'monomeric', 'three', '-', 'helix', 'bundle', '']
817592 exist
['', 