In [24]:
import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
from scipy.integrate import odeint
import networkx as nx
from pybel.examples import sialic_acid_graph as sag
import pybel as pb
import json
import time
import csv
from pybel.io.jupyter import to_jupyter
import torch
import pyro
import covid19kg
import pandas as pd

pyro.set_rng_seed(101)

In [5]:
# create generic discrete probability function
class cg_node():
    def __init__(self,n_inputs,name):
        
        self.n_inputs = n_inputs
        self.name = name
        
        if n_inputs == 0:
            self.label = 'exogenous'
        else:
            self.label = 'endogenous'
            
        return
    
    def p_init(self,input_data,var_data):
        
        self.n_data = len(input_data)
        
        self.input_data = input_data
        self.var_data = var_data
        
        if self.n_inputs == 0:
            p_ave = np.zeros(3)
            n_count = self.n_data
            for i in range(0,3):
                p_ave[i] = np.sum(var_data == i-1)/n_count
        
        elif self.n_inputs == 1:
            n_count = np.zeros(3)
            p_ave = np.zeros((3,3))
            
            for i in range(0,3):
                n_count[i] = np.sum(input_data == i-1)
                for j in range(0,3):
                    p_ave[j,i] = np.sum((input_data[:,0] == i-1)*(var_data == j-1))/n_count[i]
            
            
        elif self.n_inputs == 2:
            n_count = np.zeros((3,3))
            p_ave = np.zeros((3,3,3))
            
            for i in range(0,3):
                for j in range(0,3):
                    n_count[i,j] = np.sum((input_data[:,0] == i-1)*(input_data[:,1] == j-1))
                    for k in range(0,3):
                        p_ave[k,i,j] = np.sum(
                            (input_data[:,0] == i-1)*(input_data[:,1] == j-1)*(var_data == k-1))/n_count[i,j]
                        
        elif self.n_inputs == 3:
            n_count = np.zeros((3,3,3))
            p_ave = np.zeros((3,3,3,3))
            
            for i in range(0,3):
                for j in range(0,3):
                    for k in range(0,3):
                        n_count[i,j,k] = np.sum(
                            (input_data[:,0] == i-1)*(input_data[:,1] == j-1)*(input_data[:,2] == k-1))
                        for m in range(0,3):
                            p_ave[m,i,j,k] = np.sum((input_data[:,0] == i-1)*(input_data[:,1] == j-1)
                                *(input_data[:,2] == k-1)*(var_data == m-1))/n_count[i,j,k]
                            
        elif self.n_inputs == 4:
            n_count = np.zeros((3,3,3,3))
            p_ave = np.zeros((3,3,3,3,3))
            
            for i in range(0,3):
                for j in range(0,3):
                    for k in range(0,3):
                        for m in range(0,3):
                            n_count[i,j,k,m] = np.sum((input_data[:,0] == i-1)*(input_data[:,1] == j-1)
                                *(input_data[:,2] == k-1)*(input_data[:,3] == m-1))
                            for q in range(0,3):
                                p_ave[q,i,j,k,m] = np.sum((input_data[:,0] == i-1)*(input_data[:,1] == j-1)
                                    *(input_data[:,2] == k-1)*(input_data[:,3] == m-1)
                                    *(var_data == q-1))/n_count[i,j,k,m]
                        
            
        else:
            print('error -- too many inputs')
            return
            
        self.n_count = torch.tensor(n_count/self.n_data)
        self.prob_dist = torch.tensor(p_ave)
        
        return
    
    def sample(self,data_in=[]):
        
        if self.n_inputs == 0:
            p_temp = self.prob_dist
        elif self.n_inputs == 1:
            p_temp = torch.squeeze(self.prob_dist[:,data_in[0]+1])
        elif self.n_inputs == 2:
            p_temp = torch.squeeze(self.prob_dist[:,data_in[0]+1,data_in[1]+1])
        elif self.n_inputs == 3:
            p_temp = torch.squeeze(self.prob_dist[:,data_in[0]+1,data_in[1]+1,data_in[2]+1])
            
        else:
            print('error -- too many inputs')
            p_temp = []
        
        return torch.squeeze(pyro.sample(self.name,pyro.distributions.Categorical(probs = p_temp)).int()-1)

