In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from deeponto.onto import Ontology
from deeponto.align.bertmap import BERTMapPipeline
from deeponto.align.evaluation import AlignmentEvaluator
from deeponto.utils import FileUtils
from deeponto.align.mapping import EntityMapping, ReferenceMapping
import pandas as pd
import random
import numpy as np

Please enter the maximum memory located to JVM [8g]:
8g maximum memory allocated to JVM.
JVM started successfully.


In [2]:
# load source and target ontologies
src_onto_path = "data/snomed2fma/snomed.body.owl"
tgt_onto_path = "data/snomed2fma/fma.body.owl"
src_onto = Ontology(src_onto_path)
tgt_onto = Ontology(tgt_onto_path)
config = BERTMapPipeline.load_bertmap_config()

[main] WARN uk.ac.manchester.cs.owl.owlapi.OWLOntologyManagerImpl - Illegal redeclarations of entities: reuse of entity http://purl.org/sig/ont/fma/has_direct_shape_type in punning not allowed [Declaration(DataProperty(<http://purl.org/sig/ont/fma/has_direct_shape_type>)), Declaration(ObjectProperty(<http://purl.org/sig/ont/fma/has_direct_shape_type>))]


Use the default configuration at /home/yuan/anaconda3/envs/deeponto/lib/python3.8/site-packages/deeponto/align/bertmap/default_config.yaml.


In [3]:
# build annotation index {class_iri: class_labels}
src_annotation_index, _ = src_onto.build_annotation_index(config.annotation_property_iris)
tgt_annotation_index, _ = tgt_onto.build_annotation_index(config.annotation_property_iris)

In [4]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xxl", device_map="auto")

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [5]:
test_cands = FileUtils.read_table("data/snomed2fma/test_cands.tsv")

In [6]:
# controlling input labels because of the window size limit of T5
def process_labels(labels, cut_off = 3):
    labels = list(labels)
    if len(labels) >= cut_off:
        labels.sort(key=len, reverse=True)
    return labels[:cut_off]

In [7]:
def v_text(src_labels, tgt_labels, src_parent_labels=None, tgt_parent_labels=None, src_child_labels=None, tgt_child_labels=None):
    v = f"Source Concept:\n"
    v += f"Names: {src_labels}\n"
    has_parent_child = False
    if src_parent_labels:
        v += f"Parent Concepts: {src_parent_labels}\n"
        has_parent_child = True
    if src_child_labels:
        v += f"Child Concepts: {src_child_labels}\n"
        has_parent_child = True
    v +="\n"
    v += f"Target Concept:\n"
    v += f"Names: {tgt_labels}\n"
    if tgt_parent_labels:
        v += f"Parent Concepts: {tgt_parent_labels}\n"
        has_parent_child = True
    if tgt_child_labels:
        v += f"Child Concepts: {tgt_child_labels}\n"
        has_parent_child = True
    v +="\n"
    if not has_parent_child:
        v = "Consider two concepts, each represented by a list of names.\n\n" + v
        v += "Based on the provided names, determine if the Source Concept and the Target Concept are identical. Please respond with either 'Yes' (if they are identical) or 'No' (if they are not identical)."
    else:
        v = "Consider two concepts, each represented by a list of names, and associated with additional contexts.\n\n" + v
        v += "Based on the provided names, parent and child concepts (if any), determine if the Source Concept and the Target Concept are identical. Please respond with either 'Yes' (if they are identical) or 'No' (if they are not identical)."
    return v

print(v_text(["A"], ["B"], ["P_A1", "P_A2"], ["P_B1"], ["C_A1"], ["C_B1"]))

Consider two concepts, each represented by a list of names, and associated with additional contexts.

Source Concept:
Names: ['A']
Parent Concepts: ['P_A1', 'P_A2']
Child Concepts: ['C_A1']

Target Concept:
Names: ['B']
Parent Concepts: ['P_B1']
Child Concepts: ['C_B1']

Based on the provided names, parent and child concepts (if any), determine if the Source Concept and the Target Concept are identical. Please respond with either 'Yes' (if they are identical) or 'No' (if they are not identical).


