In [1]:
import pandas as pd

from bluegraph.core import PandasPGFrame
from bluegraph.preprocess.generators import CooccurrenceGenerator
from bluegraph.preprocess.encoders import ScikitLearnPGEncoder

# Data preparation

Fist, we read the source dataset with mentions of entities in different paragraphs

In [2]:
mentions = pd.read_csv("data/labeled_entity_occurrence.csv")

In [3]:
# Extract unique paper/seciton/paragraph identifiers
mentions = mentions.rename(columns={"occurrence": "paragraph"})
number_of_paragraphs = len(mentions["paragraph"].unique())

In [4]:
mentions

Unnamed: 0,entity,paragraph
0,lithostathine-1-alpha,1
1,pulmonary,1
2,host,1
3,lithostathine-1-alpha,2
4,surfactant protein d measurement,2
...,...,...
2281346,covid-19,227822
2281347,covid-19,227822
2281348,viral infection,227823
2281349,lipid,227823


We will also load a dataset that contains definitions of entities and their types

In [5]:
entity_data = pd.read_csv("data/entity_types_defs.csv")

In [6]:
entity_data

Unnamed: 0,entity,entity_type,definition
0,(e3-independent) e2 ubiquitin-conjugating enzyme,PROTEIN,(E3-independent) E2 ubiquitin-conjugating enzy...
1,(h115d)vhl35 peptide,CHEMICAL,A peptide vaccine derived from the von Hippel-...
2,"1,1-dimethylhydrazine",DRUG,"A clear, colorless, flammable, hygroscopic liq..."
3,"1,2-dimethylhydrazine",CHEMICAL,A compound used experimentally to induce tumor...
4,"1,25-dihydroxyvitamin d(3) 24-hydroxylase, mit...",PROTEIN,"1,25-dihydroxyvitamin D(3) 24-hydroxylase, mit..."
...,...,...,...
28127,zygomycosis,DISEASE,Any infection due to a fungus of the Zygomycot...
28128,zygomycota,ORGANISM,A phylum of fungi that are characterized by ve...
28129,zygosity,ORGANISM,"The genetic condition of a zygote, especially ..."
28130,zygote,CELL_COMPARTMENT,"The cell formed by the union of two gametes, e..."


## Generating a co-occurrence graph

We first create a graph whose nodes are entities

In [7]:
graph = PandasPGFrame()
entity_nodes = mentions["entity"].unique()
graph.add_nodes(entity_nodes)
graph.add_node_types({n: "Entity" for n in entity_nodes})

entity_props = entity_data.rename(columns={"entity": "@id"}).set_index("@id")
graph.add_node_properties(entity_props["entity_type"], prop_type="category")
graph.add_node_properties(entity_props["definition"], prop_type="text")

In [8]:
paragraph_prop = pd.DataFrame({"paragraphs": mentions.groupby("entity").aggregate(set)["paragraph"]})
graph.add_node_properties(paragraph_prop, prop_type="category")

For each node we will add the `frequency` property that counts the total number of paragraphs where the entity was mentioned.

In [9]:
frequencies = graph._nodes["paragraphs"].apply(len)
frequencies.name = "frequency"
graph.add_node_properties(frequencies)

In [10]:
graph.nodes(raw_frame=True)

Unnamed: 0_level_0,@type,entity_type,definition,paragraphs,frequency
@id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
lithostathine-1-alpha,Entity,PROTEIN,"Lithostathine-1-alpha (166 aa, ~19 kDa) is enc...","{1, 2, 3, 18178, 195589, 104454, 88967, 104455...",80
pulmonary,Entity,ORGAN,Relating to the lungs as the intended site of ...,"{1, 196612, 196613, 196614, 196621, 196623, 16...",8295
host,Entity,ORGANISM,An organism that nourishes and supports anothe...,"{1, 114689, 3, 221193, 180243, 180247, 28, 180...",2660
surfactant protein d measurement,Entity,PROTEIN,The determination of the amount of surfactant ...,"{145537, 2, 3, 4, 5, 6, 51202, 103939, 103940,...",268
communication response,Entity,PATHWAY,A statement (either spoken or written) that is...,"{46592, 64000, 2, 28162, 166912, 226304, 88585...",160
...,...,...,...,...,...
drug binding site,Entity,PATHWAY,The reactive parts of a macromolecule that dir...,"{225082, 225079}",2
carbaril,Entity,CHEMICAL,A synthetic carbamate acetylcholinesterase inh...,"{225408, 225409, 225415, 225419, 225397}",5
ny-eso-1 positive tumor cells present,Entity,CELL_TYPE,An indication that Cancer/Testis Antigen 1 exp...,"{225544, 226996}",2
mustelidae,Entity,ORGANISM,Taxonomic family which includes the Ferret.,"{225901, 225903}",2


