In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# IHDP: RieszNet

## Library Imports

In [1]:
from pathlib import Path
import os
import glob
from joblib import dump, load
import pandas as pd
import scipy
import scipy.stats
import scipy.special
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from utils.riesznet import RieszNet
from utils.moments import ate_moment_fn
from utils.ihdp_data import *

## Moment Definition

In [3]:
moment_fn = ate_moment_fn

## MAE Experiment

In [4]:
data_base_dir = "./data/IHDP/sim_data"
simulation_files = sorted(glob.glob("{}/*.csv".format(data_base_dir)))

### Estimator Settings

In [5]:
drop_prob = 0.0  # dropout prob of dropout layers throughout notebook
n_hidden = 100  # width of hidden layers throughout notebook

# Training params
learner_lr = 1e-5
learner_l2 = 1e-3
learner_l1 = 0.0
n_epochs = 600
earlystop_rounds = 40 # how many epochs to wait for an out-of-sample improvement
earlystop_delta = 1e-4
target_reg = 1.0
riesz_weight = 0.1

bs = 64
device = torch.cuda.current_device() if torch.cuda.is_available() else None
print("GPU:", torch.cuda.is_available())

from itertools import chain, combinations
from itertools import combinations_with_replacement as combinations_w_r

def _combinations(n_features, degree, interaction_only):
        comb = (combinations if interaction_only else combinations_w_r)
        return chain.from_iterable(comb(range(n_features), i)
                                   for i in range(0, degree + 1))

class Learner(nn.Module):

    def __init__(self, n_t, n_hidden, p, degree, interaction_only=False):
        super().__init__()
        n_common = 200
        self.monomials = list(_combinations(n_t, degree, interaction_only))
        self.common = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.riesz_nn = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))
        self.reg_nn0 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_poly = nn.Sequential(nn.Linear(len(self.monomials), 1))


    def forward(self, x):
        poly = torch.cat([torch.prod(x[:, t], dim=1, keepdim=True)
                          for t in self.monomials], dim=1)
        feats = self.common(x)
        riesz = self.riesz_nn(feats) + self.riesz_poly(poly)
        reg = self.reg_nn0(feats) * (1 - x[:, [0]]) + self.reg_nn1(feats) * x[:, [0]] + self.reg_poly(poly)
        return torch.cat([reg, riesz], dim=1)

GPU: False


In [6]:
nsims = 1000
np.random.seed(123)
sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)
methods = ['dr', 'direct', 'ips']
srr = {'dr' : True, 'direct' : False, 'ips' : True}

true_ATEs = []
results = []

for it, sim in enumerate(sim_ids):
    simulation_file = simulation_files[sim]
    x = load_and_format_covariates(simulation_file, delimiter=' ')
    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')
    X = np.c_[t, x]
    true_ATE = np.mean(mu_1 - mu_0)
    true_ATEs.append(true_ATE)

    y_scaler = StandardScaler(with_mean=True).fit(y)
    y = y_scaler.transform(y)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

    torch.cuda.empty_cache()
    learner = Learner(X_train.shape[1], n_hidden, drop_prob, 0, interaction_only=True)
    agmm = RieszNet(learner, moment_fn)
    # Fast training
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=2, earlystop_delta=earlystop_delta,
             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    # Fine tune
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)
    
    params = tuple(x * y_scaler.scale_[0] for method in methods
                   for x in agmm.predict_avg_moment(X, y,  model='earlystop', method = method, srr = srr[method])) + (true_ATE, )
                        
    results.append(params)

