In [1]:
import pandas as pd

from bluegraph.core import PandasPGFrame
from bluegraph.preprocess.generators import CooccurrenceGenerator
from bluegraph.preprocess.encoders import ScikitLearnPGEncoder

## Data preparation

Fist, we read the source dataset with mentions of entities in different paragraphs

In [2]:
mentions = pd.read_csv("data/labeled_entity_occurrence.csv")

In [3]:
# Extract unique paper/seciton/paragraph identifiers
mentions = mentions.rename(columns={"occurrence": "paragraph"})
number_of_paragraphs = len(mentions["paragraph"].unique())

In [4]:
mentions

Unnamed: 0,entity,paragraph
0,lithostathine-1-alpha,1
1,pulmonary,1
2,host,1
3,lithostathine-1-alpha,2
4,surfactant protein d measurement,2
...,...,...
2281346,covid-19,227822
2281347,covid-19,227822
2281348,viral infection,227823
2281349,lipid,227823


We will also load a dataset that contains definitions of entities and their types

In [5]:
entity_data = pd.read_csv("data/entity_types_defs.csv")

In [6]:
entity_data

Unnamed: 0,entity,entity_type,definition
0,(e3-independent) e2 ubiquitin-conjugating enzyme,PROTEIN,(E3-independent) E2 ubiquitin-conjugating enzy...
1,(h115d)vhl35 peptide,CHEMICAL,A peptide vaccine derived from the von Hippel-...
2,"1,1-dimethylhydrazine",DRUG,"A clear, colorless, flammable, hygroscopic liq..."
3,"1,2-dimethylhydrazine",CHEMICAL,A compound used experimentally to induce tumor...
4,"1,25-dihydroxyvitamin d(3) 24-hydroxylase, mit...",PROTEIN,"1,25-dihydroxyvitamin D(3) 24-hydroxylase, mit..."
...,...,...,...
28127,zygomycosis,DISEASE,Any infection due to a fungus of the Zygomycot...
28128,zygomycota,ORGANISM,A phylum of fungi that are characterized by ve...
28129,zygosity,ORGANISM,"The genetic condition of a zygote, especially ..."
28130,zygote,CELL_COMPARTMENT,"The cell formed by the union of two gametes, e..."


### Generation of a co-occurrence graph

We first create a graph whose nodes are entities

In [7]:
graph = PandasPGFrame()
entity_nodes = mentions["entity"].unique()
graph.add_nodes(entity_nodes)
graph.add_node_types({n: "Entity" for n in entity_nodes})

entity_props = entity_data.rename(columns={"entity": "@id"}).set_index("@id")
graph.add_node_properties(entity_props["entity_type"], prop_type="category")
graph.add_node_properties(entity_props["definition"], prop_type="text")

In [8]:
paragraph_prop = pd.DataFrame({"paragraphs": mentions.groupby("entity").aggregate(set)["paragraph"]})
graph.add_node_properties(paragraph_prop, prop_type="category")

In [9]:
graph.nodes(raw_frame=True)

Unnamed: 0_level_0,@type,entity_type,definition,paragraphs
@id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
lithostathine-1-alpha,Entity,PROTEIN,"Lithostathine-1-alpha (166 aa, ~19 kDa) is enc...","{1, 2, 3, 18178, 195589, 104454, 88967, 104455..."
pulmonary,Entity,ORGAN,Relating to the lungs as the intended site of ...,"{1, 196612, 196613, 196614, 196621, 196623, 16..."
host,Entity,ORGANISM,An organism that nourishes and supports anothe...,"{1, 114689, 3, 221193, 180243, 180247, 28, 180..."
surfactant protein d measurement,Entity,PROTEIN,The determination of the amount of surfactant ...,"{145537, 2, 3, 4, 5, 6, 51202, 103939, 103940,..."
communication response,Entity,PATHWAY,A statement (either spoken or written) that is...,"{46592, 64000, 2, 28162, 166912, 226304, 88585..."
...,...,...,...,...
drug binding site,Entity,PATHWAY,The reactive parts of a macromolecule that dir...,"{225082, 225079}"
carbaril,Entity,CHEMICAL,A synthetic carbamate acetylcholinesterase inh...,"{225408, 225409, 225415, 225419, 225397}"
ny-eso-1 positive tumor cells present,Entity,CELL_TYPE,An indication that Cancer/Testis Antigen 1 exp...,"{225544, 226996}"
mustelidae,Entity,ORGANISM,Taxonomic family which includes the Ferret.,"{225901, 225903}"