In [8]:
# vanila setting (no structural context)
results = []
corrects = []
for i, dp in test_cands.iterrows():
    src_iri = dp["SrcEntity"]
    tgt_iri = dp["TgtEntity"]
    tgt_cands = eval(dp["TgtCandidates"])
    src_labels = process_labels(src_annotation_index[src_iri])
    preds = []
    scores = []
    for tgt_cand_iri in tgt_cands:
        tgt_cand_labels = process_labels(tgt_annotation_index[tgt_cand_iri])
        input_ids = tokenizer(v_text(src_labels, tgt_cand_labels), return_tensors="pt").input_ids.to("cuda")
        outputs = model.generate(input_ids, max_new_tokens=3, return_dict_in_generate=True, output_scores=True)
        transition_scores = model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=True
        )
        input_length = 1 if model.config.is_encoder_decoder else input_ids.shape[1]
        generated_tokens = outputs.sequences[:, input_length:]
        has_answer = False
        for tok, score in zip(generated_tokens[0], transition_scores[0]):
            # | token | token string | logits | probability
            score = np.exp(score.cpu().numpy())
            tok = tokenizer.decode(tok)
            # probs.append((tok, score))
            if "Yes" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                if score > 0.55:
                    preds.append(tgt_cand_iri)
                scores.append((tgt_cand_iri, "Yes", score))
                has_answer = True
                break
            if "No" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                scores.append((tgt_cand_iri, "No", score))
                has_answer = True
                break
            if not has_answer:
                # if no yes or no, giving the worst score
                scores.append((tgt_cand_iri, "No", 1.0))
    if not preds:
        preds += ["UnMatched"]
    correct = tgt_iri in preds
    corrects.append(correct)
    results.append(scores)
    print(correct, preds)

True ['http://purl.org/sig/ont/fma/fma16072']
True ['http://purl.org/sig/ont/fma/fma83945']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma71410']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma47200']
True ['http://purl.org/sig/ont/fma/fma6652']
True ['http://purl.org/sig/ont/fma/fma323351']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma71766', 'http://purl.org/sig/ont/fma/fma5950']
True ['http://purl.org/sig/ont/fma/fma8661']
False ['UnMatched']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma61020']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma75189', 'http://purl.org/sig/ont/fma/fma19618', 'http://purl.org/sig/ont/fma/fma20281']
False ['http://purl.org/sig/ont/fma/fma75189', 'http://purl.org/sig/ont/fma/fma19618', 'http://purl.org/sig/ont/fma/fma20281']
True ['http://purl.org/sig/ont/fma/fma13715', 'http://purl.org/sig/ont/fma/fma13707']
True ['http://purl.org/sig/ont/fma/fma24139']
Tru

In [9]:
def ranking(scores):
    # rank the "yes" candidates according to their scores
    # the "yes" candidates is always before the "no" candidates
    # then rank the "no" candidates according to 1.0 - their scores
    yes = []
    no = []
    for tgt_iri, answer, score in scores:
        if answer == "Yes":
            yes.append((tgt_iri, "Yes", score))
        else:
            no.append((tgt_iri, "No", 1.0 - score))
    yes = list(sorted(yes, key=lambda x: x[2], reverse=True))
    no = list(sorted(no, key=lambda x: x[2], reverse=True))
    return yes + no

In [10]:
# uncomment when saving results
results_dict = dict()
for i, scores in enumerate(results):
    results_dict[test_cands.iloc[i]["SrcEntity"], test_cands.iloc[i]["TgtEntity"]] = ranking(scores)

FileUtils.save_file(results_dict, "flan_t5_snomed2fma_results.pkl")
ranked_results = FileUtils.load_file("flan_t5_snomed2fma_results.pkl")
# the first 50 mappings are the matched mappings
refs = ReferenceMapping.read_table_mappings("data/snomed2fma/refs/test_refs.tsv")[:50]

In [11]:
import math
def mean_reciprocal_rank(prediction_and_candidates):
    r"""Compute $MRR$ for a list of `(prediction_mapping, candidate_mappings)` pair.

    $$MRR = \sum_i^N rank_i^{-1} / N$$
    """
    sum_inverted_ranks = 0
    for pred, cands in prediction_and_candidates:
        ordered_candidates = [c.to_tuple() for c in EntityMapping.sort_entity_mappings_by_score(cands)]
        if pred.to_tuple() in ordered_candidates:
            rank = ordered_candidates.index(pred.to_tuple()) + 1
        else:
            rank = math.inf
        sum_inverted_ranks += 1 / rank
    return sum_inverted_ranks / len(prediction_and_candidates)


In [12]:
# vanilla top1

# Precision, Recall, F1
preds = []
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    mapping = EntityMapping(src, t, "=", score)
    if answer == "Yes":
        preds.append(mapping)
print(AlignmentEvaluator.f1(preds, refs, []))

# Accuracy
yes_correct = 0
no_correct = 0
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    # print(tgt_cands[0])
    if answer == "Yes" and tgt == t:
        yes_correct += 1
    elif answer == "No" and tgt == "UnMatched":
        no_correct += 1