res = tuple(np.array(x) for x in zip(*results))
truth = res[-1:]
res_dict = {}
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,
                        'MAE': np.mean(np.abs(point - truth)),
                        'std. err.': np.std(np.abs(point - truth)) / np.sqrt(nsims),
                        }
    print("{} : MAE = {:.3f} +/- {:.3f}".format(method, res_dict[method]['MAE'], res_dict[method]['std. err.']))

  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,


  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,


  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,


  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,
  return torch.load(os.path.join(self.model_dir,


  return torch.load(os.path.join(self.model_dir,

KeyboardInterrupt



In [7]:
res = tuple(np.array(x) for x in zip(*results))
truth = res[-1:]
res_dict = {}
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,
                        'MAE': np.mean(np.abs(point - truth)),
                        'std. err.': np.std(np.abs(point - truth)) / np.sqrt(nsims),
                        }
    print("{} : MAE = {:.3f} +/- {:.3f}".format(method, res_dict[method]['MAE'], res_dict[method]['std. err.']))

dr : MAE = 0.164 +/- 0.006
direct : MAE = 0.159 +/- 0.003
ips : MAE = 0.797 +/- 0.038


In [14]:
split_data = np.array_split(data, k)
for i, subset in enumerate(split_data):
    print(i)

[autoreload of sklearn.utils.class_weight failed: Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/opt/anaconda3/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 621, in _exec
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/opt/anaconda3/lib/python3.11/site-packages/sklearn/utils/class_weight.py", line 13, in <module>
    @validate_params(
     ^^^^^^^^^^^^^^^^
TypeError: validate_params() got an unexpected keyword argument 'prefer_skip_nested_validation'
]
[autoreload of threa

[autoreload of sklearn.preprocessing._encoders failed: Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/opt/anaconda3/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 621, in _exec
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/opt/anaconda3/lib/python3.11/site-packages/sklearn/preprocessing/_encoders.py", line 12, in <module>
    from ..base import BaseEstimator, OneToOneFeatureMixin, TransformerMixin, _fit_context
ImportError: cannot import name '_fit_context' from 'sklearn.

[autoreload of sklearn.metrics._plot.roc_curve failed: Traceback (most recent call last):
  File "/opt/anaconda3/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 276, in check
    superreload(m, reload, self.old_objects)
  File "/opt/anaconda3/lib/python3.11/site-packages/IPython/extensions/autoreload.py", line 475, in superreload
    module = reload(module)
             ^^^^^^^^^^^^^^
  File "/opt/anaconda3/lib/python3.11/importlib/__init__.py", line 169, in reload
    _bootstrap._exec(spec, module)
  File "<frozen importlib._bootstrap>", line 621, in _exec
  File "<frozen importlib._bootstrap_external>", line 940, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/opt/anaconda3/lib/python3.11/site-packages/sklearn/metrics/_plot/roc_curve.py", line 1, in <module>
    from ...utils._plotting import _BinaryClassifierCurveDisplayMixin
  File "/opt/anaconda3/lib/python3.11/site-packages/sklearn/utils/_plotting.py", li

NameError: name 'data' is not defined

In [None]:
path = './results/IHDP/RieszNet/MAE'

if not os.path.exists(path):
    os.makedirs(path)
            
dump(res_dict, path + '/IHDP_MAE_NN.joblib')

### Table

In [None]:
path = './results/IHDP/RieszNet/MAE'

if not os.path.exists(path):
    os.makedirs(path)
    
methods_str = ["DR", "Direct", "IPS"] 

with open(path + '/IHDP_MAE_NN.tex', "w") as f:
    f.write("\\begin{tabular}{lc} \n" +
            "\\toprule \n" +
            "& MAE $\\pm$ std. err. \\\\ \n" +
            "\\midrule \n" +
            "\\multicolumn{2}{l}{\\textbf{Auto-DML:}} \\\\ \n")
    
    for i, method in enumerate(methods):
        f.write(" & ".join([methods_str[i], "{:.3f} $\\pm$ {:.3f}".format(res_dict[method]['MAE'], 
                                                                          res_dict[method]['std. err.'])]) + " \\\\ \n")

    f.write("\\multicolumn{2}{l}{\\textbf{Benchmark:}} \\\\"
            + "\n Dragonnet & 0.146 & 0.010 \\\\ \n \\bottomrule \n \\end{tabular}")

## Coverage Experiment

In [None]:
data_base_dir = "./data/IHDP/sim_data_redraw_T"
simulation_files = sorted(glob.glob("{}/*.csv".format(data_base_dir)))

In [None]:
def rmse_fn(y_pred, y_true):
    return np.sqrt(np.mean((y_pred - y_true)**2))

nsims = 100
np.random.seed(123)
sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)
methods = ['dr', 'direct', 'ips']
srr = {'dr' : True, 'direct' : False, 'ips' : True}

true_ATEs = []
results = []

for it, sim in enumerate(sim_ids):
    simulation_file = simulation_files[sim]
    x = load_and_format_covariates(simulation_file, delimiter=' ')
    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')
    X = np.c_[t, x]
    true_ATE = np.mean(mu_1 - mu_0)
    true_ATEs.append(true_ATE)

    y_scaler = StandardScaler(with_mean=True).fit(y)
    y = y_scaler.transform(y)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

    torch.cuda.empty_cache()
    learner = Learner(X_train.shape[1], n_hidden, drop_prob, 0, interaction_only=True)
    agmm = RieszNet(learner, moment_fn)
    # Fast training
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=2, earlystop_delta=earlystop_delta,
             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    # Fine tune
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, target_reg=target_reg,
             riesz_weight=riesz_weight, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)
    
    params = tuple(x * y_scaler.scale_[0] for method in methods
                   for x in agmm.predict_avg_moment(X, y,  model='earlystop', method = method, srr = srr[method])) + (true_ATE, )
                        
    results.append(params)
                        
res = tuple(np.array(x) for x in zip(*results))
truth = res[-1:]
res_dict = {}
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub,
                        'cov': np.mean(np.logical_and(truth >= lb, truth <= ub)),
                        'bias': np.mean(point - truth),
                        'rmse': rmse_fn(point, truth)
                        }
    print("{} : bias = {:.3f}, rmse = {:.3f}, cov = {:.3f}".format(method, res_dict[method]['bias'], res_dict[method]['rmse'], res_dict[method]['cov']))

In [None]:
path = './results/IHDP/RieszNet/coverage'

if not os.path.exists(path):
    os.makedirs(path)
    
dump(res_dict, path + '/IHDP_coverage_NN.joblib')

### Histogram

In [None]:
path = './results/IHDP/RieszNet/coverage'

if not os.path.exists(path):
    os.makedirs(path)
    
method_strs = ["{}. Bias: {:.3f}, RMSE: {:.3f}, Coverage: {:.3f}".format(method, d['bias'], d['rmse'], d['cov'])
               for method, d in res_dict.items()]
plt.title("\n".join(method_strs))
for method, d in res_dict.items():
    plt.hist(np.array(d['point']), alpha=.5, label=method)
plt.axvline(x = np.mean(truth), label='true', color='red')
plt.legend()
plt.savefig(path + '/IHDP_coverage_NN.pdf', bbox_inches='tight')
plt.show()

## Ablation

### Effect of ‘end-to-end’ learning of shared representation

In [None]:
class RieszLearner(nn.Module):

    def __init__(self, n_t, p):
        super().__init__()
        n_common = 200
        self.common = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.riesz_nn0 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))

    def forward(self, x):
        feats = self.common(x)
        riesz = self.riesz_nn0(feats) * (1 - x[:, [0]]) + self.riesz_nn1(feats) * x[:, [0]]
        return torch.cat([riesz, feats], dim = 1)

class RegLearner(nn.Module):

    def __init__(self, n_hidden, p):
        super().__init__()
        n_common = 200
        self.reg_nn0 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))

    def forward(self, x):
        t = x[:, [0]]
        riesz = x[:, [1]]
        feats = x[:, 2:]
        reg = self.reg_nn0(feats) * (1 - t) + self.reg_nn1(feats) * t
        return torch.cat([reg, riesz], dim = 1)

