In [35]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from deeponto.onto import Ontology
from deeponto.align.bertmap import BERTMapPipeline
from deeponto.align.evaluation import AlignmentEvaluator
from deeponto.utils import FileUtils
from deeponto.align.mapping import EntityMapping, ReferenceMapping
import pandas as pd
import random
import numpy as np

In [2]:
# load source and target ontologies
src_onto_path = "data/ncit2doid/ncit.owl"
tgt_onto_path = "data/ncit2doid/doid.owl"
src_onto = Ontology(src_onto_path)
tgt_onto = Ontology(tgt_onto_path)
config = BERTMapPipeline.load_bertmap_config()

Use the default configuration at /home/yuan/anaconda3/envs/deeponto/lib/python3.8/site-packages/deeponto/align/bertmap/default_config.yaml.


In [3]:
# build annotation index {class_iri: class_labels}
src_annotation_index, _ = src_onto.build_annotation_index(config.annotation_property_iris)
tgt_annotation_index, _ = tgt_onto.build_annotation_index(config.annotation_property_iris)

In [4]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto")

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [5]:
test_cands = FileUtils.read_table("data/ncit2doid/test_cands.tsv")

In [8]:
# controlling input labels because of the window size limit of T5
def process_labels(labels, cut_off = 3):
    labels = list(labels)
    if len(labels) >= cut_off:
        labels.sort(key=len, reverse=True)
    return labels[:cut_off]

In [11]:
def v_text(src_labels, tgt_labels, src_parent_labels=None, tgt_parent_labels=None, src_child_labels=None, tgt_child_labels=None):
    v = "Consider two concepts, each represented by a list of associated names.\n\n"
    v += f"Source concept: {src_labels}.\n"
    if src_parent_labels:
        v = v[:-1] # removing new line character
        v += f" Its parent concepts are represented by: {src_parent_labels}.\n"
    if src_child_labels:
        v = v[:-1] # removing new line character
        v += f" Its child concepts are represented by: {src_child_labels}.\n"
    v += f"Target concept: {tgt_labels}.\n\n"
    if tgt_parent_labels:
        v = v[:-2]
        v += f" Its parent concepts are represented by: {tgt_parent_labels}.\n\n"
    if tgt_child_labels:
        v += f" Its child concepts are represented by: {tgt_child_labels}.\n\n"
    v += "Given these representations, can you determine if the source concept and the target concept are identical? Please answer with \"Yes\" if they are identical or \"No\" if they are not."
    return v

print(v_text(["A"], ["B"], ["P_A1", "P_A2"], ["P_B1"], ["C_A1"]))

Consider two concepts, each represented by a list of associated names.

Source concept: ['A']. Its parent concepts are represented by: ['P_A1', 'P_A2']. Its child concepts are represented by: ['C_A1'].
Target concept: ['B']. Its parent concepts are represented by: ['P_B1'].

Given these representations, can you determine if the source concept and the target concept are identical? Please answer with "Yes" if they are identical or "No" if they are not.


In [19]:
# vanila setting (no structural context)
results = []
corrects = []
for i, dp in test_cands.iterrows():
    src_iri = dp["SrcEntity"]
    tgt_iri = dp["TgtEntity"]
    tgt_cands = eval(dp["TgtCandidates"])
    src_labels = process_labels(src_annotation_index[src_iri])
    preds = []
    scores = []
    for tgt_cand_iri in tgt_cands:
        tgt_cand_labels = process_labels(tgt_annotation_index[tgt_cand_iri])
        input_ids = tokenizer(v_text(src_labels, tgt_cand_labels), return_tensors="pt").input_ids.to("cuda")
        outputs = model.generate(input_ids, max_new_tokens=3, return_dict_in_generate=True, output_scores=True)
        transition_scores = model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=True
        )
        input_length = 1 if model.config.is_encoder_decoder else input_ids.shape[1]
        generated_tokens = outputs.sequences[:, input_length:]
        has_answer = False
        for tok, score in zip(generated_tokens[0], transition_scores[0]):
            # | token | token string | logits | probability
            score = np.exp(score.cpu().numpy())
            tok = tokenizer.decode(tok)
            # probs.append((tok, score))
            if "Yes" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                if score > 0.55:
                    preds.append(tgt_cand_iri)
                scores.append((tgt_cand_iri, "Yes", score))
                has_answer = True
                break
            if "No" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                scores.append((tgt_cand_iri, "No", score))
                has_answer = True
                break
            if not has_answer:
                # if no yes or no, giving the worst score
                scores.append((tgt_cand_iri, "No", 1.0))
    if not preds:
        preds += ["UnMatched"]
    correct = tgt_iri in preds
    corrects.append(correct)
    results.append(scores)
    print(correct, preds)

