In [95]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [56]:
# from google.colab import drive

# drive.mount("/content/drive")

# %cd "/content/drive/MyDrive/Ivan_Diaz/RieszLearning-main"
# # %pwd
# # from pycox.models import LogisticHazard

# Randomized Interventional Effects

## Library Imports

In [96]:
from pathlib import Path
import os
import glob
from joblib import dump, load
import pandas as pd
import scipy
import scipy.stats
import scipy.special
import torch
import torch.nn as nn
import math
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from utils.riesznet import RieszNet, RieszNetNDE, RieszNetRDERIE
# from utils.moments import ate_moment_fn
from utils.moments import nde_theta1_moment_fn, nde_theta2_moment_fn
from utils.moments import rde_rie_theta3_moment_fn, rde_rie_theta2_moment_fn, rde_rie_theta1_moment_fn
# from utils.ihdp_data import *
from scipy.spatial import distance_matrix
from cvxopt import spmatrix, matrix
from cvxopt import solvers

## Moment Definition

In [97]:
moment_fn_1 = rde_rie_theta1_moment_fn
moment_fn_2 = rde_rie_theta2_moment_fn
moment_fn_3 = rde_rie_theta3_moment_fn
# moment_fn = ate_moment_fn

In [98]:
N = 2000

#W_1 = scipy.stats.norm.rvs(loc = 0, scale = 1, size = N)
W_2 = scipy.stats.bernoulli.rvs(0.5, size = N)
#W_3 = scipy.stats.uniform.rvs(loc = 0, scale = 1, size = N)

#prob_A = scipy.special.expit(-1 + W_1 / 3 + W_2 / 4)
prob_A = scipy.special.expit(-0.25 + W_2 / 4)
A = scipy.stats.bernoulli.rvs(prob_A, size = N)

prob_Z = scipy.special.expit(-0.25 + A / 2)
Z = scipy.stats.bernoulli.rvs(prob_Z, size = N)

prob_M = scipy.special.expit(-0.25 + A / 2 - Z / 2 + W_2 / 4)
M = scipy.stats.bernoulli.rvs(prob_M, size = N)

#prob_Y = scipy.special.expit(-1 + M / 2 + A / 4 - Z / 2 + W_1 / 4 - W_3 / 4)
prob_Y = scipy.special.expit(-0.25 + M / 2 + A / 4 - Z / 2 + W_2 / 4)
Y = scipy.stats.bernoulli.rvs(prob_Y, size = N)

#data_batch_1 = pd.DataFrame(data = {"W_1": W_1, "W_2": W_2, "W_3": W_3, "A": A, "Z": Z, "M": M, "Y": Y})
data_batch_1 = pd.DataFrame(data = {"W_2": W_2, "A": A, "Z": Z, "M": M, "Y": Y})

### Estimator Settings

In [100]:
drop_prob = 0.0  # dropout prob of dropout layers throughout notebook
n_hidden = 100  # width of hidden layers throughout notebook

# Training params
learner_lr = 1e-5
learner_l2 = 1e-3
learner_l1 = 0.0
n_epochs = 600
earlystop_rounds = 40 # how many epochs to wait for an out-of-sample improvement
earlystop_delta = 1e-4
target_reg_1 = 1
target_reg_2 = 1
target_reg_3 = 1
riesz_weight_1 = 1
riesz_weight_2 = 0
riesz_weight_3 = 0

bs = 32
device = torch.cuda.current_device() if torch.cuda.is_available() else None
print("GPU:", torch.cuda.is_available())

from itertools import chain, combinations
from itertools import combinations_with_replacement as combinations_w_r

def _combinations(n_features, degree, interaction_only):
        comb = (combinations if interaction_only else combinations_w_r)
        return chain.from_iterable(comb(range(n_features), i)
                                   for i in range(0, degree + 1))

