In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from deeponto.onto import Ontology
from deeponto.align.bertmap import BERTMapPipeline
from deeponto.align.evaluation import AlignmentEvaluator
from deeponto.utils import FileUtils
from deeponto.align.mapping import EntityMapping, ReferenceMapping
import pandas as pd
import random
import numpy as np

In [None]:
# load source and target ontologies
src_onto_path = "data/ncit2doid/ncit.owl"
tgt_onto_path = "data/ncit2doid/doid.owl"
src_onto = Ontology(src_onto_path)
tgt_onto = Ontology(tgt_onto_path)
config = BERTMapPipeline.load_bertmap_config()

In [None]:
# build annotation index {class_iri: class_labels}
src_annotation_index, _ = src_onto.build_annotation_index(config.annotation_property_iris)
tgt_annotation_index, _ = tgt_onto.build_annotation_index(config.annotation_property_iris)

In [None]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto")

In [None]:
test_cands = FileUtils.read_table("data/ncit2doid/test_cands.tsv")

In [None]:
# controlling input labels because of the window size limit of T5
def process_labels(labels, cut_off = 3):
    labels = list(labels)
    if len(labels) >= cut_off:
        labels.sort(key=len, reverse=True)
    return labels[:cut_off]

In [None]:
def v_text(src_labels, tgt_labels, src_parent_labels=None, tgt_parent_labels=None, src_child_labels=None, tgt_child_labels=None):
    v = "Consider two concepts, each represented by a list of associated names.\n\n"
    v += f"Source concept: {src_labels}.\n"
    if src_parent_labels:
        v = v[:-1] # removing new line character
        v += f" Its parent concepts: {src_parent_labels}.\n"
    if src_child_labels:
        v = v[:-1] # removing new line character
        v += f" Its child concepts: {src_child_labels}.\n"
    v += f"Target concept: {tgt_labels}.\n\n"
    if tgt_parent_labels:
        v = v[:-2]
        v += f" Its parent concepts: {tgt_parent_labels}.\n\n"
    if tgt_child_labels:
        v = v[:-2]
        v += f" Its child concepts: {tgt_child_labels}.\n\n"
    v += "Given these representations, can you determine if the source concept and the target concept are identical? Please answer with \"Yes\" if they are identical or \"No\" if they are not."
    return v

print(v_text(["A"], ["B"], ["P_A1", "P_A2"], ["P_B1"], ["C_A1"], ["C_B1"]))

In [None]:
# vanila setting (no structural context)
results = []
corrects = []
for i, dp in test_cands.iterrows():
    src_iri = dp["SrcEntity"]
    tgt_iri = dp["TgtEntity"]
    tgt_cands = eval(dp["TgtCandidates"])
    src_labels = process_labels(src_annotation_index[src_iri])
    preds = []
    scores = []
    for tgt_cand_iri in tgt_cands:
        tgt_cand_labels = process_labels(tgt_annotation_index[tgt_cand_iri])
        input_ids = tokenizer(v_text(src_labels, tgt_cand_labels), return_tensors="pt").input_ids.to("cuda")
        outputs = model.generate(input_ids, max_new_tokens=3, return_dict_in_generate=True, output_scores=True)
        transition_scores = model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=True
        )
        input_length = 1 if model.config.is_encoder_decoder else input_ids.shape[1]
        generated_tokens = outputs.sequences[:, input_length:]
        has_answer = False
        for tok, score in zip(generated_tokens[0], transition_scores[0]):
            # | token | token string | logits | probability
            score = np.exp(score.cpu().numpy())
            tok = tokenizer.decode(tok)
            # probs.append((tok, score))
            if "Yes" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                if score > 0.55:
                    preds.append(tgt_cand_iri)
                scores.append((tgt_cand_iri, "Yes", score))
                has_answer = True
                break
            if "No" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                scores.append((tgt_cand_iri, "No", score))
                has_answer = True
                break
            if not has_answer:
                # if no yes or no, giving the worst score
                scores.append((tgt_cand_iri, "No", 1.0))
    if not preds:
        preds += ["UnMatched"]
    correct = tgt_iri in preds
    corrects.append(correct)
    results.append(scores)
    print(correct, preds)

In [None]:
def ranking(scores):
    # rank the "yes" candidates according to their scores
    # the "yes" candidates is always before the "no" candidates
    # then rank the "no" candidates according to 1.0 - their scores
    yes = []
    no = []
    for tgt_iri, answer, score in scores:
        if answer == "Yes":
            yes.append((tgt_iri, "Yes", score))
        else:
            no.append((tgt_iri, "No", 1.0 - score))
    yes = list(sorted(yes, key=lambda x: x[2], reverse=True))
    no = list(sorted(no, key=lambda x: x[2], reverse=True))
    return yes + no

In [171]:
# uncomment when saving results
# results_dict = dict()
# for i, scores in enumerate(results):
#     results_dict[test_cands.iloc[i]["SrcEntity"], test_cands.iloc[i]["TgtEntity"]] = ranking(scores)