Now, for constructing co-occurrence network we will select only 1000 most frequent entities.

In [11]:
nodes_to_include = graph._nodes.nlargest(2000, "frequency").index

In [12]:
nodes_to_include

Index(['covid-19', 'blood', 'human', 'infectious disorder', 'heart',
       'diabetes mellitus', 'lung', 'sars-cov-2', 'mouse', 'pulmonary',
       ...
       'hepatitis b virus e antigen measurement', 'transthyretin',
       'laser speckle imaging', 'expiration', 'vascular smooth muscle tissue',
       'human acellular dermal matrix', 'natural product', 'organic',
       'embolism', 'hepatosplenomegaly'],
      dtype='object', name='@id', length=2000)

The `CooccurrenceGenerator` class allows us to generate co-occurrence edges from overlaps in node property values or edge (or edge properties). In this case we consider the `paragraph` node property and construct co-occurrence edges from overlapping sets of paragraphs. In addition, we will compute some co-occurrence statistics: total co-occurrence frequency and normalized pointwise mutual information (NPMI).

In [14]:
%%time
generator = CooccurrenceGenerator(graph.subgraph(nodes=nodes_to_include))
paragraph_cooccurrence_edges = generator.generate_from_nodes(
    "paragraphs", total_factor_instances=number_of_paragraphs,
    compute_statistics=["frequency", "npmi"],
    parallelize=True, cores=8)

Examining 1999000 pairs of terms for co-occurrence...
CPU times: user 20.9 s, sys: 6.35 s, total: 27.2 s
Wall time: 4min 45s


We add generated edges to the original graph

In [15]:
graph._edges = paragraph_cooccurrence_edges
graph.edge_prop_as_numeric("frequency")
graph.edge_prop_as_numeric("npmi")

In [16]:
graph.edges(raw_frame=True)

Unnamed: 0_level_0,Unnamed: 1_level_0,common_factors,frequency,npmi
@source_id,@target_id,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
surfactant protein d measurement,communication response,"{2, 7789}",2,0.197223
surfactant protein d measurement,microorganism,"{2, 3, 7810, 58, 41, 7754, 7850, 26218, 7853, ...",19,0.235263
surfactant protein d measurement,organic,{2},1,0.163099
surfactant protein d measurement,lung,"{2, 103939, 51202, 5, 4, 103940, 15, 145438, 3...",93,0.221395
surfactant protein d measurement,alveolar,"{223872, 2, 51202, 100502, 7831, 149657, 19522...",25,0.336175
...,...,...,...,...
cystic fibrosis pulmonary exacerbation,pancreatic insufficiency,"{101484, 144333, 46991, 145747, 46869, 151581}",6,0.290745
cystic fibrosis pulmonary exacerbation,heparin-binding egf-like growth factor,"{145811, 149909, 46517, 150136, 46365, 144255}",6,0.330128
cystic fibrosis pulmonary exacerbation,dornase alfa,"{50948, 47114, 148330, 144113, 153335}",5,0.332547
transmembrane protease serine 2,betacoronavirus,"{188257, 188475, 198870, 204982}",4,0.238142


Recall that we have generated edges only for the 1000 most frequent entities, the rest of the entities will be isolated (having no incident edges). Let us remove all the isolated nodes.

In [17]:
graph.remove_isolated_nodes()

Next, we save the generated co-occurrence graph.

In [18]:
# graph.to_csv("data/graph_nodes.csv", "data/graph_edges.csv",)

In [19]:
# graph = PandasPGFrame.from_csv(
#     "data/graph_nodes.csv", "data/graph_edges.csv",
#     node_property_types={
#         "@type": "category",
#         "entity_type": "category",
#         "definition": "text",
#         "paragraphs": "category",
#         "frequency": "numeric"
#     },
#     edge_property_types={
#         "common_factors": "category",
#         "frequency": "numeritc",
#         "ppmi": "numeric",
#         "npmi": "numeric"
#     })

## Extract node features

We extract node features from entity definitions using the `tfidf` model.

In [20]:
encoder = ScikitLearnPGEncoder(text_encoding_max_dimension=512)

In [21]:
%%time
transformed_graph = encoder.fit_transform(
    graph, node_properties=["definition"], edge_properties=None)

CPU times: user 2.22 s, sys: 117 ms, total: 2.34 s
Wall time: 2.37 s


In [22]:
vocabulary = encoder._node_encoders["definition"].vocabulary_
list(vocabulary.keys())[:10]

['relating',
 'lungs',
 'site',
 'administration',
 'product',
 'usually',
 'action',
 'lower',
 'respiratory',
 'tract']

