In [1]:
import argparse
import numpy as np
import pandas as pd
import os
import time
import scipy.stats as st
from sklearn.metrics import pairwise_distances
from sklearn.linear_model import LogisticRegression

from kpt import kernel_two_sample_test_reweight
from dr_kpt import xMMD2dr_cross_fit
from environment_ihdp import X_all, T_all, Y_all, make_scenario_binary

In [2]:
EXPNAME = "ihdp_experiment"
RESULT_DIR = f"results/{EXPNAME}/"
PARAM_CSV = f"experiment_parameters/{EXPNAME}_parameters.csv"
os.makedirs(RESULT_DIR, exist_ok=True)
os.makedirs("experiment_parameters", exist_ok=True)
NB_SEEDS = 100
SAMPLE_SIZE = 500
ITERATIONS = 10000

METHODS = ["KPT-linear", "KPT-rbf", "DR-KPT"]
SCENARIOS = ["I", "II", "III", "IV"]

In [3]:

scenario_id = 'I'
seed = 42

rng = np.random.RandomState(seed)
idx = rng.choice(len(X_all), SAMPLE_SIZE, replace=False)
X = X_all[idx]
T = T_all[idx]
Y = Y_all[idx][:, None]

logreg = LogisticRegression(max_iter=1000, random_state=seed).fit(X, T)
pi0_scores = logreg.predict_proba(X)[:, 1]
pi0_scores = np.clip(pi0_scores, 1e-4, 1 - 1e-4)
pi0_probs = np.stack([1 - pi0_scores, pi0_scores], axis=1)

pi_fn, pi_prime_fn = make_scenario_binary(scenario_id, X, logreg.coef_[0])
pi_probs = pi_fn(X)
pi_prime_probs = pi_prime_fn(X)

w_pi = pi_probs[np.arange(len(T)), T] / pi0_probs[np.arange(len(T)), T]
w_pi_prime = pi_prime_probs[np.arange(len(T)), T] / pi0_probs[np.arange(len(T)), T]

pi_samples = np.array([rng.choice([0, 1], p=p) for p in pi_probs])
pi_prime_samples = np.array([rng.choice([0, 1], p=p) for p in pi_prime_probs])

w_pi, w_pi_prime

(array([2.11590807, 0.4214689 , 2.1483974 , 1.79419649, 1.64151637,
        2.12818638, 0.31963148, 2.14491894, 2.01632203, 0.27198834,
        0.31405878, 0.2503729 , 2.23852434, 2.26876803, 0.30439314,
        2.02577328, 0.33869462, 0.28438398, 1.88939183, 2.55806265,
        0.296054  , 0.29576056, 1.97446507, 1.98391386, 0.35282752,
        2.31612606, 0.26269103, 1.87827608, 0.31696714, 0.27880676,
        0.34076174, 0.25770565, 2.28422604, 0.25856311, 2.61271537,
        0.24508984, 0.36182933, 0.3072706 , 2.04131432, 1.9493849 ,
        0.32671369, 0.31707464, 0.3152738 , 0.32582824, 0.28508683,
        2.03389801, 2.13608098, 0.32301271, 0.28576119, 0.33510975,
        0.30197285, 0.25962804, 1.83194039, 2.30843311, 0.38187475,
        2.14586996, 1.979075  , 0.28834933, 2.32089546, 0.28346791,
        0.28433734, 0.33106646, 0.24560645, 0.20742541, 0.3773967 ,
        2.47331129, 0.34681765, 0.27787678, 0.2800179 , 2.43944566,
        0.32232605, 1.71091941, 1.95746894, 0.28

In [4]:
try:
    sigma2 = np.median(pairwise_distances(Y, Y)) ** 2
    gamma_k = 1.0 / sigma2
except:
    gamma_k = None


gamma_k

np.float64(18484.89500000005)

In [5]:
np.var(Y)

np.float64(5.501287997578564e-05)

In [6]:
Y_all

array([0.0441309 , 0.03309817, 0.02794957, 0.01581357, 0.0268463 ,
       0.03052387, 0.02794957, 0.03493696, 0.04045332, 0.02611078,
       0.03604023, 0.03493696, 0.03052387, 0.04045332, 0.03493696,
       0.02427199, 0.03052387, 0.03677575, 0.02022666, 0.02978836,
       0.02280096, 0.0268463 , 0.0268463 , 0.02353648, 0.03493696,
       0.03971781, 0.03861454, 0.03309817, 0.03420145, 0.02427199,
       0.03420145, 0.04045332, 0.04045332, 0.04229211, 0.02868508,
       0.03309817, 0.04229211, 0.02978836, 0.03420145, 0.02868508,
       0.03236266, 0.02978836, 0.04229211, 0.03236266, 0.0268463 ,
       0.02868508, 0.02794957, 0.02978836, 0.02500751, 0.04155659,
       0.03309817, 0.03052387, 0.03861454, 0.03420145, 0.04229211,
       0.03052387, 0.03604023, 0.02280096, 0.03861454, 0.03162714,
       0.01765236, 0.0441309 , 0.03420145, 0.03346593, 0.03493696,
       0.02868508, 0.03971781, 0.03162714, 0.03309817, 0.03162714,
       0.03052387, 0.03089163, 0.04045332, 0.01912339, 0.02794

In [7]:
stat = xMMD2dr_cross_fit(
    Y, X, T, w_pi, w_pi_prime,
    pi_samples, pi_prime_samples,
    kernel_function="rbf", 
    reg_lambda=1e-1)
pval = 1 - st.norm.cdf(stat)

In [8]:
stat, pval

(np.float64(-0.3089702323061419), np.float64(0.6213279145720778))

In [15]:
Y, X, T

(array([[0.02978836],
        [0.0268463 ],
        [0.04486641],
        [0.03971781],
        [0.03309817],
        [0.0467052 ],
        [0.02611078],
        [0.04192435],
        [0.03861454],
        [0.02978836],
        [0.03420145],
        [0.03604023],
        [0.03162714],
        [0.02611078],
        [0.02868508],
        [0.04486641],
        [0.03420145],
        [0.02868508],
        [0.02794957],
        [0.03162714],
        [0.03236266],
        [0.03162714],
        [0.01765236],
        [0.02427199],
        [0.01581357],
        [0.05111829],
        [0.03309817],
        [0.02978836],
        [0.03971781],
        [0.03493696],
        [0.02868508],
        [0.02096218],
        [0.03162714],
        [0.03604023],
        [0.04486641],
        [0.03162714],
        [0.03236266],
        [0.02353648],
        [0.04339538],
        [0.03309817],
        [0.02500751],
        [0.04486641],
        [0.03346593],
        [0.03971781],
        [0.02868508],
        [0