In [1]:
import pyro
import torch
import pyro.distributions as dist
from scm import BoolDomain
import numpy as np
from npsem.where_do import POMISs
from optimiser import CausalOptimiser, Objective, SCM
import networkx as nx
from tqdm.auto import tqdm
import pandas as pd

def model(count):
    with pyro.plate("samples", count):
        age = pyro.sample("age", dist.Uniform(55, 76))
        bmi = pyro.sample("bmi", dist.Normal(27, 0.7))

        aspirin_p = torch.sigmoid(-8.0 + 0.10 * age + 0.03 * bmi)
        statin_p = torch.sigmoid(-13.0 + 0.10 * age + 0.20 * bmi)

        aspirin = pyro.sample("aspirin", dist.Bernoulli(aspirin_p))
        statin = pyro.sample("statin", dist.Bernoulli(statin_p))

        cancer_p = torch.sigmoid(2.2 - 0.05 * age + 0.01 * bmi - 0.04 * statin + 0.02 * aspirin)
        cancer = pyro.sample("cancer", dist.Bernoulli(cancer_p))

        mean = 6.8 + 0.04 * age - 0.15 * bmi - 0.60 * statin + 0.55 * aspirin + 1.00 * cancer
        y = pyro.sample("y", dist.Normal(mean, 0.4))
    return y

In [2]:
values = []
for age_data in tqdm(np.arange(55, 76, 0.5)):
    for bmi_data in np.arange(22, 31, 0.5):
        age_data = torch.tensor(age_data)
        bmi_data = torch.tensor(bmi_data)
        model = pyro.condition(model, data={'bmi': bmi_data, 'age': age_data})
        scm = SCM(model, [BoolDomain('aspirin'), BoolDomain('statin')], non_man={'age', 'bmi', 'cancer'})
        project = scm.induced_projection()
        pomises = POMISs(project, 'y')
        results = {}
        obj = Objective(scm, number_of_samples=5000)
        for pomis in pomises:
            results[pomis] = CausalOptimiser.optimise_for(pomis, obj, 1, 4)
        resultat = CausalOptimiser.parse_results(results).sort_values('optimas', ascending=False).iloc[0].to_dict()
        resultat['age'] = age_data.numpy()
        resultat['bmi'] = bmi_data.numpy()
        values.append(resultat)

  0%|          | 0/42 [00:00<?, ?it/s]

In [3]:
df_vals = pd.DataFrame(values)
df_vals.cont_values.value_counts()

{'statin': False, 'aspirin': True}    756
Name: cont_values, dtype: int64

In [4]:
df_vals

Unnamed: 0,POMIS,optimas,cont_values,age,bmi
0,"(statin, aspirin)",6.671428,"{'statin': False, 'aspirin': True}",55.0,22.0
1,"(statin, aspirin)",6.599313,"{'statin': False, 'aspirin': True}",55.0,22.5
2,"(statin, aspirin)",6.522711,"{'statin': False, 'aspirin': True}",55.0,23.0
3,"(statin, aspirin)",6.461853,"{'statin': False, 'aspirin': True}",55.0,23.5
4,"(statin, aspirin)",6.385441,"{'statin': False, 'aspirin': True}",55.0,24.0
...,...,...,...,...,...
751,"(statin, aspirin)",6.317378,"{'statin': False, 'aspirin': True}",75.5,28.5
752,"(statin, aspirin)",6.240059,"{'statin': False, 'aspirin': True}",75.5,29.0
753,"(statin, aspirin)",6.163151,"{'statin': False, 'aspirin': True}",75.5,29.5
754,"(statin, aspirin)",6.103097,"{'statin': False, 'aspirin': True}",75.5,30.0
