<a href="https://colab.research.google.com/github/Pallavi-Kolambkar/NEU_project/blob/master/BEL_to_SCM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Package installation

To successfully run the notebook, we need to install packages listed below:
* pyro http://download.pytorch.org/whl/cpu/torch-0.4.0-cp27-cp27mu-linux_x86_64.whl 
* covid19kg https://github.com/COVID-19-Causal-Reasoning/covid19kg

# Background of the Project

The project aims to create a causal model to study the host response to the SARS-CoV2 infection. The unique trait about the virus is the growing abundance of a particular Interleukin protein namely the Interleukin-6 protein(IL-6). Refering to the paper "COVID-19: a new virus, but an old cytokine release syndrome" by Toshio Hirano and Masaaki Murakami, one possible machanism to explain the inclination of IL-6 can be the creation of a IL-6 amplification cytokine storm. This project looks into the two possible targets leading to the creation of the IL-6 amplification cycle and uses intervention and counterfactual queries to support the observations seen in the paper.

 






In [0]:
#import all packages

import numpy as np
import matplotlib.pyplot as plt
import scipy as sp
import json
import time
import csv
import torch
import pyro
import pandas as pd
import covid19kg
from scipy.stats import norm
from sklearn.linear_model import LinearRegression
from sklearn import metrics
from pyro.infer import Importance, EmpiricalMarginal
from torch.distributions.transforms import AffineTransform
pyro.set_rng_seed(101)

In [0]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', -1)

  after removing the cwd from sys.path.


# Data Clean

In [0]:
# load the data from covid19kg in JSON format
# graph = covid19kg.get_graph()
# pb.to_jgif_file(graph, 'covid19kg.jgf')
# f = open('COVID19.jgf')
# covid19kg_dict = json.load(f)

# load the data if using google colab
from google.colab import files 
uploaded = files.upload()
with open('COVID-19-new.json') as f:
    covid19kg_dict=json.load(f)

Saving COVID-19-new.json to COVID-19-new.json


In [0]:
#convert the json data into dataframe

parents=[]
children=[]
edgetype=[]
parent_type=[]
children_type=[]
parent_labels = []
children_labels = []
nodes = set()

label_dict = {'Abundance':['a', 'r', 'm', 'g', 'p','pop', 'composite',
                           'complex','frag','fus','loc','pmod','var'],
             'Process': ['bp', 'path','act'],
             'Transformation':['sec','surf','deg','rxn','tloc','fromLoc',
                               'products','reactants','toLoc']}
             
for eachedge in covid19kg_dict[0]['nanopub']['assertions']:
    parents.append(eachedge['subject'])
    children.append(eachedge['object'])
    edgetype.append(eachedge['relation'])
    nodes.add(eachedge['subject'])
    nodes.add(eachedge['object'])
    p=eachedge['subject']
    c=eachedge['object']
    ptype = p[:p.find('(')]
    ctype = c[:c.find('(')]
    parent_type.append(ptype)
    children_type.append(ctype)
    for label in label_dict:
      if ptype in label_dict[label]:
        parent_labels.append(label)
      elif ptype == '': 
        parent_labels.append('Others')
    for label in label_dict:
      if ctype in label_dict[label]:     
        children_labels.append(label)
      elif ctype == '': 
        children_labels.append('Others')
        breal


In [0]:
newcovid_df = pd.DataFrame(list(zip(parents, parent_type, edgetype, children,
                                    children_type, parent_labels, 
                                    children_labels)), columns = ['Parent',
                                                                  'Parent_type',
                                                                  'Relation',
                                                                  'Child',
                                                                  'Child_type',
                                                                  'Parent_label',
                                                                  'Child_label'
                                                                  ])

