In [34]:
import os
from functools import partial
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pyro
import pyro.distributions as dist
from torch import nn
import seaborn as sns
from pyro.nn import PyroModule
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

In [35]:
#step-1 converting model to SCM 
def GroundTruthModel(data):
    alpha_1 = torch.tensor(0.75)
    beta_1 = torch.tensor(0.25)
    Nr = pyro.sample("Nr",dist.Categorical(torch.tensor([alpha_1,beta_1]))) 
    
    alpha_2 = torch.tensor(0.6)
    beta_2 = torch.tensor(0.4)
    Ns = pyro.sample("Ns",dist.Categorical(torch.tensor([alpha_2,beta_2])))
    
    alpha_3 = torch.tensor(0.)
    beta_3 = torch.tensor(1.)
    Nk = pyro.sample("Nk",dist.Normal(torch.tensor(alpha_3),torch.tensor(beta_3))) 
    
    Nr = Nr.type(torch.FloatTensor)
    Ns = Ns.type(torch.FloatTensor)
    Nk = Nk.type(torch.FloatTensor)
    
    gamma_1 =  pyro.sample("gamma_1",dist.Delta(Nr))
    gamma_2 =  pyro.sample("gamma_2",dist.Delta(Ns))
    gamma_3 =  pyro.sample("gamma_3",dist.Delta(Nk))

    
    R     = pyro.sample("R",dist.Delta(gamma_1))
    S     = pyro.sample("S",dist.Delta(gamma_2))
    K     = pyro.sample("K",dist.Delta(gamma_3))
    
    
    delta_1 = torch.tensor(0.)
    delta_2 = torch.tensor(1.)
    
    Gval  = K + 2.1* R + 3.3 * S + 0.5 * pyro.sample("g",dist.Normal(delta_1,delta_2))
    kappa_1 = torch.tensor(Gval)
    G     = pyro.sample("G",dist.Delta(kappa_1))
    
    Lval  = K + 5.8 * R + 0.7 * S + 0.1 * pyro.sample("l",dist.Normal(delta_1,delta_2))
    kappa_2 = torch.tensor(Lval)
    L     = pyro.sample("L",dist.Delta(kappa_2))
    
    Fval  = K + 2.3 * R + 1.0 * S + 0.3 * pyro.sample("f",dist.Normal(delta_1,delta_2))
    kappa_3 = torch.tensor(Fval)
    F     = pyro.sample("F",dist.Delta(kappa_3))
trace_handler = pyro.poutine.trace(GroundTruthModel)
samples = pd.DataFrame(columns=['R', 'S', 'K', 'G', 'L', 'F', 'p'])

In [36]:
#step-2 generate synthetic data from GTM
full_sample= []
for i in range(1000):
    trace = trace_handler.get_trace(1)
    R = trace.nodes['R']['value']
    S = trace.nodes['S']['value']
    K = trace.nodes['K']['value']
    G = trace.nodes['G']['value']
    L = trace.nodes['L']['value']
    F = trace.nodes['F']['value']
    # get prob of each combination
    log_prob = trace.log_prob_sum()
    p = np.exp(log_prob)
    samples = samples.append({'R': R, 'S': S, 'K': K, 'G': G, 'L':L, 'F': F, 'p': p}, ignore_index=True)
    full_sample.append(([R,S,G,L,K,F]))

samples.head()

  del sys.path[0]


Unnamed: 0,R,S,K,G,L,F,p
0,tensor(0.),tensor(1.),tensor(-0.7498),tensor(3.4070),tensor(-0.1890),tensor(-0.2442),tensor(0.0001)
1,tensor(0.),tensor(1.),tensor(-1.3202),tensor(1.4906),tensor(-0.6467),tensor(0.0407),tensor(0.0009)
2,tensor(0.),tensor(0.),tensor(0.0959),tensor(0.4512),tensor(0.1727),tensor(0.2072),tensor(0.0061)
3,tensor(0.),tensor(0.),tensor(1.2153),tensor(2.1946),tensor(1.2148),tensor(1.2494),tensor(0.0008)
4,tensor(1.),tensor(0.),tensor(-0.3527),tensor(1.5024),tensor(5.4682),tensor(1.8401),tensor(0.0029)