print(yes_correct, no_correct, yes_correct + no_correct)

# MRR
formatted_results = []
for (src, tgt), tgt_cands in ranked_results.items():
    ref_mapping = EntityMapping(src, tgt, "=", 1.0)
    cand_mappings = [EntityMapping(src, t, "=", score) for t, _, score in tgt_cands]
    formatted_results.append((ref_mapping, cand_mappings))
# again, only the first 50 has a match
print(mean_reciprocal_rank(formatted_results[:50]))

{'P': 0.479, 'R': 0.46, 'F1': 0.469}
26 35 61
0.8011904761904763


In [15]:
# compute Precision, Recall
# vanilla top1 + threshold
preds = []
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    mapping = EntityMapping(src, t, "=", score)
    if answer == "Yes":
        if score > 0.55:
            preds.append(mapping)
print(AlignmentEvaluator.f1(preds, refs, []))

# Accuracy
yes_correct = 0
no_correct = 0
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    # print(tgt_cands[0])
    if answer == "Yes" and tgt == t and score > 0.55:
        yes_correct += 1
    elif answer == "No" and tgt == "UnMatched":
        no_correct += 1
print(yes_correct, no_correct, yes_correct + no_correct)

# MRR
formatted_results = []
for (src, tgt), tgt_cands in ranked_results.items():
    ref_mapping = EntityMapping(src, tgt, "=", 1.0)
    cand_mappings = [EntityMapping(src, t, "=", score) for t, _, score in tgt_cands]
    formatted_results.append((ref_mapping, cand_mappings))
# again, only the first 50 has a match
print(mean_reciprocal_rank(formatted_results[:50]))

{'P': 0.5, 'R': 0.4, 'F1': 0.444}
23 35 58
0.8011904761904763


Now, run the model again with structural context

In [17]:
def get_parent_labels(ontology: Ontology, annotation_index, class_iri, cutoff = 3):
    concept = ontology.get_owl_object_from_iri(class_iri)
    concept_parents = ontology.get_asserted_parents(concept)
    # concept_parent_iris = [str(p.getIRI()) for p in concept_parents]
    concept_parent_labels = []
    for p in concept_parents:
        try:
            p_iri = str(p.getIRI())
            concept_parent_labels += process_labels(annotation_index[p_iri])
        except:
            continue
    concept_parent_labels = set(concept_parent_labels)
    if len(concept_parent_labels) > cutoff:
        return list(random.sample(concept_parent_labels, k=cutoff))
    else:
        return list(concept_parent_labels)
    
def get_child_labels(ontology, annotation_index, class_iri, cutoff = 3):
    concept = ontology.get_owl_object_from_iri(class_iri)
    concept_children = ontology.get_asserted_children(concept)
    # concept_children_iris = [str(p.getIRI()) for p in concept_children]
    concept_children_labels = []
    for c in concept_children:
        try: 
            c_iri = str(c.getIRI())
            concept_children_labels += process_labels(annotation_index[c_iri])
        except:
            continue
    concept_children_labels = set(concept_children_labels)
    if len(concept_children_labels) > cutoff:
        return list(random.sample(concept_children_labels, k=cutoff))
    else:
        return list(concept_children_labels)

In [18]:
# with structural context
results = []
corrects = []
for i, dp in test_cands.iterrows():
    src_iri = dp["SrcEntity"]
    src_parents = get_parent_labels(src_onto, src_annotation_index, src_iri)
    src_children = get_child_labels(src_onto, src_annotation_index, src_iri)
    tgt_iri = dp["TgtEntity"]
    tgt_cands = eval(dp["TgtCandidates"])
    src_labels = process_labels(src_annotation_index[src_iri])
    preds = []
    scores = []
    for tgt_cand_iri in tgt_cands:
        tgt_cand_labels = process_labels(tgt_annotation_index[tgt_cand_iri])
        tgt_cand_parents = get_parent_labels(tgt_onto, tgt_annotation_index, tgt_cand_iri)
        tgt_cand_children = get_child_labels(tgt_onto, tgt_annotation_index, tgt_cand_iri)
        # print(v_text(src_labels, tgt_cand_labels, src_parents, tgt_cand_parents, src_children, tgt_cand_children))
        # continue
        input_ids = tokenizer(v_text(src_labels, tgt_cand_labels, src_parents, tgt_cand_parents, src_children, tgt_cand_children), return_tensors="pt").input_ids.to("cuda")
        outputs = model.generate(input_ids, max_new_tokens=3, return_dict_in_generate=True, output_scores=True)
        transition_scores = model.compute_transition_scores(
            outputs.sequences, outputs.scores, normalize_logits=True
        )
        input_length = 1 if model.config.is_encoder_decoder else input_ids.shape[1]
        generated_tokens = outputs.sequences[:, input_length:]
        has_answer = False
        for tok, score in zip(generated_tokens[0], transition_scores[0]):
            # | token | token string | logits | probability
            score = np.exp(score.cpu().numpy())
            tok = tokenizer.decode(tok)
            # probs.append((tok, score))
            if "Yes" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                if score > 0.55:
                    preds.append(tgt_cand_iri)
                scores.append((tgt_cand_iri, "Yes", score))
                has_answer = True
                break
            if "No" in tok:
                # print(f"| {tok:8s} | {score:.2%}")
                scores.append((tgt_cand_iri, "No", score))
                has_answer = True
                break
            if not has_answer:
                # if no yes or no, giving the worst score
                scores.append((tgt_cand_iri, "No", 1.0))
    if not preds:
        preds += ["UnMatched"]
    correct = tgt_iri in preds
    corrects.append(correct)
    results.append(scores)
    print(correct, preds)

