In [1]:
# Construction of dataset

import os, itertools, time, pickle, sys, glob, requests
import subprocess
from xml.dom import minidom
from collections import Counter, OrderedDict
from operator import itemgetter
from nltk.corpus import wordnet
import tensorflow as tf
import tensorflow_hub as hub
from scipy import spatial
from sklearn.metrics import precision_score, accuracy_score, recall_score, f1_score
from sklearn.feature_extraction.text import TfidfVectorizer
import re
import numpy as np
import scipy.sparse as sp
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from math import ceil, exp
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import networkx as nx
import matplotlib.pyplot as plt
from sentence_transformers import models, SentenceTransformer
%matplotlib inline  

In [2]:
flatten = lambda l: [item for sublist in l for item in sublist]

class Ontology():
    def __init__(self, ontology):
        self.ontology = ontology
        self.ontology_obj = minidom.parse(ontology)
        self.root = self.ontology_obj.documentElement
        self.construct_mapping_dict()
        
        self.parents_dict = {}
        self.subclasses = self.parse_subclasses()
        self.object_properties = self.parse_object_properties()
        self.data_properties = self.parse_data_properties()
        self.triples = self.parse_triples()
        self.classes = self.parse_classes()        
    
    def construct_mapping_dict(self):
        self.mapping_dict = {self.extract_ID(el, False): self.get_child_node(el, "rdfs:label")[0].firstChild.nodeValue for el in self.root.getElementsByTagName("owl:Class") if self.get_child_node(el, "rdfs:label")}
        self.mapping_dict_inv = {self.mapping_dict[key]: key for key in self.mapping_dict}
        return
        
    def get_child_node(self, element, tag):
        return [e for e in element._get_childNodes() if type(e)==minidom.Element and e._get_tagName() == tag]
        
    def has_attribute_value(self, element, attribute, value):
        return True if element.getAttribute(attribute).split("#")[-1] == value else False
    
    def get_subclass_triples(self):
        subclasses = self.get_subclasses()
        for (a,b,c) in subclasses:
            if c == "subclass_of" and a!="Thing" and b!="Thing":
                if b not in self.parents_dict:
                    self.parents_dict[b] = [a]
                else:
                    self.parents_dict[b].append(a)
        return [(b,a,c) for (a,b,c) in subclasses]
    
    def parse_triples(self, union_flag=0, subclass_of=True, data_prop=True):
        obj_props = self.object_properties
        if data_prop:
            data_props = self.data_properties
            props = obj_props + data_props
        else:
            props = obj_props
        all_triples = []
        for prop in props:
            domain_children = self.get_child_node(prop, "rdfs:domain")
            range_children = self.get_child_node(prop, "rdfs:range")
            domain_prop = self.filter_null([self.extract_ID(el) for el in domain_children])
            range_prop = self.filter_null([self.extract_ID(el) for el in range_children])
            if not domain_children or not range_children:
                continue
            if not domain_prop:
                domain_prop = self.filter_null([self.extract_ID(el) for el in domain_children[0].getElementsByTagName("owl:Class")])
            if not range_prop:
                range_prop = self.filter_null([self.extract_ID(el) for el in range_children[0].getElementsByTagName("owl:Class")])
            if domain_prop and range_prop:
                if union_flag == 0:
                    all_triples.extend([(el[0], el[1], self.extract_ID(prop)) for el in list(itertools.product(domain_prop, range_prop))])
                else:
                    all_triples.append(("###".join(domain_prop), "###".join(range_prop), self.extract_ID(prop)))
        if subclass_of:
            all_triples.extend(self.get_subclass_triples())
        return list(set(all_triples))
    
    def get_triples(self, union_flag=0, subclass_of=True):
        return self.parse_triples(union_flag, subclass_of)

    def parse_subclasses(self, union_flag=0):
        subclasses = self.root.getElementsByTagName("rdfs:subClassOf")
        subclass_pairs = []
        for el in subclasses:
            inline_subclasses = self.extract_ID(el)
            if inline_subclasses:
                subclass_pairs.append((el, el.parentNode, "subclass_of"))
            else:
                level1_class = self.get_child_node(el, "owl:Class")
                if not level1_class:
                    restriction = el.getElementsByTagName("owl:Restriction")
                    if not restriction:
                        continue
                    prop = self.get_child_node(restriction[0], "owl:onProperty")
                    some_vals = self.get_child_node(restriction[0], "owl:someValuesFrom")
                    
                    if not prop or not some_vals:
                        continue
#                     print(self.extract_ID(el), "**", self.extract_ID(some_vals[0]), "**", self.extract_ID(prop[0]))
                    try:
                        if self.extract_ID(prop[0]) and self.extract_ID(some_vals[0]):
                            subclass_pairs.append((el.parentNode, some_vals[0], self.extract_ID(prop[0])))
                        elif self.extract_ID(prop[0]) and not self.extract_ID(some_vals[0]):
                            class_vals = self.get_child_node(some_vals[0], "owl:Class")
                            subclass_pairs.append((el.parentNode, class_vals[0], self.extract_ID(prop[0])))
                        elif not self.extract_ID(prop[0]) and self.extract_ID(some_vals[0]):
                            prop_vals = self.get_child_node(prop[0], "owl:ObjectProperty")
                            subclass_pairs.append((el.parentNode, some_vals[0], self.extract_ID(prop_vals[0])))
                        else:
                            prop_vals = self.get_child_node(prop[0], "owl:ObjectProperty")
                            class_vals = self.get_child_node(some_vals[0], "owl:Class")
                            subclass_pairs.append((el.parentNode, class_vals[0], self.extract_ID(prop_vals[0])))
                    except Exception as e:
                        print ("error", e)
                        continue
                else:
                    if self.extract_ID(level1_class[0]):
                        subclass_pairs.append((level1_class[0], el.parentNode, "subclass_of"))
                    else:
#                         level2classes = level1_class[0].getElementsByTagName("owl:Class")
#                         subclass_pairs.extend([(elem, el.parentNode, "subclass_of") for elem in level2classes if self.extract_ID(elem)])
                        continue
        return subclass_pairs
        
    def get_subclasses(self):
        subclasses = [(self.extract_ID(a), self.extract_ID(b), c) for (a,b,c) in self.subclasses]
        return [el for el in subclasses if el[0] and el[1] and el[2] and el[0]!="Thing" and el[1]!="Thing"]
    
    def filter_null(self, data):
        return [el for el in data if el]
    
    def extract_ID(self, element, check_coded = True):
        element_id = element.getAttribute("rdf:ID") or element.getAttribute("rdf:resource") or element.getAttribute("rdf:about")
        element_id = element_id.split("#")[-1]
        if len(list(filter(str.isdigit, element_id))) >= 3 and "_" in element_id and check_coded:
            return self.mapping_dict[element_id]
        return element_id.replace("UNDEFINED_", "").replace("DO_", "")
    
    def parse_classes(self):
        class_elems = [self.extract_ID(el) for el in self.root.getElementsByTagName("owl:Class")]
        subclass_classes = list(set(flatten([el[:-1] for el in self.triples])))
        return list(set(self.filter_null(class_elems + subclass_classes)))
    
    def get_classes(self):
        return self.classes
    
    def get_entities(self):
        entities = [self.extract_ID(el) for el in self.root.getElementsByTagName("owl:Class")]
        return list(set(self.filter_null(entities)))

    def parse_data_properties(self):
        data_properties = [el for el in self.get_child_node(self.root, 'owl:DatatypeProperty')]
        fn_data_properties = [el for el in self.get_child_node(self.root, 'owl:FunctionalProperty') if el]
        fn_data_properties = [el for el in fn_data_properties if type(el)==minidom.Element and 
            [el for el in self.get_child_node(el, "rdf:type") if 
             self.has_attribute_value(el, "rdf:resource", "DatatypeProperty")]]
        inv_fn_data_properties = [el for el in self.get_child_node(self.root, 'owl:InverseFunctionalProperty') if el]
        inv_fn_data_properties = [el for el in inv_fn_data_properties if type(el)==minidom.Element and 
            [el for el in self.get_child_node(el, "rdf:type") if 
             self.has_attribute_value(el, "rdf:resource", "DatatypeProperty")]]
        return data_properties + fn_data_properties + inv_fn_data_properties
        
    def parse_object_properties(self):
        obj_properties = [el for el in self.get_child_node(self.root, 'owl:ObjectProperty')]
        fn_obj_properties = [el for el in self.get_child_node(self.root, 'owl:FunctionalProperty') if el]
        fn_obj_properties = [el for el in fn_obj_properties if type(el)==minidom.Element and 
            [el for el in self.get_child_node(el, "rdf:type") if 
             self.has_attribute_value(el, "rdf:resource", "ObjectProperty")]]
        inv_fn_obj_properties = [el for el in self.get_child_node(self.root, 'owl:InverseFunctionalProperty') if el]
        inv_fn_obj_properties = [el for el in inv_fn_obj_properties if type(el)==minidom.Element and 
            [el for el in self.get_child_node(el, "rdf:type") if 
             self.has_attribute_value(el, "rdf:resource", "ObjectProperty")]]
        return obj_properties + fn_obj_properties + inv_fn_obj_properties
    
    def get_object_properties(self):
        obj_props = [self.extract_ID(el) for el in self.object_properties]
        return list(set(self.filter_null(obj_props)))
    
    def get_data_properties(self):
        data_props = [self.extract_ID(el) for el in self.data_properties]
        return list(set(self.filter_null(data_props)))


In [3]:
USE_folder = "/home/vlead/USE"
alignment_folder = "../Anatomy/Alignments/"

# Load reference alignments 
def load_alignments(folder):
    alignments = []
    for f in os.listdir(folder):
        doc = minidom.parse(folder + f)
        ls = list(zip(doc.getElementsByTagName('entity1'), doc.getElementsByTagName('entity2')))
        alignments.extend([(a.getAttribute('rdf:resource'), b.getAttribute('rdf:resource')) for (a,b) in ls])
    return alignments

model = SentenceTransformer('bert-large-nli-mean-tokens')

def extract_huggingface_embeddings(words):
    return model.encode(words)

def cos_sim(a,b):
    return 1 - spatial.distance.cosine(a, b)


# reference_alignments = load_alignments(alignment_folder)

ra_anatomy_coded = load_alignments(alignment_folder)
ontologies_in_alignment = [["../Anatomy/Ontologies/mouse.owl", "../Anatomy/Ontologies/human.owl"]]

reference_alignments = []

for (ont1_name, ont2_name) in ontologies_in_alignment:
    ont1 = Ontology(ont1_name)
    ont2 = Ontology(ont2_name)
    for elem in ra_anatomy_coded:
        pre1, pre2 = elem[0].split("#")[0].split(".")[0].split("/")[-1], elem[1].split("#")[0].split(".")[0].split("/")[-1]
        elem1, elem2 = elem[0].split("#")[-1], elem[1].split("#")[-1]
        reference_alignments.append(( pre1 + "#" + "_".join(ont1.mapping_dict[elem1].split()), pre2 + "#" + "_".join(ont2.mapping_dict[elem2].split())))

gt_mappings = [tuple([elem.split("/")[-1] for elem in el]) for el in reference_alignments]
# gt_mappings.extend(ra_anatomy)


# ontologies_in_alignment = pickle.load(open("../data_generic.pkl", "rb"))[-1][:-1]


In [33]:
len(reference_alignments)

1516

In [5]:
# Combinatorial mapping generation
all_mappings = []
for l in ontologies_in_alignment:
    ont1 = Ontology(l[0])
    ont2 = Ontology(l[1])
    
    ent1 = ont1.get_entities()
    ent2 = ont2.get_entities()
    
    obj1 = ont1.get_object_properties()
    obj2 = ont2.get_object_properties()
    
    data1 = ont1.get_data_properties()
    data2 = ont2.get_data_properties()

    mappings = list(itertools.product(ent1, ent2)) + list(itertools.product(obj1, obj2)) + list(itertools.product(data1, data2))

    ont_prefix1 = l[0].split("/")[-1].split(".")[0]
    ont_prefix2 = l[1].split("/")[-1].split(".")[0]
    
    all_mappings.extend([(ont_prefix1 + "#" + el[0], ont_prefix2 + "#" + el[1]) for el in mappings])
    

data = {mapping: False for mapping in all_mappings}
for mapping in set(gt_mappings):
    data[mapping] = True

In [6]:
# Abbrevation resolution preprocessing

abbreviations_dict = {}
final_dict = {}

