In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import pandas as pd
from tqdm import tqdm
from PFNExperiments.Evaluation.ClassifcationBasedComparison import compare_samples_classifier_based
from PFNExperiments.Evaluation.MMD import compare_samples_mmd
from PFNExperiments.Evaluation.BasicMetrics import compare_Wasserstein
from scipy.stats import wilcoxon

In [4]:
# sample data such that the marginals for x and y are always the same, only the joint distribution changes
P = 5
N_samples = 1000

def samplex(N_samples = N_samples):
    return torch.randn(N_samples, P)

def generate_sample_y_t(t):
    def sampley_t(x):
        """
        Interpolate between sampley3 and sampley1
        t = 1 -> deterministic dependence on x
        t = 0 -> no dependence on x
        """
        xt = 2*t*x

        noise = torch.randn(N_samples, P)*((4 - (2*t)**2)**0.5)

        return xt + noise
    return sampley_t


In [5]:
t_max = 0.99

In [6]:
sampley1 = generate_sample_y_t(t_max)

In [7]:
# consistency check if any difference in the marginals regarding y can be detected

def sample_marginal(n_x_samples, sample_ya, sample_yb):
    res_list = []
    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex()
        x2 = samplex()
        y1 = sample_ya(x1)
        y2 = sample_yb(x2)
        diff_mmd = compare_Wasserstein(y1, y2)

        res_list.append(diff_mmd)

    return pd.DataFrame(res_list)

marginal_y1y2_res = sample_marginal(n_x_samples = 500, sample_ya = sampley1, sample_yb = generate_sample_y_t(0.1))
marginal_y1y3_res = sample_marginal(n_x_samples = 500, sample_ya = sampley1, sample_yb =generate_sample_y_t(0.8))

print(f"""
      marginal_y1y2_res: {marginal_y1y2_res.mean()}
        marginal_y1y3_res: {marginal_y1y3_res.mean()}
        """)

print(wilcoxon(marginal_y1y2_res, marginal_y1y3_res))

100%|██████████| 500/500 [02:13<00:00,  3.75it/s]
100%|██████████| 500/500 [02:17<00:00,  3.65it/s]


      marginal_y1y2_res: Wasserstein_distance with metric euclidean    1.632349
dtype: float64
        marginal_y1y3_res: Wasserstein_distance with metric euclidean    1.632636
dtype: float64
        
WilcoxonResult(statistic=array([62375.]), pvalue=array([0.93835016]))





In [84]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    marginal_y1yt_res = sample_marginal(n_x_samples = 500, sample_ya = sampley1, sample_yb = generate_sample_y_t(t))
    different_t_res.append(marginal_y1yt_res.mean())

print(different_t_res)

100%|██████████| 500/500 [03:38<00:00,  2.29it/s]
100%|██████████| 500/500 [03:27<00:00,  2.41it/s]
100%|██████████| 500/500 [03:27<00:00,  2.41it/s]
100%|██████████| 500/500 [03:27<00:00,  2.41it/s]
100%|██████████| 500/500 [03:24<00:00,  2.44it/s]
100%|██████████| 500/500 [03:26<00:00,  2.43it/s]
100%|██████████| 500/500 [03:34<00:00,  2.33it/s]
100%|██████████| 500/500 [04:04<00:00,  2.04it/s]
100%|██████████| 500/500 [03:54<00:00,  2.14it/s]
100%|██████████| 500/500 [04:16<00:00,  1.95it/s]
100%|██████████| 10/10 [36:40<00:00, 220.01s/it]

[Wasserstein_distance with metric euclidean    1.632527
dtype: float64, Wasserstein_distance with metric euclidean    1.633233
dtype: float64, Wasserstein_distance with metric euclidean    1.631207
dtype: float64, Wasserstein_distance with metric euclidean    1.632713
dtype: float64, Wasserstein_distance with metric euclidean    1.631065
dtype: float64, Wasserstein_distance with metric euclidean    1.63283
dtype: float64, Wasserstein_distance with metric euclidean    1.632659
dtype: float64, Wasserstein_distance with metric euclidean    1.630798
dtype: float64, Wasserstein_distance with metric euclidean    1.631543
dtype: float64, Wasserstein_distance with metric euclidean    1.631402
dtype: float64]





In [85]:
# check expected difference in the posteriors


def sample_posterior(n_x_samples, sample_ya, sample_yb):
    res_list = []
    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex(N_samples=1)
        x2 = x1
        y1 = sample_ya(x1)
        y2 = sample_yb(x2)
        diff_mmd = compare_Wasserstein(y1, y2)

        res_list.append(diff_mmd)

    return pd.DataFrame(res_list)

posterior_y1y2_res = sample_posterior(n_x_samples = 1000, sample_ya = sampley1, sample_yb = generate_sample_y_t(0.1))
posterior_y1y3_res = sample_posterior(n_x_samples = 1000, sample_ya = sampley1, sample_yb = generate_sample_y_t(0.8))

print(f"""
Posterior y1 y2: {posterior_y1y2_res.mean()}    
Posterior y1 y3: {posterior_y1y3_res.mean()}
      """)

print(wilcoxon(posterior_y1y2_res, posterior_y1y3_res))

100%|██████████| 1000/1000 [08:10<00:00,  2.04it/s]
100%|██████████| 1000/1000 [08:48<00:00,  1.89it/s]