In [0]:
ailias_dict = {'bp(GO:"tumor necrosis factor-mediated signaling pathway")':
               'TNF',
               'bp(GO:"epidermal growth factor receptor signaling pathway")':
               'EGFR',
               'act(p(HGNC:EGF))':'EGF',
               'act(p(HGNC:ADAM17))': 'ADAM17',
               'bp(GO:"positive regulation of interleukin-6 production")':
               'IL_6_AMP',
               'act(complex(GO:"NF-kappaB complex"))': 'NF_xB',
               'bp(GO:"pattern recognition receptor signaling pathway")': 'PRR',
               'a(TAX:"Severe acute respiratory syndrome coronavirus 2")':
               'SARS_2',
               'act(complex(p(HGNC:IL6),p(HGNC:STAT3)))': 'IL6_STAT3',
               'act(p(HGNC:IL6R))': 'IL_6',
               'act(p(HGNC:AGTR1))': 'ATIR',
               'a(CHEBI:"angiotensin II")': 'AngII',
               'act(p(HGNC:ACE2))': 'ACE2',
               'deg(p(HGNC:ACE2))': 'deg',
               'path(MESH:"Severe Acute Respiratory Syndrome")': 'ARDS',
               'act(complex(GO:"NF-kappaB p50/p65 complex"))': 'NF_xB' }
newcovid_df['Parent_alias'] = newcovid_df['Parent'].apply(
    lambda x : ailias_dict[x])
newcovid_df['Child_alias'] = newcovid_df['Child'].apply(
    lambda x : ailias_dict[x])
nodes = [ailias_dict[i] for i in nodes]

In [0]:
newcovid_df.sort_values('Parent').head()

Unnamed: 0,Parent,Parent_type,Relation,Child,Child_type,Parent_label,Child_label,Parent_alias,Child_alias
13,"a(CHEBI:""angiotensin II"")",a,increases,act(p(HGNC:AGTR1)),act,Abundance,Process,AngII,ATIR
8,"a(TAX:""Severe acute respiratory syndrome coronavirus 2"")",a,increases,"bp(GO:""pattern recognition receptor signaling pathway"")",bp,Abundance,Process,SARS_2,PRR
16,"a(TAX:""Severe acute respiratory syndrome coronavirus 2"")",a,increases,deg(p(HGNC:ACE2)),deg,Abundance,Transformation,SARS_2,deg
6,"act(complex(GO:""NF-kappaB complex""))",act,increases,"bp(GO:""positive regulation of interleukin-6 production"")",bp,Process,Process,NF_xB,IL_6_AMP
9,"act(complex(p(HGNC:IL6),p(HGNC:STAT3)))",act,increases,"bp(GO:""positive regulation of interleukin-6 production"")",bp,Process,Process,IL6_STAT3,IL_6_AMP


# The SCM