In [None]:
data_base_dir = "./data/IHDP/sim_data_redraw_T"
simulation_files = sorted(glob.glob("{}/*.csv".format(data_base_dir)))

In [None]:
def rmse_fn(y_pred, y_true):
    return np.sqrt(np.mean((y_pred - y_true)**2))

nsims = 100
np.random.seed(123)
sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)
methods = ['dr', 'direct', 'ips']
srr = {'dr' : True, 'direct' : False, 'ips' : True}

true_ATEs = []
results = []

for it, sim in enumerate(sim_ids):
    simulation_file = simulation_files[sim]
    x = load_and_format_covariates(simulation_file, delimiter=' ')
    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')
    X = np.c_[t, x]
    true_ATE = np.mean(mu_1 - mu_0)
    true_ATEs.append(true_ATE)

    y_scaler = StandardScaler(with_mean=True).fit(y)
    y = y_scaler.transform(y)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

    torch.cuda.empty_cache()

    # Train Riesz
    rrlearner = RieszLearner(X_train.shape[1], drop_prob)
    rrnn = RieszNetRR(rrlearner, moment_fn)
    ## Fast training
    rrnn.fit(X_train, Xval=X_test,
             earlystop_rounds=2, earlystop_delta=1e-2,
             learner_lr=1e-1, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    ## Fine tune
    rrnn.fit(X_train, Xval=X_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=1e-2,
             learner_lr=1e-3, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)

    # Train Reg
    reglearner = RegLearner(n_hidden, drop_prob)
    regnn = RieszNet(reglearner, moment_fn)
    
    inputs = np.hstack((X[:, [0]], rrnn.predict(X, model = 'earlystop')))
    input_train = np.hstack((X_train[:, [0]], rrnn.predict(X_train, model = 'earlystop')))
    input_test = np.hstack((X_test[:, [0]], rrnn.predict(X_test, model = 'earlystop')))
    
    ## Fast training
    regnn.fit(input_train, y_train, input_test, yval=y_test,
             earlystop_rounds=2, earlystop_delta=earlystop_delta,
             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, target_reg=target_reg,
             riesz_weight=0.0, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    # Fine tune
    regnn.fit(input_train, y_train, input_test, yval=y_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, target_reg=target_reg,
             riesz_weight=0.0, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)
    
    params = tuple(x * y_scaler.scale_[0] for method in methods
                   for x in regnn.predict_avg_moment(inputs, y,  model='earlystop', method = method, srr = srr[method])) + (true_ATE, )
                        
    results.append(params)
                        
res = tuple(np.array(x) for x in zip(*results))
truth = res[-1:]
res_dict = {}
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub, 'truth': truth,
                        'cov': np.mean(np.logical_and(truth >= lb, truth <= ub)),
                        'bias': np.mean(point - truth),
                        'rmse': rmse_fn(point, truth)
                        }
    print("{} : bias = {:.3f}, rmse = {:.3f}, cov = {:.3f}".format(method, res_dict[method]['bias'], res_dict[method]['rmse'], res_dict[method]['cov']))

In [None]:
path = './results/IHDP/RieszNet/ablation'

if not os.path.exists(path):
    os.makedirs(path)
    
dump(res_dict, path + '/IHDP_shared_ablation.joblib')

['./results/IHDP/RieszNet/ablation/IHDP_shared_ablation.joblib']

### Effect of ‘end-to-end’ learning of TMLE adjustment

In [None]:
def rmse_fn(y_pred, y_true):
    return np.sqrt(np.mean((y_pred - y_true)**2))

nsims = 100
np.random.seed(123)
sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)
methods = ['dr', 'direct', 'ips']