For each node we will add the `frequency` property that counts the total number of paragraphs where the entity was mentioned.

In [10]:
frequencies = graph._nodes["paragraphs"].apply(len)
frequencies.name = "frequency"
graph.add_node_properties(frequencies)

In [11]:
graph.nodes(raw_frame=True)

Unnamed: 0_level_0,@type,entity_type,definition,paragraphs,frequency
@id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
lithostathine-1-alpha,Entity,PROTEIN,"Lithostathine-1-alpha (166 aa, ~19 kDa) is enc...","{1, 2, 3, 18178, 195589, 104454, 88967, 104455...",80
pulmonary,Entity,ORGAN,Relating to the lungs as the intended site of ...,"{1, 196612, 196613, 196614, 196621, 196623, 16...",8295
host,Entity,ORGANISM,An organism that nourishes and supports anothe...,"{1, 114689, 3, 221193, 180243, 180247, 28, 180...",2660
surfactant protein d measurement,Entity,PROTEIN,The determination of the amount of surfactant ...,"{145537, 2, 3, 4, 5, 6, 51202, 103939, 103940,...",268
communication response,Entity,PATHWAY,A statement (either spoken or written) that is...,"{46592, 64000, 2, 28162, 166912, 226304, 88585...",160
...,...,...,...,...,...
drug binding site,Entity,PATHWAY,The reactive parts of a macromolecule that dir...,"{225082, 225079}",2
carbaril,Entity,CHEMICAL,A synthetic carbamate acetylcholinesterase inh...,"{225408, 225409, 225415, 225419, 225397}",5
ny-eso-1 positive tumor cells present,Entity,CELL_TYPE,An indication that Cancer/Testis Antigen 1 exp...,"{225544, 226996}",2
mustelidae,Entity,ORGANISM,Taxonomic family which includes the Ferret.,"{225901, 225903}",2


Now, for constructing co-occurrence network we will select only 1000 most frequent entities.

In [12]:
nodes_to_include = graph._nodes.nlargest(1000, "frequency").index

In [13]:
nodes_to_include

Index(['covid-19', 'blood', 'human', 'infectious disorder', 'heart',
       'diabetes mellitus', 'lung', 'sars-cov-2', 'mouse', 'pulmonary',
       ...
       'wheezing', 'chief complaint', 'azathioprine', 'ileum', 'hematology',
       'nonalcoholic steatohepatitis', 'nervous system disorder',
       'renal impairment', 'urticaria', 'rectum'],
      dtype='object', name='@id', length=1000)

The `CooccurrenceGenerator` class allows us to generate co-occurrence edges from overlaps in node property values or edge (or edge properties). In this case we consider the `paragraph` node property and construct co-occurrence edges from overlapping sets of paragraphs. In addition, we will compute some co-occurrence statistics: total co-occurrence frequency and normalized pointwise mutual information (NPMI).

In [14]:
%%time
generator = CooccurrenceGenerator(graph.subgraph(nodes=nodes_to_include))
paragraph_cooccurrence_edges = generator.generate_from_nodes(
    "paragraphs", total_factor_instances=number_of_paragraphs,
    compute_statistics=["frequency", "npmi"],
    parallelize=True, cores=8)

Examining 499500 pairs of terms for co-occurrence...
CPU times: user 7.51 s, sys: 2.12 s, total: 9.63 s
Wall time: 1min 18s


In [15]:
cutoff = paragraph_cooccurrence_edges["npmi"].mean()

In [16]:
paragraph_cooccurrence_edges = paragraph_cooccurrence_edges[paragraph_cooccurrence_edges["npmi"] > cutoff]

We add generated edges to the original graph

In [17]:
graph._edges = paragraph_cooccurrence_edges
graph.edge_prop_as_numeric("frequency")
graph.edge_prop_as_numeric("npmi")

In [18]:
graph.edges(raw_frame=True)