for mapping in all_mappings:
    mapping = tuple([el.split("#")[1] for el in mapping])
    is_abb = re.search("[A-Z][A-Z]+", mapping[0])
    if is_abb:
        abbreviation = "".join([el[0].upper() for el in mapping[1].split("_")])
        if is_abb.group() in abbreviation:
            
            start = abbreviation.find(is_abb.group())
            end = start + len(is_abb.group())
            fullform = "_".join(mapping[1].split("_")[start:end])
            print ("left", mapping, abbreviation, fullform)
            
            rest_first = " ".join([el for el in mapping[0].replace(is_abb.group(), "").split("_") if el]).lower()
            rest_second = " ".join(mapping[1].split("_")[:start] + mapping[1].split("_")[end:])
            if is_abb.group() not in final_dict:
                final_dict[is_abb.group()] = [(fullform, rest_first, rest_second)]
            else:
                final_dict[is_abb.group()].append((fullform, rest_first, rest_second))

    is_abb = re.search("[A-Z][A-Z]+", mapping[1])
    if is_abb:
        abbreviation = "".join([el[0].upper() for el in mapping[0].split("_")])
        
        if is_abb.group() in abbreviation:
            start = abbreviation.find(is_abb.group())
            end = start + len(is_abb.group())
            fullform = "_".join(mapping[0].split("_")[start:end])
            print ("right", mapping, abbreviation, fullform)

            rest_first = " ".join([el for el in mapping[1].replace(is_abb.group(), "").split("_") if el]).lower()
            rest_second = " ".join(mapping[0].split("_")[:start] + mapping[0].split("_")[end:])
            if is_abb.group() not in final_dict:
                final_dict[is_abb.group()] = [(fullform, rest_first, rest_second)]
            else:
                final_dict[is_abb.group()].append((fullform, rest_first, rest_second))

keys = [el for el in list(set(flatten([flatten([tup[1:] for tup in final_dict[key]]) for key in final_dict]))) if el]
abb_embeds = dict(zip(keys, extract_huggingface_embeddings(keys)))

scored_dict = {}
for abbr in final_dict:
    sim_list = [(tup[0], tup[1], tup[2], cos_sim(abb_embeds[tup[1]], abb_embeds[tup[2]])) if tup[1] and tup[2]
                else (tup[0], tup[1], tup[2], 0) for tup in final_dict[abbr]]
    scored_dict[abbr] = sorted(list(set(sim_list)), key=lambda x:x[-1], reverse=True)

resolved_dict = {key: scored_dict[key][0] for key in scored_dict}
filtered_dict = {key: " ".join(resolved_dict[key][0].split("_")) for key in resolved_dict if resolved_dict[key][-1] > 0.9}


