## Library Imports

In [4]:
from pathlib import Path
import os
import glob
from joblib import dump, load
import pandas as pd
import scipy
import scipy.stats
import scipy.special
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from utils.riesznet_NDE import RieszNetNDE
# from utils.moments import ate_moment_fn
from utils.moments import nde_theta1_moment_fn, nde_theta2_moment_fn
# from utils.ihdp_data import *

## Moment Definition

In [5]:
moment_fn_1 = nde_theta1_moment_fn
moment_fn_2 = nde_theta2_moment_fn
# moment_fn = ate_moment_fn

### Estimator Settings

In [10]:
drop_prob = 0.0  # dropout prob of dropout layers throughout notebook
n_hidden = 100  # width of hidden layers throughout notebook

# Training params
learner_lr = 1e-5
learner_l2 = 1e-3
learner_l1 = 0.0
n_epochs = 600
earlystop_rounds = 40 # how many epochs to wait for an out-of-sample improvement
earlystop_delta = 1e-4
target_reg_1 = 0
target_reg_2 = 0
riesz_weight_1 = 1
riesz_weight_2 = 0

bs = 16
device = torch.cuda.current_device() if torch.cuda.is_available() else None
print("GPU:", torch.cuda.is_available())

from itertools import chain, combinations
from itertools import combinations_with_replacement as combinations_w_r

def _combinations(n_features, degree, interaction_only):
        comb = (combinations if interaction_only else combinations_w_r)
        return chain.from_iterable(comb(range(n_features), i)
                                   for i in range(0, degree + 1))

class Learner(nn.Module):

    def __init__(self, n_t, n_hidden, p, degree, interaction_only=False):
        super().__init__()
        n_common = 200

        # Don't know what it is, but should be different for 1 and 2
        self.monomials1 = list(_combinations(n_t - 1, degree, interaction_only))
        self.monomials2 = list(_combinations(n_t, degree, interaction_only))

        # Common layers for g and alpha
        # theta_1 and m_1 are the function of (A, W)
        # theta_2 and m_2 are the function of (A, M, W)
        self.common1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t - 1, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.common2 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())

        # Riesz specific layers
        # alpha_1 is related to theta_1 and m_1
        # alpha_2 is related to theta_2 and m_2
        self.riesz_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_poly1 = nn.Sequential(nn.Linear(len(self.monomials1), 1))
        self.riesz_nn2 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_poly2 = nn.Sequential(nn.Linear(len(self.monomials2), 1))

        # Regression loss layers
        # Indexes are the same as Riesz specific layers
        self.reg_nn0_1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1_1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn0_2 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1_2 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))

        # Don't know what it is, but should be different for 1 and 2
        self.reg_poly1 = nn.Sequential(nn.Linear(len(self.monomials1), 1))
        self.reg_poly2 = nn.Sequential(nn.Linear(len(self.monomials2), 1))


    def forward(self, x):
        # Create a new dataset x1, which is (A, W), used for theta_1 and m_1
        # Recall: we assume x = (A, M, W)

        if torch.is_tensor(x):
            with torch.no_grad():
                x1 = torch.cat([torch.reshape(x[:, 0].to(device), (-1, 1)), x[:, 2:]], dim=1)
        else:
            x1 = np.hstack([np.array(x[:, 0]).reshape(-1, 1), x[:, 2:]])

        poly1 = torch.cat([torch.prod(x1[:, t], dim=1, keepdim=True)
                          for t in self.monomials1], dim=1)
        poly2 = torch.cat([torch.prod(x[:, t], dim=1, keepdim=True)
                          for t in self.monomials2], dim=1)

        feats1 = self.common1(x1)
        feats2 = self.common2(x)

        riesz1 = self.riesz_nn1(feats1) + self.riesz_poly1(poly1)
        riesz2 = self.riesz_nn2(feats2) + self.riesz_poly2(poly2)

        reg1 = self.reg_nn0_1(feats1) * (1 - x1[:, [0]]) + self.reg_nn1_1(feats1) * x1[:, [0]] + self.reg_poly1(poly1)
        reg2 = self.reg_nn0_2(feats2) * (1 - x[:, [0]]) + self.reg_nn1_2(feats2) * x[:, [0]] + self.reg_poly2(poly2)
        return torch.cat([reg1, riesz1, reg2, riesz2], dim=1)