In [6]:
class cg_graph():
    
    def __init__(self,str_list=[],bel_graph=[], text = []):
        
        edge_list = []

        entity_list = []
        # add a new input data type to handle the data from CBN
        if text:
            
            for item in text:
                
                rel_temp = item.split('*')[1]
                
                sub_temp = item.split('*')[0]
                obj_temp = item.split('*')[2]
                
                if sub_temp not in entity_list:
                    entity_list.append(sub_temp)
                if obj_temp not in entity_list:
                    entity_list.append(obj_temp)
                
                if rel_temp.find('crease') > 0:
                    edge_list.append([sub_temp,obj_temp,rel_temp])
        
        if str_list:

            for item in str_list:

                sub_ind = item.find('=')

                sub_temp = item[:sub_ind-1]
                obj_temp = item[sub_ind+3:]
                
                rel_temp = item[sub_ind:sub_ind+2]

                if sub_temp not in entity_list:
                    entity_list.append(sub_temp)
                if obj_temp not in entity_list:
                    entity_list.append(obj_temp)
                
                # ignore hasVariant, partOf relations
                
                if rel_temp.find('crease') > 0:
                    edge_list.append([sub_temp,obj_temp,rel_temp])
                
                # check for duplicate edges
                #nodes_temp = [sub_temp,obj_temp]
                #list_temp = [[item[0],item[1]] for item in edge_list]
                #if nodes_temp in list_temp:
                    #ind_temp = list_temp.index(nodes_temp)
                    #edge_list[ind_temp][2] += ',' + rel_temp
                #else:
                    #edge_list.append([sub_temp,obj_temp,rel_temp])
                
        elif bel_graph:
            
            for item in bel_graph.edges:
                edge_temp = bel_graph.get_edge_data(item[0],item[1],item[2])
                sub_temp = str(item[0]).replace('"','')
                obj_temp = str(item[1]).replace('"','')
                rel_temp = edge_temp['relation']
                
                if sub_temp not in entity_list:
                    entity_list.append(sub_temp)
                if obj_temp not in entity_list:
                    entity_list.append(obj_temp)
                
                # ignore hasVariant, partOf relations
                
                if rel_temp.find('crease') > 0:
                    edge_list.append([sub_temp,obj_temp,rel_temp])
                
                # check for duplicate edges
                #nodes_temp = [sub_temp,obj_temp]
                #list_temp = [[item[0],item[1]] for item in edge_list]
                #if nodes_temp in list_temp:
                    #ind_temp = list_temp.index(nodes_temp)
                    #edge_list[ind_temp][2] += ',' + rel_temp
                #else:
                    #edge_list.append([sub_temp,obj_temp,rel_temp])

        n_nodes = len(entity_list)
        self.n_nodes = n_nodes

        adj_mat = np.zeros((n_nodes,n_nodes),dtype=int)

        for item in edge_list:
            out_ind = entity_list.index(item[0])
            in_ind = entity_list.index(item[1])
            adj_mat[out_ind,in_ind] = 1
            
        self.edge_list = edge_list
        self.entity_list = entity_list
        
        #self.graph = nx.DiGraph(adj_mat)
        
        node_dict = {}
        
        for i in range(0,n_nodes):
            node_dict[entity_list[i]] = cg_node(np.sum(adj_mat[:,i]),entity_list[i])
        
        self.node_dict = node_dict
        
        #self.parent_ind_list = []
        #self.child_ind_list = []
        self.parent_name_dict = {}
        #self.child_name_dict = {}
        
        self.parent_ind_list = [np.where(adj_mat[:,i] > 0)[0] for i in range(0,n_nodes)]
        #self.child_ind_list = [np.where(self.adj_mat[i,:] > 0)[0] for i in range(0,n_nodes)]
        
        for i in range(0,n_nodes):
            self.parent_name_dict[entity_list[i]] = [entity_list[item] for item in self.parent_ind_list[i]]
            #self.child_name_dict[entity_list[i]] = [entity_list[item] for item in self.child_ind_list[i]]

        return
    
    
    def prob_init(self,data_in):
        # initialize all of the nodes
        
        exog_list = []
        prob_dict = {}
        
        for name in self.node_dict:
            i = self.entity_list.index(name)
            data_in_temp = data_in[:,self.parent_ind_list[i]]
            data_out_temp = data_in[:,i]
            
            self.node_dict[name].p_init(data_in_temp,data_out_temp)
            
            if self.node_dict[name].n_inputs == 0:
                exog_list.append(name)
            prob_dict[name] = self.node_dict[name].prob_dist
        
        self.exog_list = exog_list
        self.prob_dict = prob_dict

        return
        
    def model_sample(self):
        
        # define exogenous samples
        
        sample_dict = {}
        
        for item in self.exog_list:
            sample_dict[item] = self.node_dict[item].sample()
            
        flag = 0
        while flag == 0:
            
            # find all nodes not in sample_dict with parents entirely in sample dict and sample those nodes
            for item in self.entity_list:
                if (item not in sample_dict 
                    and np.all([item2 in sample_dict for item2 in self.parent_name_dict[item]])):
                    
                    sample_dict[item] = self.node_dict[item].sample(
                        [sample_dict[item2] for item2 in self.parent_name_dict[item]])
            
            # if sample dict has all of the nodes in entity list, stop
            if sorted([item for item in sample_dict]) == sorted(self.entity_list):
                flag = 1
            
        
        return sample_dict
    
    def model_cond_sample(self,data_dict):
        
        data_in = {}
        for item in data_dict:
            data_in[item] = data_dict[item] + 1
        
        cond_model = pyro.condition(self.model_sample,data=data_in)
        return cond_model()
        
    def model_do_sample(self,do_dict):
        
        data_in = {}
        for item in do_dict:
            data_in[item] = do_dict[item] + 1
        
        do_model = pyro.do(self.model_sample,data=data_in)
        return do_model()
    
    def model_do_cond_sample(self,do_dict,data_dict):
        
        if np.any([[item1 == item2 for item1 in do_dict] for item2 in data_dict]):
            print('overlapping lists!')
            return
        else:
            do_dict_in = {}
            for item in do_dict:
                do_dict_in[item] = do_dict[item] + 1
                
            data_dict_in = {}
            for item in data_dict:
                data_dict_in[item] = data_dict[item] + 1
            
            do_model = pyro.do(self.model_sample,data=do_dict_in)
            cond_model = pyro.condition(do_model,data=data_dict_in)
            return cond_model()
    
    def model_counterfact(self,obs_dict,do_dict_counter):
        # find conditional distribution on exogenous variables given observations and do variable values
        cond_dict = self.model_cond_sample(obs_dict)
        cond_dict_temp = {}
        for item in self.exog_list:
            cond_dict_temp[item] = cond_dict[item]
        
        # evaluate observed variables given this condition distribution and do_dict_counter do-variables
        return self.model_do_cond_sample(do_dict_counter,cond_dict_temp)
        
        
    def cond_mut_info(self,target,test,cond,data_in):
        
        cond_temp = cond
        
        if not cond:
            # find parents of target
            for item in target:
                for item2 in self.parent_name_dict[item]:
                    if item2 not in cond_temp:
                        cond_temp.append(item2)
        
        
        target_inds = [self.entity_list.index(item) for item in target]
        test_inds = [self.entity_list.index(item) for item in test]
        cond_inds = [self.entity_list.index(item) for item in cond_temp]
        
        n_total = len(data_in)
        
        null_joint = data_in[:,target_inds + cond_inds]
        null_cond = data_in[:,cond_inds]
        
        hypth_joint = data_in[:,target_inds + test_inds + cond_inds]
        hypth_cond = data_in[:,test_inds + cond_inds]
        
        null_entropy = 0
        null_list = []
        
        hypth_entropy = 0
        hypth_list = []
        for i in range(0,n_total):

            if np.all([np.any(null_joint[i,:] != item) for item in null_list]):
                num_sum = np.sum([np.all(null_joint[i,:] == item) for item in null_joint])
                denom_sum = np.sum([np.all(null_cond[i,:] == item) for item in null_cond])
                null_entropy = null_entropy - num_sum*np.log(num_sum/denom_sum)
                null_list.append(null_joint[i,:])
            
            if np.all([np.any(hypth_joint[i,:] != item) for item in hypth_list]):
                num_sum = np.sum([np.all(hypth_joint[i,:] == item) for item in hypth_joint])
                denom_sum = np.sum([np.all(hypth_cond[i,:] == item) for item in hypth_cond])
                hypth_entropy = hypth_entropy - num_sum*np.log(num_sum/denom_sum)
                hypth_list.append(hypth_joint[i,:])
                
        return (null_entropy - hypth_entropy)/n_total
        
    def g_test(self,name,data_in):
        # do the G-test on a single variable of interest
        
        #p_name = self.calc_prob(name)*len(data_in)
        # generate an empirical distribution for variable name
        model_data = np.zeros(len(data_in))
        for i in range(0,len(data_in)):
            model_data[i] = self.model_sample()[name[0]].item()
            
        p_model = torch.Tensor([np.sum(model_data == -1),np.sum(model_data == 0),np.sum(model_data == 1)])
        print(p_model)
        
        name_ind = self.entity_list.index(name[0])
        name_data = data_in[:,name_ind]
        
        p_data = torch.Tensor([np.sum(name_data == -1),np.sum(name_data == 0),np.sum(name_data == 1)])
        print(p_data)
        
        g_val = 2*torch.sum(p_data*torch.log(p_data/p_model))
        
        dof = len(data_in) - 1
        
        p_val = 1-sp.stats.chi2.cdf(g_val.item(), dof)
        
        return g_val,p_val
        
  
    def write_to_cf(self,filename,spacing):
        # write the causal graph to a text file to import into causal fusion
        
        pos_dict = nx.drawing.layout.planar_layout(self.graph)
        
        write_dict = {}
        write_dict['name'] = 'causal_graph'
        
        # write nodes
        write_dict['nodes'] = []
        for i in range(0,len(self.entity_list)):
            name = self.entity_list[i]
            
            write_dict['nodes'].append({})
            
            write_dict['nodes'][-1]['id'] = 'node' + str(i)
            write_dict['nodes'][-1]['name'] = name
            write_dict['nodes'][-1]['label'] = name
            write_dict['nodes'][-1]['type'] = 'basic'
            write_dict['nodes'][-1]['metadata'] = {}
            write_dict['nodes'][-1]['metadata']['x'] = spacing*pos_dict[i][0]
            write_dict['nodes'][-1]['metadata']['y'] = spacing*pos_dict[i][1]
            write_dict['nodes'][-1]['metadata']['label'] = ''
            write_dict['nodes'][-1]['metadata']['shape'] = 'ellipse'
            write_dict['nodes'][-1]['metadata']['fontSize'] = 14
            write_dict['nodes'][-1]['metadata']['sizeLabelMode'] = 5
            write_dict['nodes'][-1]['metadata']['font'] = {}
            write_dict['nodes'][-1]['metadata']['font']['size'] = 14
            write_dict['nodes'][-1]['metadata']['size'] = 14
            write_dict['nodes'][-1]['metadata']['labelNodeId'] = 'node' + str(i) + 'ID'
            write_dict['nodes'][-1]['metadata']['labelNodeOffset'] = {}
            write_dict['nodes'][-1]['metadata']['labelNodeOffset']['x'] = 0
            write_dict['nodes'][-1]['metadata']['labelNodeOffset']['y'] = 0
            write_dict['nodes'][-1]['metadata']['labelOffset'] = {}
            write_dict['nodes'][-1]['metadata']['labelOffset']['x'] = 0
            write_dict['nodes'][-1]['metadata']['labelOffset']['y'] = 0
            write_dict['nodes'][-1]['metadata']['shadow'] = {}
            write_dict['nodes'][-1]['metadata']['shadow']['color'] = '#00000080'
            write_dict['nodes'][-1]['metadata']['shadow']['size'] = 0
            write_dict['nodes'][-1]['metadata']['shadow']['x'] = 0
            write_dict['nodes'][-1]['metadata']['shadow']['y'] = 0
            
        # write edges
        write_dict['edges'] = []
        
        for i in range(0,len(self.edge_list)):
            
            item = self.edge_list[i]
            from_node = self.entity_list.index(item[0])
            to_node = self.entity_list.index(item[1])
            
            write_dict['edges'].append({})
            
            write_dict['edges'][-1]['id'] = 'node' + str(from_node) + '->node' + str(to_node)
            write_dict['edges'][-1]['from'] = item[0]
            write_dict['edges'][-1]['to'] = item[1]
            write_dict['edges'][-1]['type'] = 'directed'
            write_dict['edges'][-1]['metadata'] = {}
            write_dict['edges'][-1]['metadata']['isLabelDraggable'] = True
            write_dict['edges'][-1]['metadata']['label'] = ''
            
        
        write_dict['task'] = {}
        
        write_dict['metadata'] = {}
        
        write_dict['project_id'] = '123456789'
        write_dict['_fileType'] = 'graph'
                
        with open(filename + '.json', 'w') as json_file:
            json.dump(write_dict, json_file)
        

