In [11]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
import torch
import pandas as pd
from tqdm import tqdm
from PFNExperiments.Evaluation.ClassifcationBasedComparison import compare_samples_classifier_based
from PFNExperiments.Evaluation.MMD import compare_samples_mmd
from scipy.stats import wilcoxon
from sklearn.tree import DecisionTreeClassifier

In [13]:
# sample data such that the marginals for x and y are always the same, only the joint distribution changes

P = 5
N_samples = 1000

def samplex(N_samples = N_samples):
    return torch.randn(N_samples, P)

def generate_sample_y_t(t):
    def sampley_t(x):
        """
        Interpolate between sampley3 and sampley1
        t = 1 -> deterministic dependence on x
        t = 0 -> no dependence on x
        """
        xt = 2*t*x

        noise = torch.randn(N_samples, P)*((4 - (2*t)**2)**0.5)

        return xt + noise
    return sampley_t




In [14]:
t_max = 0.99
sampley_gt = generate_sample_y_t(t_max)


In [15]:
# consistency check if any difference in the marginals regarding y can be detected

def sample_marginal(n_x_samples, sample_ya, sample_yb):
    res_list = []
    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex()
        x2 = samplex()
        y1 = sample_ya(x1)
        y2 = sample_yb(x2)
        diff_mmd = compare_samples_classifier_based(y1, y2, used_model = DecisionTreeClassifier())

        res_list.append(diff_mmd)

    return pd.DataFrame(res_list)

marginal_y1y2_res = sample_marginal(n_x_samples = 100, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.1))
marginal_y1y3_res = sample_marginal(n_x_samples = 100, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.8))

print(f"""
      marginal_y1y2_res: {marginal_y1y2_res.mean()}
        marginal_y1y3_res: {marginal_y1y3_res.mean()}
        """)

print(wilcoxon(marginal_y1y2_res, marginal_y1y3_res))

100%|██████████| 100/100 [01:02<00:00,  1.59it/s]
100%|██████████| 100/100 [01:00<00:00,  1.67it/s]


      marginal_y1y2_res: CST_accuracy         0.500665
CST_roc_auc_score    0.500665
dtype: float64
        marginal_y1y3_res: CST_accuracy         0.500065
CST_roc_auc_score    0.500065
dtype: float64
        
WilcoxonResult(statistic=array([2407. , 2400.5]), pvalue=array([0.68494024, 0.66859404]))





In [16]:

different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    marginal_y1yt_res = sample_marginal(n_x_samples = 100, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(t))
    different_t_res.append(marginal_y1yt_res.mean())

print(different_t_res)

100%|██████████| 100/100 [00:55<00:00,  1.80it/s]
100%|██████████| 100/100 [00:48<00:00,  2.04it/s]
100%|██████████| 100/100 [00:49<00:00,  2.02it/s]
100%|██████████| 100/100 [00:48<00:00,  2.06it/s]
100%|██████████| 100/100 [00:48<00:00,  2.05it/s]
100%|██████████| 100/100 [00:47<00:00,  2.10it/s]
100%|██████████| 100/100 [00:48<00:00,  2.04it/s]
100%|██████████| 100/100 [00:49<00:00,  2.04it/s]
100%|██████████| 100/100 [00:48<00:00,  2.04it/s]
100%|██████████| 100/100 [00:48<00:00,  2.04it/s]
100%|██████████| 10/10 [08:14<00:00, 49.49s/it]

[CST_accuracy         0.497685
CST_roc_auc_score    0.497685
dtype: float64, CST_accuracy         0.498315
CST_roc_auc_score    0.498315
dtype: float64, CST_accuracy         0.499365
CST_roc_auc_score    0.499365
dtype: float64, CST_accuracy         0.500985
CST_roc_auc_score    0.500985
dtype: float64, CST_accuracy         0.500245
CST_roc_auc_score    0.500245
dtype: float64, CST_accuracy         0.499365
CST_roc_auc_score    0.499365
dtype: float64, CST_accuracy         0.5007
CST_roc_auc_score    0.5007
dtype: float64, CST_accuracy         0.498685
CST_roc_auc_score    0.498685
dtype: float64, CST_accuracy         0.500175
CST_roc_auc_score    0.500175
dtype: float64, CST_accuracy         0.49924
CST_roc_auc_score    0.49924
dtype: float64]





In [17]:
# check expected difference in the posteriors


def sample_posterior(n_x_samples, sample_ya, sample_yb):
    res_list = []
    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex(N_samples=1)
        x2 = x1
        y1 = sample_ya(x1)
        y2 = sample_yb(x2)
        diff_mmd = compare_samples_classifier_based(y1, y2, used_model = DecisionTreeClassifier())

        res_list.append(diff_mmd)

    return pd.DataFrame(res_list)

posterior_y1y2_res = sample_posterior(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.1))
posterior_y1y3_res = sample_posterior(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.8))

print(f"""
Posterior y1 y2: {posterior_y1y2_res.mean()}    
Posterior y1 y3: {posterior_y1y3_res.mean()}
      """)

print(wilcoxon(posterior_y1y2_res, posterior_y1y3_res))

100%|██████████| 1000/1000 [03:51<00:00,  4.32it/s]
100%|██████████| 1000/1000 [05:18<00:00,  3.14it/s]


Posterior y1 y2: CST_accuracy         0.991135
CST_roc_auc_score    0.991135
dtype: float64    
Posterior y1 y3: CST_accuracy         0.955395
CST_roc_auc_score    0.955395
dtype: float64
      