We will add additional properties to our transformed graph corresponding to the entity type labels. We will also add NPMI as an edge property to this transformed graph.

In [23]:
transformed_graph.add_node_properties(
    graph.get_node_property_values("entity_type"))
transformed_graph.add_edge_properties(
    graph.get_edge_property_values("npmi"), prop_type="numeric")

In [24]:
transformed_graph.nodes(raw_frame=True)

Unnamed: 0_level_0,features,@type,entity_type
@id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
pulmonary,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGAN
host,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM
surfactant protein d measurement,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,PROTEIN
communication response,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,PATHWAY
microorganism,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM
...,...,...,...
cefoxitin,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,CHEMICAL
transmembrane protease serine 2,"[0.0, 0.0, 0.3570916937758588, 0.0, 0.0, 0.0, ...",Entity,PROTEIN
"intraoperative cardiac injury, ctcae","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,DISEASE
betacoronavirus,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM


## Embed using StellarGraph

Using `StellarGraphNodeEmbedder` we construct three different embeddings of our transformed graph.

In [25]:
from bluegraph.backends.stellargraph import StellarGraphNodeEmbedder

In [26]:
embedder = StellarGraphNodeEmbedder(
    "node2vec", edge_weight="npmi", embedding_dimension=128, length=10, number_of_walks=20)
node2vec_embedding = embedder.fit_model(transformed_graph)

In [76]:
embedder = StellarGraphNodeEmbedder(
    "attri2vec", feature_vector_prop="features",
    length=5, number_of_walks=10,
    epochs=10, embedding_dimension=128, edge_weight="npmi")
attri2vec_embedding = embedder.fit_model(transformed_graph)

link_classification: using 'ip' method to combine node embeddings into edge embeddings


In [77]:
embedder = StellarGraphNodeEmbedder(
    "gcn_dgi", feature_vector_prop="features", epochs=150, embedding_dimension=512)
gcn_dgi_embedding = embedder.fit_model(transformed_graph)

Using GCN (local pooling) filters...


In [78]:
transformed_graph.add_node_properties(
    node2vec_embedding.rename(columns={"embedding": "node2vec"}))

In [79]:
transformed_graph.add_node_properties(
    attri2vec_embedding.rename(columns={"embedding": "attri2vec"}))

In [80]:
transformed_graph.add_node_properties(
    gcn_dgi_embedding.rename(columns={"embedding": "gcn_dgi"}))

In [81]:
transformed_graph.nodes(raw_frame=True)

