In [46]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [47]:
import torch
import pandas as pd
from tqdm import tqdm
from PFNExperiments.Evaluation.ClassifcationBasedComparison import compare_samples_classifier_based
from PFNExperiments.Evaluation.MMD import compare_samples_mmd
from scipy.stats import wilcoxon

In [48]:
# sample data such that the marginals for x and y are always the same, only the joint distribution changes

P = 5
N_samples = 1000

def samplex(N_samples = N_samples):
    return torch.randn(N_samples, P)

def generate_sample_y_t(t):
    def sampley_t(x):
        """
        Interpolate between sampley3 and sampley1
        t = 1 -> deterministic dependence on x
        t = 0 -> no dependence on x
        """
        xt = 2*t*x

        noise = torch.randn(N_samples, P)*((4 - (2*t)**2)**0.5)

        return xt + noise
    return sampley_t




In [49]:
t_max = 0.99
sampley_gt = generate_sample_y_t(t_max)


In [50]:
# consistency check if any difference in the marginals regarding y can be detected

def sample_marginal(n_x_samples, sample_ya, sample_yb):
    res_list = []
    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex()
        x2 = samplex()
        y1 = sample_ya(x1)
        y2 = sample_yb(x2)
        diff_mmd = compare_samples_mmd(y1, y2)

        res_list.append(diff_mmd)

    return pd.DataFrame(res_list)

marginal_y1y2_res = sample_marginal(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.1))
marginal_y1y3_res = sample_marginal(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.8))

print(f"""
      marginal_y1y2_res: {marginal_y1y2_res.mean()}
        marginal_y1y3_res: {marginal_y1y3_res.mean()}
        """)

print(wilcoxon(marginal_y1y2_res, marginal_y1y3_res))

100%|██████████| 1000/1000 [05:02<00:00,  3.31it/s]
100%|██████████| 1000/1000 [04:39<00:00,  3.58it/s]


      marginal_y1y2_res: MMD    0.005658
dtype: float64
        marginal_y1y3_res: MMD    0.005645
dtype: float64
        
WilcoxonResult(statistic=array([247366.]), pvalue=array([0.75223811]))





In [51]:

different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    marginal_y1yt_res = sample_marginal(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(t))
    different_t_res.append(marginal_y1yt_res.mean())

print(different_t_res)

100%|██████████| 1000/1000 [03:55<00:00,  4.24it/s]
100%|██████████| 1000/1000 [04:34<00:00,  3.65it/s]
100%|██████████| 1000/1000 [04:23<00:00,  3.79it/s]
100%|██████████| 1000/1000 [04:12<00:00,  3.96it/s]
100%|██████████| 1000/1000 [03:10<00:00,  5.26it/s]
100%|██████████| 1000/1000 [02:09<00:00,  7.75it/s]
100%|██████████| 1000/1000 [02:31<00:00,  6.60it/s]
100%|██████████| 1000/1000 [02:02<00:00,  8.17it/s]
100%|██████████| 1000/1000 [02:02<00:00,  8.15it/s]
100%|██████████| 1000/1000 [02:03<00:00,  8.09it/s]
100%|██████████| 10/10 [31:05<00:00, 186.58s/it]

[MMD    0.005661
dtype: float64, MMD    0.005692
dtype: float64, MMD    0.005545
dtype: float64, MMD    0.00559
dtype: float64, MMD    0.005683
dtype: float64, MMD    0.005696
dtype: float64, MMD    0.005593
dtype: float64, MMD    0.005585
dtype: float64, MMD    0.005537
dtype: float64, MMD    0.005659
dtype: float64]





In [52]:
# check expected difference in the posteriors


def sample_posterior(n_x_samples, sample_ya, sample_yb):
    res_list = []
    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex(N_samples=1)
        x2 = x1
        y1 = sample_ya(x1)
        y2 = sample_yb(x2)
        diff_mmd = compare_samples_mmd(y1, y2)

        res_list.append(diff_mmd)

    return pd.DataFrame(res_list)

posterior_y1y2_res = sample_posterior(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.1))
posterior_y1y3_res = sample_posterior(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.8))

print(f"""
Posterior y1 y2: {posterior_y1y2_res.mean()}    
Posterior y1 y3: {posterior_y1y3_res.mean()}
      """)

print(wilcoxon(posterior_y1y2_res, posterior_y1y3_res))

100%|██████████| 1000/1000 [02:40<00:00,  6.22it/s]
100%|██████████| 1000/1000 [02:36<00:00,  6.38it/s]