Posterior y1 y2: Wasserstein_distance with metric euclidean    5.266666
dtype: float64    
Posterior y1 y3: Wasserstein_distance with metric euclidean    2.149913
dtype: float64
      
WilcoxonResult(statistic=array([0.]), pvalue=array([3.32585912e-165]))





In [86]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    posterior_y1yt_res = sample_posterior(n_x_samples = 1000, sample_ya = sampley1, sample_yb = generate_sample_y_t(t))
    different_t_res.append(posterior_y1yt_res.mean())

print(different_t_res)

100%|██████████| 1000/1000 [09:19<00:00,  1.79it/s]
100%|██████████| 1000/1000 [08:09<00:00,  2.04it/s]
100%|██████████| 1000/1000 [09:09<00:00,  1.82it/s]
100%|██████████| 1000/1000 [08:44<00:00,  1.91it/s]
100%|██████████| 1000/1000 [08:27<00:00,  1.97it/s]
100%|██████████| 1000/1000 [09:04<00:00,  1.83it/s]
100%|██████████| 1000/1000 [05:03<00:00,  3.29it/s]
100%|██████████| 1000/1000 [04:58<00:00,  3.35it/s]
100%|██████████| 1000/1000 [04:34<00:00,  3.65it/s]
100%|██████████| 1000/1000 [03:44<00:00,  4.46it/s]
100%|██████████| 10/10 [1:11:16<00:00, 427.61s/it]

[Wasserstein_distance with metric euclidean    5.53613
dtype: float64, Wasserstein_distance with metric euclidean    5.226158
dtype: float64, Wasserstein_distance with metric euclidean    4.864709
dtype: float64, Wasserstein_distance with metric euclidean    4.458535
dtype: float64, Wasserstein_distance with metric euclidean    4.004716
dtype: float64, Wasserstein_distance with metric euclidean    3.524912
dtype: float64, Wasserstein_distance with metric euclidean    2.984637
dtype: float64, Wasserstein_distance with metric euclidean    2.344583
dtype: float64, Wasserstein_distance with metric euclidean    1.53766
dtype: float64, Wasserstein_distance with metric euclidean    0.230151
dtype: float64]





In [87]:
def sample_joint(
    n_x_samples = 100,
    sample_ya = sampley1,
    sample_yb = sampley2
):
    res_list = []

    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex(N_samples=N_samples)
        x2 = samplex(N_samples=N_samples)

        y1 = sample_ya(x1)
        y2 = sample_yb(x2)

        xy1 = torch.cat([x1, y1], dim=1)
        xy2 = torch.cat([x2, y2], dim=1)

        mmd = compare_Wasserstein(xy1, xy2)

        res_list.append(mmd)

    return pd.DataFrame(res_list)



joint_y1y2_res = sample_joint(n_x_samples=500, sample_ya=sampley1, sample_yb=generate_sample_y_t(0.1))
joint_y1y3_res = sample_joint(n_x_samples=500, sample_ya=sampley1, sample_yb=generate_sample_y_t(0.8))

print(f"""
Joint y1 y2: {joint_y1y2_res.mean()}
Joint y1 y3: {joint_y1y3_res.mean()}
      """)

print(wilcoxon(joint_y1y2_res, joint_y1y3_res))

100%|██████████| 500/500 [02:22<00:00,  3.52it/s]
100%|██████████| 500/500 [02:26<00:00,  3.40it/s]


Joint y1 y2: Wasserstein_distance with metric euclidean    3.186253
dtype: float64
Joint y1 y3: Wasserstein_distance with metric euclidean    2.228198
dtype: float64
      
WilcoxonResult(statistic=array([0.]), pvalue=array([1.26471895e-83]))





In [88]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    joint_y1yt_res = sample_joint(n_x_samples=500, sample_ya=sampley1, sample_yb=generate_sample_y_t(t))
    different_t_res.append(joint_y1yt_res.mean())

print(different_t_res)

100%|██████████| 500/500 [02:38<00:00,  3.16it/s]
100%|██████████| 500/500 [02:31<00:00,  3.31it/s]
100%|██████████| 500/500 [02:22<00:00,  3.50it/s]
100%|██████████| 500/500 [02:18<00:00,  3.60it/s]
100%|██████████| 500/500 [02:18<00:00,  3.60it/s]
100%|██████████| 500/500 [02:16<00:00,  3.67it/s]
100%|██████████| 500/500 [02:35<00:00,  3.22it/s]
100%|██████████| 500/500 [02:15<00:00,  3.69it/s]
100%|██████████| 500/500 [02:14<00:00,  3.73it/s]
100%|██████████| 500/500 [02:10<00:00,  3.82it/s]
100%|██████████| 10/10 [23:41<00:00, 142.14s/it]

[Wasserstein_distance with metric euclidean    3.306616
dtype: float64, Wasserstein_distance with metric euclidean    3.17288
dtype: float64, Wasserstein_distance with metric euclidean    3.038421
dtype: float64, Wasserstein_distance with metric euclidean    2.90107
dtype: float64, Wasserstein_distance with metric euclidean    2.755061
dtype: float64, Wasserstein_distance with metric euclidean    2.606836
dtype: float64, Wasserstein_distance with metric euclidean    2.447963
dtype: float64, Wasserstein_distance with metric euclidean    2.277621
dtype: float64, Wasserstein_distance with metric euclidean    2.090024
dtype: float64, Wasserstein_distance with metric euclidean    1.870655
dtype: float64]