Unnamed: 0_level_0,features,@type,entity_type,node2vec,attri2vec,gcn_dgi
@id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
pulmonary,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGAN,"[0.08507618308067322, 0.08960522711277008, -0....","[2.2304715457721613e-05, 2.30722962442087e-05,...","[0.013598875142633915, 0.014510048553347588, 0..."
host,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM,"[0.06400765478610992, -0.03573942929506302, -0...","[0.00033861398696899414, 0.0005860328674316406...","[0.013612938113510609, 0.01374877616763115, 0...."
surfactant protein d measurement,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,PROTEIN,"[0.06652342528104782, -0.05284678190946579, -0...","[0.00031250715255737305, 0.0001552104949951172...","[0.004781435243785381, 0.011905567720532417, 0..."
communication response,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,PATHWAY,"[0.057319268584251404, -0.055733706802129745, ...","[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, ...","[0.006033382378518581, 0.011774810031056404, 0..."
microorganism,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM,"[0.06486348062753677, 0.0019589976873248816, -...","[0.0027934908866882324, 0.004592776298522949, ...","[0.014491668902337551, 0.012006293050944805, 0..."
...,...,...,...,...,...,...
cefoxitin,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,CHEMICAL,"[0.07369915395975113, 0.04778401181101799, -0....","[9.508418588666245e-05, 4.140875898883678e-05,...","[0.0046445345506072044, 0.005839885678142309, ..."
transmembrane protease serine 2,"[0.0, 0.0, 0.3570916937758588, 0.0, 0.0, 0.0, ...",Entity,PROTEIN,"[0.059876542538404465, -0.05344674736261368, -...","[0.001462697982788086, 0.0003876984119415283, ...","[0.006346818991005421, 0.013288427144289017, 0..."
"intraoperative cardiac injury, ctcae","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,DISEASE,"[0.07498086988925934, 0.03932490199804306, -0....","[0.0029019415378570557, 0.0014293193817138672,...","[0.00040178094059228897, 0.009683115407824516,..."
betacoronavirus,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",Entity,ORGANISM,"[0.06702052056789398, -0.004926824942231178, -...","[0.0006445348262786865, 0.0002875030040740967,...","[0.002051224932074547, 0.0083466786891222, 0.0..."


## Similar entities

In [82]:
import numpy as np

from bluegraph.downstream.similarity import NodeSimilarityProcessor

In [83]:
node2vec_l2 = NodeSimilarityProcessor(transformed_graph, "node2vec")
node2vec_cosine = NodeSimilarityProcessor(
    transformed_graph, "node2vec", similarity="cosine")

In [84]:
node2vec_l2.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 0.0,
  'high density lipoprotein': 0.0022658955,
  'diet': 0.0024289526,
  'serum ldl cholesterol measurement': 0.00281263,
  'angiotensin converting enzyme measurement': 0.002945846,
  'renin': 0.0029646917,
  'hypothalamus': 0.0030440816,
  'metformin': 0.0030828633,
  "parkinson's disease": 0.0031146465,
  'nadh dehydrogenase (ubiquinone)': 0.003123281},
 'covid-19': {'covid-19': 0.0,
  'coronavirus': 0.0029676454,
  'n-terminal fragment brain natriuretic protein': 0.0033119144,
  'bias': 0.0039801677,
  'antithrombin antigen measurement': 0.004200006,
  'brain death': 0.004505181,
  'genus': 0.0045362785,
  'coronary vasospasm': 0.004553576,
  'measles': 0.004568702,
  'campylobacter': 0.004868541}}

In [85]:
node2vec_cosine.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 1.0,
  'high density lipoprotein': 0.99952364,
  'diet': 0.999498,
  'thiazolidinedione antidiabetic agent': 0.99947554,
  'serum ldl cholesterol measurement': 0.9994713,
  'renin': 0.9994428,
  'dopamine hydrochloride': 0.99943995,
  'cholesterol to hdl-cholesterol ratio measurement': 0.9994321,
  'microcirculatory bed': 0.9994309,
  'nonalcoholic steatohepatitis': 0.9994266},
 'covid-19': {'covid-19': 1.0,
  'gas exchanger device': 0.99951154,
  'brain natriuretic peptide measurement': 0.9995111,
  'cardiac valve injury': 0.9995052,
  'tidal volume': 0.99949956,
  'lopinavir/ritonavir': 0.9994949,
  'coronavirus': 0.99948764,
  'diabetes mellitus': 0.9994817,
  'cholecystokinin': 0.999477,
  'antithrombin antigen measurement': 0.99947095}}

In [86]:
attri2vec_l2 = NodeSimilarityProcessor(transformed_graph, "attri2vec")
attri2vec_cosine = NodeSimilarityProcessor(
    transformed_graph, "attri2vec", similarity="cosine")

In [87]:
attri2vec_l2.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 0.0,
  'hyperplasia': 2.4351445e-07,
  'eo_disease_maps_to_human_disease': 2.4647102e-07,
  'citrulline': 2.487164e-07,
  'vasopressor': 2.4942256e-07,
  'ribonucleic acid': 2.836718e-07,
  'blood circulation': 3.014543e-07,
  'thrombosis': 3.2471385e-07,
  'pharmacologic substance': 3.367524e-07,
  'angiotensin ii receptor antagonist': 3.435445e-07},
 'covid-19': {'covid-19': 0.0,
  'organ': 8.51084e-09,
  'pollutant': 1.1228083e-08,
  'hypothalamus': 1.2533386e-08,
  'blood': 1.3337877e-08,
  'disease or disorder': 1.5125666e-08,
  'systemic inflammatory response syndrome': 1.5514285e-08,
  'mupirocin': 1.790136e-08,
  'chicken pox': 1.8030148e-08,
  'severe acute respiratory syndrome': 2.4289317e-08}}

In [88]:
attri2vec_cosine.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 1.0,
  'dofetilide': 0.9906268,
  'plaque': 0.9895834,
  'smooth muscle cell': 0.9882462,
  'blood product': 0.98774266,
  'leridistim': 0.9874199,
  'upper respiratory infection, ctcae': 0.98728484,
  'genetic disorder': 0.98699605,
  'dornase alfa': 0.9868901,
  'polyunsaturated fatty acid': 0.9867853},
 'covid-19': {'covid-19': 1.0,
  'arginine': 0.99533504,
  'pioglitazone': 0.99434143,
  'encephalitis': 0.99408466,
  'transcription factor': 0.99359196,
  'opioid': 0.99355924,
  'pulmonary edema': 0.99340224,
  'lower respiratory tract infection': 0.99314576,
  'hypogammaglobulinemia': 0.9927063,
  'cyclophosphamide': 0.99262846}}

In [89]:
gcn_l2 = NodeSimilarityProcessor(transformed_graph, "gcn_dgi")
gcn_cosine = NodeSimilarityProcessor(
    transformed_graph, "gcn_dgi", similarity="cosine")

