In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import pandas as pd
from tqdm import tqdm
from PFNExperiments.Evaluation.ClassifcationBasedComparison import compare_samples_classifier_based
from PFNExperiments.Evaluation.MMD import compare_samples_mmd
from scipy.stats import wilcoxon

In [3]:
# sample data such that the marginals for x and y are always the same, only the joint distribution changes

P = 1
N_samples = 1000

def samplex(N_samples = N_samples):
    return torch.randn(N_samples, P)

def generate_sample_y_t(t):
    def sampley_t(x):
        """
        Interpolate between sampley3 and sampley1
        t = 1 -> deterministic dependence on x
        t = 0 -> no dependence on x
        """
        xt = 2*t*x

        noise = torch.randn(N_samples, P)*((4 - (2*t)**2)**0.5)

        return xt + noise
    return sampley_t



sampley_gt = generate_sample_y_t(1)
t_max = 0.99

In [5]:
t_max = 0.99
sampley1 = generate_sample_y_t(t_max)

In [6]:
# consistency check if any difference in the marginals regarding y can be detected

def sample_marginal(n_x_samples, sample_ya, sample_yb):
    res_list = []
    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex()
        x2 = samplex()
        y1 = sample_ya(x1)
        y2 = sample_yb(x2)
        diff_mmd = compare_samples_classifier_based(y1, y2, n_folds=3)

        res_list.append(diff_mmd)

    return pd.DataFrame(res_list)

marginal_y1y2_res = sample_marginal(n_x_samples = 100, sample_ya = sampley1, sample_yb = generate_sample_y_t(0.5))
marginal_y1y3_res = sample_marginal(n_x_samples = 100, sample_ya = sampley1, sample_yb = generate_sample_y_t(0.0))

print(f"""
      marginal_y1y2_res: {marginal_y1y2_res.mean()}
        marginal_y1y3_res: {marginal_y1y3_res.mean()}
        """)

print(wilcoxon(marginal_y1y2_res, marginal_y1y3_res))

100%|██████████| 100/100 [01:51<00:00,  1.12s/it]
100%|██████████| 100/100 [01:11<00:00,  1.39it/s]


      marginal_y1y2_res: CST_accuracy         0.495813
CST_roc_auc_score    0.495844
dtype: float64
        marginal_y1y3_res: CST_accuracy         0.504255
CST_roc_auc_score    0.504296
dtype: float64
        
WilcoxonResult(statistic=array([2120., 2113.]), pvalue=array([0.16376327, 0.15660166]))





In [7]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    marginal_y1yt_res = sample_marginal(n_x_samples = 100, sample_ya = sampley1, sample_yb = generate_sample_y_t(t))
    different_t_res.append(marginal_y1yt_res.mean())

print(different_t_res)

100%|██████████| 100/100 [39:17<00:00, 23.57s/it]
100%|██████████| 100/100 [00:41<00:00,  2.44it/s]]
100%|██████████| 100/100 [00:38<00:00,  2.58it/s] 
100%|██████████| 100/100 [00:39<00:00,  2.54it/s]
100%|██████████| 100/100 [00:39<00:00,  2.51it/s]
100%|██████████| 100/100 [00:46<00:00,  2.14it/s]
100%|██████████| 100/100 [01:31<00:00,  1.09it/s]
100%|██████████| 100/100 [01:34<00:00,  1.06it/s]
100%|██████████| 100/100 [01:20<00:00,  1.24it/s]
100%|██████████| 100/100 [01:16<00:00,  1.31it/s]
100%|██████████| 10/10 [48:25<00:00, 290.58s/it]