GPU: False


In [11]:
# np.shape(X[:, 0])
# np.shape(X[:, 2:])

In [12]:
A = np.array(data_batch_1['A'])
W = np.array(data_batch_1['W'])
M = np.array(data_batch_1['M'])
Y = np.array(data_batch_1['Y'])

In [13]:
X = np.c_[A, M, W]
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.2)

In [14]:
torch.cuda.empty_cache()
learner = Learner(X_train.shape[1], n_hidden, drop_prob, 0, interaction_only=True)
agmm = RieszNetNDE(learner, moment_fn_1, moment_fn_2)

In [15]:
X_train.shape

(800, 3)

In [16]:
# Fast training
agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
          earlystop_rounds=2, earlystop_delta=earlystop_delta,
          learner_lr=1e-04, learner_l2=learner_l2, learner_l1=learner_l1,
          n_epochs=100, bs=bs, target_reg_1=target_reg_1,
          riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
          riesz_weight_2=riesz_weight_2, optimizer='adam',
          model_dir=str(Path.home()), device=device, verbose=1)

Epoch #0
Validation losses: 5.461832 -1.7486612 5.461832 0.08951285 -0.01699384 0.08951285 3.8026837334036827
Epoch #1
Validation losses: 1.1293243 -2.1624784 1.1293243 0.02913642 -0.07128578 0.02913642 -1.00401771068573
Epoch #2
Validation losses: 0.9952726 -2.2057724 0.9952726 0.014283759 -0.10626545 0.014283759 -1.196216064505279
Epoch #3
Validation losses: 0.92525727 -2.2239373 0.92525727 0.007310999 -0.10791551 0.007310999 -1.2913690083660185
Epoch #4
Validation losses: 0.90776825 -2.2094016 0.90776825 0.009098115 -0.074975386 0.009098115 -1.2925352426245809
Epoch #5
Validation losses: 0.8900177 -2.207999 0.8900177 0.0068795923 -0.051322855 0.0068795923 -1.3111017104238272
Epoch #6
Validation losses: 0.88108796 -2.1561184 0.88108796 0.0057955505 -0.013997099 0.0057955505 -1.2692348835989833
Epoch #7
Validation losses: 0.87864435 -2.2076812 0.87864435 0.006174771 -0.016934864 0.006174771 -1.32286206074059
Epoch #8
Validation losses: 0.87112284 -2.1922774 0.87112284 0.0043529654 0.0

<utils.riesznet.RieszNetNDE at 0x7c1ed4a49c90>

In [17]:
# riesz_weight_1 = 0
# riesz_weight_2 = 0.1

# Fine tune
agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
          earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
          learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
          n_epochs=100, bs=bs, target_reg_1=target_reg_1,
          riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
          riesz_weight_2=riesz_weight_2, optimizer='adam', warm_start=True,
          model_dir=str(Path.home()), device=device, verbose=1)