left ('hippocampus CA4', 'Calcarine_Artery') CA Calcarine_Artery
left ('hippocampus CA4', 'Anterior_Choroidal_Artery') ACA Choroidal_Artery
left ('hippocampus CA4', 'Spinal_Cord_Arachnoid_Membrane') SCAM Cord_Arachnoid
left ('hippocampus CA4', 'External_Circumflex_Artery') ECA Circumflex_Artery
left ('hippocampus CA4', 'Cerebral_Arachnoid_Membrane') CAM Cerebral_Arachnoid
left ('hippocampus CA4', 'External_Carotid_Artery_Branch') ECAB Carotid_Artery
left ('hippocampus CA4', 'Internal_Calcanean_Artery') ICA Calcanean_Artery
left ('hippocampus CA4', 'Cervical_Artery') CA Cervical_Artery
left ('hippocampus CA4', 'Posterior_Cerebral_Artery_Branch') PCAB Cerebral_Artery
left ('hippocampus CA4', 'Cerebral_Artery') CA Cerebral_Artery
left ('hippocampus CA4', 'Callosomarginal_Artery') CA Callosomarginal_Artery
left ('hippocampus CA4', 'Left_Coronary_Artery_Branch') LCAB Coronary_Artery
left ('hippocampus CA4', 'Circumflex_Branch_of_the_Left_Coronary_Artery') CBOTLCA Coronary_Artery
left ('hipp

left ('hippocampus CA2', 'Calcarine_Artery') CA Calcarine_Artery
left ('hippocampus CA2', 'Anterior_Choroidal_Artery') ACA Choroidal_Artery
left ('hippocampus CA2', 'Spinal_Cord_Arachnoid_Membrane') SCAM Cord_Arachnoid
left ('hippocampus CA2', 'External_Circumflex_Artery') ECA Circumflex_Artery
left ('hippocampus CA2', 'Cerebral_Arachnoid_Membrane') CAM Cerebral_Arachnoid
left ('hippocampus CA2', 'External_Carotid_Artery_Branch') ECAB Carotid_Artery
left ('hippocampus CA2', 'Internal_Calcanean_Artery') ICA Calcanean_Artery
left ('hippocampus CA2', 'Cervical_Artery') CA Cervical_Artery
left ('hippocampus CA2', 'Posterior_Cerebral_Artery_Branch') PCAB Cerebral_Artery
left ('hippocampus CA2', 'Cerebral_Artery') CA Cerebral_Artery
left ('hippocampus CA2', 'Callosomarginal_Artery') CA Callosomarginal_Artery
left ('hippocampus CA2', 'Left_Coronary_Artery_Branch') LCAB Coronary_Artery
left ('hippocampus CA2', 'Circumflex_Branch_of_the_Left_Coronary_Artery') CBOTLCA Coronary_Artery
left ('hipp

left ('trochlear IV nerve', 'External_Iliac_Vein') EIV Iliac_Vein
left ('trochlear IV nerve', 'Ileocecal_Valve') IV Ileocecal_Valve
left ('trochlear IV nerve', 'Internal_Iliac_Vein') IIV Iliac_Vein
left ('trochlear IV nerve', 'Right_Innominate_Vein') RIV Innominate_Vein
left ('trochlear IV nerve', 'Innominate_Vein') IV Innominate_Vein
left ('trochlear IV nerve', 'Ileocolic_Vein') IV Ileocolic_Vein
left ('trochlear IV nerve', 'Inferior_Vermian_Artery') IVA Inferior_Vermian
left ('trochlear IV nerve', 'Interosseous_Vein') IV Interosseous_Vein
left ('trochlear IV nerve', 'Iliac_Vein') IV Iliac_Vein
left ('trochlear IV nerve', 'Inferior_Vena_Cava_Opening') IVCO Inferior_Vena
left ('trochlear IV nerve', 'Left_Innominate_Vein') LIV Innominate_Vein
left ('trochlear IV nerve', 'Fissure_of_the_Inferior_Vena_Cava') FOTIVC Inferior_Vena
left ('trochlear IV nerve', 'Ileal_Vein') IV Ileal_Vein
left ('trochlear IV nerve', 'Superior_Intercostal_Vein') SIV Intercostal_Vein
left ('trochlear IV nerve', 

left ('abducens VI nerve', 'Velum_Interpositum_Cistern') VIC Velum_Interpositum


In [7]:

def camel_case_split(identifier):
    matches = re.finditer('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)', identifier)
    return [m.group(0) for m in matches]

def parse(word):
    return flatten([el.split("_") for el in camel_case_split(word)])
    

extracted_elems = []

for ont_name in list(set(flatten(ontologies_in_alignment))):
    ont = Ontology(ont_name)
    entities = ont.get_entities()
    props = ont.get_object_properties() + ont.get_data_properties()
    triples = list(set(flatten(ont.get_triples())))
    extracted_elems.extend([ont_name.split("/")[-1].split(".")[0] + "#" + elem for elem in entities + props + triples])

extracted_elems = list(set(extracted_elems))
inp_anatomy = [" ".join(parse(word.split("#")[1])) for word in extracted_elems]

roman_regex = "^M{0,3}(CM|CD|D?C{0,3})(XC|XL|L?X{0,3})(IX|IV|V?I{0,3})$"
inp_anatomy = [" ".join([word.replace("/", " or ") for word in elem.split() 
           if not re.search(roman_regex, word) and not re.search("\d", word)]) for elem in inp_anatomy]

inp_resolved = []
for concept in inp_anatomy:
    for key in filtered_dict:
        concept = concept.replace(key, filtered_dict[key])
    final_list = []
    # Lowering case except in abbreviations
    for word in concept.split(" "):
        if not re.search("[A-Z][A-Z]+", word):
            final_list.append(word.lower())
        else:
            final_list.append(word)
    concept = " ".join(final_list)
    inp_resolved.append(concept)


print ("Total number of extracted unique classes and properties from entire RA set: ", len(extracted_elems))

extracted_elems = ["<UNK>"] + extracted_elems

embeds = np.array([np.zeros(512,)] + list(extract_huggingface_embeddings(inp_resolved)))
# embeds = np.array([np.zeros(512,)] + list(extractUSEEmbeddings(inp_spellchecked)))
embeddings = dict(zip(extracted_elems, embeds))


emb_vals = list(embeddings.values())
emb_indexer = {key: i for i, key in enumerate(list(embeddings.keys()))}
emb_indexer_inv = {i: key for i, key in enumerate(list(embeddings.keys()))}


Total number of extracted unique classes and properties from entire RA set:  6055


In [25]:
def path_to_root(elem, ont_mappings):
    if elem not in ont_mappings or not ont_mappings[elem]:
        return []
    output = flatten([[e] + path_to_root(e, ont_mappings) for e in ont_mappings[elem]])
    return output

def get_one_hop_neighbours(ont, K=1):
    ont_obj = Ontology(ont)
    triples = ont_obj.get_triples()
    entities = [(a,b) for (a,b,c) in triples]
    neighbours_dict = {elem: [elem] for elem in list(set(flatten(entities)))}
    for e1, e2 in entities:
        neighbours_dict[e1].append(e2)
        neighbours_dict[e2].append(e1)
    print (neighbours_dict)
    rootpath_dict = ont_obj.parents_dict
    rootpath_dict = {elem: path_to_root(elem, rootpath_dict) for elem in rootpath_dict}
    ont = ont.split("/")[-1].split(".")[0]
    
    for entity in neighbours_dict:
        if entity in rootpath_dict and len(rootpath_dict[entity]) > 0:
            neighbours_dict[entity].extend(rootpath_dict[entity])
        else:
            continue
    
    
#     prop_triples = ont_obj.get_triples(subclass_of=False)
#     neighbours_dict_props = {c: [c] for a,b,c in prop_triples}
#     for e1, e2, p in prop_triples:
#         neighbours_dict_props[p].extend([e1, e2])

    #neighbours_dict = {**neighbours_dict, **neighbours_dict_props}
    
    # for elem in ont_obj.get_entities() + ont_obj.get_object_properties() + ont_obj.get_data_properties():
    #     if elem not in neighbours_dict:
    #         neighbours_dict[elem] = [elem]

    neighbours_dict = {el: neighbours_dict[el][:1] + sorted(list(set(neighbours_dict[el][1:])))
                       for el in neighbours_dict}
#     neighbours_dict = {el: neighbours_dict[el][:23] for el in neighbours_dict if len( neighbours_dict[el]) > 2}
#     ont = ont.split("/")[-1].split(".")[0]
    neighbours_dict = {ont + "#" + el: [ont + "#" + e for e in neighbours_dict[el]] for el in neighbours_dict}
    return neighbours_dict

neighbours_dicts = {ont.split("/")[-1].split(".")[0]: get_one_hop_neighbours(ont) for ont in list(set(flatten(ontologies_in_alignment)))}
max_neighbours = np.max(flatten([[len(el[e]) for e in el] for el in neighbours_dicts.values()]))
neighbours_lens = {ont: {key: len(neighbours_dicts[ont][key]) for key in neighbours_dicts[ont]}
                   for ont in neighbours_dicts}
neighbours_dicts = {ont: {key: neighbours_dicts[ont][key] + ["<UNK>" for i in range(max_neighbours -len(neighbours_dicts[ont][key]))]
              for key in neighbours_dicts[ont]} for ont in neighbours_dicts}


data_items = data.items()
data_shuffled_t = [elem for elem in data_items if elem[1]]
data_shuffled_f = [elem for elem in data_items if not elem[1]][:400*len(data_shuffled_t)]

data_shuffled_train = data_shuffled_t[:int(0.9*len(data_shuffled_t))] + data_shuffled_f[:int(0.9*len(data_shuffled_f))]
data_shuffled_test = data_shuffled_t[int(0.9*len(data_shuffled_t)):] + data_shuffled_f[int(0.9*len(data_shuffled_f)):]

np.random.shuffle(data_shuffled_train)
np.random.shuffle(data_shuffled_test)

data_train = OrderedDict(data_shuffled_train)
data_test = OrderedDict(data_shuffled_test)

# ontologies_in_alignment = [[el.split("/")[1].split(".")[0] for el in ont] for ont in ontologies_in_alignment]
f = open("data_anatomy.pkl", "wb")
pickle.dump([data_train, data_test, emb_indexer, emb_indexer_inv, emb_vals, gt_mappings, neighbours_dicts, ontologies_in_alignment], f)

{'Systemic_Vein': ['Systemic_Vein', 'Vein_of_the_Head_or_Neck', 'Deep_Vein', 'Vein', 'Superficial_Vein'], 'Parathyroid_Gland_Chief_Cell': ['Parathyroid_Gland_Chief_Cell', 'Parathyroid_Gland_Parenchymal_Cell', 'Chief_Cell'], 'Depressor_Labii_Inferioris': ['Depressor_Labii_Inferioris', 'Muscle'], 'Tonsillar_Crypt': ['Tonsillar_Crypt', 'Tonsil', 'Tonsil_Part'], 'Skin_of_Other_and_Unspecified_Parts_of_Face': ['Skin_of_Other_and_Unspecified_Parts_of_Face', 'Other_Anatomic_Concept'], 'Anterior_Lobe_of_the_Pituitary_Gland': ['Anterior_Lobe_of_the_Pituitary_Gland', 'Corticotrope_Cell', 'Pituitary_Gland', 'Endocrine_System_Part', 'Adenohypophysial_Cell', 'Gonadotrope_Cell'], 'Photosensitive_Region_of_the_Retina': ['Photosensitive_Region_of_the_Retina', 'Neural_Retina', 'Eye_Part'], 'Fundus': ['Fundus', 'Other_Anatomic_Concept'], 'Isthmic_Segment_of_the_Fallopian_Tube': ['Isthmic_Segment_of_the_Fallopian_Tube', 'Fallopian_Tube', 'Female_Reproductive_System_Part'], 'Portal_Vein': ['Portal_Vein', 

{'pelvic girdle muscle': ['pelvic girdle muscle', 'pelvis muscle'], 'lung connective tissue': ['lung connective tissue', 'lung', 'respiratory system connective tissue'], 'adrenal gland zona reticularis': ['adrenal gland zona reticularis', 'adrenal gland cortex zone'], 'cochlear duct': ['cochlear duct', 'vestibular membrane', 'basilar membrane', 'spiral sulcus', 'tectorial membrane', 'limbus lamina spiralis', 'spiral ligament', 'cochlea', 'membranous labyrinth', 'stria vascularis', 'spiral organ'], 'mammary gland milk': ['mammary gland milk', 'mammary gland fluid/secretion'], 'area cribosa': ['area cribosa', 'papillary duct'], 'ethmoidal artery': ['ethmoidal artery', 'artery'], 'cellular cartilage': ['cellular cartilage', 'cartilage'], 'lobar bronchus connective tissue': ['lobar bronchus connective tissue', 'lobar bronchus'], 'lower leg bone': ['lower leg bone', 'lower leg', 'leg bone', 'tibia', 'sesamoid bone of gastrocnemius', 'fibula'], 'lymph node primary follicle': ['lymph node pri

In [31]:
emb_indexer["upper_jaw"]

KeyError: 'upper_jaw'

606400

In [28]:
cos_sim(*extractUSEEmbeddings(['late paid applicant', 'late registered participant']))

0.6148052215576172

In [9]:
# AML test
def is_test(test_onto, key):
    return tuple([el.split("#")[0] for el in key]) in test_onto

results = []
all_ont_pairs = list(set([tuple([el.split("#")[0] for el in l]) for l in data.keys()]))
for i in list(range(0, len(all_ont_pairs), 3)):
    test_onto = all_ont_pairs[i:i+3]
    for ont_pair in test_onto:
        a, b, c = ont_pair[0], ont_pair[1], ont_pair[0] + "-" + ont_pair[1]
        java_command = "java -jar AML_v3.1/AgreementMakerLight.jar -s conference_ontologies/" + a + ".owl" + \
                            " -t conference_ontologies/" + b + ".owl -o AML-test-results/" + c + ".rdf -a"
        process = subprocess.Popen(java_command.split(), stdout=subprocess.PIPE)
        output, error = process.communicate()
    print (os.listdir("AML-test-results/"))
    pred_aml = load_alignments("AML-test-results/")
    pred_aml = [tuple([el.split("/")[-1] for el in key]) for key in pred_aml]
    tp = len([elem for elem in pred_aml if data[elem]])
    fn = len([key for key in gt_mappings if key not in set(pred_aml) and is_test(test_onto, key)])
    fp = len([elem for elem in pred_aml if not data[elem]])

    precision = tp/(tp+fp)
    recall = tp/(tp+fn)
    f1score = 2 * precision * recall / (precision + recall)
    f2score = 5 * precision * recall / (4 * precision + recall)
    f0_5score = 1.25 * precision * recall / (0.25 * precision + recall)
    print (precision, recall, f1score, f2score, f0_5score)
    
    metrics = [precision, recall, f1score, f2score, f0_5score]
    results.append(metrics)
    
    _ = [os.remove(f) for f in glob.glob('AML-test-results/*')]
    
print ("Final Results:", np.mean(results, axis=0))

['confOf-sigkdd.rdf', 'iasted-sigkdd.rdf', 'cmt-ekaw.rdf']
0.8275862068965517 0.7272727272727273 0.7741935483870968 0.7453416149068324 0.8053691275167786
['confOf-iasted.rdf', 'conference-edas.rdf', 'cmt-sigkdd.rdf']
0.8148148148148148 0.5789473684210527 0.6769230769230768 0.6145251396648045 0.7534246575342465
['ekaw-sigkdd.rdf', 'conference-sigkdd.rdf', 'conference-confOf.rdf']
0.78125 0.6097560975609756 0.684931506849315 0.6377551020408163 0.7396449704142012
['confOf-edas.rdf', 'edas-iasted.rdf', 'cmt-conference.rdf']
0.7941176470588235 0.5094339622641509 0.6206896551724137 0.548780487804878 0.7142857142857143
['edas-sigkdd.rdf', 'conference-iasted.rdf', 'ekaw-iasted.rdf']
0.7916666666666666 0.48717948717948717 0.6031746031746031 0.5277777777777778 0.7037037037037036
['cmt-confOf.rdf', 'cmt-edas.rdf', 'edas-ekaw.rdf']
0.8181818181818182 0.5192307692307693 0.6352941176470589 0.5601659751037344 0.733695652173913
['conference-ekaw.rdf', 'confOf-ekaw.rdf', 'cmt-iasted.rdf']
0.79545454545

In [168]:
ontologies_in_alignment = [[el.split("/")[1].split(".")[0] for el in ont] for ont in ontologies_in_alignment][:-1] + [["human", "mouse"]]

In [8]:
for i in list(range(0, len(ontologies_in_alignment)-1, 3)):
    
    test_onto = ontologies_in_alignment[i:i+3]
    print (test_onto)

[['conference', 'ekaw'], ['confOf', 'ekaw'], ['ekaw', 'sigkdd']]
[['edas', 'sigkdd'], ['confOf', 'sigkdd'], ['iasted', 'sigkdd']]
[['confOf', 'iasted'], ['conference', 'iasted'], ['confOf', 'edas']]
[['edas', 'iasted'], ['conference', 'edas'], ['cmt', 'ekaw']]
[['cmt', 'confOf'], ['cmt', 'edas'], ['conference', 'sigkdd']]
[['cmt', 'sigkdd'], ['conference', 'confOf'], ['edas', 'ekaw']]
[['cmt', 'conference'], ['cmt', 'iasted'], ['ekaw', 'iasted']]


In [20]:
f = open("data_unhas.pkl", "wb")
pickle.dump([data, emb_indexer, emb_indexer_inv, emb_vals, gt_mappings, neighbours_dicts, ontologies_in_alignment], f)


In [20]:
def count_non_unk(elem):
    return len([l for l in elem if l!="<UNK>"])
neighbours_dicts = {ont: {el: neighbours_dicts[ont][el][:int(sys.argv[1])] for el in neighbours_dicts[ont]
       if count_non_unk(neighbours_dicts[ont][el]) > int(sys.argv[2])} for ont in neighbours_dicts}

(167, 1240)

In [18]:
from collections import Counter
c = Counter(dict(sorted(Counter(flatten([l.split() for l in inp])).items(), key=lambda x:x[1], reverse=True)))
c

Counter({'has': 95,
         'paper': 61,
         'conference': 58,
         'of': 54,
         'is': 49,
         'by': 46,
         'topic': 39,
         'a': 35,
         'review': 31,
         'event': 27,
         'committee': 27,
         'member': 24,
         'chair': 22,
         'registration': 18,
         'date': 18,
         'author': 18,
         'reviewer': 15,
         'for': 14,
         'session': 14,
         'fee': 13,
         'contribution': 13,
         'program': 13,
         'speaker': 12,
         'abstract': 11,
         'name': 11,
         'computer': 11,
         'deadline': 11,
         'submission': 10,
         'city': 9,
         'hotel': 9,
         'in': 9,
         'to': 9,
         'networks': 9,
         'student': 9,
         'was': 8,
         'document': 8,
         'presenter': 8,
         'an': 8,
         'time': 8,
         'presentation': 8,
         'on': 8,
         'call': 7,
         'subclass': 7,
         'invited': 7,
         'tal

In [32]:
extracted_elems

['<UNK>',
 'human#Vastus_Lateralis',
 'mouse#thoracic vertebra',
 'human#Axillary_Vein',
 'human#Dorsal_Curve',
 'mouse#enteric nervous system',
 'human#Semitendinosus',
 'mouse#chest blood vessel',
 'mouse#pancreas parenchyma',
 'human#Utricle',
 'mouse#bed nucleus of stria terminalis',
 'human#Gastrointestinal_Tract_Lamina_Propria',
 'human#Mesencephalic_Perforating_Artery',
 'human#Bile_Duct_Epithelium',
 'mouse#vein endothelium',
 'mouse#knee joint',
 'human#Bronchus_Elastic_Tissue',
 'human#Longitudinal_Fissure',
 'mouse#respiratory system venous blood vessel',
 'mouse#efferent duct epithelium',
 'human#Perforating_Canal',
 'mouse#pars intermedia',
 'mouse#cranial ganglion/nerve',
 'human#Subsegmental_Bronchus',
 'human#Metacarpal_Bone_Digit_1',
 'human#Subsegmental_Bronchus_of_Left_Lung',
 'mouse#anterior abdominal wall muscle',
 'mouse#dorsal root ganglion',
 'mouse#salivary gland epithelium',
 'mouse#lymph node endothelium',
 'human#Arytenoid_Cartilage',
 'human#Central_Portion

In [83]:
import requests

url = "https://montanaflynn-spellcheck.p.rapidapi.com/check/"

headers = {
    'x-rapidapi-host': "montanaflynn-spellcheck.p.rapidapi.com",
    'x-rapidapi-key': "9965b01207msh06291e57d6f2c55p1a6a16jsn0fb016da4a62"
    }

# inp_spellchecked = []
for concept in inp[731:]:
    querystring = {"text": concept}
    response = requests.request("GET", url, headers=headers, params=querystring).json()
    if response["suggestion"] != concept:
        resolved = str(concept)
        for word in response["corrections"]:
            if not re.search("[A-Z][A-Z]+", concept):
                resolved = resolved.replace(word, response["corrections"][word][0])
        
        inp_spellchecked.append(resolved)
        print (concept, resolved)
    else:
        inp_spellchecked.append(concept)




registeered applicant registered applicant
technically organised by technically organized by
ngo no
sponzorship sponsorship


In [77]:
querystring = {"text": "technically Organised By"}
response = requests.request("GET", url, headers=headers, params=querystring)
response.json()

{'original': 'technically Organised By',
 'suggestion': 'technically Organized By',
 'corrections': {'Organised': ['Organized',
   'Organist',
   'Organism',
   'Organizes',
   'Disorganize',
   'Organize',
   'Agonized']}}

In [None]:
inp_spellchecked[730], inp[731]

In [78]:
fn_spellchecked, fp_spellchecked = [dict(el) for el in pickle.load(open("test_v2.pkl", "rb"))]
fn_baseline, fp_baseline = [dict(el) for el in pickle.load(open("test_best.pkl", "rb"))]
fn_unhas, fp_unhas = [dict(el) for el in pickle.load(open("test_unhas.pkl", "rb"))]
fn_resolved, fp_resolved = [dict(el) for el in pickle.load(open("test_resolved.pkl", "rb"))]

fn_dict, fp_dict = {}, {}
def create_comparison_file(file, idx):
    fn, fp = [dict(el) for el in pickle.load(open(file, "rb"))]
    
    for key in fn:
        if key in fn_dict:
            fn_dict[key][idx] = fn[key]
        else:
            fn_dict[key] = ["N/A" for i in range(4)]
            fn_dict[key][idx] = fn[key]
    
    for key in fp:
        if key in fp_dict:
            fp_dict[key][idx] = fp[key]
        else:
            fp_dict[key] = ["N/A" for i in range(4)]
            fp_dict[key][idx] = fp[key]
    

create_comparison_file("test_best.pkl", 0)
create_comparison_file("test_unhas.pkl", 1)
create_comparison_file("test_v2.pkl", 2)
create_comparison_file("test_resolved.pkl", 3)

open("fn - comparison.tsv", "w+").write("\n".join(["\t".join([str(el) for el in flatten(el)]) for el in fn_dict.items()]))
open("fp - comparison.tsv", "w+").write("\n".join(["\t".join([str(el) for el in flatten(el)]) for el in fp_dict.items()]))

7796

In [14]:
ontologies_in_alignment = pickle.load(open("data_path.pkl", "rb"))[-1]
ontologies_in_alignment

[['confOf', 'sigkdd'],
 ['iasted', 'sigkdd'],
 ['cmt', 'ekaw'],
 ['confOf', 'iasted'],
 ['conference', 'edas'],
 ['cmt', 'sigkdd'],
 ['ekaw', 'sigkdd'],
 ['conference', 'confOf'],
 ['conference', 'sigkdd'],
 ['confOf', 'edas'],
 ['cmt', 'conference'],
 ['edas', 'iasted'],
 ['conference', 'iasted'],
 ['edas', 'sigkdd'],
 ['ekaw', 'iasted'],
 ['cmt', 'edas'],
 ['edas', 'ekaw'],
 ['cmt', 'confOf'],
 ['confOf', 'ekaw'],
 ['conference', 'ekaw'],
 ['cmt', 'iasted']]

In [72]:
d = {('confOf#Organization', 'sigkdd#Organizator'): (1,2,3,4),
 ('iasted#Document', 'sigkdd#Document'): (5,6,78,8)}
[[str(el) for el in flatten(el)] for el in d.items()]

[['confOf#Organization', 'sigkdd#Organizator', '1', '2', '3', '4'],
 ['iasted#Document', 'sigkdd#Document', '5', '6', '78', '8']]

In [34]:
abbreviations_dict = {}
final_dict = {}

for mapping in all_mappings:
    mapping = tuple([el.split("#")[1] for el in mapping])
    is_abb = re.search("[A-Z][A-Z]+", mapping[0])
    if is_abb:
        abbreviation = "".join([el[0].upper() for el in mapping[1].split("_")])
        if is_abb.group() in abbreviation:
            
            start = abbreviation.find(is_abb.group())
            end = start + len(is_abb.group())
            fullform = "_".join(mapping[1].split("_")[start:end])
            print ("left", mapping, abbreviation, fullform)
            
            rest_first = " ".join([el for el in mapping[0].replace(is_abb.group(), "").split("_") if el]).lower()
            rest_second = " ".join(mapping[1].split("_")[:start] + mapping[1].split("_")[end:])
            if is_abb.group() not in final_dict:
                final_dict[is_abb.group()] = [(fullform, rest_first, rest_second)]
            else:
                final_dict[is_abb.group()].append((fullform, rest_first, rest_second))

    is_abb = re.search("[A-Z][A-Z]+", mapping[1])
    if is_abb:
        abbreviation = "".join([el[0].upper() for el in mapping[0].split("_")])
        
        if is_abb.group() in abbreviation:
            start = abbreviation.find(is_abb.group())
            end = start + len(is_abb.group())
            fullform = "_".join(mapping[0].split("_")[start:end])
            print ("right", mapping, abbreviation, fullform)

            rest_first = " ".join([el for el in mapping[1].replace(is_abb.group(), "").split("_") if el]).lower()
            rest_second = " ".join(mapping[0].split("_")[:start] + mapping[0].split("_")[end:])
            if is_abb.group() not in final_dict:
                final_dict[is_abb.group()] = [(fullform, rest_first, rest_second)]
            else:
                final_dict[is_abb.group()].append((fullform, rest_first, rest_second))

keys = [el for el in list(set(flatten([flatten([tup[1:] for tup in final_dict[key]]) for key in final_dict]))) if el]
abb_embeds = dict(zip(keys, extractUSEEmbeddings(keys)))

scored_dict = {}
for abbr in final_dict:
    sim_list = [(tup[0], tup[1], tup[2], cos_sim(abb_embeds[tup[1]], abb_embeds[tup[2]])) if tup[1] and tup[2]
                else (tup[0], tup[1], tup[2], 0) for tup in final_dict[abbr]]
    scored_dict[abbr] = sorted(list(set(sim_list)), key=lambda x:x[-1], reverse=True)

resolved_dict = {key: scored_dict[key][0] for key in scored_dict}
filtered_dict = {key: " ".join(resolved_dict[key][0].split("_")) for key in resolved_dict if resolved_dict[key][-1] > 0.9}
inp_resolved = []
for concept in inp:
    for key in filtered_dict:
        concept = concept.replace(key, filtered_dict[key])
    inp_resolved.append(concept)
inp_resolved

left ('Chair_PC', 'Program_Chair') PC Program_Chair
left ('Chair_PC', 'Program_Committee') PC Program_Committee
left ('Chair_PC', 'Program_Committee_member') PCM Program_Committee
left ('Member_PC', 'Program_Chair') PC Program_Chair
left ('Member_PC', 'Program_Committee') PC Program_Committee
left ('Member_PC', 'Program_Committee_member') PCM Program_Committee
left ('Chair_PC', 'Presenter_city') PC Presenter_city
left ('Member_PC', 'Presenter_city') PC Presenter_city
left ('OC_Member', 'Organizing_Committee_member') OCM Organizing_Committee
left ('OC_Member', 'Organizing_Committee') OC Organizing_Committee
left ('PC_Member', 'Program_Chair') PC Program_Chair
left ('PC_Member', 'Program_Committee') PC Program_Committee
left ('PC_Member', 'Program_Committee_member') PCM Program_Committee
left ('OC_Chair', 'Organizing_Committee_member') OCM Organizing_Committee
left ('OC_Chair', 'Organizing_Committee') OC Organizing_Committee
left ('PC_Chair', 'Program_Chair') PC Program_Chair
left ('PC_C

In [48]:
keys = [el for el in list(set(flatten([flatten([tup[1:] for tup in final_dict[key]]) for key in final_dict]))) if el]
abb_embeds = dict(zip(keys, extractUSEEmbeddings(keys)))


In [27]:
cos_sim(*extractUSEEmbeddings(["Conference Banquet", "Dinner Banquet"]))

0.8085169792175293

In [26]:
fn

[(('confOf#hasEmail', 'sigkdd#E-mail'), 0.9161555776735063),
 (('confOf#Chair_PC', 'sigkdd#Program_Chair'), 0.8290806880788957),
 (('iasted#Student_registration_fee', 'sigkdd#Registration_Student'),
  0.9156892709934972),
 (('iasted#Deadline_for_notification_of_acceptance',
   'sigkdd#Deadline_Author_notification'),
  0.5085231269229767),
 (('iasted#Nonmember_registration_fee', 'sigkdd#Registration_Non-Member'),
  0.8098175068174337),
 (('cmt#ConferenceMember', 'ekaw#Conference_Participant'), 0.3532655254301476),
 (('cmt#Author', 'ekaw#Paper_Author'), 0.9288647440450568),
 (('cmt#writtenBy', 'ekaw#reviewWrittenBy'), 0.7592434303979944),
 (('cmt#hasBeenAssigned', 'ekaw#reviewerOfPaper'), 0.2502871547766108),
 (('cmt#assignedTo', 'ekaw#hasReviewer'), 0.3499445900379554),
 (('cmt#PaperFullVersion', 'ekaw#Regular_Paper'), 0.8379052493933772),
 (('confOf#Administrative_event', 'iasted#Activity_before_conference'),
  0.37740041094526244),
 (('conference#has_the_first_name', 'edas#hasFirstNam

In [53]:
scored_dict = {}
for abbr in final_dict:
    sim_list = [(tup[0], tup[1], tup[2], cos_sim(abb_embeds[tup[1]], abb_embeds[tup[2]])) if tup[1] and tup[2]
                else (tup[0], tup[1], tup[2], 0) for tup in final_dict[abbr]]
    scored_dict[abbr] = sorted(list(set(sim_list)), key=lambda x:x[-1], reverse=True)


In [61]:
inp_case_handled = []
for concept in inp:
    final_list = []
    for word in concept.split(" "):
        if not re.search("[A-Z][A-Z]+", concept):
            final_list.append(word.lower())
        else:
            final_list.append(word)
    case_resolved = " ".join(final_list)
    inp_case_handled.append(case_resolved)
    
inp_case_handled

['pay',
 'rejected by',
 'Registration SIGMOD Member',
 'is connected with',
 'NGO',
 'overhead projector',
 'name of conference',
 'call for participation',
 'coffee break',
 'scientifically organised by',
 'volunteer',
 'publisher',
 'regular',
 'add program committee member',
 'contact email',
 'part of event',
 'dinner banquet',
 'social program',
 'decision',
 'is sent by',
 'was a committee chair of',
 'organisation',
 'paper',
 'assign external reviewer',
 'was a member of',
 'has cost amount',
 'session chair',
 'single level conference',
 'hotel',
 'deadline hotel reservation',
 'country',
 'important dates',
 'paper',
 'has parts',
 'event',
 'paper due on',
 'is the 1th part of',
 'has surname',
 'wireless communications topic',
 'two level conference',
 'has first name',
 'bid',
 'is designed for',
 'author',
 'double hotel room',
 'review form',
 'presentation',
 'subject area',
 'positive integer',
 'deadline abstract submission',
 'has a degree',
 'is present in',
 'rela

In [23]:
Ontology("conference_ontologies/conference.owl").triples

[('Conference_proceedings', 'string', 'has_a_name'),
 ('Committee_member', 'Committee', 'was_a_member_of'),
 ('Review', 'Conference_document', 'subclass_of'),
 ('Review_preference', 'Reviewer', 'belongs_to_reviewers'),
 ('Information_for_participants', 'Conference_document', 'subclass_of'),
 ('Regular_author', 'Conference_contributor', 'subclass_of'),
 ('Late_paid_applicant', 'Paid_applicant', 'subclass_of'),
 ('Active_conference_participant', 'Presentation', 'gives_presentations'),
 ('Invited_speaker', 'Conference_contributor', 'subclass_of'),
 ('Program_committee', 'Conference_volume', 'was_a_program_committee_of'),
 ('Topic', 'Review_preference', 'has_been_assigned_a_review_reference'),
 ('Conference_volume', 'Important_dates', 'has_important_dates'),
 ('Rejected_contribution', 'Reviewed_contribution', 'subclass_of'),
 ('Contribution_co-author', 'Regular_author', 'subclass_of'),
 ('Active_conference_participant', 'Conference_contributor', 'subclass_of'),
 ('Presentation', 'Abstract'

In [67]:
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=385.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435778770.0, style=ProgressStyle(descri…




In [75]:
# from transformers import XLNetTokenizer, XLNetModel
# import torch
# import scipy
# import torch.nn as nn
# import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

input_ids = torch.tensor(tokenizer.encode("fastigial nucleus", add_special_tokens=True)).unsqueeze(0)
outputs = model(input_ids)
last_hidden_states = outputs[0].mean(1)

input_ids = torch.tensor(tokenizer.encode("femur", add_special_tokens=True)).unsqueeze(0) 

outputs1 = model(input_ids)
last_hidden_states1 = outputs1[0].mean(1)

cos = nn.CosineSimilarity(dim=1, eps=1e-6)
cos(last_hidden_states, last_hidden_states1)

tensor([0.6853], grad_fn=<DivBackward0>)

In [76]:
last_hidden_states.shape

torch.Size([1, 768])