The code in this notebook replicates the CIFAR-10 vs CIFAR-10.1 experiment of
Liu et al.
(Learning Deep Kernels for Non-Parametric Two-Sample Tests, 
ICML 2020). 
We utilize their code which is under the MIT license:
https://github.com/fengliu90/DK-for-TST/blob/master/Deep_Baselines_CIFAR10.py

# Environment mmdfuse-env

In [None]:
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
from jax import random
import jax.numpy as jnp
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)

In [8]:
from all_tests import mmdfuse_test, mmdfuse_ae_test, mmdfuse_ae_new_test
from all_tests import mmd_median_test, mmd_split_test
from all_tests import mmdagg_test, mmdagginc_test, deep_mmd_test
from all_tests import met_test, scf_test
from all_tests import ctt_test, actt_test

In [None]:
# parameters
N1 = 1000
img_size = 64
batch_size = 100
K = 10
N = 100

In [5]:
# Load the CIFAR 10 data and CIFAR 10.1

# Configure data loader
dataset_test = datasets.CIFAR10(root='./cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000, shuffle=True, num_workers=1)

# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('./cifar_data/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [1]:
# Run experiment

seed = 0
key = random.PRNGKey(42)

tests = (mmdfuse_test, mmd_median_test, mmd_split_test, mmdagg_test, mmdagginc_test, ctt_test, actt_test)

outputs = [[] for _ in range(len(tests))]
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
        train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]
    
    # Run two-sample test on the training set
    # Fetch training data
    s1_tr = data_all[Ind_tr]
    s2_tr = data_trans[Ind_tr_v4]
    
    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans) - N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1_te = data_all_te[Ind_N_te]
        s2_te = data_trans[Ind_te_v4]
        
        # concatenate the split data
        X = jnp.array(torch.concatenate((s1_tr, s1_te)))
        Y = jnp.array(torch.concatenate((s2_tr, s2_te)))
        
        seed += 1
        key, subkey = random.split(key)
        for t in range(len(tests)):
            test = tests[t]
            outputs[t].append(test(X, Y, subkey, seed))

output = jnp.mean(jnp.array(outputs), -1)

if save:
    jnp.save("results/cifar.npy", output)

for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


<function mmdfuse_test at 0x7f2bd070e9d0>
0.9366197
 
<function mmd_median_test at 0x7f291878a430>
0.67800003
 
<function mmd_split_test at 0x7f291878a4c0>
0.25100002
 
<function mmdagg_test at 0x7f291878a550>
0.883
 
<function mmdagginc_test at 0x7f291878a5e0>
0.28077313
 
<function ctt_test at 0x7f291878a820>
0.711
 
<function actt_test at 0x7f291878a8b0>
0.652



# Environment autogluon-env

In [2]:
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch
from jax import random
import jax.numpy as jnp
from tqdm.auto import tqdm
from pathlib import Path
Path("results").mkdir(exist_ok=True)

In [7]:
import autotst
from utils import HiddenPrints

def autotst_test(X, Y, key, seed, time=60):
    with HiddenPrints():
        tst = autotst.AutoTST(X, Y, split_ratio=0.5, model=autotst.model.AutoGluonTabularPredictor)
        tst.split_data()
        tst.fit_witness(time_limit=time)  # time limit adjustable to your needs (in seconds)
        p_value = tst.p_value_evaluate(permutations=10000)  # control number of permutations in the estimation
    return int(p_value <= 0.05)

In [None]:
# parameters
N1 = 1000
img_size = 64
batch_size = 100
K = 10
N = 100

In [8]:
# Load the CIFAR 10 data and CIFAR 10.1

# Configure data loader
dataset_test = datasets.CIFAR10(root='./cifar_data/cifar10', download=True,train=False,
                           transform=transforms.Compose([
                               transforms.Resize(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=10000, shuffle=True, num_workers=1)

# Obtain CIFAR10 images
for i, (imgs, Labels) in enumerate(dataloader_test):
    data_all = imgs
    label_all = Labels
Ind_all = np.arange(len(data_all))

# Obtain CIFAR10.1 images
data_new = np.load('./cifar_data/cifar10.1_v4_data.npy')
data_T = np.transpose(data_new, [0,3,1,2])
ind_M = np.random.choice(len(data_T), len(data_T), replace=False)
data_T = data_T[ind_M]
TT = transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trans = transforms.ToPILImage()
data_trans = torch.zeros([len(data_T),3,img_size,img_size])
data_T_tensor = torch.from_numpy(data_T)
for i in range(len(data_T)):
    d0 = trans(data_T_tensor[i])
    data_trans[i] = TT(d0)
Ind_v4_all = np.arange(len(data_T))

Files already downloaded and verified


In [2]:
# Run experiment

seed = 0
key = random.PRNGKey(42)

tests = (autotst_test, )

outputs = [[] for _ in range(len(tests))]
for kk in tqdm(range(K)):
    torch.manual_seed(kk * 19 + N1)
    torch.cuda.manual_seed(kk * 19 + N1)
    np.random.seed(seed=1102 * (kk + 10) + N1)

    # Collect CIFAR10 images
    Ind_tr = np.random.choice(len(data_all), N1, replace=False)
    Ind_te = np.delete(Ind_all, Ind_tr)
    train_data = []
    for i in Ind_tr:
        train_data.append([data_all[i], label_all[i]])

    dataloader = torch.utils.data.DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
    )

    # Collect CIFAR10.1 images
    np.random.seed(seed=819 * (kk + 9) + N1)
    Ind_tr_v4 = np.random.choice(len(data_T), N1, replace=False)
    Ind_te_v4 = np.delete(Ind_v4_all, Ind_tr_v4)
    New_CIFAR_tr = data_trans[Ind_tr_v4]
    New_CIFAR_te = data_trans[Ind_te_v4]
    
    # Run two-sample test on the training set
    # Fetch training data
    s1_tr = data_all[Ind_tr]
    s2_tr = data_trans[Ind_tr_v4]
    
    for k in tqdm(range(N)):
        # Fetch test data
        np.random.seed(seed=1102 * (k + 1) + N1)
        data_all_te = data_all[Ind_te]
        N_te = len(data_trans) - N1
        Ind_N_te = np.random.choice(len(Ind_te), N_te, replace=False)
        s1_te = data_all_te[Ind_N_te]
        s2_te = data_trans[Ind_te_v4]
        
        # MMD-FUSE & MMDAgg do not split the data
        s1_tr = jnp.array(s1_tr)
        s1_te = jnp.array(s1_te)
        s2_tr = jnp.array(s2_tr)
        s2_te = jnp.array(s2_te)
        
        # concatenate the split data
        X = jnp.concatenate((s1_tr, s1_te))
        Y = jnp.concatenate((s2_tr, s2_te))
        
        seed += 1
        key, subkey = random.split(key)
        for t in range(len(tests)):
            test = tests[t]
            outputs[t].append(client.submit(test, X, Y, subkey, seed))

output = jnp.mean(jnp.array(outputs), -1)

jnp.save("results/cifar_autotst.npy", output)

for t in range(len(tests)):
    print(" ")
    print(tests[t])
    print(output[t])


<function autotst_test at 0x7f4c1cf84af0>
0.5438067