Epoch #0
Validation losses: 0.83844423 -2.215949 0.83844423 0.0022501918 0.063772544 0.0022501918 -1.3752546338364482
Epoch #1
Validation losses: 0.8388585 -2.2145693 0.8388585 0.00248772 0.06732659 0.00248772 -1.3732231250032783
Epoch #2
Validation losses: 0.8376539 -2.212738 0.8376539 0.0021959047 0.07116628 0.0021959047 -1.372888257028535
Epoch #3
Validation losses: 0.83771133 -2.2095766 0.83771133 0.0023448882 0.07475635 0.0023448882 -1.3695203843526542
Epoch #4
Validation losses: 0.83751196 -2.2101443 0.83751196 0.0025167433 0.07792105 0.0025167433 -1.37011558143422
Epoch #5
Validation losses: 0.8362738 -2.2075062 0.8362738 0.0021576753 0.08178261 0.0021576753 -1.3690747150685638
Epoch #6
Epoch 00007: reducing learning rate of group 0 to 5.0000e-06.
Epoch 00007: reducing learning rate of group 1 to 5.0000e-06.
Validation losses: 0.83609766 -2.2068818 0.83609766 0.00198797 0.08492205 0.00198797 -1.368796133901924
Epoch #7
Validation losses: 0.834116 -2.2058997 0.834116 0.0022009672

<utils.riesznet.RieszNetNDE at 0x7c1ed4a49c90>

In [18]:
# Freeze network parameters w.r.t. alpha1

for name, param in agmm.learner.named_parameters():
  if("riesz_nn1" in name or "riesz_poly1" in name or "common1" in name):
    print(name)
    param.requires_grad = False

learner.common1.1.weight
learner.common1.1.bias
learner.common1.4.weight
learner.common1.4.bias
learner.common1.7.weight
learner.common1.7.bias
learner.riesz_nn1.1.weight
learner.riesz_nn1.1.bias
learner.riesz_poly1.0.weight
learner.riesz_poly1.0.bias


In [19]:
# Fast training
riesz_weight_2 = 1
riesz_weight_1 = 0

agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
          earlystop_rounds=2, earlystop_delta=earlystop_delta,
          learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
          n_epochs=100, bs=bs, target_reg_1=target_reg_1,
          riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
          riesz_weight_2=riesz_weight_2, optimizer='adam',
          model_dir=str(Path.home()), device=device, verbose=1, warm_start=True)

Epoch #0
Validation losses: 0.920897 -2.215949 0.920897 0.0028251454 -0.4926522 0.0028251454 0.43106994475238025
Epoch #1
Validation losses: 0.9088416 -2.215949 0.9088416 0.007875676 -1.4428633 0.007875676 -0.5261460589244962
Epoch #2
Validation losses: 0.8392445 -2.215949 0.8392445 0.007722784 -3.0562994 0.007722784 -2.2093321792781353
Epoch #3
Validation losses: 0.8174927 -2.215949 0.8174927 0.02262691 -3.8060663 0.02262691 -2.965946640819311
Epoch #4
Validation losses: 0.842319 -2.215949 0.842319 0.023530405 -4.050703 0.023530405 -3.184853632003069
Epoch #5
Validation losses: 0.8424967 -2.215949 0.8424967 0.009547047 -4.0759325 0.009547047 -3.223888762295246
Epoch #6
Validation losses: 0.8387242 -2.215949 0.8387242 0.012777927 -4.117236 0.012777927 -3.2657340141013265
Epoch #7
Validation losses: 0.9018853 -2.215949 0.9018853 0.03145323 -4.2347755 0.03145323 -3.3014370426535606
Epoch #8
Validation losses: 0.8497123 -2.215949 0.8497123 0.014827074 -4.2694817 0.014827074 -3.40494227316

<utils.riesznet.RieszNetNDE at 0x7c1ed4a49c90>

In [20]:
# riesz_weight_1 = 0
# riesz_weight_2 = 0.1

# Fine tune
agmm.fit(X_train, y_train, Xval=X_test, yval=y_test,
          earlystop_rounds=earlystop_rounds, earlystop_delta=earlystop_delta,
          learner_lr=1e-05, learner_l2=learner_l2, learner_l1=learner_l1,
          n_epochs=100, bs=bs, target_reg_1=target_reg_1,
          riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
          riesz_weight_2=riesz_weight_2, optimizer='adam', warm_start=True,
          model_dir=str(Path.home()), device=device, verbose=1)