class Learner(nn.Module):

    def __init__(self, n_t, n_hidden, p, degree, interaction_only=False):
        super().__init__()
        n_common = 200

        # Don't know what it is, but should be different for 1, 2 and 3
        self.monomials1 = list(_combinations(n_t - 2, degree, interaction_only))
        self.monomials2 = list(_combinations(n_t - 1, degree, interaction_only))
        self.monomials3 = list(_combinations(n_t, degree, interaction_only))

        # Common layers for g and alpha
        # theta_1 and m_1 are the function of (A, W)
        # theta_2 and m_2 are the function of (A, Z, W)
        # theta_3 and m_3 are the function of (A, Z, M, W)
        self.common1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t - 2, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.common2 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t - 1, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())
        self.common3 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_common, n_common), nn.ELU())

        # Riesz specific layers
        # alpha_1 is related to theta_1 and m_1
        # alpha_2 is related to theta_2 and m_2
        # alpha_3 is related to theta_3 and m_3
        self.riesz_nn1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_poly1 = nn.Sequential(nn.Linear(len(self.monomials1), 1))
        self.riesz_nn2 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_poly2 = nn.Sequential(nn.Linear(len(self.monomials2), 1))
        self.riesz_nn3 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, 1))
        self.riesz_poly3 = nn.Sequential(nn.Linear(len(self.monomials3), 1))

        # Regression loss layers
        # Indexes are the same as Riesz specific layers
        self.reg_nn0_1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1_1 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn0_2 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1_2 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn0_3 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))
        self.reg_nn1_3 = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_common, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ELU(),
                                    nn.Dropout(p=p), nn.Linear(n_hidden, 1))

        # Don't know what it is, but should be different for 1 and 2
        self.reg_poly1 = nn.Sequential(nn.Linear(len(self.monomials1), 1))
        self.reg_poly2 = nn.Sequential(nn.Linear(len(self.monomials2), 1))
        self.reg_poly3 = nn.Sequential(nn.Linear(len(self.monomials3), 1))


    def forward(self, x):
        # Create a new dataset x1, which is (A, W), used for theta_1 and m_1
        # Create a new dataset x2, which is (A, Z, W), used for theta_2 and m_2
        # Recall: we assume x = (A, Z, M, W)

        if torch.is_tensor(x):
            with torch.no_grad():
                x1 = torch.cat([torch.reshape(x[:, 0].to(device), (-1, 1)), x[:, 3:]], dim=1)
        else:
            x1 = np.hstack([np.array(x[:, 0]).reshape(-1, 1), x[:, 3:]])
        
        if torch.is_tensor(x):
            with torch.no_grad():
                x2 = torch.cat([torch.reshape(x[:, [0, 1]].to(device), (-1, 2)), x[:, 3:]], dim=1)
        else:
            x2 = np.hstack([np.array(x[:, [0, 1]]).reshape(-1, 2), x[:, 3:]])

        poly1 = torch.cat([torch.prod(x1[:, t], dim=1, keepdim=True)
                          for t in self.monomials1], dim=1)
        poly2 = torch.cat([torch.prod(x2[:, t], dim=1, keepdim=True)
                          for t in self.monomials2], dim=1)
        poly3 = torch.cat([torch.prod(x[:, t], dim=1, keepdim=True)
                          for t in self.monomials2], dim=1)

        feats1 = self.common1(x1)
        feats2 = self.common2(x2)
        feats3 = self.common3(x)

        riesz1 = self.riesz_nn1(feats1) + self.riesz_poly1(poly1)
        riesz2 = self.riesz_nn2(feats2) + self.riesz_poly2(poly2)
        riesz3 = self.riesz_nn3(feats3) + self.riesz_poly3(poly3)

        reg1 = self.reg_nn0_1(feats1) * (1 - x1[:, [0]]) + self.reg_nn1_1(feats1) * x1[:, [0]] + self.reg_poly1(poly1)
        reg2 = self.reg_nn0_2(feats2) * (1 - x2[:, [0]]) + self.reg_nn1_2(feats2) * x2[:, [0]] + self.reg_poly2(poly2)
        reg3 = self.reg_nn0_3(feats3) * (1 - x[:, [0]]) + self.reg_nn1_3(feats3) * x[:, [0]] + self.reg_poly3(poly3)
        return torch.cat([reg1, riesz1, reg2, riesz2, reg3, riesz3], dim=1)

GPU: False


In [101]:
A = np.array(data_batch_1["A"])
Z = np.array(data_batch_1["Z"])
M = np.array(data_batch_1["M"])
#W_1 = np.array(data_batch_1["W_1"])
W_2 = np.array(data_batch_1["W_2"])
#W_3 = np.array(data_batch_1["W_3"])
Y = np.array(data_batch_1["Y"])