# Load the data

In [4]:
f = open('Epithelial Innate Immune Activation-2.0-Hs.jgf')
Epth = json.load(f)

In [56]:
f = open('COVID19.jgf')
covid19 = json.load(f)

In [5]:
type(Epth)

dict

# Extract the type directly from JSON format

The CBN website categorize some other relation as causal effect which are not included in this notebook: association(3), positiveCorrolation(4), negativeCorrolation(1), biomarkerFor(2). Thus, we only have 215 causal edges.

In [57]:
# type has the structur of (type_parent, type_children, type_relation):[[edges]]

types = {}
nodes = set()
parents = set()
children = set()
str_list = []
for edges in covid19['graph']['edges']:
    parent = edges['source']
    child = edges['target']
    relation = edges['relation']
    text = edges['label']
    if relation.find('crease') > 0:
        nodes.add(parent)
        nodes.add(child)
        parents.add(parent)
        children.add(child)
        str_list.append(parent + '*' + relation + '*' + child)   
        parent_type = parent[:parent.find('(')]
        chidren_type = child[:child.find('(')]
        types.setdefault((parent_type,chidren_type,relation), []).append(parent + '*' + relation + '*' + child) 

In [58]:
print(len(nodes))
print(len(str_list))
print(len(types))

72
99
24