Unnamed: 0_level_0,Unnamed: 1_level_0,common_factors,frequency,npmi
@source_id,@target_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
surfactant protein d measurement,microorganism,"{2, 3, 7810, 58, 41, 7754, 7850, 26218, 7853, ...",19,0.235263
surfactant protein d measurement,lung,"{2, 103939, 51202, 5, 4, 103940, 15, 145438, 3...",93,0.221395
surfactant protein d measurement,alveolar,"{223872, 2, 51202, 100502, 7831, 149657, 19522...",25,0.336175
surfactant protein d measurement,epithelial cell,"{2, 4, 5, 222298, 7825, 7732, 7733, 169174, 7738}",9,0.175923
surfactant protein d measurement,molecule,"{2, 7750, 49991, 134504, 206448, 49, 52, 20645...",10,0.113611
...,...,...,...,...
severe acute respiratory syndrome,caax prenyl protease 2,"{205345, 185829, 227486, 220124, 220126}",5,0.142611
severe acute respiratory syndrome,transmembrane protease serine 2,"{223746, 223747, 167301, 223752, 200971, 22375...",21,0.238160
ciliated bronchial epithelial cell,cystic fibrosis pulmonary exacerbation,{46779},1,0.088963
ciliated bronchial epithelial cell,caax prenyl protease 2,"{215748, 220047}",2,0.151639


Recall that we have generated edges only for the 1000 most frequent entities, the rest of the entities will be isolated (having no incident edges). Let us remove all the isolated nodes.

In [19]:
graph.remove_isolated_nodes()

In [20]:
graph.number_of_nodes()

1000

Next, we save the generated co-occurrence graph.

In [21]:
# graph.to_csv("data/graph_nodes.csv", "data/graph_edges.csv",)

In [22]:
# graph = PandasPGFrame.from_csv(
#     "data/graph_nodes.csv", "data/graph_edges.csv",
#     node_property_types={
#         "@type": "category",
#         "entity_type": "category",
#         "definition": "text",
#         "paragraphs": "category",
#         "frequency": "numeric"
#     },
#     edge_property_types={
#         "common_factors": "category",
#         "frequency": "numeritc",
#         "ppmi": "numeric",
#         "npmi": "numeric"
#     })

### Node feature extraction

We extract node features from entity definitions using the `tfidf` model.

In [23]:
encoder = ScikitLearnPGEncoder(text_encoding_max_dimension=512)

In [24]:
%%time
transformed_graph = encoder.fit_transform(
    graph, node_properties=["definition"], edge_properties=None)

CPU times: user 666 ms, sys: 18.9 ms, total: 685 ms
Wall time: 690 ms


We can have a glance at the vocabulary that the encoder constructed for the 'definition' property

In [60]:
vocabulary = encoder._node_encoders["definition"].vocabulary_
list(vocabulary.keys())[:10]

['relating',
 'lungs',
 'site',
 'administration',
 'product',
 'usually',
 'action',
 'lower',
 'respiratory',
 'tract']

We will add additional properties to our transformed graph corresponding to the entity type labels. We will also add NPMI as an edge property to this transformed graph.

In [61]:
transformed_graph.add_node_properties(
    graph.get_node_property_values("entity_type"))
transformed_graph.add_edge_properties(
    graph.get_edge_property_values("npmi"), prop_type="numeric")

In [62]:
transformed_graph.nodes(raw_frame=True)

Unnamed: 0_level_0,features,@type,entity_type,node2vec,attri2vec,gcn_dgi
@id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
pulmonary,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGAN,"[0.1292733997106552, -0.4646681547164917, 0.22...","[0.011593103408813477, 0.01611831784248352, 0....","[0.012593621388077736, 0.0, 0.0323212780058383..."
host,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM,"[0.08811916410923004, -0.2682400643825531, 0.5...","[0.03360855579376221, 0.04184946417808533, 0.0...","[0.010720414109528065, 0.0, 0.0392156355082988..."
surfactant protein d measurement,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,PROTEIN,"[0.05927851423621178, -0.3263268768787384, 0.5...","[0.04456901550292969, 0.013118952512741089, 0....","[0.0, 0.0, 0.030114546418190002, 0.0, 0.020363..."
microorganism,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM,"[0.22872965037822723, -0.3821519613265991, 0.4...","[0.07238531112670898, 0.10694733262062073, 0.0...","[0.03063371405005455, 0.0, 0.04631862416863441..."
lung,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGAN,"[0.09093325585126877, -0.30084678530693054, 0....","[0.012629806995391846, 0.016663789749145508, 0...","[0.0033380892127752304, 0.0, 0.035439349710941..."
...,...,...,...,...,...,...
candida parapsilosis,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM,"[0.19119645655155182, -0.4822970926761627, 0.3...","[0.04431799054145813, 0.08474072813987732, 0.0...","[0.02675497531890869, 0.0, 0.03201226517558098..."
ciliated bronchial epithelial cell,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,CELL_TYPE,"[0.07605913281440735, -0.32774993777275085, 0....","[0.009401559829711914, 0.007941067218780518, 0...","[0.0, 0.0, 0.03129826858639717, 0.0, 0.0218690..."
cystic fibrosis pulmonary exacerbation,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,DISEASE,"[0.10934313386678696, -0.3377675712108612, 0.3...","[0.04671040177345276, 0.0782189667224884, 0.07...","[0.01399250142276287, 0.0, 0.03161218762397766..."
caax prenyl protease 2,"[0.0, 0.0, 0.3198444339599345, 0.0, 0.0, 0.0, ...",Entity,PROTEIN,"[0.06578871607780457, -0.320480078458786, 0.45...","[0.020859450101852417, 0.008496463298797607, 0...","[0.0, 0.0, 0.029353361576795578, 0.0, 0.022348..."