In [102]:
#X = np.c_[A, Z, M, W_1, W_2, W_3]
X = np.c_[A, Z, M, W_2]
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.2)

In [103]:
torch.cuda.empty_cache()
learner = Learner(X_train.shape[1], n_hidden, drop_prob, 0, interaction_only=True)
agmm = RieszNetRDERIE(learner, moment_fn_1, moment_fn_2, moment_fn_3)

In [104]:
def permutation_matrix_cpu(data_temp):
    # A, Z, M, W1
    N = np.shape(data_temp)[0]
    D = distance_matrix(data_temp[:, [0, 3]].reshape(-1, 2), data_temp[:, [0, 3]].reshape(-1, 2))  # .reshape(-1, )
    row, col = np.triu_indices(N, 1)
    D[row, col] = 0

    D = D.reshape(-1, )
    D = D / max(D)
    D = matrix(D)

    A2 = -np.eye(N * N)
    b2 = np.zeros(N * N)

    row_list = []
    col_list = []
    data_list = []
    for i in range(N):
        index_col_list = [i * N + j for j in range(N)]
        col_list = col_list + index_col_list
        row_list = row_list + [i for j in range(N)]
        data_list = data_list + [1 for j in range(N)]

    for i in range(N - 1):
        index_col_list = [j * N + i for j in range(N)]
        col_list = col_list + index_col_list
        row_list = row_list + [i + N for j in range(N)]
        data_list = data_list + [1 for j in range(N)]

    index_col_list = [j * (N + 1) for j in range(N)]
    col_list = col_list + index_col_list
    row_list = row_list + [2 * N - 1 for j in range(N)]
    data_list = data_list + [1 for j in range(N)]

    row_list_1 = [i for i in range(N * N)]
    row_list_2 = [i + N * N for i in row_list]
    row_list_3 = [i + 2 * N for i in row_list_2]
    row_list = row_list_1 + row_list_2 + row_list_3

    data_list = data_list + [-i for i in data_list]
    col_list = col_list * 2
    data_list = [-1 for i in range(N * N)] + data_list
    col_list = [i for i in range(N * N)] + col_list

    b3 = np.ones(2 * N)
    b3[2 * N - 1] = 0
    b4 = -b3
    b = matrix(b2.tolist() + b3.tolist() + b4.tolist())
    A = spmatrix(data_list, row_list, col_list)

    sol = solvers.lp(D, A, b, solver="glpk")
    P = np.array(sol['x']).reshape(N, N)
    P[P < 10 ** (-3)] = 0
    P[P > 1 - 10 ** (-3)] = 1

    res = data_temp.copy()
    res[:, 1] = np.matmul(P, res[:, 1])

    return res


In [106]:
batch_size = 32
X_supp_test = X_test.copy()
for i in range(int(X_test.shape[0] / batch_size)):
    data_temp = X_test[i * batch_size:(i + 1) * batch_size,:].copy()
    X_supp_test[i * batch_size:(i + 1) * batch_size,:] = permutation_matrix_cpu(data_temp)

In [107]:
X_supp_test

array([[1, 1, 1, 1],
       [0, 0, 0, 1],
       [1, 1, 1, 0],
       ...,
       [0, 0, 1, 1],
       [1, 1, 1, 0],
       [1, 1, 0, 1]])

In [108]:
solvers.options['glpk'] = {'msg_lev' : 'GLP_MSG_OFF'}

In [109]:
# Fast training
agmm.fit(X_train, y_train, Xval=X_test, yval=y_test, Xval_supp = X_supp_test,
          earlystop_rounds=2, earlystop_delta=earlystop_delta,
          learner_lr=1e-04, learner_l2=learner_l2, learner_l1=learner_l1,
          n_epochs=100, bs=bs, target_reg_1=target_reg_1,
          riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
          riesz_weight_2=riesz_weight_2, target_reg_3=target_reg_3,
          riesz_weight_3=riesz_weight_3, optimizer='adam',
          model_dir=str(Path.home()), device=device, verbose=1)