We have 140 nodes, 54 edges and 60 type of causal edges.

# Use bel2pyro model to extract the types

In [17]:
# We still need to extract the string list first from the JSON format data to get the correct pyro model
graph_test = cg_graph(text = str_list)

In [18]:
print(len(graph_test.node_dict))
print(len(graph_test.edge_list))

72
99


In [19]:
types = {}
for edge in graph_test.edge_list:
    parent_type = edge[0][:edge[0].find('(')]
    chidren_type = edge[1][:edge[1].find('(')]
    types.setdefault((parent_type,chidren_type,edge[2]), []).append(edge) 

In [20]:
len(types)

24

We got the same result when using the bey2pyro model

In [41]:
# new data covid knowledge graph
graph = covid19kg.get_graph()
graph.summarize()

Covid19KG v0.0.1-dev
Number of Nodes: 3954
Number of Edges: 9484
Number of Citations: 185
Number of Authors: 950
Network Density: 6.07E-04
Number of Components: 29


In [38]:
graph_test = cg_graph(bel_graph = graph)

In [39]:
print(len(graph_test.node_dict))
print(len(graph_test.edge_list))

3954
2324


In [44]:
type(graph)

pybel.struct.graph.BELGraph

## refining of type and label of the BEL

In [107]:
# type has the structure of (type_parent, type_children, type_relation):[[edges]]
label_dict = {'Abundance':['a', 'r', 'm', 'g', 'p','pop', 'composite', 'complex','frag','fus','loc','pmod','var'],
             'Process': ['bp', 'path','act'],
             'Transformation':['sec','surf','deg','rxn','tloc','fromLoc','products','reactants','toLoc']}