In [90]:
gcn_l2.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 0.0,
  'diabetes mellitus': 1.2274763e-05,
  'plasma': 1.5114634e-05,
  'kidney': 1.7331444e-05,
  'survival': 2.1568154e-05,
  'inflammation': 2.3674924e-05,
  'cancer': 2.6763799e-05,
  'organ': 2.6941272e-05,
  'person': 2.7500597e-05,
  'tissue': 2.8594814e-05},
 'covid-19': {'covid-19': 0.0,
  'obesity': 7.531098e-05,
  'insulin': 0.00011627325,
  'cardiovascular disorder': 0.00013621274,
  'nuclear': 0.00017118,
  'fibrosis': 0.00018238342,
  'dysfunction': 0.00019133318,
  'calcium': 0.00020952553,
  'sars-cov-2': 0.00021393181,
  'interleukin-19': 0.0002337117}}

In [91]:
gcn_cosine.get_similar_nodes(["glucose", "covid-19"], k=10)

{'glucose': {'glucose': 1.0,
  'diabetes mellitus': 0.9997525,
  'plasma': 0.999693,
  'kidney': 0.999645,
  'survival': 0.99956703,
  'inflammation': 0.9995329,
  'cancer': 0.9994532,
  'organ': 0.99944824,
  'tissue': 0.99944377,
  'person': 0.9994389},
 'covid-19': {'covid-19': 1.0000001,
  'obesity': 0.99840355,
  'insulin': 0.99757445,
  'cardiovascular disorder': 0.9970801,
  'nuclear': 0.9966514,
  'fibrosis': 0.99604404,
  'dysfunction': 0.9958848,
  'calcium': 0.9954989,
  'sars-cov-2': 0.995387,
  'interleukin-19': 0.9948972}}

## Node classification

We will build a predictive model for entity type prediction based on:

- Only node features
- Node2vec embeddings (only structure)
- Attri2vec embeddings (structure and node features)
- GCN Deep Graph Infomax embeddings (structure and node features)

### Splitting the graph into train/test set

In [92]:
from sklearn import model_selection
from sklearn.svm import LinearSVC

In [93]:
transformed_graph._nodes["entity_type"]

@id
pulmonary                                  ORGAN
host                                    ORGANISM
surfactant protein d measurement         PROTEIN
communication response                   PATHWAY
microorganism                           ORGANISM
                                          ...   
cefoxitin                               CHEMICAL
transmembrane protease serine 2          PROTEIN
intraoperative cardiac injury, ctcae     DISEASE
betacoronavirus                         ORGANISM
favipiravir                                 DRUG
Name: entity_type, Length: 2000, dtype: object

In [123]:
class NodeClassifier(object):

    def __init__(self, model, feature_vector_prop=None, feature_props=None):
        self.model = model
        self.feature_vector_prop = feature_vector_prop
        self.feature_props = feature_props

    def _concatenate_feature_props(self, pgframe, nodes):
        if self.feature_props is None or len(self.feature_props) == 0:
            raise ValueError
        return pgframe.nodes(
            raw_frame=True).loc[nodes, self.feature_props].to_numpy()

    def fit(self, pgframe, train_nodes=None, labels=None, label_prop=None):
        # If no train nodes provided, use all nodes of the input graph
        if train_nodes is None:
            train_nodes = pgframe.nodes()
            
        # If no labels provided, try to use a label property
        if labels is None:
            if label_prop not in pgframe.node_properties():
                raise ValueError()
            labels = pgframe.get_node_property_values(
                label_prop).loc[train_nodes].tolist()
        
        # If no feature vector property provided, try to concatenate feature_props
        if self.feature_vector_prop:
            data = pgframe.get_node_property_values(
                self.feature_vector_prop).loc[train_nodes].tolist()
        else:
            data = self._concatenate_feature_props(pgframe, train_nodes)
        self.model.fit(data, labels)

    def predict(self, pgframe, predict_nodes=None):
        # If no prediction nodes provided, use all nodes of the input graph
        if predict_nodes is None:
            predict_nodes = pgframe.nodes()
        if self.feature_vector_prop:
            data = pgframe.get_node_property_values(
                self.feature_vector_prop).loc[predict_nodes].tolist()
        else:
            data = self._concatenate_feature_props(pgframe, predict_nodes)
        return self.model.predict(data)

Split the graph into 

In [124]:
train_nodes, test_nodes = model_selection.train_test_split(
    transformed_graph.nodes(), train_size=0.4)