true_ATEs = []
results = []

for it, sim in enumerate(sim_ids):
    simulation_file = simulation_files[sim]
    x = load_and_format_covariates(simulation_file, delimiter=' ')
    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')
    X = np.c_[t, x]
    true_ATE = np.mean(mu_1 - mu_0)
    true_ATEs.append(true_ATE)

    y_scaler = StandardScaler(with_mean=True).fit(y)
    y = y_scaler.transform(y)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

    torch.cuda.empty_cache()
    learner = Learner(X_train.shape[1], n_hidden, drop_prob)
    agmm = RieszNet(learner, moment_fn)
    # Fast training
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=2, earlystop_delta=earlystop_delta,
             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, target_reg=0.0,
             riesz_weight=riesz_weight, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    # Fine tune
    agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, target_reg=0.0,
             riesz_weight=riesz_weight, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)
    
    params = tuple(x * y_scaler.scale_[0] for method in methods
                   for x in agmm.predict_avg_moment(X, y,  model='earlystop', method = method, srr = False, postTMLE = True)) + (true_ATE, )
                        
    results.append(params)
                        
res = tuple(np.array(x) for x in zip(*results))
truth = res[-1:]
res_dict = {}
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub, 'truth': truth,
                        'cov': np.mean(np.logical_and(truth >= lb, truth <= ub)),
                        'bias': np.mean(point - truth),
                        'rmse': rmse_fn(point, truth)
                        }
    print("{} : bias = {:.3f}, rmse = {:.3f}, cov = {:.3f}".format(method, res_dict[method]['bias'], res_dict[method]['rmse'], res_dict[method]['cov']))