def get_information(jgf_file):
    # causal_relations = ['increases', 'decreases', 'directlyIncreases', 'directlyDecreases']
    types = []
    # nodes = set()
    parents = []
    parent_types = []
    children_types = []
    parent_labels = []
    children_labels = []
    children = []
    str_list = []
    df = pd.DataFrame()
    
    for edges in jgf_file['graph']['edges']:
        parent = edges['source']
        child = edges['target']
        relation = edges['relation']
        text = edges['label']
        if relation.find('crease') > 0:
            # nodes.add(parent)
            # nodes.add(child)
            parents.append(parent)
            children.append(child)
            str_list.append(parent + '*' + relation + '*' + child)   
            parent_type = parent[:parent.find('(')]
            children_type = child[:child.find('(')]
            types.append(relation)
            children_types.append(children_type)
            parent_types.append(parent_type)
            for label in label_dict:
                if parent_type in label_dict[label]:
                    parent_labels.append(label)
                elif parent_type == '': 
                    parent_labels.append('Others')
            for label in label_dict:
                if children_type in label_dict[label]:     
                    children_labels.append(label)
                elif children_type == '': 
                    children_labels.append('Others')
                    break
            # types.setdefault((parent_type,chidren_type,relation), []).append(parent + '*' + relation + '*' + child)
    df['parents'] = parents
    df['children'] = children
    df['types'] = types
    df['statements'] = str_list
    df['parent_types'] = parent_types
    df['parent_labels'] = parent_labels
    df['children_types'] = children_types
    df['children_labels'] = children_labels
    
    return df