## Node embedding and downstream tasks

### Node embedding using StellarGraph

Using `StellarGraphNodeEmbedder` we construct three different embeddings of our transformed graph corresponding to different embedding techniques.

In [29]:
from bluegraph.backends.stellargraph import StellarGraphNodeEmbedder

In [30]:
embedder = StellarGraphNodeEmbedder(
    "node2vec", edge_weight="npmi", embedding_dimension=64, length=10, number_of_walks=20)
node2vec_embedding = embedder.fit_model(transformed_graph)

In [31]:
embedder = StellarGraphNodeEmbedder(
    "attri2vec", feature_vector_prop="features",
    length=5, number_of_walks=10,
    epochs=10, embedding_dimension=128, edge_weight="npmi")
attri2vec_embedding = embedder.fit_model(transformed_graph)

link_classification: using 'ip' method to combine node embeddings into edge embeddings


In [32]:
embedder = StellarGraphNodeEmbedder(
    "gcn_dgi", feature_vector_prop="features", epochs=250, embedding_dimension=512)
gcn_dgi_embedding = embedder.fit_model(transformed_graph)

Using GCN (local pooling) filters...


The `fit_model` methods produces a dataframe of the following shape

In [64]:
gcn_dgi_embedding

Unnamed: 0,embedding
pulmonary,"[0.012593621388077736, 0.0, 0.0323212780058383..."
host,"[0.010720414109528065, 0.0, 0.0392156355082988..."
surfactant protein d measurement,"[0.0, 0.0, 0.030114546418190002, 0.0, 0.020363..."
microorganism,"[0.03063371405005455, 0.0, 0.04631862416863441..."
lung,"[0.0033380892127752304, 0.0, 0.035439349710941..."
...,...
candida parapsilosis,"[0.02675497531890869, 0.0, 0.03201226517558098..."
ciliated bronchial epithelial cell,"[0.0, 0.0, 0.03129826858639717, 0.0, 0.0218690..."
cystic fibrosis pulmonary exacerbation,"[0.01399250142276287, 0.0, 0.03161218762397766..."
caax prenyl protease 2,"[0.0, 0.0, 0.029353361576795578, 0.0, 0.022348..."


Let us add the embedding vectors obtained using different models as node properties of our graph.

In [65]:
transformed_graph.add_node_properties(
    node2vec_embedding.rename(columns={"embedding": "node2vec"}))

In [66]:
transformed_graph.add_node_properties(
    attri2vec_embedding.rename(columns={"embedding": "attri2vec"}))

In [67]:
transformed_graph.add_node_properties(
    gcn_dgi_embedding.rename(columns={"embedding": "gcn_dgi"}))

In [68]:
transformed_graph.nodes(raw_frame=True)