In [132]:
features_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="features")
features_classifier.fit(transformed_graph, train_nodes=train_nodes, label_prop="entity_type")
pred_y = features_classifier.predict(transformed_graph, predict_nodes=test_nodes)
accuracy = (transformed_graph._nodes.loc[test_nodes, "entity_type"] == pred_y).mean()
print(accuracy)

0.5666666666666667


In [133]:
node2vec_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="node2vec")
node2vec_classifier.fit(transformed_graph, train_nodes=train_nodes, label_prop="entity_type")
pred_y = node2vec_classifier.predict(transformed_graph, predict_nodes=test_nodes)
accuracy = (transformed_graph._nodes.loc[test_nodes, "entity_type"] == pred_y).mean()
print(accuracy)

0.4


In [134]:
attri2vec_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="attri2vec")
attri2vec_classifier.fit(transformed_graph, train_nodes=train_nodes, label_prop="entity_type")
pred_y = attri2vec_classifier.predict(transformed_graph, predict_nodes=test_nodes)
accuracy = (transformed_graph._nodes.loc[test_nodes, "entity_type"] == pred_y).mean()
print(accuracy)

0.3458333333333333




In [135]:
gcn_dgi_classifier = NodeClassifier(LinearSVC(), feature_vector_prop="gcn_dgi")
gcn_dgi_classifier.fit(transformed_graph, train_nodes=train_nodes, label_prop="entity_type")
pred_y = gcn_dgi_classifier.predict(transformed_graph, predict_nodes=test_nodes)
accuracy = (transformed_graph._nodes.loc[test_nodes, "entity_type"] == pred_y).mean()
print(accuracy)

0.3883333333333333


## Split nodes into train/test set

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
def graph_train_test_split(pgframe, test_size=0.3, random_state=None):
    def edges_to_include(df):
        index = df.index
        selectors = list(
            map(lambda e: e[0] in index and e[1] in index, pgframe._edges.index))
        return pgframe._edges.index[selectors]
    
    train_nodes, test_nodes = train_test_split(pgframe._nodes, test_size=test_size)
    train_edges = edges_to_include(train_nodes)
    test_edges = edges_to_include(test_nodes)
    train_pgframe = PandasPGFrame.from_frames(
        train_nodes,
        pgframe._edges.loc[train_edges],
        node_prop_types=pgframe._node_prop_types,
        edge_prop_types=pgframe._edge_prop_types)
    test_pgframe = PandasPGFrame.from_frames(
        test_nodes,
        pgframe._edges.loc[test_edges],
        node_prop_types=pgframe._node_prop_types,
        edge_prop_types=pgframe._edge_prop_types)
    return train_pgframe, test_pgframe

In [None]:
train_graph, test_graph = graph_train_test_split(transformed_frame, test_size=0.2)

In [None]:
from blugraph.downstream.node_classification import NodeClassifier

## Link prediction

In [None]:
from bluegraph.backends.stellargraph import pgframe_to_stellargraph

In [None]:
stellar_object = pgframe_to_stellargraph(
    transformed_frame, feature_vector_prop="features", directed=False, edge_weight="npmi")

In [None]:
print(stellar_object.info())

In [None]:
train_targets = train_graph.get_node_property_values("entity_type")
test_targets = test_graph.get_node_property_values("entity_type")

In [None]:
train_targets.value_counts()

In [None]:
from sklearn import preprocessing

In [None]:
target_encoding = preprocessing.LabelBinarizer()

encoded_train_targets = target_encoding.fit_transform(train_targets)
encoded_test_targets = target_encoding.transform(test_targets)

In [105]:
from stellargraph.mapper import FullBatchNodeGenerator
from stellargraph.layer import GCN

from tensorflow.keras import layers, optimizers, losses, metrics, Model

In [106]:
generator = FullBatchNodeGenerator(stellar_object, method="gcn", weighted=True)

Using GCN (local pooling) filters...


In [107]:
train_gen = generator.flow(train_targets.index, encoded_train_targets)

In [132]:
gcn = GCN(layer_sizes=[16, 16], activations=["relu", "relu"], generator=generator, dropout=0.5)

In [124]:
x_inp, x_out = gcn.in_out_tensors()
x_out

<KerasTensor: shape=(1, None, 30) dtype=float32 (created by layer 'gather_indices_5')>

In [125]:
predictions = layers.Dense(units=encoded_train_targets.shape[1], activation="softmax")(x_out)

In [126]:
model = Model(inputs=x_inp, outputs=predictions)
model.compile(
    optimizer=optimizers.Adam(lr=0.01),
    loss=losses.categorical_crossentropy,
    metrics=["acc"],
)

In [127]:
model.fit(
    train_gen,
    epochs=400,
    verbose=2,
    shuffle=False,  # this should be False, since shuffling data means shuffling the whole graph
)