True ['http://purl.org/sig/ont/fma/fma16072']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma47200']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma323351']
False ['UnMatched']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma8661']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma84245']
True ['http://purl.org/sig/ont/fma/fma61020']
False ['UnMatched']
True ['http://purl.org/sig/ont/fma/fma75189', 'http://purl.org/sig/ont/fma/fma19618']
False ['http://purl.org/sig/ont/fma/fma75189', 'http://purl.org/sig/ont/fma/fma19618']
True ['http://purl.org/sig/ont/fma/fma13715']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['UnMatched']
False ['http://purl.org/sig/ont/fma/fma15767']
True ['http://purl.org/sig/ont/fma/fma

In [19]:
results_dict = dict()
for i, scores in enumerate(results):
    results_dict[test_cands.iloc[i]["SrcEntity"], test_cands.iloc[i]["TgtEntity"]] = ranking(scores)

FileUtils.save_file(results_dict, "flan_t5_snomed2fma_results_context.pkl")
ranked_results = FileUtils.load_file("flan_t5_snomed2fma_results_context.pkl")

In [20]:
# compute Precision, Recall
preds = []
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    mapping = EntityMapping(src, t, "=", score)
    if answer == "Yes":
        preds.append(mapping)
print(AlignmentEvaluator.f1(preds, refs, []))

# Accuracy
yes_correct = 0
no_correct = 0
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    # print(tgt_cands[0])
    if answer == "Yes" and tgt == t:
        yes_correct += 1
    elif answer == "No" and tgt == "UnMatched":
        no_correct += 1
print(yes_correct, no_correct, yes_correct + no_correct)

# MRR
formatted_results = []
for (src, tgt), tgt_cands in ranked_results.items():
    ref_mapping = EntityMapping(src, tgt, "=", 1.0)
    cand_mappings = [EntityMapping(src, t, "=", score) for t, _, score in tgt_cands]
    formatted_results.append((ref_mapping, cand_mappings))
# again, only the first 50 has a match
print(mean_reciprocal_rank(formatted_results[:50]))

{'P': 0.5, 'R': 0.26, 'F1': 0.342}
17 46 63
0.7254080267558528


In [21]:
# compute Precision, Recall
preds = []
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    mapping = EntityMapping(src, t, "=", score)
    if answer == "Yes":
        if score > 0.65:
            preds.append(mapping)
print(AlignmentEvaluator.f1(preds, refs, []))

# Accuracy
yes_correct = 0
no_correct = 0
for (src, tgt), tgt_cands in ranked_results.items():
    t, answer, score = tgt_cands[0]
    # print(tgt_cands[0])
    if answer == "Yes" and tgt == t and score > 0.65:
        yes_correct += 1
    elif answer == "No" and tgt == "UnMatched":
        no_correct += 1
print(yes_correct, no_correct, yes_correct + no_correct)

# MRR
formatted_results = []
for (src, tgt), tgt_cands in ranked_results.items():
    ref_mapping = EntityMapping(src, tgt, "=", 1.0)
    cand_mappings = [EntityMapping(src, t, "=", score) for t, _, score in tgt_cands]
    formatted_results.append((ref_mapping, cand_mappings))
# again, only the first 50 has a match
print(mean_reciprocal_rank(formatted_results[:50]))

{'P': 0.667, 'R': 0.2, 'F1': 0.308}
11 46 57
0.7254080267558528