# FileUtils.save_file(results_dict, "flan_t5_ncit2doid_results.pkl")
ranked_results = FileUtils.load_file("flan_t5_ncit2doid_results.pkl")
# the first 50 mappings are the matched mappings
refs = ReferenceMapping.read_table_mappings("data/ncit2doid/refs/test_refs.tsv")[:50]

In [160]:
import math
def mean_reciprocal_rank(prediction_and_candidates):
    r"""Compute $MRR$ for a list of `(prediction_mapping, candidate_mappings)` pair.

    $$MRR = \sum_i^N rank_i^{-1} / N$$
    """
    sum_inverted_ranks = 0
    for pred, cands in prediction_and_candidates:
        ordered_candidates = [c.to_tuple() for c in EntityMapping.sort_entity_mappings_by_score(cands)]
        if pred.to_tuple() in ordered_candidates:
            rank = ordered_candidates.index(pred.to_tuple()) + 1
        else:
            rank = math.inf
        sum_inverted_ranks += 1 / rank
    return sum_inverted_ranks / len(prediction_and_candidates)


In [174]:
# vanilla top1

# Precision, Recall, F1
preds = []
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    mapping = EntityMapping(src, t, "=", score)
    if answer == "Yes":
        preds.append(mapping)
print(AlignmentEvaluator.f1(preds, refs, []))

# Accuracy
yes_correct = 0
no_correct = 0
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    # print(tgt_cands[0])
    if answer == "Yes" and tgt == t:
        yes_correct += 1
    elif answer == "No" and tgt == "UnMatched":
        no_correct += 1
print(yes_correct, no_correct, yes_correct + no_correct)

# MRR
formatted_results = []
for (src, tgt), tgt_cands in ranked_results.items():
    ref_mapping = EntityMapping(src, tgt, "=", 1.0)
    cand_mappings = [EntityMapping(src, t, "=", score) for t, _, score in tgt_cands]
    formatted_results.append((ref_mapping, cand_mappings))
# again, only the first 50 has a match
print(mean_reciprocal_rank(formatted_results[:50]))

{'P': 0.736, 'R': 0.78, 'F1': 0.757}
39 39 78
0.9533333333333335


In [173]:
# compute Precision, Recall
# vanilla top1 + threshold
preds = []
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    mapping = EntityMapping(src, t, "=", score)
    if answer == "Yes":
        if score > 0.65:
            preds.append(mapping)
print(AlignmentEvaluator.f1(preds, refs, []))

# Accuracy
yes_correct = 0
no_correct = 0
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    # print(tgt_cands[0])
    if answer == "Yes" and tgt == t and score > 0.65:
        yes_correct += 1
    elif answer == "No" and tgt == "UnMatched":
        no_correct += 1
print(yes_correct, no_correct, yes_correct + no_correct)

# MRR
formatted_results = []
for (src, tgt), tgt_cands in ranked_results.items():
    ref_mapping = EntityMapping(src, tgt, "=", 1.0)
    cand_mappings = [EntityMapping(src, t, "=", score) for t, _, score in tgt_cands]
    formatted_results.append((ref_mapping, cand_mappings))
# again, only the first 50 has a match
print(mean_reciprocal_rank(formatted_results[:50]))

{'P': 0.844, 'R': 0.76, 'F1': 0.8}
38 39 77
0.9533333333333335


Now, run the model again with structural context

In [None]:
def get_parent_labels(ontology: Ontology, annotation_index, class_iri, cutoff = 3):
    concept = ontology.get_owl_object_from_iri(class_iri)
    concept_parent_iris = ontology.reasoner.get_inferred_super_entities(concept, direct=True)
    # concept_parent_iris = [str(p.getIRI()) for p in concept_parents]
    concept_parent_labels = []
    for p_iri in concept_parent_iris:
        concept_parent_labels += process_labels(annotation_index[p_iri])
    concept_parent_labels = set(concept_parent_labels)
    if len(concept_parent_labels) > cutoff:
        return list(random.sample(concept_parent_labels, k=cutoff))
    else:
        return list(concept_parent_labels)
    
def get_child_labels(ontology, annotation_index, class_iri, cutoff = 3):
    concept = ontology.get_owl_object_from_iri(class_iri)
    concept_children_iris = ontology.reasoner.get_inferred_sub_entities(concept, direct=True)
    # concept_children_iris = [str(p.getIRI()) for p in concept_children]
    concept_children_labels = []
    for p_iri in concept_children_iris:
        concept_children_labels += process_labels(annotation_index[p_iri])
    concept_children_labels = set(concept_children_labels)
    if len(concept_children_labels) > cutoff:
        return list(random.sample(concept_children_labels, k=cutoff))
    else:
        return list(concept_children_labels)