[CST_accuracy         0.491727
CST_roc_auc_score    0.491760
dtype: float64, CST_accuracy         0.506863
CST_roc_auc_score    0.506919
dtype: float64, CST_accuracy         0.502577
CST_roc_auc_score    0.502617
dtype: float64, CST_accuracy         0.494149
CST_roc_auc_score    0.494214
dtype: float64, CST_accuracy         0.499879
CST_roc_auc_score    0.499929
dtype: float64, CST_accuracy         0.498941
CST_roc_auc_score    0.499018
dtype: float64, CST_accuracy         0.501841
CST_roc_auc_score    0.501903
dtype: float64, CST_accuracy         0.500224
CST_roc_auc_score    0.500403
dtype: float64, CST_accuracy         0.498377
CST_roc_auc_score    0.498454
dtype: float64, CST_accuracy         0.501729
CST_roc_auc_score    0.501846
dtype: float64]





In [8]:
# check expected difference in the posteriors


def sample_posterior(n_x_samples, sample_ya, sample_yb):
    res_list = []
    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex(N_samples=1)
        x2 = samplex(N_samples=1)
        y1 = sample_ya(x1)
        y2 = sample_yb(x2)
        print(y1.shape, y2.shape)
        diff_mmd = compare_samples_classifier_based(y1, y2, n_folds=2)

        res_list.append(diff_mmd)

    return pd.DataFrame(res_list)

posterior_y1y2_res = sample_posterior(n_x_samples = 100, sample_ya = sampley1, sample_yb = generate_sample_y_t(0.5))
posterior_y1y3_res = sample_posterior(n_x_samples = 100, sample_ya = sampley1, sample_yb = generate_sample_y_t(0.0))

print(f"""
Posterior y1 y2: {posterior_y1y2_res.mean()}    
Posterior y1 y3: {posterior_y1y3_res.mean()}
      """)