In [37]:
#step-3 GUIDE
def ProposedModel(data):
   
    alpha_1 = pyro.param('alpha_1', torch.tensor(0.75)) #, constraint=constraints.positive
    alpha_2 = pyro.param('alpha_2', torch.tensor(0.6))#
    alpha_3 = pyro.param('alpha_3', torch.tensor(0.))#
    
    beta_1 = pyro.param('beta_1', torch.tensor(0.25)) #
    beta_2 = pyro.param('beta_2', torch.tensor(0.4)) #
    beta_3 = pyro.param('beta_3', torch.tensor(1.)) #
    
    Nr = pyro.sample('Nr', dist.Categorical(torch.tensor([alpha_1,beta_1])))
    Ns = pyro.sample('Ns', dist.Categorical(torch.tensor([alpha_2,beta_2])))
    Nk = pyro.sample('Nk', dist.Normal(torch.tensor(alpha_3),torch.tensor(beta_3)))
    
    Nr = Nr.type(torch.FloatTensor)
    Ns = Ns.type(torch.FloatTensor)
    Nk = Nk.type(torch.FloatTensor)
    
    gamma_1 = pyro.sample('gamma_1',dist.Delta(Nr))
    gamma_2 = pyro.sample('gamma_2',dist.Delta(Nr))
    gamma_3 = pyro.sample('gamma_3',dist.Delta(Nr))
    
    
    R     = pyro.sample("R",dist.Delta(gamma_1))
    S     = pyro.sample("S",dist.Delta(gamma_2))
    K     = pyro.sample("K",dist.Delta(gamma_3))
    
    delta_1 = pyro.param('delta_1', torch.tensor(0.))#
    delta_2 = pyro.param('delta_2', torch.tensor(1.))#
    
    Gval  = K + alpha_1* R + beta_1 * S + gamma_1 * pyro.sample("g",dist.Normal(torch.tensor(delta_1),torch.tensor(delta_2)))
    
    G     = pyro.sample("G",dist.Delta(Gval))
    
    Lval  = K + alpha_2 * R + beta_2 * S + gamma_2 * pyro.sample("l",dist.Normal(torch.tensor(delta_1),torch.tensor(delta_2)))
    L     = pyro.sample("L",dist.Delta(Lval))
    
    Fval  = K + alpha_3 * R + beta_3 * S + gamma_3 * pyro.sample("f",dist.Normal(torch.tensor(delta_1),torch.tensor(delta_2)))
    F     = pyro.sample("F",dist.Delta(Lval))
trace_handler = pyro.poutine.trace(ProposedModel)
samples = pd.DataFrame(columns=['R', 'S', 'K', 'G', 'L', 'F', 'p'])

In [38]:
#Step 4 - train the ProposedModel on SD 
adam_params = {"lr": 0.0005}
optimizer = Adam(adam_params)
svi = SVI(GroundTruthModel, ProposedModel, optimizer, loss=Trace_ELBO())
n_steps = 2501
for step in range(n_steps):
    svi.step(full_sample)

  
  del sys.path[0]


In [39]:
#Printing samples from the PM model.
#Later we need to use this model to do step 5 through 8
from pyro.infer import Predictive


num_samples = 1000
predictive = Predictive(GroundTruthModel, guide=ProposedModel, num_samples=1000)
svi_samples = {k: v.reshape(num_samples).detach().cpu().numpy()
               for k, v in predictive(1).items()
               if k != "obs"}
print(svi_samples)

  
  del sys.path[0]


{'Nr': array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1,
       0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
       0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1,
       0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0,
       1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0,
       1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0,
       0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0,
       0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
       0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0