Unnamed: 0_level_0,features,@type,entity_type,node2vec,attri2vec,gcn_dgi
@id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
pulmonary,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGAN,"[0.1292733997106552, -0.4646681547164917, 0.22...","[0.011593103408813477, 0.01611831784248352, 0....","[0.012593621388077736, 0.0, 0.0323212780058383..."
host,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM,"[0.08811916410923004, -0.2682400643825531, 0.5...","[0.03360855579376221, 0.04184946417808533, 0.0...","[0.010720414109528065, 0.0, 0.0392156355082988..."
surfactant protein d measurement,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,PROTEIN,"[0.05927851423621178, -0.3263268768787384, 0.5...","[0.04456901550292969, 0.013118952512741089, 0....","[0.0, 0.0, 0.030114546418190002, 0.0, 0.020363..."
microorganism,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM,"[0.22872965037822723, -0.3821519613265991, 0.4...","[0.07238531112670898, 0.10694733262062073, 0.0...","[0.03063371405005455, 0.0, 0.04631862416863441..."
lung,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGAN,"[0.09093325585126877, -0.30084678530693054, 0....","[0.012629806995391846, 0.016663789749145508, 0...","[0.0033380892127752304, 0.0, 0.035439349710941..."
...,...,...,...,...,...,...
candida parapsilosis,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM,"[0.19119645655155182, -0.4822970926761627, 0.3...","[0.04431799054145813, 0.08474072813987732, 0.0...","[0.02675497531890869, 0.0, 0.03201226517558098..."
ciliated bronchial epithelial cell,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,CELL_TYPE,"[0.07605913281440735, -0.32774993777275085, 0....","[0.009401559829711914, 0.007941067218780518, 0...","[0.0, 0.0, 0.03129826858639717, 0.0, 0.0218690..."
cystic fibrosis pulmonary exacerbation,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,DISEASE,"[0.10934313386678696, -0.3377675712108612, 0.3...","[0.04671040177345276, 0.0782189667224884, 0.07...","[0.01399250142276287, 0.0, 0.03161218762397766..."
caax prenyl protease 2,"[0.0, 0.0, 0.3198444339599345, 0.0, 0.0, 0.0, ...",Entity,PROTEIN,"[0.06578871607780457, -0.320480078458786, 0.45...","[0.020859450101852417, 0.008496463298797607, 0...","[0.0, 0.0, 0.029353361576795578, 0.0, 0.022348..."


### Node similarity

We would like to be able to search for similar nodes using the computed vector embeddings. For this we can use the `NodeSimilarityProcessor` interfaces provided as a part of `bluegraph`.

In [69]:
import numpy as np

from bluegraph.downstream.similarity import NodeSimilarityProcessor

We construct similarity processors for different embeddings and query top 10 most similar nodes to the terms `glucose` and `covid-19`.

In [70]:
node2vec_l2 = NodeSimilarityProcessor(transformed_graph, "node2vec", similarity="euclidean")
node2vec_cosine = NodeSimilarityProcessor(
    transformed_graph, "node2vec", similarity="cosine")

In [71]:
node2vec_l2.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 0.0,
  'high density lipoprotein': 0.009419716,
  'nonalcoholic fatty liver disease': 0.013079917,
  'metabolic disorder': 0.015094419,
  'insulin': 0.02462136,
  'metabolic syndrome': 0.028628448,
  'hyperglycemia': 0.0542507,
  'organic phosphate': 0.061695296,
  'dopamine': 0.06182158,
  'respiration': 0.062377073},
 'covid-19': {'covid-19': 0.0,
  'systemic inflammatory response syndrome': 0.04765156,
  'person': 0.049512457,
  'coronavirus': 0.052711684,
  'sterile': 0.05350962,
  'middle east respiratory syndrome': 0.07760977,
  'hydroxychloroquine': 0.07878502,
  'gas exchanger device': 0.080798015,
  'severe acute respiratory syndrome': 0.10219382,
  'hypoglycemia': 0.10504973}}

In [72]:
node2vec_cosine.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 1.0,
  'high density lipoprotein': 0.99822766,
  'nonalcoholic fatty liver disease': 0.9973357,
  'metabolic disorder': 0.996855,
  'metabolic syndrome': 0.995108,
  'insulin': 0.9948307,
  'hyperglycemia': 0.9890145,
  'respiration': 0.9884881,
  'dopamine': 0.98723495,
  'metformin': 0.9871816},
 'covid-19': {'covid-19': 1.0,
  'coronavirus': 0.99175423,
  'hydroxychloroquine': 0.98977274,
  'systemic inflammatory response syndrome': 0.98943216,
  'person': 0.9885227,
  'sterile': 0.98816544,
  'fatal': 0.98755616,
  'middle east respiratory syndrome': 0.98249936,
  'tidal volume': 0.9820141,
  'gas exchanger device': 0.98122466}}