dr : bias = -0.058, rmse = 0.188, cov = 0.930
direct : bias = -0.058, rmse = 0.188, cov = 0.670
ips : bias = -0.171, rmse = 0.321, cov = 0.990


In [None]:
path = './results/IHDP/RieszNet/ablation'

if not os.path.exists(path):
    os.makedirs(path)
    
dump(res_dict, path + '/IHDP_postTMLE_ablation.joblib')

['./results/IHDP/RieszNet/ablation/IHDP_postTMLE_ablation.joblib']

### Separate Networks

In [None]:
class RieszLearner(nn.Module):

    def __init__(self, n_t, p):
        super().__init__()
        n_common = 200
        self.common = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.riesz_nn0 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))

    def forward(self, x):
        feats = self.common(x)
        riesz = self.riesz_nn0(feats) * (1 - x[:, [0]]) + self.riesz_nn1(feats) * x[:, [0]]
        return riesz

class RegLearner(nn.Module):

    def __init__(self, n_t, n_hidden, p):
        super().__init__()
        n_common = 200
        self.common = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.reg_nn0 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))

    def forward(self, x):
        X = x[:, 0:-1]
        riesz = x[:, [-1]]
        feats = self.common(X)
        reg = self.reg_nn0(feats) * (1 - X[:, [0]]) + self.reg_nn1(feats) * X[:, [0]]
        return torch.cat([reg, riesz], dim = 1)

In [None]:
def rmse_fn(y_pred, y_true):
    return np.sqrt(np.mean((y_pred - y_true)**2))

nsims = 100
np.random.seed(123)
sim_ids = np.random.choice(len(simulation_files), nsims, replace=False)
methods = ['dr', 'direct', 'ips']
srr = {'dr' : True, 'direct' : False, 'ips' : True}

true_ATEs = []
results = []