Epoch #0
Validation losses: 0.29076183 0.019168727 0.03818228 -1.7440721 -1.0493206 -0.13344301 0.03726881 0.018688187 0.2907628 -1.049239419400692
Epoch #1
Validation losses: 0.25330698 0.0044166856 0.004364865 -2.0675323 -1.1694912 -0.7317278 0.00426376 0.0043267133 0.2532395 -1.543613777961582
Epoch #2
Validation losses: 0.23654123 0.0024472962 0.00108228 -2.06976 -1.2081627 -1.2475462 0.000997234 0.0022357046 0.23660083 -1.5898555094609037
Epoch #3
Validation losses: 0.23448615 0.0033667036 0.00012521243 -2.0704484 -1.2085326 -1.4934449 0.00019047856 0.0021976356 0.23484106 -1.5952411498437868
Epoch #4
Validation losses: 0.23462833 0.0035309047 0.0006500842 -2.0678058 -1.335981 -1.4608905 0.00080076157 0.0023852377 0.23505557 -1.5907548845862038
Epoch #5
Validation losses: 0.2360596 0.0019012977 0.00028948212 -2.0723062 -1.2324076 -1.3631506 0.00023318056 0.0012292302 0.23677124 -1.595822114875773
Epoch #6
Validation losses: 0.23599358 0.0037580722 0.0013813563 -2.0644891 -1.391572

<utils.riesznet.RieszNetRDERIE at 0x13be6c290>

In [110]:
# riesz_weight_1 = 0
# riesz_weight_2 = 0.1

# Fine tune
agmm.fit(X_train, y_train, Xval=X_test, yval=y_test, Xval_supp = X_supp_test,
          earlystop_rounds=2, earlystop_delta=earlystop_delta,
          learner_lr=learner_lr, learner_l2=learner_l2, learner_l1=learner_l1,
          n_epochs=100, bs=bs, target_reg_1=target_reg_1,
          riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
          riesz_weight_2=riesz_weight_2, target_reg_3=target_reg_3,
          riesz_weight_3=riesz_weight_3, optimizer='adam', warm_start=True,
          model_dir=str(Path.home()), device=device, verbose=1)

Epoch #0
Validation losses: 0.23509364 0.00085628574 0.0002645898 -2.0669708 -1.3824434 -1.3229094 0.00028707166 0.0008430716 0.23555477 -1.594071408646414
Epoch #1
Validation losses: 0.23501438 0.0008968842 0.00021251147 -2.0679421 -1.3705618 -1.2658304 0.00019870883 0.0009227988 0.23547791 -1.5952189549134346
Epoch #2
Validation losses: 0.23506232 0.0009056263 0.00017526893 -2.0637362 -1.422268 -1.2213322 0.00018660851 0.00093657203 0.23555389 -1.5909159128350439
Epoch #3
Validation losses: 0.23483448 0.0009652622 0.00023304067 -2.0668173 -1.3714325 -1.168101 0.00027003043 0.00090112584 0.23530799 -1.5943053518712986
Epoch #4
Validation losses: 0.23486628 0.0009486337 0.00022381938 -2.0644553 -1.396608 -1.1345406 0.00023383129 0.0011161263 0.23532195 -1.5917446297244169


<utils.riesznet.RieszNetRDERIE at 0x13be6c290>

In [111]:
# Freeze network parameters w.r.t. alpha1

for name, param in agmm.learner.named_parameters():
  if("riesz_nn1" in name or "riesz_poly1" in name or "common1" in name):
    print(name)
    param.requires_grad = False

learner.common1.1.weight
learner.common1.1.bias
learner.common1.4.weight
learner.common1.4.bias
learner.common1.7.weight
learner.common1.7.bias
learner.riesz_nn1.1.weight
learner.riesz_nn1.1.bias
learner.riesz_poly1.0.weight
learner.riesz_poly1.0.bias


In [112]:
# Fast training
riesz_weight_2 = 1
riesz_weight_1 = 0

agmm.fit(X_train, y_train, Xval=X_test, yval=y_test, Xval_supp = X_supp_test,
          earlystop_rounds=2, earlystop_delta=earlystop_delta,
          learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
          n_epochs=100, bs=bs, target_reg_1=target_reg_1,
          riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
          riesz_weight_2=riesz_weight_2, target_reg_3=target_reg_3,
          riesz_weight_3=riesz_weight_3, optimizer='adam',
          model_dir=str(Path.home()), device=device, verbose=1, warm_start=True)