Posterior y1 y2: MMD    2.850188
dtype: float64    
Posterior y1 y3: MMD    1.624913
dtype: float64
      
WilcoxonResult(statistic=array([1.]), pvalue=array([3.33585995e-165]))





In [53]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    posterior_y1yt_res = sample_posterior(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(t))
    different_t_res.append(posterior_y1yt_res.mean())

print(different_t_res)

100%|██████████| 1000/1000 [04:08<00:00,  4.02it/s]
100%|██████████| 1000/1000 [03:41<00:00,  4.51it/s]
100%|██████████| 1000/1000 [02:46<00:00,  6.01it/s]
100%|██████████| 1000/1000 [03:22<00:00,  4.93it/s]
100%|██████████| 1000/1000 [02:00<00:00,  8.27it/s]
100%|██████████| 1000/1000 [02:03<00:00,  8.07it/s]
100%|██████████| 1000/1000 [03:10<00:00,  5.25it/s]
100%|██████████| 1000/1000 [03:20<00:00,  4.99it/s]
100%|██████████| 1000/1000 [03:10<00:00,  5.26it/s]
100%|██████████| 1000/1000 [03:37<00:00,  4.60it/s]
100%|██████████| 10/10 [31:23<00:00, 188.33s/it]

[MMD    2.993758
dtype: float64, MMD    2.809594
dtype: float64, MMD    2.65092
dtype: float64, MMD    2.484032
dtype: float64, MMD    2.320952
dtype: float64, MMD    2.141462
dtype: float64, MMD    1.956068
dtype: float64, MMD    1.707329
dtype: float64, MMD    1.293886
dtype: float64, MMD    0.005551
dtype: float64]





In [54]:
def sample_joint(
    n_x_samples = 1000,
    sample_ya = sampley1,
    sample_yb = sampley2
):
    res_list = []

    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex(N_samples=N_samples)
        x2 = samplex(N_samples=N_samples)

        y1 = sample_ya(x1)
        y2 = sample_yb(x2)

        xy1 = torch.cat([x1, y1], dim=1)
        xy2 = torch.cat([x2, y2], dim=1)

        mmd = compare_samples_mmd(xy1, xy2)

        res_list.append(mmd)

    return pd.DataFrame(res_list)



joint_y1y2_res = sample_joint(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.1))
joint_y1y3_res = sample_joint(n_x_samples = 1000, sample_ya = sampley_gt, sample_yb = generate_sample_y_t(0.8))

print(f"""
Joint y1 y2: {joint_y1y2_res.mean()}
Joint y1 y3: {joint_y1y3_res.mean()}
      """)

print(wilcoxon(joint_y1y2_res, joint_y1y3_res))

100%|██████████| 1000/1000 [02:38<00:00,  6.31it/s]
100%|██████████| 1000/1000 [03:58<00:00,  4.20it/s]


Joint y1 y2: MMD    0.044684
dtype: float64
Joint y1 y3: MMD    0.008151
dtype: float64
      
WilcoxonResult(statistic=array([0.]), pvalue=array([3.3258404e-165]))





In [55]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    joint_y1yt_res = sample_joint(n_x_samples=1000, sample_ya=sampley_gt, sample_yb=generate_sample_y_t(t))
    different_t_res.append(joint_y1yt_res.mean())

print(different_t_res)

100%|██████████| 1000/1000 [03:23<00:00,  4.92it/s]
100%|██████████| 1000/1000 [03:41<00:00,  4.51it/s]
100%|██████████| 1000/1000 [02:29<00:00,  6.68it/s]
100%|██████████| 1000/1000 [02:19<00:00,  7.16it/s]
100%|██████████| 1000/1000 [03:34<00:00,  4.66it/s]
100%|██████████| 1000/1000 [02:19<00:00,  7.16it/s]
100%|██████████| 1000/1000 [02:17<00:00,  7.30it/s]
100%|██████████| 1000/1000 [04:09<00:00,  4.01it/s]
100%|██████████| 1000/1000 [03:19<00:00,  5.01it/s]
100%|██████████| 1000/1000 [03:26<00:00,  4.85it/s]
100%|██████████| 10/10 [31:00<00:00, 186.07s/it]

[MMD    0.052592
dtype: float64, MMD    0.043865
dtype: float64, MMD    0.036127
dtype: float64, MMD    0.02898
dtype: float64, MMD    0.022631
dtype: float64, MMD    0.017106
dtype: float64, MMD    0.012478
dtype: float64, MMD    0.008892
dtype: float64, MMD    0.00656
dtype: float64, MMD    0.005581
dtype: float64]