Epoch 1/400
1/1 - 1s - loss: 2.2013 - acc: 0.1167
Epoch 2/400
1/1 - 0s - loss: 2.1604 - acc: 0.2711
Epoch 3/400
1/1 - 0s - loss: 2.1143 - acc: 0.2698
Epoch 4/400
1/1 - 0s - loss: 2.0560 - acc: 0.2699
Epoch 5/400
1/1 - 0s - loss: 1.9881 - acc: 0.2699
Epoch 6/400
1/1 - 0s - loss: 1.9196 - acc: 0.2699
Epoch 7/400
1/1 - 0s - loss: 1.8688 - acc: 0.2699
Epoch 8/400
1/1 - 0s - loss: 1.8307 - acc: 0.2700
Epoch 9/400
1/1 - 0s - loss: 1.7918 - acc: 0.2743
Epoch 10/400
1/1 - 0s - loss: 1.7480 - acc: 0.3178
Epoch 11/400
1/1 - 0s - loss: 1.6999 - acc: 0.3978
Epoch 12/400
1/1 - 0s - loss: 1.6589 - acc: 0.4680
Epoch 13/400
1/1 - 0s - loss: 1.6288 - acc: 0.4852
Epoch 14/400
1/1 - 0s - loss: 1.5991 - acc: 0.4861
Epoch 15/400
1/1 - 0s - loss: 1.5726 - acc: 0.4884
Epoch 16/400
1/1 - 0s - loss: 1.5400 - acc: 0.4965
Epoch 17/400
1/1 - 0s - loss: 1.5057 - acc: 0.5094
Epoch 18/400
1/1 - 0s - loss: 1.4759 - acc: 0.5257
Epoch 19/400
1/1 - 0s - loss: 1.4429 - acc: 0.5369
Epoch 20/400
1/1 - 0s - loss: 1.4268 - a

Epoch 161/400
1/1 - 0s - loss: 1.0207 - acc: 0.6437
Epoch 162/400
1/1 - 0s - loss: 1.0101 - acc: 0.6442
Epoch 163/400
1/1 - 0s - loss: 1.0140 - acc: 0.6462
Epoch 164/400
1/1 - 0s - loss: 1.0169 - acc: 0.6467
Epoch 165/400
1/1 - 0s - loss: 1.0069 - acc: 0.6525
Epoch 166/400
1/1 - 0s - loss: 1.0056 - acc: 0.6492
Epoch 167/400
1/1 - 0s - loss: 1.0212 - acc: 0.6483
Epoch 168/400
1/1 - 0s - loss: 1.0162 - acc: 0.6464
Epoch 169/400
1/1 - 0s - loss: 1.0046 - acc: 0.6498
Epoch 170/400
1/1 - 0s - loss: 1.0162 - acc: 0.6467
Epoch 171/400
1/1 - 0s - loss: 1.0062 - acc: 0.6464
Epoch 172/400
1/1 - 0s - loss: 1.0137 - acc: 0.6472
Epoch 173/400
1/1 - 0s - loss: 1.0029 - acc: 0.6493
Epoch 174/400
1/1 - 0s - loss: 1.0051 - acc: 0.6502
Epoch 175/400
1/1 - 0s - loss: 1.0071 - acc: 0.6517
Epoch 176/400
1/1 - 0s - loss: 1.0096 - acc: 0.6469
Epoch 177/400
1/1 - 0s - loss: 1.0050 - acc: 0.6505
Epoch 178/400
1/1 - 0s - loss: 1.0131 - acc: 0.6447
Epoch 179/400
1/1 - 0s - loss: 0.9980 - acc: 0.6501
Epoch 180/40

Epoch 319/400
1/1 - 0s - loss: 0.9769 - acc: 0.6529
Epoch 320/400
1/1 - 0s - loss: 0.9652 - acc: 0.6611
Epoch 321/400
1/1 - 0s - loss: 0.9573 - acc: 0.6612
Epoch 322/400
1/1 - 0s - loss: 0.9621 - acc: 0.6600
Epoch 323/400
1/1 - 0s - loss: 0.9532 - acc: 0.6649
Epoch 324/400
1/1 - 0s - loss: 0.9596 - acc: 0.6613
Epoch 325/400
1/1 - 0s - loss: 0.9550 - acc: 0.6657
Epoch 326/400
1/1 - 0s - loss: 0.9532 - acc: 0.6636
Epoch 327/400
1/1 - 0s - loss: 0.9532 - acc: 0.6588
Epoch 328/400
1/1 - 0s - loss: 0.9578 - acc: 0.6585
Epoch 329/400
1/1 - 0s - loss: 0.9601 - acc: 0.6617
Epoch 330/400
1/1 - 0s - loss: 0.9561 - acc: 0.6649
Epoch 331/400
1/1 - 0s - loss: 0.9614 - acc: 0.6601
Epoch 332/400
1/1 - 0s - loss: 0.9529 - acc: 0.6662
Epoch 333/400
1/1 - 0s - loss: 0.9542 - acc: 0.6635
Epoch 334/400
1/1 - 0s - loss: 0.9651 - acc: 0.6582
Epoch 335/400
1/1 - 0s - loss: 0.9631 - acc: 0.6578
Epoch 336/400
1/1 - 0s - loss: 0.9570 - acc: 0.6667
Epoch 337/400
1/1 - 0s - loss: 0.9552 - acc: 0.6633
Epoch 338/40