True ['http://purl.obolibrary.org/obo/DOID_3747']
True ['http://purl.obolibrary.org/obo/DOID_2490']
True ['http://purl.obolibrary.org/obo/DOID_5047']
True ['http://purl.obolibrary.org/obo/DOID_11713']
True ['http://purl.obolibrary.org/obo/DOID_7578']
True ['http://purl.obolibrary.org/obo/DOID_7689', 'http://purl.obolibrary.org/obo/DOID_6086']
True ['http://purl.obolibrary.org/obo/DOID_6511']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
True ['http://purl.obolibrary.org/obo/DOID_6228']
False ['UnMatched']
False ['http://purl.obolibrary.org/obo/DOID_7505']
False ['UnMatched']
True ['http://purl.obolibrary.org/obo/DOID_7459']
True ['http://purl.obolibrary.org/obo/DOID_13372']
True ['http://purl.obolibrary.org/obo/DOID_0060111']
False ['UnMatched']
True ['http://purl.obolibrary.org/obo/DOID_12270']
True ['http://purl.obolibrary.org/obo/DOID_5637']
True ['http://purl.obolibrary.org/obo/DOID_12318']
True ['http://purl.obolibrary.org/obo/DOID_3646']
True ['h

In [76]:
def ranking(scores):
    # rank the "yes" candidates according to their scores
    # the "yes" candidates is always before the "no" candidates
    # then rank the "no" candidates according to 1.0 - their scores
    yes = []
    no = []
    for tgt_iri, answer, score in scores:
        if answer == "Yes":
            yes.append((tgt_iri, "Yes", score))
        else:
            no.append((tgt_iri, "No", 1.0 - score))
    yes = list(sorted(yes, key=lambda x: x[2], reverse=True))
    no = list(sorted(no, key=lambda x: x[2], reverse=True))
    return yes + no

In [77]:
results_dict = dict()
for i, scores in enumerate(results):
    results_dict[test_cands.iloc[i]["SrcEntity"]] = ranking(scores)

In [79]:
FileUtils.save_file(results_dict, "flan_t5_ncit2doid_results.pkl")

In [80]:
ranked_results = FileUtils.load_file("flan_t5_ncit2doid_results.pkl")

In [83]:
for k, v in ranked_results.items():
    assert len(ranked_results) == 100

In [90]:
# compute Precision, Recall
preds = []
for src, tgts in ranked_results.items():
    for t, answer, score in tgts:
        mapping = EntityMapping(src, t, "=", score)
        if answer == "Yes":
            preds.append(mapping)
            break
        # if score > 0.5:
        #     preds.append(mapping)
        # break
        # if score > 0.65:
        #     mapping = EntityMapping(test_cands.iloc[i]["SrcEntity"], t_iri, "=", score)
        #     preds.append(mapping)

In [91]:
refs = ReferenceMapping.read_table_mappings("data/ncit2doid/refs/test_refs.tsv")[:50]

In [92]:
AlignmentEvaluator.f1(preds, refs, [])

{'P': 0.736, 'R': 0.78, 'F1': 0.757}

In [89]:
preds

[EntityMapping(http://ncicb.nci.nih.gov/xml/owl/EVS/Thesaurus.owl#C27420 = http://purl.obolibrary.org/obo/DOID_3747, 0.9185360074043274),
 EntityMapping(http://ncicb.nci.nih.gov/xml/owl/EVS/Thesaurus.owl#C97172 = http://purl.obolibrary.org/obo/DOID_2490, 0.9413300156593323),
 EntityMapping(http://ncicb.nci.nih.gov/xml/owl/EVS/Thesaurus.owl#C7017 = http://purl.obolibrary.org/obo/DOID_5047, 0.8481510281562805),
 EntityMapping(http://ncicb.nci.nih.gov/xml/owl/EVS/Thesaurus.owl#C35610 = http://purl.obolibrary.org/obo/DOID_11713, 0.8025959730148315),
 EntityMapping(http://ncicb.nci.nih.gov/xml/owl/EVS/Thesaurus.owl#C7362 = http://purl.obolibrary.org/obo/DOID_7578, 0.7001780271530151),
 EntityMapping(http://ncicb.nci.nih.gov/xml/owl/EVS/Thesaurus.owl#C8312 = http://purl.obolibrary.org/obo/DOID_7689, 0.9076669812202454),
 EntityMapping(http://ncicb.nci.nih.gov/xml/owl/EVS/Thesaurus.owl#C8312 = http://purl.obolibrary.org/obo/DOID_6086, 0.7143499851226807),
 EntityMapping(http://ncicb.nci.nih.g