WilcoxonResult(statistic=array([0., 0.]), pvalue=array([3.31260033e-165, 3.31087845e-165]))





In [18]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    posterior_y1yt_res = sample_posterior(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(t))
    different_t_res.append(posterior_y1yt_res.mean())

print(different_t_res)

100%|██████████| 1000/1000 [03:40<00:00,  4.54it/s]
100%|██████████| 1000/1000 [03:49<00:00,  4.37it/s]
100%|██████████| 1000/1000 [04:00<00:00,  4.15it/s]
100%|██████████| 1000/1000 [05:06<00:00,  3.26it/s]
100%|██████████| 1000/1000 [05:06<00:00,  3.27it/s]
100%|██████████| 1000/1000 [05:27<00:00,  3.05it/s]
100%|██████████| 1000/1000 [05:03<00:00,  3.30it/s]
100%|██████████| 1000/1000 [06:00<00:00,  2.77it/s]
100%|██████████| 1000/1000 [07:47<00:00,  2.14it/s]
100%|██████████| 1000/1000 [08:29<00:00,  1.96it/s]
100%|██████████| 10/10 [54:32<00:00, 327.26s/it]

[CST_accuracy         0.992131
CST_roc_auc_score    0.992131
dtype: float64, CST_accuracy         0.990771
CST_roc_auc_score    0.990771
dtype: float64, CST_accuracy         0.989069
CST_roc_auc_score    0.989069
dtype: float64, CST_accuracy         0.986926
CST_roc_auc_score    0.986926
dtype: float64, CST_accuracy         0.984438
CST_roc_auc_score    0.984438
dtype: float64, CST_accuracy         0.980236
CST_roc_auc_score    0.980236
dtype: float64, CST_accuracy         0.974055
CST_roc_auc_score    0.974054
dtype: float64, CST_accuracy         0.961328
CST_roc_auc_score    0.961328
dtype: float64, CST_accuracy         0.923077
CST_roc_auc_score    0.923077
dtype: float64, CST_accuracy         0.500088
CST_roc_auc_score    0.500088
dtype: float64]





In [19]:
def sample_joint(
    n_x_samples = 1000,
    sample_ya = None,
    sample_yb = None 
):
    res_list = []

    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex(N_samples=N_samples)
        x2 = samplex(N_samples=N_samples)

        y1 = sample_ya(x1)
        y2 = sample_yb(x2)

        xy1 = torch.cat([x1, y1], dim=1)
        xy2 = torch.cat([x2, y2], dim=1)

        mmd = compare_samples_classifier_based(xy1, xy2, used_model = DecisionTreeClassifier())

        res_list.append(mmd)

    return pd.DataFrame(res_list)



joint_y1y2_res = sample_joint(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.1))
joint_y1y3_res = sample_joint(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.8))

print(f"""
Joint y1 y2: {joint_y1y2_res.mean()}
Joint y1 y3: {joint_y1y3_res.mean()}
      """)

print(wilcoxon(joint_y1y2_res, joint_y1y3_res))

100%|██████████| 1000/1000 [13:24<00:00,  1.24it/s]
100%|██████████| 1000/1000 [14:28<00:00,  1.15it/s]


Joint y1 y2: CST_accuracy         0.864961
CST_roc_auc_score    0.864961
dtype: float64
Joint y1 y3: CST_accuracy         0.690791
CST_roc_auc_score    0.690791
dtype: float64
      
WilcoxonResult(statistic=array([0., 0.]), pvalue=array([3.32085497e-165, 3.31891032e-165]))





In [20]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    joint_y1yt_res = sample_joint(n_x_samples=100, sample_ya=sampley_gt, sample_yb=generate_sample_y_t(t))
    different_t_res.append(joint_y1yt_res.mean())

print(different_t_res)

100%|██████████| 100/100 [01:17<00:00,  1.29it/s]
100%|██████████| 100/100 [01:13<00:00,  1.37it/s]
100%|██████████| 100/100 [01:37<00:00,  1.03it/s]
100%|██████████| 100/100 [01:38<00:00,  1.02it/s]
100%|██████████| 100/100 [01:23<00:00,  1.20it/s]
100%|██████████| 100/100 [01:21<00:00,  1.22it/s]
100%|██████████| 100/100 [01:24<00:00,  1.19it/s]
100%|██████████| 100/100 [01:32<00:00,  1.09it/s]
100%|██████████| 100/100 [01:41<00:00,  1.01s/it]
100%|██████████| 100/100 [01:08<00:00,  1.47it/s]
100%|██████████| 10/10 [14:18<00:00, 85.80s/it]

[CST_accuracy         0.877275
CST_roc_auc_score    0.877275
dtype: float64, CST_accuracy         0.863255
CST_roc_auc_score    0.863255
dtype: float64, CST_accuracy         0.849185
CST_roc_auc_score    0.849185
dtype: float64, CST_accuracy         0.830785
CST_roc_auc_score    0.830785
dtype: float64, CST_accuracy         0.811605
CST_roc_auc_score    0.811605
dtype: float64, CST_accuracy         0.78664
CST_roc_auc_score    0.78664
dtype: float64, CST_accuracy         0.75262
CST_roc_auc_score    0.75262
dtype: float64, CST_accuracy         0.70723
CST_roc_auc_score    0.70723
dtype: float64, CST_accuracy         0.63391
CST_roc_auc_score    0.63391
dtype: float64, CST_accuracy         0.49862
CST_roc_auc_score    0.49862
dtype: float64]