Epoch #0
Validation losses: 0.23437999 0.0017093649 0.0003143017 -2.0679421 -2.257577 -1.1544385 0.0003391573 0.002143942 0.23429564 -1.7843945476925
Epoch #1
Validation losses: 0.23389314 0.0011985464 0.001122034 -2.0679421 -2.302374 -1.148068 0.0010755459 0.0014551767 0.23427056 -1.8293588918168098
Epoch #2
Validation losses: 0.23436174 0.0009655937 0.00029869706 -2.0679421 -2.3034914 -1.1403993 0.00030582355 0.0011958943 0.2345075 -1.831856108008651
Epoch #3
Validation losses: 0.23694251 0.001405185 0.00028772288 -2.0679421 -2.304655 -1.1703688 0.0003021882 0.0012965506 0.23764038 -1.8267805346113164
Epoch #4
Validation losses: 0.23982607 0.0020270078 0.001071279 -2.0679421 -2.305677 -1.1550964 0.0010661553 0.0017761121 0.24068286 -1.8192274669418111
Epoch #5
Validation losses: 0.2369436 0.0014322561 0.0003013058 -2.0679421 -2.3062382 -1.1894574 0.00030000033 0.0012719727 0.23770882 -1.8282802205940243


<utils.riesznet.RieszNetRDERIE at 0x13be6c290>

In [113]:
# riesz_weight_1 = 0
# riesz_weight_2 = 0.1

# Fine tune
agmm.fit(X_train, y_train, Xval=X_test, yval=y_test, Xval_supp = X_supp_test,
          earlystop_rounds=2, earlystop_delta=earlystop_delta,
          learner_lr=1e-05, learner_l2=learner_l2, learner_l1=learner_l1,
          n_epochs=100, bs=bs, target_reg_1=target_reg_1,
          riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
          riesz_weight_2=riesz_weight_2, target_reg_3=target_reg_3,
          riesz_weight_3=riesz_weight_3, optimizer='adam', warm_start=True,
          model_dir=str(Path.home()), device=device, verbose=1)

Epoch #0
Validation losses: 0.2359726 0.0010469753 0.0002014115 -2.0679421 -2.3042192 -1.1539469 0.00020202543 0.0009494382 0.23653162 -1.8293151942489203
Epoch #1
Validation losses: 0.23566282 0.0010992517 0.0001840712 -2.0679421 -2.304861 -1.1505969 0.00018484154 0.001347396 0.23616205 -1.8302206289081369
Epoch #2
Validation losses: 0.23520973 0.0009241642 0.00026203424 -2.0679421 -2.3054032 -1.1445546 0.00027170865 0.0010603181 0.23561804 -1.8320572256925516
Epoch #3
Validation losses: 0.23529169 0.00095204805 0.00026163526 -2.0679421 -2.305879 -1.1496103 0.00026083962 0.00093003706 0.23576446 -1.8324184062948916
Epoch #4
Validation losses: 0.23554064 0.00088497036 0.0002099609 -2.0679421 -2.306267 -1.1401333 0.00021360819 0.00093579496 0.23602325 -1.8324588069372112
Epoch #5
Validation losses: 0.235253 0.00095592707 0.00021417446 -2.0679421 -2.3066604 -1.1537496 0.00021525407 0.00096644694 0.23573923 -1.8333163766219513
Epoch #6
Validation losses: 0.23507053 0.0010079115 0.00024766

<utils.riesznet.RieszNetRDERIE at 0x13be6c290>

In [114]:
# Freeze network parameters w.r.t. alpha1

for name, param in agmm.learner.named_parameters():
    if("riesz_nn2" in name or "riesz_poly2" in name or "common2" in name):
        print(name)
        param.requires_grad = False

learner.common2.1.weight
learner.common2.1.bias
learner.common2.4.weight
learner.common2.4.bias
learner.common2.7.weight
learner.common2.7.bias
learner.riesz_nn2.1.weight
learner.riesz_nn2.1.bias
learner.riesz_poly2.0.weight
learner.riesz_poly2.0.bias