In [73]:
attri2vec_l2 = NodeSimilarityProcessor(transformed_graph, "attri2vec")
attri2vec_cosine = NodeSimilarityProcessor(
    transformed_graph, "attri2vec", similarity="cosine")

In [74]:
attri2vec_l2.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 0.0,
  'pericyte': 0.010421049,
  'cell': 0.011523034,
  'immunoglobulin': 0.011836685,
  'sodium chloride': 0.012092523,
  'serine protease': 0.012353379,
  'vitamin': 0.01254856,
  'tissue': 0.0129674245,
  'animal': 0.013297284,
  'hemoglobin': 0.013491765},
 'covid-19': {'covid-19': 0.0,
  'autoimmune disease': 0.0005344469,
  'pleural effusion': 0.0005353675,
  'chronic obstructive pulmonary disease': 0.00054345716,
  'osteoporosis': 0.00061592436,
  'vasculitis': 0.0006377298,
  'neuropathy': 0.0006947107,
  'pulmonary edema': 0.0007101719,
  'liver failure': 0.0007402084,
  'anemia': 0.00085831934}}

In [75]:
attri2vec_cosine.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 1.0,
  'aldosterone': 0.96977174,
  'food': 0.96892774,
  'high sensitivity c-reactive protein measurement': 0.96874994,
  'electrolytes': 0.96874887,
  'cell': 0.96841586,
  'macrolide': 0.9680135,
  'pericyte': 0.96730614,
  'vitamin': 0.9654966,
  'glycoprotein': 0.9653116},
 'covid-19': {'covid-19': 0.99999994,
  'middle east respiratory syndrome': 0.9832929,
  'septicemia': 0.975744,
  'childhood-onset systemic lupus erythematosus': 0.9735457,
  'severe acute respiratory syndrome': 0.97291094,
  'infertility': 0.9713731,
  'cystic fibrosis pulmonary exacerbation': 0.968078,
  'pulmonary': 0.96767455,
  'h1n1 influenza': 0.9663221,
  'allergic rhinitis': 0.96605027}}

In [76]:
gcn_l2 = NodeSimilarityProcessor(transformed_graph, "gcn_dgi")
gcn_cosine = NodeSimilarityProcessor(
    transformed_graph, "gcn_dgi", similarity="cosine")

In [77]:
gcn_l2.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 0.0,
  'insulin': 0.0033723912,
  'glucose tolerance test': 0.0035072344,
  'triglycerides': 0.0036036542,
  'high density lipoprotein': 0.0037229697,
  'cholesterol': 0.0054044337,
  'uric acid': 0.0056960755,
  'organic phosphate': 0.0057665086,
  'proteinuria': 0.0062844246,
  'low density lipoprotein': 0.006509323},
 'covid-19': {'covid-19': 0.0,
  'coronavirus': 0.0009918846,
  'acute respiratory distress syndrome': 0.0022277941,
  'fatal': 0.0024863742,
  'myocarditis': 0.0035599233,
  'angiotensin ii receptor antagonist': 0.0038421175,
  'cardiac valve injury': 0.0039954726,
  'sars-cov-2': 0.005124502,
  'diabetes mellitus': 0.0051375376,
  'severe acute respiratory syndrome': 0.005171085}}

In [78]:
gcn_cosine.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 1.0,
  'triglycerides': 0.98296165,
  'insulin': 0.9824796,
  'high density lipoprotein': 0.9789133,
  'cholesterol': 0.97885364,
  'glucose tolerance test': 0.9785327,
  'low density lipoprotein': 0.9738021,
  'plasma': 0.97061884,
  'glomerulus': 0.9683418,
  'calcium': 0.9677339},
 'covid-19': {'covid-19': 1.0,
  'coronavirus': 0.99588025,
  'acute respiratory distress syndrome': 0.9908408,
  'fatal': 0.9901426,
  'angiotensin ii receptor antagonist': 0.9854746,
  'myocarditis': 0.9845337,
  'cardiac valve injury': 0.9828098,
  'cardiovascular disorder': 0.98031616,
  'severe acute respiratory syndrome': 0.9791422,
  'sars-cov-2': 0.97871155}}

### Node classification

Another downstream task that we would like to perform is node classification. We would like to automatically assign entity types according to their node embeddings. For this we will build predictive models for entity type prediction based on:

- Only node features
- Node2vec embeddings (only structure)
- Attri2vec embeddings (structure and node features)
- GCN Deep Graph Infomax embeddings (structure and node features)