In [113]:
df = get_information(covid19)
df = df[~df['children_labels'].isin(['Others'])] # filter the two with nested structures

In [121]:
stats = df.groupby(['types','parent_labels','children_labels']).count()
stats

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,parents,children,statements,parent_types,children_types
types,parent_labels,children_labels,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
decreases,Abundance,Abundance,22,22,22,22,22
decreases,Abundance,Process,5,5,5,5,5
decreases,Abundance,Transformation,1,1,1,1,1
decreases,Process,Abundance,3,3,3,3,3
directlyDecreases,Abundance,Process,2,2,2,2,2
directlyIncreases,Abundance,Process,1,1,1,1,1
increases,Abundance,Abundance,21,21,21,21,21
increases,Abundance,Process,15,15,15,15,15
increases,Abundance,Transformation,2,2,2,2,2
increases,Process,Abundance,22,22,22,22,22


In [None]:
def get_distribution(node_type):
    ## we have one category in children nodes pop which we have no information about
    for label in label_dict:
        if label == 'Abundance':
            
    if node_type in ['a', 'r', 'm', 'g', 'p', 'composite', 'complex']:
        ## this is an abundance type node
        ## chose lognormal because we need a +ve continuous distribution here
        return dist.LogNormal(torch.tensor([0.0]), torch.tensor([1.0])
    if node_type in ['p', 'bp', 'path']: 
      ## processes have binary distribution. path is a pathology process
        return dist.Categorical(torch.tensor([0.5]), torch.tensor([0.5]))
    if node_type in ['act', 'molecularActivity', 'chap', 'pep', 'ribo']:
      ## activity is continuous                       
        return dist.LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
    if node_type in ['reaction', 'degradation']:
      ## it should be continuous so starting with lognormal
        return dist.LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
    if node_type in ['tloc', 'sec', 'surf', 'tscript', 'tport']:
      ## transport category
        return dist.LogNormal(torch.tensor([0.0]), torch.tensor([1.0]))
    if node_type in['gtp', 'cat', 'kin', 'phos']:
      ## these are binary
        return dist.Categorical(torch.tensor([0.5]), torch.tensor([0.5]))