<tensorflow.python.keras.callbacks.History at 0x7fa5fe01e320>

In [130]:
test_gen = generator.flow(list(test_targets.index), encoded_test_targets)

In [131]:
test_metrics = model.evaluate(test_gen)
print("\nTest Set Metrics:")
for name, val in zip(model.metrics_names, test_metrics):
    print("\t{}: {:0.4f}".format(name, val))


Test Set Metrics:
	loss: 1.4122
	acc: 0.5856


In [134]:
all_gen = generator.flow(transformed_frame.nodes())
all_predictions = model.predict(all_gen)

In [135]:
embedding_model = Model(inputs=x_inp, outputs=x_out)

In [136]:
emb = embedding_model.predict(all_gen)
emb.shape

(1, 17989, 30)

In [140]:
emb[0].shape

(17989, 30)

# prepare entity data

In [329]:
neighbors, dist = node2vec.get_similar_points(existing_indices=["glucose", "covid-19"])

In [330]:
neighbors

[Index(['glucose', 'nonalcoholic steatohepatitis', 'insulin', 'adenosine',
        'metabolic disorder', 'testis', 'stress', 'alanine', 'plasma',
        'glyburide'],
       dtype='object'),
 Index(['covid-19', 'multiple organ failure', 'platelet', 'coronavirus',
        'acute respiratory distress syndrome', 'thrombolytic agent',
        'cardiovascular complication', 'severe acute respiratory syndrome',
        'cardiovascular system', 'dysfunction'],
       dtype='object')]

In [331]:
dist

array([[1.0000001 , 0.9999429 , 0.9999413 , 0.9999406 , 0.99993944,
        0.99993914, 0.99993783, 0.99993575, 0.99993414, 0.9999339 ],
       [1.        , 0.9999411 , 0.999937  , 0.9999359 , 0.99993193,
        0.99993044, 0.9999237 , 0.999921  , 0.99991894, 0.9999187 ]],
      dtype=float32)

In [332]:
node2vec = FaissSimilarityProcessor(
    dimension=64,
    similarity="dot",
    initial_vectors=np.array(node2vec_embedding["embedding"].tolist()),
    initial_index=node2vec_embedding.index,)

In [333]:
neighbors, dist = node2vec.get_similar_points(existing_indices=["glucose", "covid-19"])

In [334]:
neighbors

[Index(['intestine', 'leukocyte', 'bone marrow', 'skin necrosis', 'man',
        'skin rash', 'accumulation', 'growth factor', 'proliferation', 'colon'],
       dtype='object'),
 Index(['intestine', 'leukocyte', 'bone marrow', 'skin necrosis', 'man',
        'skin rash', 'accumulation', 'growth factor', 'proliferation', 'colon'],
       dtype='object')]

In [12]:
a = [5, 0]
b = [0, 5]
c = [3, 2]
d = [5, 0.5]

In [13]:
import numpy as np

In [15]:

# a, d, c, b
# b, c, d, a

In [16]:
# a, d, c, b

In [10]:
def cosine(a, b):
    return np.inner(np.array(a)/np.linalg.norm(np.array(a)), np.array(b)/np.linalg.norm(np.array(b)))

In [26]:
cosine(b, a)

0.0

In [27]:
import networkx as nx

In [28]:
nx.from_pandas_edgelist?

In [1]:
import faiss
import numpy as np

In [2]:
# a = np.array([5, 0])
# b = np.array([0, 5])
# c = np.array([3, 2])
# d = np.array([5, 0.5])
# vectors = [a, a*2, b, b*2, c, c*2, d, d*2]
vectors = np.random.rand(10000, 2)
vectors = np.array(vectors).astype(np.float32)

In [4]:
q = faiss.IndexFlatL2(2)
index = faiss.IndexIVFFlat(q, 2, 100)

In [5]:
index.train(vectors)

In [6]:
index.make_direct_map()

In [7]:
index.add(vectors)

In [8]:
index.reconstruct(1)

array([0.48872378, 0.42502713], dtype=float32)