In [1]:
import model
from data_utils.loader import Loader
from model.deconfounder import Deconfounder
from pyro.optim import Adam

import torch
import pyro
from pyro.distributions import Normal, Bernoulli
import numpy as np
from scipy import stats, sparse
from numpy import random as npr

In [2]:
data_path = "data/ohe_movies.csv"
loader = Loader(data_path)

In [3]:
X, y = torch.tensor(loader.X), torch.tensor(loader.y)
print(X.shape, y.shape)

torch.Size([3181, 129]) torch.Size([3181])


In [4]:
num_datapoints, data_dim = X.shape

holdout_portion = 0.1
n_holdout = int(holdout_portion * num_datapoints * data_dim)

holdout_row = np.random.randint(num_datapoints, size=n_holdout)
holdout_col = np.random.randint(data_dim, size=n_holdout)
holdout_mask = (sparse.coo_matrix((np.ones(n_holdout), \
                            (holdout_row, holdout_col)), \
                            shape = X.shape)).toarray()

holdout_subjects = np.unique(holdout_row)
holdout_mask = np.minimum(1, holdout_mask)

x_train = np.multiply(1-holdout_mask, X)
x_val = np.multiply(holdout_mask, X)

In [5]:
# linear conf
step1_opt = Adam({"lr": 0.0005})
step2_opt = Adam({"lr": 0.005})
# seed def = 3493204
deconfounder = Deconfounder(step1_opt, step2_opt, 
                            seed=5323,
                            step1_iters=500, step2_iters=500)

In [6]:
step1_params, step2_params = deconfounder.train(X, y, mask=torch.Tensor(1-holdout_mask))


 Training Z marginal and W parameter marginal...
[iteration 0001] loss: 378.1893
[iteration 0101] loss: 357.4209
[iteration 0201] loss: 336.0054
[iteration 0301] loss: 325.3372
[iteration 0401] loss: 312.4947
Updating value of hypermeterqz_mean
Updating value of hypermeterqz_stddv
Updating value of hypermeterqw_mean
Updating value of hypermeterqw_stddv
Training Bayesian regression parameters...
[iteration 0001] loss: 484.5499




[iteration 0101] loss: 385.2076
[iteration 0201] loss: 362.6943
[iteration 0301] loss: 334.9950
[iteration 0401] loss: 328.3191
Updating value of hypermeter: w_loc
Updating value of hypermeter: w_scale
Updating value of hypermeter: b_loc
Updating value of hypermeter: b_scale
Updating value of hypermeter: sigma_loc
Updating value of hypermeter: sigma_scale
Training complete.


In [7]:
step1_params['z_mean0'].shape

torch.Size([3181, 50])

In [8]:
step1_params['z_mean0'].shape

torch.Size([3181, 50])

In [9]:
n_rep = 100 # number of replicated datasets we generate
holdout_gen = np.zeros((n_rep,*(x_train.shape)))

for i in range(n_rep):
    w_sample = pyro.sample('w', Normal(step1_params['w_mean0'], step1_params['w_std0']))
    z_sample = pyro.sample('z', Normal(step1_params['z_mean0'], step1_params['z_std0']))
    linear_exp = torch.matmul(z_sample, w_sample)
    x_generated = pyro.sample("x", Bernoulli(logits = linear_exp))

    # look only at the heldout entries
    holdout_gen[i] = np.multiply(x_generated, holdout_mask)

In [10]:
from tqdm import tqdm

n_eval = 100 # we draw samples from the inferred Z and W
obs_ll = []
rep_ll = []
for j in tqdm(range(n_eval)):
    w_sample = pyro.sample('w', Normal(step1_params['w_mean0'], step1_params['w_std0']))
    z_sample = pyro.sample('z', Normal(step1_params['z_mean0'], step1_params['z_std0']))
    linear_exp = torch.matmul(z_sample, w_sample)
    x_generated = np.multiply(pyro.sample("x", Bernoulli(logits = linear_exp)), holdout_mask)
    obs_ll.append(np.mean(stats.norm(x_generated).logpdf(x_val), axis=1))
    rep_ll.append(np.mean(stats.norm(x_generated).logpdf(holdout_gen), axis=2))
    

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [11]:
obs_ll_per_zi, rep_ll_per_zi = np.mean(np.array(obs_ll), axis=0), np.mean(np.array(rep_ll), axis=0)

pvals = np.array([np.mean(rep_ll_per_zi[:,i] < obs_ll_per_zi[i]) for i in range(num_datapoints)])
holdout_subjects = np.unique(holdout_row)
overall_pval = np.mean(pvals[holdout_subjects])
print("Predictive check p-values", overall_pval)

Predictive check p-values 0.49686262181703866