print(wilcoxon(posterior_y1y2_res, posterior_y1y3_res))

  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  1%|          | 1/100 [00:00<01:22,  1.20it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  2%|▏         | 2/100 [00:01<01:00,  1.62it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  3%|▎         | 3/100 [00:01<00:56,  1.72it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  4%|▍         | 4/100 [00:02<00:54,  1.77it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  5%|▌         | 5/100 [00:02<00:54,  1.75it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  6%|▌         | 6/100 [00:03<00:55,  1.69it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  7%|▋         | 7/100 [00:04<00:56,  1.64it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  8%|▊         | 8/100 [00:04<00:54,  1.69it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  9%|▉         | 9/100 [00:05<00:55,  1.65it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 10%|█         | 10/100 [00:05<00:51,  1.76it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 11%|█         | 11/100 [00:06<00:47,  1.87it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 12%|█▏        | 12/100 [00:06<00:45,  1.92it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 13%|█▎        | 13/100 [00:07<00:46,  1.86it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 14%|█▍        | 14/100 [00:07<00:45,  1.91it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 15%|█▌        | 15/100 [00:08<00:43,  1.97it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 16%|█▌        | 16/100 [00:08<00:41,  2.00it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 17%|█▋        | 17/100 [00:09<00:42,  1.94it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 18%|█▊        | 18/100 [00:09<00:42,  1.94it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 19%|█▉        | 19/100 [00:10<00:41,  1.95it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 20%|██        | 20/100 [00:11<00:45,  1.75it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 21%|██        | 21/100 [00:11<00:43,  1.82it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 22%|██▏       | 22/100 [00:12<00:41,  1.88it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 23%|██▎       | 23/100 [00:12<00:38,  1.98it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 24%|██▍       | 24/100 [00:13<00:40,  1.90it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 25%|██▌       | 25/100 [00:13<00:39,  1.91it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 26%|██▌       | 26/100 [00:14<00:39,  1.89it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 27%|██▋       | 27/100 [00:14<00:39,  1.83it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 28%|██▊       | 28/100 [00:15<00:37,  1.93it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 29%|██▉       | 29/100 [00:15<00:36,  1.96it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 30%|███       | 30/100 [00:16<00:35,  1.96it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 31%|███       | 31/100 [00:16<00:38,  1.81it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 32%|███▏      | 32/100 [00:17<00:36,  1.85it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 33%|███▎      | 33/100 [00:17<00:35,  1.87it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 34%|███▍      | 34/100 [00:18<00:35,  1.86it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 35%|███▌      | 35/100 [00:19<00:35,  1.82it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 36%|███▌      | 36/100 [00:19<00:35,  1.82it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 37%|███▋      | 37/100 [00:20<00:34,  1.82it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 38%|███▊      | 38/100 [00:20<00:35,  1.73it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 39%|███▉      | 39/100 [00:21<00:37,  1.65it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 40%|████      | 40/100 [00:22<00:35,  1.69it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 41%|████      | 41/100 [00:22<00:39,  1.50it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 42%|████▏     | 42/100 [00:23<00:38,  1.50it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 43%|████▎     | 43/100 [00:24<00:35,  1.62it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 44%|████▍     | 44/100 [00:24<00:33,  1.68it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 45%|████▌     | 45/100 [00:25<00:30,  1.80it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 46%|████▌     | 46/100 [00:25<00:29,  1.85it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 47%|████▋     | 47/100 [00:26<00:28,  1.85it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 48%|████▊     | 48/100 [00:26<00:28,  1.85it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 49%|████▉     | 49/100 [00:27<00:29,  1.73it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 50%|█████     | 50/100 [00:27<00:27,  1.85it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 51%|█████     | 51/100 [00:28<00:27,  1.81it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 52%|█████▏    | 52/100 [00:28<00:25,  1.88it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 53%|█████▎    | 53/100 [00:29<00:24,  1.94it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 54%|█████▍    | 54/100 [00:29<00:23,  1.94it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 55%|█████▌    | 55/100 [00:30<00:23,  1.89it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 56%|█████▌    | 56/100 [00:30<00:23,  1.91it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 57%|█████▋    | 57/100 [00:31<00:22,  1.95it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 58%|█████▊    | 58/100 [00:31<00:21,  1.95it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 59%|█████▉    | 59/100 [00:32<00:22,  1.82it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 60%|██████    | 60/100 [00:33<00:21,  1.90it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 61%|██████    | 61/100 [00:33<00:20,  1.88it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 62%|██████▏   | 62/100 [00:34<00:19,  1.95it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 63%|██████▎   | 63/100 [00:34<00:18,  1.99it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 64%|██████▍   | 64/100 [00:34<00:17,  2.03it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 65%|██████▌   | 65/100 [00:35<00:18,  1.94it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 66%|██████▌   | 66/100 [00:36<00:18,  1.79it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 67%|██████▋   | 67/100 [00:36<00:18,  1.76it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 68%|██████▊   | 68/100 [00:37<00:17,  1.88it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 69%|██████▉   | 69/100 [00:37<00:17,  1.73it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 70%|███████   | 70/100 [00:38<00:16,  1.83it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 71%|███████   | 71/100 [00:38<00:15,  1.87it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 72%|███████▏  | 72/100 [00:39<00:15,  1.81it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 73%|███████▎  | 73/100 [00:40<00:14,  1.85it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 74%|███████▍  | 74/100 [00:40<00:14,  1.77it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 75%|███████▌  | 75/100 [00:41<00:14,  1.69it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 76%|███████▌  | 76/100 [00:41<00:13,  1.73it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 77%|███████▋  | 77/100 [00:42<00:13,  1.72it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 78%|███████▊  | 78/100 [00:42<00:12,  1.81it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 79%|███████▉  | 79/100 [00:43<00:13,  1.60it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 80%|████████  | 80/100 [00:44<00:12,  1.58it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 81%|████████  | 81/100 [00:44<00:11,  1.66it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 82%|████████▏ | 82/100 [00:45<00:10,  1.71it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 83%|████████▎ | 83/100 [00:45<00:09,  1.77it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 84%|████████▍ | 84/100 [00:46<00:09,  1.77it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 85%|████████▌ | 85/100 [00:47<00:08,  1.77it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 86%|████████▌ | 86/100 [00:47<00:08,  1.73it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 87%|████████▋ | 87/100 [00:48<00:07,  1.85it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 88%|████████▊ | 88/100 [00:48<00:06,  1.90it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 89%|████████▉ | 89/100 [00:49<00:06,  1.70it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 90%|█████████ | 90/100 [00:49<00:05,  1.78it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 91%|█████████ | 91/100 [00:50<00:04,  1.83it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 92%|█████████▏| 92/100 [00:50<00:04,  1.87it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 93%|█████████▎| 93/100 [00:51<00:03,  1.87it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 94%|█████████▍| 94/100 [00:51<00:03,  1.94it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 95%|█████████▌| 95/100 [00:52<00:02,  1.97it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 96%|█████████▌| 96/100 [00:52<00:02,  1.99it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 97%|█████████▋| 97/100 [00:53<00:01,  2.00it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 98%|█████████▊| 98/100 [00:53<00:00,  2.01it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 99%|█████████▉| 99/100 [00:54<00:00,  1.93it/s]

torch.Size([100, 1]) torch.Size([100, 1])


100%|██████████| 100/100 [00:54<00:00,  1.82it/s]
  0%|          | 0/100 [00:00<?, ?it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  1%|          | 1/100 [00:00<00:50,  1.95it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  2%|▏         | 2/100 [00:01<00:50,  1.94it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  3%|▎         | 3/100 [00:01<00:47,  2.03it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  4%|▍         | 4/100 [00:01<00:46,  2.05it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  5%|▌         | 5/100 [00:02<00:45,  2.07it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  6%|▌         | 6/100 [00:02<00:46,  2.04it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  7%|▋         | 7/100 [00:03<00:46,  2.02it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  8%|▊         | 8/100 [00:04<00:53,  1.71it/s]

torch.Size([100, 1]) torch.Size([100, 1])


  9%|▉         | 9/100 [00:04<00:51,  1.76it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 10%|█         | 10/100 [00:05<00:51,  1.76it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 11%|█         | 11/100 [00:05<00:49,  1.80it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 12%|█▏        | 12/100 [00:06<00:46,  1.88it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 13%|█▎        | 13/100 [00:06<00:45,  1.92it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 14%|█▍        | 14/100 [00:07<00:43,  1.98it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 15%|█▌        | 15/100 [00:07<00:42,  2.02it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 16%|█▌        | 16/100 [00:08<00:43,  1.95it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 17%|█▋        | 17/100 [00:08<00:42,  1.97it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 18%|█▊        | 18/100 [00:09<00:40,  2.03it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 19%|█▉        | 19/100 [00:09<00:41,  1.95it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 20%|██        | 20/100 [00:10<00:44,  1.79it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 21%|██        | 21/100 [00:11<00:43,  1.83it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 22%|██▏       | 22/100 [00:11<00:43,  1.80it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 23%|██▎       | 23/100 [00:12<00:47,  1.64it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 24%|██▍       | 24/100 [00:12<00:44,  1.71it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 25%|██▌       | 25/100 [00:13<00:43,  1.72it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 26%|██▌       | 26/100 [00:13<00:40,  1.83it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 27%|██▋       | 27/100 [00:14<00:38,  1.90it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 28%|██▊       | 28/100 [00:14<00:36,  1.98it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 29%|██▉       | 29/100 [00:15<00:35,  2.03it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 30%|███       | 30/100 [00:15<00:34,  2.05it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 31%|███       | 31/100 [00:16<00:33,  2.07it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 32%|███▏      | 32/100 [00:16<00:32,  2.07it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 33%|███▎      | 33/100 [00:17<00:31,  2.11it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 34%|███▍      | 34/100 [00:17<00:31,  2.10it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 35%|███▌      | 35/100 [00:18<00:32,  1.99it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 36%|███▌      | 36/100 [00:18<00:31,  2.04it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 37%|███▋      | 37/100 [00:19<00:30,  2.08it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 38%|███▊      | 38/100 [00:19<00:29,  2.07it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 39%|███▉      | 39/100 [00:20<00:29,  2.07it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 40%|████      | 40/100 [00:20<00:29,  2.02it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 41%|████      | 41/100 [00:21<00:28,  2.04it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 42%|████▏     | 42/100 [00:21<00:30,  1.91it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 43%|████▎     | 43/100 [00:22<00:35,  1.62it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 44%|████▍     | 44/100 [00:23<00:39,  1.41it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 45%|████▌     | 45/100 [00:24<00:38,  1.44it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 46%|████▌     | 46/100 [00:24<00:38,  1.39it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 47%|████▋     | 47/100 [00:25<00:36,  1.43it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 48%|████▊     | 48/100 [00:26<00:35,  1.48it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 49%|████▉     | 49/100 [00:26<00:32,  1.57it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 50%|█████     | 50/100 [00:27<00:32,  1.55it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 51%|█████     | 51/100 [00:28<00:32,  1.52it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 52%|█████▏    | 52/100 [00:28<00:31,  1.54it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 53%|█████▎    | 53/100 [00:29<00:29,  1.60it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 54%|█████▍    | 54/100 [00:29<00:28,  1.64it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 55%|█████▌    | 55/100 [00:30<00:28,  1.61it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 56%|█████▌    | 56/100 [00:31<00:27,  1.61it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 57%|█████▋    | 57/100 [00:31<00:28,  1.50it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 58%|█████▊    | 58/100 [00:32<00:26,  1.56it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 59%|█████▉    | 59/100 [00:33<00:24,  1.64it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 60%|██████    | 60/100 [00:33<00:23,  1.70it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 61%|██████    | 61/100 [00:34<00:22,  1.73it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 62%|██████▏   | 62/100 [00:34<00:21,  1.77it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 63%|██████▎   | 63/100 [00:35<00:20,  1.81it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 64%|██████▍   | 64/100 [00:35<00:19,  1.83it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 65%|██████▌   | 65/100 [00:36<00:19,  1.84it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 66%|██████▌   | 66/100 [00:36<00:19,  1.72it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 67%|██████▋   | 67/100 [00:37<00:21,  1.57it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 68%|██████▊   | 68/100 [00:38<00:19,  1.64it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 69%|██████▉   | 69/100 [00:38<00:17,  1.75it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 70%|███████   | 70/100 [00:39<00:16,  1.86it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 71%|███████   | 71/100 [00:39<00:15,  1.88it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 72%|███████▏  | 72/100 [00:40<00:14,  1.89it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 73%|███████▎  | 73/100 [00:40<00:13,  1.95it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 74%|███████▍  | 74/100 [00:41<00:13,  1.94it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 75%|███████▌  | 75/100 [00:41<00:13,  1.88it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 76%|███████▌  | 76/100 [00:42<00:13,  1.77it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 77%|███████▋  | 77/100 [00:43<00:13,  1.76it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 78%|███████▊  | 78/100 [00:43<00:12,  1.70it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 79%|███████▉  | 79/100 [00:44<00:12,  1.69it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 80%|████████  | 80/100 [00:44<00:12,  1.63it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 81%|████████  | 81/100 [00:45<00:11,  1.65it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 82%|████████▏ | 82/100 [00:46<00:11,  1.61it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 83%|████████▎ | 83/100 [00:46<00:10,  1.59it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 84%|████████▍ | 84/100 [00:47<00:09,  1.68it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 85%|████████▌ | 85/100 [00:47<00:09,  1.62it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 86%|████████▌ | 86/100 [00:48<00:09,  1.53it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 87%|████████▋ | 87/100 [00:49<00:08,  1.62it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 88%|████████▊ | 88/100 [00:49<00:07,  1.56it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 89%|████████▉ | 89/100 [00:50<00:06,  1.61it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 90%|█████████ | 90/100 [00:51<00:05,  1.73it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 91%|█████████ | 91/100 [00:51<00:04,  1.82it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 92%|█████████▏| 92/100 [00:52<00:04,  1.86it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 93%|█████████▎| 93/100 [00:52<00:03,  1.87it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 94%|█████████▍| 94/100 [00:53<00:03,  1.92it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 95%|█████████▌| 95/100 [00:53<00:02,  1.90it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 96%|█████████▌| 96/100 [00:54<00:02,  1.84it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 97%|█████████▋| 97/100 [00:54<00:01,  1.87it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 98%|█████████▊| 98/100 [00:55<00:01,  1.97it/s]

torch.Size([100, 1]) torch.Size([100, 1])


 99%|█████████▉| 99/100 [00:55<00:00,  2.02it/s]

torch.Size([100, 1]) torch.Size([100, 1])


100%|██████████| 100/100 [00:56<00:00,  1.78it/s]


Posterior y1 y2: CST_accuracy         0.85215
CST_roc_auc_score    0.85215
dtype: float64    
Posterior y1 y3: CST_accuracy         0.84045
CST_roc_auc_score    0.84045
dtype: float64
      
WilcoxonResult(statistic=array([2021. , 2006.5]), pvalue=array([0.15173056, 0.1375889 ]))





In [9]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    posterior_y1yt_res = sample_posterior(n_x_samples = 100, sample_ya = sampley1, sample_yb = generate_sample_y_t(t))
    different_t_res.append(posterior_y1yt_res.mean())

print(different_t_res)

  0%|          | 0/10 [00:00<?, ?it/s]

torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]
 10%|█         | 1/10 [00:52<07:53, 52.66s/it]

torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])


100%|██████████| 100/100 [00:52<00:00,  1.90it/s]
 20%|██        | 2/10 [01:45<07:01, 52.63s/it]

torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])


100%|██████████| 100/100 [00:50<00:00,  1.99it/s]
 30%|███       | 3/10 [02:35<06:01, 51.59s/it]

torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])


100%|██████████| 100/100 [00:54<00:00,  1.83it/s]
 40%|████      | 4/10 [03:30<05:16, 52.78s/it]

torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])




torch.Size([100, 1]) torch.Size([100, 1])


 32%|███▏      | 32/100 [00:17<00:36,  1.85it/s]
 40%|████      | 4/10 [03:47<05:41, 56.88s/it]


KeyboardInterrupt: 

In [None]:
def sample_joint(
    n_x_samples = 1000,
    sample_ya = sampley1,
    sample_yb = sampley2
):
    res_list = []

    for i in tqdm(list(range(n_x_samples))):
        x1 = samplex(N_samples=N_samples)
        x2 = samplex(N_samples=N_samples)

        y1 = sample_ya(x1)
        y2 = sample_yb(x2)

        xy1 = torch.cat([x1, y1], dim=1)
        xy2 = torch.cat([x2, y2], dim=1)

        mmd = compare_samples_classifier_based(xy1, xy2, n_folds=3)

        res_list.append(mmd)

    return pd.DataFrame(res_list)



joint_y1y2_res = sample_joint(n_x_samples=100, sample_ya=sampley1, sample_yb=generate_sample_y_t(0.5))
joint_y1y3_res = sample_joint(n_x_samples=100, sample_ya=sampley1, sample_yb=generate_sample_y_t(0.0))

print(f"""
Joint y1 y2: {joint_y1y2_res.mean()}
Joint y1 y3: {joint_y1y3_res.mean()}
      """)

print(wilcoxon(joint_y1y2_res, joint_y1y3_res))

 20%|█▉        | 197/1000 [00:10<00:42, 18.71it/s]

In [None]:
different_t_res = []
for t in tqdm(torch.linspace(0, t_max, 10)):
    joint_y1yt_res = sample_joint(n_x_samples=100, sample_ya=sampley1, sample_yb=generate_sample_y_t(t))
    different_t_res.append(joint_y1yt_res.mean())

print(different_t_res)

  0%|          | 0/10 [00:00<?, ?it/s]