# 一、 *NIST & USPS

In [79]:
import itertools
import pandas as pd
import numpy as np
from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance


all_names = ['MNIST', 'EMNIST', 'FashionMNIST', 'KMNIST', 'USPS'] #, 'CIFAR10'
name2idx = {task: i for i, task in enumerate(all_names)}
all_loaders = {task: load_torchvision_data(task, resize=28)[0] for task in all_names}
all_otdds = pd.DataFrame(columns=all_names, index=range(len(all_names)))
for comb in list(itertools.combinations(range(len(all_names)), 2)):
    src_name, tgt_name = all_names[comb[0]], all_names[comb[1]]
    print(">>> Computing OTDD for {}-{}".format(src_name, tgt_name))
    src_loader, tgt_loader = all_loaders[src_name], all_loaders[tgt_name]
    
    dist = DatasetDistance(src_loader['test'], tgt_loader['test'],
                          inner_ot_method = 'exact',
                          debiased_loss = True,
                          p = 2, entreg = 1e-1,
                          device='cuda:0')

    d = dist.distance(maxsamples=10000).item()
    print('OTDD({},{})={:.2f}'.format(src_name, tgt_name, d))
    all_otdds.iloc[name2idx[tgt_name]][src_name] = d
    all_otdds.iloc[name2idx[src_name]][tgt_name] = d

Fold Sizes: 54000/6000/10000 (train/valid/test)
Fold Sizes: 112320/12480/20800 (train/valid/test)
Fold Sizes: 54000/6000/10000 (train/valid/test)
Fold Sizes: 54000/6000/10000 (train/valid/test)
Fold Sizes: 6561/730/2007 (train/valid/test)
>>> Computing OTDD for MNIST-EMNIST


Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 87.44it/s]
Computing label-to-label distances: 100%|██████████| 325/325 [00:03<00:00, 103.83it/s]
Computing label-to-label distances: 100%|██████████| 260/260 [00:02<00:00, 110.36it/s]


OTDD(MNIST,EMNIST)=852.10
>>> Computing OTDD for MNIST-FashionMNIST


Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 84.00it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 77.30it/s]
Computing label-to-label distances: 100%|██████████| 100/100 [00:01<00:00, 86.24it/s]


OTDD(MNIST,FashionMNIST)=1233.05
>>> Computing OTDD for MNIST-KMNIST


Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 109.85it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 107.81it/s]
Computing label-to-label distances: 100%|██████████| 100/100 [00:01<00:00, 77.10it/s]


OTDD(MNIST,KMNIST)=1106.96
>>> Computing OTDD for MNIST-USPS


Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 88.84it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 121.12it/s]
Computing label-to-label distances: 100%|██████████| 100/100 [00:00<00:00, 119.85it/s]


OTDD(MNIST,USPS)=960.06
>>> Computing OTDD for EMNIST-FashionMNIST


Computing label-to-label distances: 100%|██████████| 325/325 [00:02<00:00, 114.88it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 107.86it/s]
Computing label-to-label distances: 100%|██████████| 260/260 [00:02<00:00, 106.09it/s]


OTDD(EMNIST,FashionMNIST)=1144.06
>>> Computing OTDD for EMNIST-KMNIST


Computing label-to-label distances: 100%|██████████| 325/325 [00:02<00:00, 115.26it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 109.47it/s]
Computing label-to-label distances: 100%|██████████| 260/260 [00:02<00:00, 106.80it/s]


OTDD(EMNIST,KMNIST)=1100.71
>>> Computing OTDD for EMNIST-USPS


Computing label-to-label distances: 100%|██████████| 325/325 [00:03<00:00, 107.44it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 109.48it/s]
Computing label-to-label distances: 100%|██████████| 260/260 [00:02<00:00, 109.64it/s]


OTDD(EMNIST,USPS)=993.17
>>> Computing OTDD for FashionMNIST-KMNIST


Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 83.76it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 103.90it/s]
Computing label-to-label distances: 100%|██████████| 100/100 [00:00<00:00, 105.87it/s]


OTDD(FashionMNIST,KMNIST)=1235.62
>>> Computing OTDD for FashionMNIST-USPS


Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 76.68it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 114.43it/s]
Computing label-to-label distances: 100%|██████████| 100/100 [00:00<00:00, 112.74it/s]


OTDD(FashionMNIST,USPS)=771.77
>>> Computing OTDD for KMNIST-USPS


Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 100.37it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:00<00:00, 121.84it/s]
Computing label-to-label distances: 100%|██████████| 100/100 [00:00<00:00, 115.17it/s]


OTDD(KMNIST,USPS)=986.04


In [80]:
all_otdds

Unnamed: 0,MNIST,EMNIST,FashionMNIST,KMNIST,USPS
0,,852.098511,1233.050049,1106.957886,960.061523
1,852.098511,,1144.060913,1100.706055,993.173706
2,1233.050049,1144.060913,,1235.623535,771.768372
3,1106.957886,1100.706055,1235.623535,,986.038696
4,960.061523,993.173706,771.768372,986.038696,