In [None]:
# vanila setting (no structural context)
results = []
corrects = []
for i, dp in test_cands.iterrows():
    src_iri = dp["SrcEntity"]
    src_parents = get_parent_labels(src_onto, src_annotation_index, src_iri)
    src_children = get_child_labels(src_onto, src_annotation_index, src_iri)
    tgt_iri = dp["TgtEntity"]
    tgt_cands = eval(dp["TgtCandidates"])
    src_labels = process_labels(src_annotation_index[src_iri])
    preds = []
    scores = []
    for tgt_cand_iri in tgt_cands:
        tgt_cand_labels = process_labels(tgt_annotation_index[tgt_cand_iri])
        tgt_cand_parents = get_parent_labels(tgt_onto, tgt_annotation_index, tgt_cand_iri)
        tgt_cand_children = get_child_labels(tgt_onto, tgt_annotation_index, tgt_cand_iri)
        # print(v_text(src_labels, tgt_cand_labels, src_parents, tgt_cand_parents, src_children, tgt_cand_children))
        # continue
        input_ids = tokenizer(v_text(src_labels, tgt_cand_labels, src_parents, tgt_cand_parents, src_children, tgt_cand_children), return_tensors="pt").input_ids.to("cuda")
        outputs = model.generate(input_ids, max_new_tokens=3, return_dict_in_generate=True, output_scores=True)
        transition_scores = model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=True
        )
        input_length = 1 if model.config.is_encoder_decoder else input_ids.shape[1]
        generated_tokens = outputs.sequences[:, input_length:]
        has_answer = False
        for tok, score in zip(generated_tokens[0], transition_scores[0]):
            # | token | token string | logits | probability
            score = np.exp(score.cpu().numpy())
            tok = tokenizer.decode(tok)
            # probs.append((tok, score))
            if "Yes" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                if score > 0.55:
                    preds.append(tgt_cand_iri)
                scores.append((tgt_cand_iri, "Yes", score))
                has_answer = True
                break
            if "No" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                scores.append((tgt_cand_iri, "No", score))
                has_answer = True
                break
            if not has_answer:
                # if no yes or no, giving the worst score
                scores.append((tgt_cand_iri, "No", 1.0))
    if not preds:
        preds += ["UnMatched"]
    correct = tgt_iri in preds
    corrects.append(correct)
    results.append(scores)
    print(correct, preds)

In [163]:
# results_dict = dict()
# for i, scores in enumerate(results):
#     results_dict[test_cands.iloc[i]["SrcEntity"], test_cands.iloc[i]["TgtEntity"]] = ranking(scores)

# FileUtils.save_file(results_dict, "flan_t5_ncit2doid_results_context.pkl")
ranked_results = FileUtils.load_file("flan_t5_ncit2doid_results_context.pkl")

In [170]:
# compute Precision, Recall
preds = []
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    mapping = EntityMapping(src, t, "=", score)
    if answer == "Yes":
        preds.append(mapping)
print(AlignmentEvaluator.f1(preds, refs, []))

# Accuracy
yes_correct = 0
no_correct = 0
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    # print(tgt_cands[0])
    if answer == "Yes" and tgt == t:
        yes_correct += 1
    elif answer == "No" and tgt == "UnMatched":
        no_correct += 1
print(yes_correct, no_correct, yes_correct + no_correct)

# MRR
formatted_results = []
for (src, tgt), tgt_cands in ranked_results.items():
    ref_mapping = EntityMapping(src, tgt, "=", 1.0)
    cand_mappings = [EntityMapping(src, t, "=", score) for t, _, score in tgt_cands]
    formatted_results.append((ref_mapping, cand_mappings))
# again, only the first 50 has a match
print(mean_reciprocal_rank(formatted_results[:50]))

{'P': 0.789, 'R': 0.6, 'F1': 0.682}
30 45 75
0.9416666666666668


In [169]:
# compute Precision, Recall
preds = []
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    mapping = EntityMapping(src, t, "=", score)
    if answer == "Yes":
        if score > 0.65:
            preds.append(mapping)
print(AlignmentEvaluator.f1(preds, refs, []))

# Accuracy
yes_correct = 0
no_correct = 0
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    # print(tgt_cands[0])
    if answer == "Yes" and tgt == t and score > 0.65:
        yes_correct += 1
    elif answer == "No" and tgt == "UnMatched":
        no_correct += 1
print(yes_correct, no_correct, yes_correct + no_correct)

# MRR
formatted_results = []
for (src, tgt), tgt_cands in ranked_results.items():
    ref_mapping = EntityMapping(src, tgt, "=", 1.0)
    cand_mappings = [EntityMapping(src, t, "=", score) for t, _, score in tgt_cands]
    formatted_results.append((ref_mapping, cand_mappings))
# again, only the first 50 has a match
print(mean_reciprocal_rank(formatted_results[:50]))

{'P': 0.909, 'R': 0.4, 'F1': 0.556}
20 45 65
0.9416666666666668