In [47]:
from bluegraph.downstream.node_classification import NodeClassifier
from bluegraph.downstream.benchmark import get_classification_scores

from sklearn import model_selection
from sklearn.svm import LinearSVC

First of all, we split the graph nodes into the train and the test sets.

In [79]:
train_nodes, test_nodes = model_selection.train_test_split(
    transformed_graph.nodes(), train_size=0.8)

Now we use the `NodeClassifier` interface to create our classification models. As the base model we will use the linear SVM classifier (`LinearSVC`) provided by `scikit-learn`.

In [87]:
features_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="features")
features_classifier.fit(transformed_graph, train_elements=train_nodes, label_prop="entity_type")
features_pred = features_classifier.predict(transformed_graph, predict_elements=test_nodes)

In [82]:
node2vec_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="node2vec")
node2vec_classifier.fit(transformed_graph, train_elements=train_nodes, label_prop="entity_type")
node2vec_pred = node2vec_classifier.predict(transformed_graph, predict_elements=test_nodes)

In [83]:
attri2vec_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="attri2vec")
attri2vec_classifier.fit(transformed_graph, train_elements=train_nodes, label_prop="entity_type")
attri2vec_pred = attri2vec_classifier.predict(transformed_graph, predict_elements=test_nodes)



In [84]:
gcn_dgi_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="gcn_dgi")
gcn_dgi_classifier.fit(transformed_graph, train_elements=train_nodes, label_prop="entity_type")
gcn_dgi_pred = gcn_dgi_classifier.predict(transformed_graph, predict_elements=test_nodes)

Let us have a look at the scores of different node classification models we have produced.

In [86]:
true_labels = transformed_graph._nodes.loc[test_nodes, "entity_type"]

In [132]:
get_classification_scores(true_labels, features_pred, multiclass=True)

{'accuracy': 0.545,
 'precision': 0.545,
 'recall': 0.545,
 'f1_score': 0.545,
 'roc_auc_score': 0.761836169848989}

In [133]:
get_classification_scores(true_labels, node2vec_pred, multiclass=True)

{'accuracy': 0.375,
 'precision': 0.375,
 'recall': 0.375,
 'f1_score': 0.375,
 'roc_auc_score': 0.6977021844962854}

In [134]:
get_classification_scores(true_labels, attri2vec_pred, multiclass=True)

{'accuracy': 0.46,
 'precision': 0.46,
 'recall': 0.46,
 'f1_score': 0.46,
 'roc_auc_score': 0.728145074309897}

In [135]:
get_classification_scores(true_labels, gcn_dgi_pred, multiclass=True)

{'accuracy': 0.395,
 'precision': 0.395,
 'recall': 0.395,
 'f1_score': 0.395,
 'roc_auc_score': 0.7043840332385899}

## Link prediction

Finally, we would like to use the produced node embeddings to predict the existance of edges. This downstream task is formulated as follows: given a pair of nodes and their embedding vectors, is there an edge between these nodes?

In [53]:
from bluegraph.downstream.link_prediction import (generate_negative_edges,
                                                  EdgePredictor)

As the first step of the edges prediciton task we will generate false edges for training (node pairs that don't have edges between them).

In [54]:
false_edges = generate_negative_edges(transformed_graph)

We will now split both true and false edges into training and test sets.

In [55]:
true_train_edges, true_test_edges = model_selection.train_test_split(
    transformed_graph.edges(), train_size=0.8)

In [56]:
false_train_edges, false_test_edges = model_selection.train_test_split(
    false_edges, train_size=0.8)

And, finally, we will use the `EdgePredictor` interface to build our model (using `LinearSVC` as before and the Hadamard product as the binary operator between the embedding vectors for the source and the target nodes.

In [136]:
model = EdgePredictor(LinearSVC(), feature_vector_prop="node2vec",
                      operator="hadamard", directed=False)
model.fit(transformed_graph, true_train_edges, negative_samples=false_train_edges)

In [137]:
true_labels = np.hstack([
    np.ones(len(true_test_edges)),
    np.zeros(len(false_test_edges))])

In [138]:
y_pred = model.predict(transformed_graph, true_test_edges + false_test_edges)

Let us have a look at the obtained scores.

In [140]:
get_classification_scores(true_labels, y_pred)

TypeError: 'numpy.float64' object is not iterable