for it, sim in enumerate(sim_ids):
    simulation_file = simulation_files[sim]
    x = load_and_format_covariates(simulation_file, delimiter=' ')
    t, y, y_cf, mu_0, mu_1 = load_other_stuff(simulation_file, delimiter=' ')
    X = np.c_[t, x]
    true_ATE = np.mean(mu_1 - mu_0)
    true_ATEs.append(true_ATE)

    y_scaler = StandardScaler(with_mean=True).fit(y)
    y = y_scaler.transform(y)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2)

    torch.cuda.empty_cache()

    # Train Riesz
    rrlearner = RieszLearner(X_train.shape[1], drop_prob)
    rrnn = RieszNetRR(rrlearner, moment_fn)
    ## Fast training
    rrnn.fit(X_train, Xval=X_test,
             earlystop_rounds=2, earlystop_delta=1e-2,
             learner_lr=1e-1, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    ## Fine tune
    rrnn.fit(X_train, Xval=X_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=1e-2,
             learner_lr=1e-3, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)

    # Train Reg
    reglearner = RegLearner(X_train.shape[1], n_hidden, drop_prob)
    regnn = RieszNet(reglearner, moment_fn)
    
    inputs = np.hstack((X, rrnn.predict(X, model = 'earlystop')))
    input_train = np.hstack((X_train, rrnn.predict(X_train, model = 'earlystop')))
    input_test = np.hstack((X_test, rrnn.predict(X_test, model = 'earlystop')))
    
    ## Fast training
    regnn.fit(input_train, y_train, input_test, yval=y_test,
             earlystop_rounds=2, earlystop_delta=earlystop_delta,
             learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=100, bs=bs, target_reg=target_reg,
             riesz_weight=0.0, optimizer='adam',
             model_dir=str(Path.home()), device=device, verbose=0)
    # Fine tune
    regnn.fit(input_train, y_train, input_test, yval=y_test,
             earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
             learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
             n_epochs=600, bs=bs, target_reg=target_reg,
             riesz_weight=0.0, optimizer='adam', warm_start=True,
             model_dir=str(Path.home()), device=device, verbose=0)
    
    params = tuple(x * y_scaler.scale_[0] for method in methods
                   for x in regnn.predict_avg_moment(inputs, y,  model='earlystop', method = method, srr = srr[method])) + (true_ATE, )
                        
    results.append(params)
                        
res = tuple(np.array(x) for x in zip(*results))
truth = res[-1:]
res_dict = {}
for it, method in enumerate(methods):
    point, lb, ub = res[it * 3: (it + 1)*3]
    res_dict[method] = {'point': point, 'lb': lb, 'ub': ub, 'truth': truth,
                        'cov': np.mean(np.logical_and(truth >= lb, truth <= ub)),
                        'bias': np.mean(point - truth),
                        'rmse': rmse_fn(point, truth)
                        }
    print("{} : bias = {:.3f}, rmse = {:.3f}, cov = {:.3f}".format(method, res_dict[method]['bias'], res_dict[method]['rmse'], res_dict[method]['cov']))

dr : bias = -0.176, rmse = 0.411, cov = 0.880
direct : bias = -0.125, rmse = 0.190, cov = 0.710
ips : bias = -0.034, rmse = 1.739, cov = 0.690


In [None]:
path = './results/IHDP/RieszNet/ablation'

if not os.path.exists(path):
    os.makedirs(path)
    
dump(res_dict, path + '/IHDP_separateNNs_ablation.joblib')

['./results/IHDP/RieszNet/ablation/IHDP_separateNNs_ablation.joblib']

#### Table

In [None]:
path = './results/IHDP/RieszNet/ablation'

if not os.path.exists(path):
    os.makedirs(path)
    
methods_str = ["DR", "Direct", "IPS"] 
methods = ['direct', 'ips', 'dr']
files = ['./results/IHDP/RieszNet/coverage/IHDP_coverage_NN.joblib',
         path + '/IHDP_separateNNs_ablation.joblib',
         path + '/IHDP_shared_ablation.joblib',
         path + '/IHDP_postTMLE_ablation.joblib']
names = ["RieszNet", "Separate NNs", "No end-to-end", "TMLE post-proc."]

with open(path + "/ablation.tex", "w") as f:
    f.write("\\begin{tabular}{*{10}{r}} \n" +
            "\\toprule \n" +
            "& \\multicolumn{3}{c}{Direct} & \\multicolumn{3}{c}{IPS} & \\multicolumn{3}{c}{DR} \\\\ \n" +
            "\\cmidrule(lr){2-4} \\cmidrule(lr){5-7} \\cmidrule(lr){8-10} \n" +
            "&  Bias &  RMSE &  Cov. &  Bias &  RMSE &  Cov. &  Bias &  RMSE &  Cov. \\\\ \n" +
            "\\midrule \n")
    
    for i in range(4):
        loaded = load(files[i])
        f.write(names[i] + " & ")    
        f.write(" & ".join(["{:.3f}".format(np.mean(loaded[method][x])) for method in methods
                            for x in ['bias', 'rmse', 'cov']]) + " \\\\ \n")

    f.write("\\bottomrule \n \\end{tabular}")