In [115]:
# Fast training
riesz_weight_2 = 0
riesz_weight_3 = 1

agmm.fit(X_train, y_train, Xval=X_test, yval=y_test, Xval_supp = X_supp_test,
          earlystop_rounds=2, earlystop_delta=earlystop_delta,
          learner_lr=1e-4, learner_l2=learner_l2, learner_l1=learner_l1,
          n_epochs=100, bs=bs, target_reg_1=target_reg_1,
          riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
          riesz_weight_2=riesz_weight_2, target_reg_3=target_reg_3,
          riesz_weight_3=riesz_weight_3, optimizer='adam',
          model_dir=str(Path.home()), device=device, verbose=1, warm_start=True)

Epoch #0
Validation losses: 0.24257287 0.0009952494 0.0003053945 -2.0679421 -2.3081005 -2.0292478 0.0003055451 0.0011779157 0.24306522 -1.5408255551592447
Epoch #1
Validation losses: 0.23780447 0.0024561833 0.0010049659 -2.0679421 -2.3081005 -2.371417 0.0010121082 0.0025593666 0.23798767 -1.8885922787012532
Epoch #2
Validation losses: 0.23644577 0.00081163854 0.00023708712 -2.0679421 -2.3081005 -2.3788242 0.00023672263 0.0007940746 0.23654665 -1.9037522877479205
Epoch #3
Validation losses: 0.23353039 0.0016229722 0.0003007892 -2.0679421 -2.3081005 -2.3817704 0.00031186242 0.0015889788 0.23349625 -1.9109191254538018
Epoch #4
Validation losses: 0.23646697 0.0014934955 0.0004885808 -2.0679421 -2.3081005 -2.3846066 0.0004924237 0.0014691567 0.23672792 -1.9074680442572571
Epoch #5
Validation losses: 0.23653805 0.0021376046 0.00045196695 -2.0679421 -2.3081005 -2.3862772 0.00044835004 0.0022113798 0.23623765 -1.9082522027310915
Epoch #6
Validation losses: 0.2380837 0.0009769424 0.00023080155 

<utils.riesznet.RieszNetRDERIE at 0x13be6c290>

In [116]:
# # riesz_weight_1 = 0
# # riesz_weight_2 = 0.1

# # Fine tune
# agmm.fit(X_train, y_train, Xval=X_test, yval=y_test, Xval_supp = X_supp_test,
#           earlystop_rounds=10, earlystop_delta=earlystop_delta,
#           learner_lr=1e-05, learner_l2=learner_l2, learner_l1=learner_l1,
#           n_epochs=100, bs=bs, target_reg_1=target_reg_1,
#           riesz_weight_1=riesz_weight_1, target_reg_2=target_reg_2,
#           riesz_weight_2=riesz_weight_2, target_reg_3=target_reg_3,
#           riesz_weight_3=riesz_weight_3, optimizer='adam', warm_start=True,
#           model_dir=str(Path.home()), device=device, verbose=1)

# Task 1: See whether the estimators alpha1 and alpha2 are consistent with true values.

In [117]:
methods = ['dr', 'direct', 'ips', 'test']
srr = {'dr' : True, 'direct' : False, 'ips' : True, 'test': True}

In [118]:
import pandas as pd

In [119]:
alpha_results = pd.DataFrame(data={"alpha1": agmm.predict(X)[:,1], "alpha2": agmm.predict(X)[:,3], "alpha3": agmm.predict(X)[:,5]})

In [120]:
alpha_results

Unnamed: 0,alpha1,alpha2,alpha3
0,-0.057211,-0.008415,-0.101419
1,-0.057211,-0.018224,-0.090568
2,-0.057211,-0.018224,-0.006973
3,0.057557,0.023950,-0.025191
4,0.057557,0.019582,0.101540
...,...,...,...
1995,-0.057212,-0.018224,-0.090568
1996,0.057557,0.019582,0.101540
1997,2.285810,2.345643,2.176990
1998,0.057557,0.019582,0.041414


In [35]:
agmm.predict_avg_moment(X, Y,  model='earlystop', method = "test", srr = srr["test"])

(0.3120565523400965, 0.2575672797490606, 0.3665458249311324)