In [92]:
from scipy.stats import spearmanr
claimed_otdds = np.array([[np.nan, 1.04, 1.74, 1.41, 1.26],
                          [1.04, np.nan, 1.57, 1.28, 1.32],
                          [1.74, 1.57, np.nan, 1.67, 1.10],
                          [1.41, 1.28, 1.67, np.nan, 1.30],
                          [1.28, 1.32, 1.10, 1.30, np.nan]]) * 1000

all_otdds_f = all_otdds.values.astype(np.float32).flatten()
idxs = ~np.isnan(all_otdds_f)
all_otdds_f = all_otdds_f[idxs]
claimed_otdds_f = claimed_otdds.flatten()
idxs = ~np.isnan(claimed_otdds_f)
claimed_otdds_f = claimed_otdds_f[idxs]
print(spearmanr(all_otdds_f, claimed_otdds_f))
print(claimed_otdds_f.mean()/all_otdds_f.mean())

SpearmanrResult(correlation=0.9310086072192505, pvalue=2.613673456616178e-09)
1.3193960401801275


In [93]:
all_otdds.to_pickle("/home/brian/work/ICLR-Representation_Transferability_X/cv_emd/results/reproduced_cv_raw_otdds.pkl")
pd.DataFrame(claimed_otdds, columns=all_names).to_pickle(
            "/home/brian/work/ICLR-Representation_Transferability_X/cv_emd/results/claimed_cv_raw_otdds.pkl")

In [53]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

def truncate_colormap(cmapIn='jet', minval=0.0, maxval=1.0, n=100):
    '''truncate_colormap(cmapIn='jet', minval=0.0, maxval=1.0, n=100)'''
    cmapIn = plt.get_cmap(cmapIn)
    new_cmap = mpl.colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmapIn.name, a=minval, b=maxval),
        cmapIn(np.linspace(minval, maxval, n)))
    return new_cmap
new_cmap = truncate_colormap('flare', minval=0.2, maxval=0.8)

def plotter(dist_df):
    ax = sns.heatmap(dist_df.fillna(0), cmap='OrRd', annot=True, fmt='.3f', annot_kws={'size': 'small'},
                    cbar=False, linewidths=1.5, mask=dist_df.isna(), vmin=0, vmax=1.02)
    ax.xaxis.tick_top()
    ax.tick_params(axis=u'both', which=u'both', length=0)
plotter(all_otdds)
plt.show()

# 二、MNIST-CIFAR10

In [90]:
import pickle
import torch
from torchvision.models import resnet18

from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance, FeatureCost

# Load MNIST/CIFAR in 3channels (needed by torchvision models)
loaders_src = load_torchvision_data('CIFAR10', resize=28)[0]#, maxsize=20
loaders_tgt = load_torchvision_data('MNIST', resize=28, to3channels=True)[0]# ,maxsize=20

# Embed using a pretrained (+frozen) resnet
# embedder = resnet18(pretrained=True).eval()
# embedder.fc = torch.nn.Identity()
# for p in embedder.parameters():
#     p.requires_grad = False

# Here we use same embedder for both datasets
# feature_cost = FeatureCost(src_embedding = embedder,
#                            src_dim = (3,28,28),
#                            tgt_embedding = embedder,
#                            tgt_dim = (3,28,28),
#                            p = 2,
#                            device='cpu')

dist = DatasetDistance(loaders_src['test'], loaders_tgt['test'],
                          inner_ot_method = 'exact',
                          debiased_loss = True,
                        #   feature_cost = feature_cost,
                          sqrt_method = 'spectral',
                          sqrt_niters=10,
                          precision='single',
                          p = 2, entreg = 1e-1,
                          device='cpu')

d = dist.distance(maxsamples = 10000)
print(f'Embedded OTDD(CIFAR10,MNIST)={d:8.2f}')

d = {('mnist', 'cifar10'): d.item()}
with open("/home/brian/work/ICLR-Representation_Transferability_X/cv_emd/results/mnist_cifar10_raw_otdd.pkl", 'wb') as handle:
    pickle.dump(d, handle)

Files already downloaded and verified
Files already downloaded and verified
Fold Sizes: 45000/5000/10000 (train/valid/test)
Fold Sizes: 54000/6000/10000 (train/valid/test)


Computing label-to-label distances: 100%|██████████| 45/45 [00:04<00:00, 10.13it/s]
Computing label-to-label distances: 100%|██████████| 45/45 [00:04<00:00, 10.73it/s]
Computing label-to-label distances: 100%|██████████| 100/100 [00:09<00:00, 10.81it/s]


Embedded OTDD(CIFAR10,MNIST)= 4322.91