We classified all nodes into three types: Abundance, Process and Transformation. Abundance and Transformation are sampled from log normal distribution and Process should be sampled from Bernoulli distribution. (The details of classification could be found https://github.com/belbio/bel_specifications/blob/master/specifications/bel_v2_1_0.yaml)

Based on the nodes types, we have three basic types of parent-child pairs. For Abundance to Process, we used a logic gate corresponding to the relation type 'increase' or 'decrease'. For Process to Abundance, we first check if Process is active, than sampled the value based on child's distribution. For Abundance to Transformation, we used quadratic function to get the node's value. For more complex situation, we combined these three basic cases together. (We used AND gate for multiple Process parents here which should be further validated when applied to other graph)

In [0]:


def check_increase(x):
    """
    Description: Helper function for SCM_model(), 
                 to be used with increasing type edges
    Parameters:  Result of parents' equation (x)
    Returns:     1.0 if value is greater than set threshold
                 else 0.0
    """
    threshold = 0.5
    if x > threshold:
        return 1.0
    else:
        return 0.0

def check_decrease(x):
    """
    Description: Helper function for SCM_model(), 
                 to be used with decreasing type edges
    Parameters:  Result of parents' equation (x)
    Returns:     0.0 if value is greater than set threshold
                 else 1.0
    """
    threshold = 0.5
    if x > threshold:
        return 0.0
    else:
        return 1.0


In [0]:
def SCM_model():
    """
    Description: This function is to be build a Structural Causal Model for
                 Covid-19-new knowledge graph
    Parameters: None
    Returns: sampled values for all nodes of the graph in tensor format
    """

    threshold = 0.5
    # the threshold should be determined carefully (e.g. 0.5)
    lognormal_dist = dist.LogNormal(torch.tensor(0.0), torch.tensor(0.5))
    lognormal_dist_noise = dist.Gamma(torch.tensor(0.5),torch.tensor(2.))
    normal_dist = dist.Normal(torch.tensor(0.0),torch.tensor(0.1))
    samples = {}

    SARS_N = pyro.sample('SARS_N',lognormal_dist) 
    # SARS_2 = pyro.sample('SARS_2',pyro.distributions.Delta(SARS_N))
    SARS_2 = pyro.sample('SARS_2',dist.Normal(SARS_N, 0.01))
    samples['SARS_2'] = SARS_2

    PRR_N = pyro.sample("PRR_N", normal_dist)
    PRR = torch.tensor(check_increase(PRR_N + SARS_2))
    PRR = pyro.sample('PRR',pyro.distributions.Delta(PRR))
    samples['PRR'] = PRR

    deg_N = pyro.sample("deg_N", normal_dist)
    deg = SARS_2 + SARS_2*SARS_2 + deg_N
    # deg = pyro.sample('deg',pyro.distributions.Delta(deg))
    deg = pyro.sample('deg',pyro.distributions.Normal(deg, 0.01))
    samples['deg'] = deg

    ACE2_N = pyro.sample("ACE2_N", normal_dist)
    ACE2 = torch.tensor(check_decrease(deg + ACE2_N))
    ACE2 = pyro.sample('ACE2',pyro.distributions.Delta(ACE2)) 
    samples['ACE2'] = ACE2 

    AngII_N = pyro.sample('AngII_N',lognormal_dist_noise)
    if ACE2 == 0:
        AngII = pyro.sample('AngII',lognormal_dist)
    else:
        AngII = pyro.sample('AngII',pyro.distributions.Normal(AngII_N, 0.01))
    samples['AngII'] = AngII

    ATIR_N = pyro.sample("ATIR_N", normal_dist)
    ATIR = torch.tensor(check_increase(AngII + ATIR_N))
    ATIR = pyro.sample('ATIR',pyro.distributions.Delta(ATIR))
    samples['ATIR'] = ATIR

    ADAM17_N = pyro.sample("ADAM17_N", normal_dist)
    ADAM17 = torch.tensor(check_increase(ATIR + ADAM17_N))
    ADAM17 = pyro.sample('ADAM17',pyro.distributions.Delta(ADAM17))
    samples['ADAM17'] = ADAM17

    TNF_N = pyro.sample("TNF_N", normal_dist)
    TNF = torch.tensor(check_increase(ADAM17 + TNF_N))
    TNF = pyro.sample('TNF',pyro.distributions.Delta(TNF))
    samples['TNF'] = TNF

    EGF_N = pyro.sample("EGF_N", normal_dist)
    EGF = torch.tensor(check_increase(ADAM17 + EGF_N))
    EGF = pyro.sample('EGF',pyro.distributions.Delta(EGF))
    samples['EGF'] = EGF

    EGFR_N = pyro.sample("EGFR_N", normal_dist)
    EGFR = torch.tensor(check_increase(EGF + EGFR_N))
    EGFR = pyro.sample('EGFR',pyro.distributions.Delta(EGFR))
    samples['EGFR'] = EGFR

    IL_6_N = pyro.sample("IL_6_N", normal_dist)
    IL_6 = torch.tensor(check_increase(ADAM17 + IL_6_N))
    IL_6 = pyro.sample('IL_6',pyro.distributions.Delta(IL_6))
    samples['IL_6'] = IL_6

    IL6_STAT3_N = pyro.sample("IL6_STAT3_N", normal_dist)
    IL6_STAT3 = torch.tensor(check_increase(IL_6 + IL6_STAT3_N))
    IL6_STAT3 = pyro.sample('IL6_STAT3',pyro.distributions.Delta(IL6_STAT3))
    samples['IL6_STAT3'] = IL6_STAT3

    # TNF PRR EGFR (ATRI while there is a edge in the graph) code as AND gate
    NF_xB_N = pyro.sample("NF_xB_N", normal_dist)
    # the threshold should be determined carefully
    if TNF + PRR + EGFR + NF_xB_N > 2.5:
        NF_xB = torch.tensor(1.0)
    else: 
        NF_xB = torch.tensor(0.0)
    NF_xB = pyro.sample('NF_xB',pyro.distributions.Delta(NF_xB))
    samples['NF_xB'] = NF_xB

    # NFXB IL6STAT3 
    IL_6_AMP_N = pyro.sample("IL_6_AMP_N", normal_dist)
    # the threshold should be determined carefully
    if IL6_STAT3 + NF_xB + IL_6_AMP_N > 1.5:
        IL_6_AMP = torch.tensor(1.0)
    else: 
        IL_6_AMP = torch.tensor(0.0)
    IL_6_AMP = pyro.sample('IL_6_AMP',pyro.distributions.Delta(IL_6_AMP))
    samples['IL_6_AMP'] = IL_6_AMP

    ARDS_N = pyro.sample("ARDS_N", normal_dist)
    ARDS = torch.tensor(check_increase(IL_6_AMP + ARDS_N))
    ARDS = pyro.sample('ARDS',pyro.distributions.Delta(ARDS))
    samples['ARDS'] = ARDS

    return samples

In [0]:
pyro.set_rng_seed(101)
SCM_model()

# The Intervention function

In [0]:


def intervention(model, do_variable, do_val, target_variable):
    """
      Description: This is a function to perform intervention
      query for sampling
      Parameters:  Structural Causal Model (model),
               a list of variables to be intervened (do_variable),
               list of values for intervened variables in
               same order (do_val),
               target variable (target_variable)
      Returns:  probability of target variable in given setting
    """
    # get the conditions for the do model
    conditions = {}
    for i in range(len(do_variable)):
        conditions[do_variable[i]] = torch.tensor(do_val[i])
    do_model = pyro.do(model, data=conditions)
    posterior = pyro.infer.Importance(do_model, num_samples=1000).run()
    marginal = EmpiricalMarginal(posterior, target_variable)
    target = [marginal().item() for i in range(1000)]
    return np.mean(target)

Performed some intervention for sanity check and debugging.
We know that inhibition of ACE_2 can contribute to ARDS. Which implies that if ACE_2 is inhibited (set to 0), the probability of ARDS should be high. 

So, checking P(ARDS | do (ACE_2= 0)) 

In [0]:
intervention(SCM_model,['ACE2'],[0.0],'ARDS')

0.814

What is the Probability of ARDS if ACE_2 is active (set to 1)?

Checking, P (ARDS | do (ACE_2 = 1))

In [0]:
intervention(SCM_model,['ACE2'],[1.0],'ARDS')

Similarly following the second path, if PRR is active (set to 1), the probability of ARDS should be high. 

Checking, P (ARDS | do (PRR = 1))

In [0]:
intervention(SCM_model,['PRR'],[1.0],'ARDS')

what would ARDS be if PRR is set to low (0)?

P (ARDS | do (PRR = 0))

In [0]:
intervention(SCM_model, ['PRR'], [0.0], 'ARDS')

what would ARDS be if IL_6_STAT3 is set to low (0)?

P (ARDS | do (IL6_STAT3 = 0))

In [0]:
intervention(SCM_model, ['IL6_STAT3'], [0.0], 'ARDS')

#Counterfactual function

In [0]:


def counterfactual(model, observed_variable, observed_val, exogenous_variable,
                   cf_variable, cf_val, target_variable):
    """
      Description: This is a function to perform counterfactual
      query for sampling
      Parameters:  Structural Causal Model (model),
               a list of observed variables (observed_variable),
               list of values of obeserved variables in
               same order (observed_val),
               list of exogenous variables (exogenous_variable),
               counterfactual variable (cf_variable),
               supposed value of counterfactual variable (cf_val),
               target variable (target_variable)
      Returns:  probability of target variable in given setting
    """
    target = []
    # get the conditional model
    conditions = {}
    for i in range(len(observed_variable)):
        conditions[observed_variable[i]] = torch.tensor(observed_val[i])
    conditioned_model = pyro.condition(model, data=conditions)
    # get the posterior and marginal distribution
    posterior = pyro.infer.Importance(conditioned_model,
                                      num_samples=1000).run()
    marginal = EmpiricalMarginal(posterior, exogenous_variable)
    # get the do model
    dos = {}
    for i in range(len(cf_variable)):
        dos[cf_variable[i]] = torch.tensor(cf_val[i])
    do_model = pyro.do(model, data=dos)
    # sample target variable under updated posterior and do model
    for i in range(1000):
        samples = marginal()
        exog = {}
        for i in range(len(exogenous_variable)):
            exog[exogenous_variable[i]] = torch.tensor(samples[i])
        cf_model = pyro.condition(do_model, data=exog)
        trace_handler = pyro.poutine.trace(cf_model)
        trace = trace_handler.get_trace()

        target.append(trace.nodes[target_variable]['value'].numpy().item())

    return np.mean(target)

We started our queries by asking debugging questions, 
For example, we know that if SARS_2 is active, ACE_2 will be inhibited which can lead to ARDS following path1. So what is the probability of ARDS if ACE_2 is active?

P ($ARDS_{ACE_2 = 1}$= 1 | ACE_2 = 0, SARS_2 = 1, ARDS = 1)

In [0]:
ARDS_ACE_2_1 = counterfactual(SCM_model, ['SARS_2', 'ACE2', 'ARDS'],
                              [1., 0., 1.],
                              ['SARS_N', 'PRR_N', 'deg_N',
                               'ACE2_N', 'AngII_N', 'ATIR_N',
                               'ADAM17_N', 'TNF_N', 'EGF_N',
                               'EGFR_N', 'IL_6_N', 'IL6_STAT3_N',
                               'NF_xB_N', 'IL_6_AMP_N', 'ARDS_N'],
                ['ACE2'], [1.], 'ARDS')
print(ARDS_ACE_2_1)



0.044


As expected it came out to be close to 0.

One of the major pathways for NF-κB activation after coronavirus infection is the MyD88 pathway through pattern recognition receptors (PRRs), leading to the induction of a variety of pro-inflammatory cytokines, including interleukin-6 (IL-6), tumor necrosis factor alpha (TNFα) and chemokines. 

This means that if PRR is set to inactive or 0, ARDS should be close to 0 too.

So, checking P ($ARDS_{PRR = 0}$ = 1 | PRR = 1, NF_xB = 1, ARDS = 1)

In [0]:
ARDS_PRR_0 = counterfactual(SCM_model, ['PRR', 'NF_xB', 'ARDS'],
                              [1., 1., 1.],
                              ['SARS_N', 'PRR_N', 'deg_N',
                               'ACE2_N', 'AngII_N', 'ATIR_N',
                               'ADAM17_N', 'TNF_N', 'EGF_N',
                               'EGFR_N', 'IL_6_N', 'IL6_STAT3_N',
                               'NF_xB_N', 'IL_6_AMP_N', 'ARDS_N'],
                ['PRR'], [0.], 'ARDS')
print(ARDS_PRR_0)



0.0


It is indeed 0. 
This indicates that our Structural Causal Model is able to produce the results we expected following the knowledge graph. 

Now, we will move on and ask questions we are actually interested in.

We would like to know if there is a node in our pathways where we can intervene with some drug and break the path and hence the cytokine storm. The paper suggests that we can intervene on sIL-6R$\alpha$ or IL6-STAT3 safely without facing any side-effects. 

So, we set up our counterfactual query with the knowledge that if SARS_2 is active (set to 1), ACE_2 will be inactive (set to 0) and it can lead to ARDS to be active (set to 1). So what will the probability of ARDS be if soluble IL-6R$\alpha$ is set to low (0)?

So, checking 
P ($ARDS_{sIL-6R\alpha = 0}$= 1 | ACE_2 = 0, SARS_2 = 1, ARDS = 1)

In [0]:
ARDS_sIL6_0 = counterfactual(SCM_model, ['SARS_2', 'ACE2', 'ARDS'],
                              [1., 0., 1.],
                              ['SARS_N', 'PRR_N', 'deg_N',
                               'ACE2_N', 'AngII_N', 'ATIR_N',
                               'ADAM17_N', 'TNF_N', 'EGF_N',
                               'EGFR_N', 'IL_6_N', 'IL6_STAT3_N',
                               'NF_xB_N', 'IL_6_AMP_N', 'ARDS_N'],
                ['IL_6'], [0.], 'ARDS')
print(ARDS_sIL6_0)



0.0


We notice that the probability of ARDS is 0 for this setting so this sIL-6R$\alpha$ is definitely worth consideration. 

Next we check given the same setting as above what the probability of ARDS be if we set IL6-STAT3 to low(0).

P ($ARDS_{IL6-STAT3 = 0}$= 1 | ACE_2 = 0, SARS_2 = 1, ARDS = 1)

In [0]:
ARDS_IL6_STAT3_0 = counterfactual(SCM_model, ['SARS_2', 'ACE2', 'ARDS'],
                              [1., 0., 1.],
                              ['SARS_N', 'PRR_N', 'deg_N',
                               'ACE2_N', 'AngII_N', 'ATIR_N',
                               'ADAM17_N', 'TNF_N', 'EGF_N',
                               'EGFR_N', 'IL_6_N', 'IL6_STAT3_N',
                               'NF_xB_N', 'IL_6_AMP_N', 'ARDS_N'],
                ['IL6_STAT3'], [0.], 'ARDS')
print(ARDS_IL6_STAT3_0)



0.0


In this case too, we observed a 0 probability of ARDS.

# Automatically build the SCM

The model is constructed using a speciifc BEL graph. In order to generalize the process of assigning weights and threshold to the edges for creating the SCM we decided to use KL divergence and Linear Regression. This will help us to automate the weight and threshold distribution process for any given BEL graph. We are working on automating the process to identify and modify the cumulative relationship of the parent node with the child node(i.e. defining the node as an AND gate or an OR gate) since the cumulative relationship of the node is very specific to the knowledge of the graph.

In [0]:
def get_distribution(node_type):
  '''
    Description: This function is to get the distribution for a node based on its type
    Parameters: the node's type
    Returns: sampled values for node in tensor format based on its type
    '''
    if node_type == 'Abundance':
        return dist.LogNormal(torch.tensor(0.0),torch.tensor(1.0))
    if node_type == 'Process':
        return dist.Categorical(torch.tensor([0.5,0.5]))
    else:
        return dist.LogNormal(torch.tensor(0.0),torch.tensor(1.0))

In [0]:
def SCM_rowwise(parent_label, child_label, relation, threshold, w):
    """
    Description: This function is to be build a Structural Causal Model for 
                  for any child-parent cluster
    Parameters: list of parent labels holding same relationship with child
                (parent_label), 
                label of child (child_label), 
                relationship of parents with child (relation), 
                threshold for cut-off (threshold), 
                weights for parents' equations (w) 
    Returns: sampled values for all nodes in tensor format
    """
    parents = []
    process = []
    abundance = []
    transformation = []
    
    for i in range(len(parent_label)):
        parent = pyro.sample("par_%s" % i, get_distribution(parent_label[i]))
        parents.append(parent)
        if parent_label[i] == 'Process':
            process.append(parent)
        if parent_label[i] == 'Abundance':
            abundance.append(parent)
        if parent_label[i] == 'Transformation':
            transformation.append(parent)
    if child_label != 'Transformation':
        if relation == 'increases' or relation == 'directlyIncreases':
            # process is more like an OR switch
            if sum(process) > 0 :
                child_N = sum([x*y for x,y in zip(w ,abundance)]) + \
                pyro.sample("child_n", dist.LogNormal(torch.tensor(0.0),
                                                    torch.tensor(1.0)))
                  # + sum([x*y*y for x,y in zip(w ,transformation)])
            else: 
                child_N = pyro.sample("child_n", dist.LogNormal(
                  torch.tensor(0.0),torch.tensor(1.0))) 
            if child_label == 'Process':
                check = lambda x : 1.0 if x > threshold else 0.0
                child_N = torch.tensor(check(child_N))
            child = pyro.sample('child',pyro.distributions.Delta(child_N))

        if relation == 'decreases' or relation == 'directlyDecreases':
            if sum(process) > 0:
                child_N = pyro.sample("child_n", dist.LogNormal(
                  torch.tensor(0.0),torch.tensor(1.0)))
            else:
                child_N =  sum([x*y for x,y in zip(w ,abundance)]) + \
                pyro.sample("child_n", dist.LogNormal(
                  torch.tensor(0.0),torch.tensor(1.0)))
                  # + sum([x*y*y for x,y in zip(w ,transformation)])
            if child_label == 'Process':
                check = lambda x : 0.0 if x > threshold else 1.0
                child_N = torch.tensor(check(child_N))
            child = pyro.sample('child',pyro.distributions.Delta(child_N))


    else:
        if relation == 'increases' or relation == 'directlyIncreases':
            ## process is more like an OR switch
            if sum(process) > 0 :
                child_N =  sum([x*y for x,y in zip(w ,abundance)]) + \
                pyro.sample("child_n", dist.LogNormal(
                  torch.tensor(0.0),torch.tensor(1.0))) + \
                sum([x*y*y for x,y in zip(w ,transformation)])
            else: 
                child_N = pyro.sample("child_n", dist.LogNormal(
                    torch.tensor(0.0),torch.tensor(1.0))) 
            if child_label == 'Process':
                check = lambda x : 1.0 if x > threshold else 0.0
                child_N = torch.tensor(check(child_N))
            child = pyro.sample('child',pyro.distributions.Delta(child_N))

        if relation == 'decreases' or relation == 'directlyDecreases':
            if sum(process) > 0:
                child_N = pyro.sample("child_n", dist.LogNormal(
                    torch.tensor(0.0),torch.tensor(1.0))) 
            else:
                child_N =  sum([x*y for x,y in zip(w ,abundance)]) + \
                pyro.sample("child_n", dist.LogNormal(
                    torch.tensor(0.0),torch.tensor(1.0))) + \
                    sum([x*y*y for x,y in zip(w ,transformation)])
            if child_label == 'Process':
                check = lambda x : 0.0 if x > threshold else 1.0
                child_N = torch.tensor(check(child_N))
            child = pyro.sample('child',pyro.distributions.Delta(child_N))
    # use the return when doing the KL divergence
#     return child.numpy()
    return{'parents': parents[0],'child': child}

## KL Divergence

Adding a KL Divergence test to check the correctness of distributions we get from SCM against the true distribution of our child nodes. 

In [0]:
# get the model for certain weights and thresholds
SCM_list23 = []
true_children = []
for i in range(newcovid_df.shape[0]):
    test = newcovid_df.iloc[i]
    SCM_list23.append(SCM_rowwise(test['Parent_label'],
                                  test['Child_label'],
                                  test['Relation'], 2,
                                  [3]*len(test['Parent_label'])))
    true_children.append(pyro.sample("child",
                                     get_distribution(
                                         test['Child_label'])))

In [0]:
def KL(P,Q):
    """
    Description: Function to check KL-Divergence 
                 Epsilon is used here to avoid conditional code for
                 checking that neither P nor Q is equal to 0.
    Parameters:  Two probability distributions P and Q
    Returns:     Divergence between the distributions. 0 if they are same
    
    """
    epsilon = 0.00001

    P = P+epsilon
    Q = Q+epsilon

    divergence = np.sum(P*np.log(P/Q))
    return divergence

In [0]:
# get the KL divergence which is quite small
print(KL(np.asarray(SCM_list23[5]), np.asarray(true_children[5])))

##  Linear regression to get weights from SCM

In [0]:
# get the linear regressor
regressor = LinearRegression()  

In [0]:
# get 100 samples from one row-wised model
samples_2= []
test = newcovid_df.iloc[16]
for i in range(100):
    samples_2.append((SCM_rowwise([test['Parent_label']], [test['Child_label']], 
                        test['Relation'],2,[2]*len(test['Parent_label']))))
samples_2_df = pd.DataFrame(samples_2)
samples_2_df.head()

In [0]:
#To retrieve the intercept:
print(sample2_fit.intercept_)
#For retrieving the slope:
print(sample2_fit.coef_)