Epoch #0
Validation losses: 0.8354654 -2.215949 0.8354654 0.008651276 -4.2863383 0.008651276 -3.442221681587398
Epoch #1
Validation losses: 0.83322936 -2.215949 0.83322936 0.008370077 -4.2891526 0.008370077 -3.4475531820207834
Epoch #2
Validation losses: 0.83417493 -2.215949 0.83417493 0.0095742885 -4.2742834 0.0095742885 -3.43053418956697
Epoch #3
Validation losses: 0.8333731 -2.215949 0.8333731 0.010054781 -4.276206 0.010054781 -3.432778106071055
Epoch #4
Validation losses: 0.8317288 -2.215949 0.8317288 0.009468661 -4.259841 0.009468661 -3.418643488548696
Epoch #5
Validation losses: 0.8314665 -2.215949 0.8314665 0.009082231 -4.298546 0.009082231 -3.457997110672295
Epoch #6
Validation losses: 0.83163965 -2.215949 0.83163965 0.008750994 -4.261979 0.008750994 -3.421588461846113
Epoch #7
Validation losses: 0.83147955 -2.215949 0.83147955 0.008813088 -4.27105 0.008813088 -3.4307573391124606
Epoch #8
Validation losses: 0.8315638 -2.215949 0.8315638 0.0092811445 -4.2567096 0.0092811445 -3.4

<utils.riesznet.RieszNetNDE at 0x7c1ed4a49c90>

# Task 1: See whether the estimators alpha1 and alpha2 are consistent with true values.

In [21]:
methods = ['dr', 'direct', 'ips', 'test']
srr = {'dr' : True, 'direct' : False, 'ips' : True, 'test': True}

In [22]:
import pandas as pd

In [49]:
alpha_results = pd.DataFrame(data={"alpha1": agmm.predict(X)[:,1], "alpha2": agmm.predict(X)[:,3]})

In [50]:
alpha_results

Unnamed: 0,alpha1,alpha2
0,-0.013872,0.092815
1,1.986853,-2.060415
2,0.063197,1.521148
3,1.988412,-2.059385
4,2.004495,-1.972282
...,...,...
9995,-0.027939,2.366625
9996,-0.016918,3.395357
9997,-0.025539,1.560724
9998,-0.013945,0.883878


In [24]:
agmm.predict(X)[:,1]

array([ 1.85343277e+00,  2.15819001e-01,  1.85299993e+00,  2.02123547e+00,
        1.86360502e+00, -3.66931558e-02, -6.18512630e-02, -4.31969166e-02,
        1.92685974e+00, -3.80441546e-02,  1.85680485e+00,  1.87672782e+00,
        1.96540785e+00,  2.18295097e+00,  2.51077700e+00,  1.89424038e+00,
        2.23088074e+00,  1.88627172e+00, -2.23979950e-02, -3.98099422e-03,
        1.14051223e-01,  2.04663444e+00,  1.90468717e+00,  1.91013932e-01,
       -7.07552433e-02,  1.92342234e+00,  1.86341178e+00,  1.93052208e+00,
       -1.01161003e-02,  2.25809693e-01, -3.06901932e-02,  1.89816415e+00,
        1.90520787e+00,  2.25487161e+00,  6.74732924e-02,  1.86221218e+00,
       -5.20735383e-02, -6.70698881e-02,  1.79082155e-02, -6.91472292e-02,
        1.85424793e+00,  1.21415854e-01, -6.48680925e-02,  1.87511325e+00,
        9.40959454e-02, -4.39934731e-02,  2.21092653e+00, -4.91776466e-02,
       -1.42046213e-02,  1.91617143e+00, -9.35518742e-03, -5.68319559e-02,
        1.92532957e+00,  

In [23]:
agmm.predict_avg_moment(X, Y,  model='earlystop', method = "test", srr = srr["test"])

(1.081697453132623, 0.9216995688183476, 1.2416953374468986)