In [None]:
import ot
import torch
import numpy as np
import pandas as pd
from itertools import product
from sklearn.neighbors import KDTree
import plotly.express as px
from plotly.subplots import make_subplots
from IPython.display import clear_output
from scipy.stats import pearsonr, spearmanr, kendalltau, norm, multivariate_normal, entropy
from pathlib import Path

import util
import ffjord as C
from Distributions import GaussianMixtureDA, Gaussian
from FitVAE import VAE, decode_data
# from FitKNN import logp_knn, logp_knn_screening, detect_breakpoint
from RealNVP import RealNVP
import GradualSelfTrain as G
from mainCNF import settings

In [None]:
plt_settings = dict(margin=dict(t=30, b=30, r=30),
                    font=dict(family="PTSerif", size=18),
                    legend=dict(orientation="h", bordercolor="Black",
                                borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5))

colors = [list(px.colors.hex_to_rgb(_hex)) for _hex in px.colors.qualitative.Plotly]

ks = np.array([5, 10, 15, 20, 30])

def make_subplots_wrapper(rows: int, cols:int, **kwargs):
    fig = make_subplots(rows=rows, cols=cols, **kwargs)
    pos = np.array(np.meshgrid(np.arange(rows)+1, np.arange(cols)+1)).T.reshape(-1,2)
    return fig, pos
make_subplots_wrapper.__doc__ = make_subplots.__doc__


def remove_unnecessary_layout(fig):
    # remove the title of legend
    fig.layout['legend']['title']['text'] = ''
    # remove "=" on the title of each columns
    for layout in fig.layout['annotations']:
        layout['text'] = layout['text'].split('=')[1]
    return fig


def facet_col_warp_parser(facet_col_wrap:int, array:list) -> list:
    position = []
    count = 1
    row_count = 0
    col_count = 1
    for i in DATASET.keys():
        position.append((row_count, col_count))
        col_count += 1
        if count == facet_col_wrap:
            row_count += 1
            col_count = 1
            count = 0
        count += 1
    return position


def load_result(file_name:str):
    # load args
    args = pd.read_pickle(f'./result/args_{file_name}.pkl')
    # load data
    x_all, y_all = pd.read_pickle(f'./data/data_{args.dataset}.pkl')[args.n_dim]
    x_eval, y_eval = x_all.pop(), y_all.pop()
    if args.dataset == 'mnist_dense':
        given_domain = [0, args.inter_index, 28]
        x_all = [x_all[i].copy() for i in given_domain]
        y_all = [y_all[i].copy() for i in given_domain]
    elif args.dataset == 'gaussian':
        given_domain = [0, args.inter_index, 4]
        x_all = [x_all[i].copy() for i in given_domain]
        y_all = [y_all[i].copy() for i in given_domain]
    if args.no_inter:
        x_all = [x_all[0].copy(), x_all[-1].copy()]
        y_all = [y_all[0].copy(), y_all[-1].copy()]
    # load model
    cnf = C.build_model_tabular(args, args.n_dim, None)
    cnf = util.load_model(cnf, f'./result/state_{file_name}.tar')
    cnf = util.torch_to(cnf)
    # load base distribution
    if args.base_distribution == 'normal':
        prior = Gaussian(args.n_dim, args.seed,)
    elif args.base_distribution == 'gmm':
        prior = GaussianMixtureDA(args.n_dim, args.n_class, args.seed, args.mean_r)
    # predict source and target(eval) data
    s_acc = C.predict_target(cnf, prior, x_all[0], y_all[0], 0)[-1]
    # t_acc = C.predict_target(cnf, prior, x_all[-1], y_all[-1], len(x_all)-2)[-1]
    t_acc = C.predict_target(cnf, prior, x_eval, y_eval, len(x_all)-1)[-1]
    # show loss history
    fig, lh = util.plot_loss_history(f'./result/lh_{file_name}.pkl')
    return x_all, y_all, args, cnf, prior, s_acc, t_acc, fig, lh


def make_summary(*settings):
    
    def get_accuracy_and_loss(file_name):
        res = load_result(file_name)
        loss_min = np.sum(res[-1], axis=0).min()
        return pd.Series([res[5], res[6], loss_min])  # source accuracy, target accuracy, min loss
    
    df = list(product(*settings))
    df = pd.DataFrame(df, columns=[f'v{i}' for i in range(len(settings))])
    df['fn'] = df.apply(lambda s: '_'.join(s), axis=1)
    df[['accuracy_s', 'accuracy_t', 'loss']] = df['fn'].apply(get_accuracy_and_loss)
    clear_output()
    return df


def make_cv_summary(*settings):
    df = list(product(*settings))
    df = pd.DataFrame(df, columns=[f'v{i}' for i in range(len(settings))])
    df['fn'] = df.apply(lambda s: '_'.join(s), axis=1)
    cv_result = df['fn'].apply(lambda fn: pd.Series(pd.read_pickle(f'./result/cv_{fn}.pkl')[1]))
    df['mean'] = cv_result.mean(axis=1)
    df['std'] = cv_result.std(axis=1)
    return df


def load_baselne_result(file_name:str):
    # load args
    args = pd.read_pickle(f'./result/args_{file_name}.pkl')
    if 'eaml' in file_name:
        return None, None, args, None, args.accuracy_score
    # load data
    x_all, y_all = pd.read_pickle(f'./data/data_{args.dataset}.pkl')[args.n_dim]
    x_eval, y_eval = x_all.pop(), y_all.pop()
    # load model
    try:
        model = G.MLP(num_labels=args.n_class, input_dim=args.n_dim, hidden_dim=args.hidden_dim)
        model = util.load_model(model, f'./result/state_{file_name}.tar')
    except:
        print(f'param mismatch: {file_name}')
        map_location = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        param = torch.load(f'./result/state_{file_name}.tar', map_location=map_location)
        n_labels = param[list(param.keys())[-1]].size()[0]
        model = G.MLP(num_labels=n_labels, input_dim=args.n_dim, hidden_dim=args.hidden_dim)
        model = util.load_model(model, f'./result/state_{file_name}.tar')
    # predict target(eval) data
    t_acc = G.calc_accuracy(model, x_eval, y_eval)
    return x_all, y_all, args, model, t_acc


def make_baseline_summary(*settings):
    df = list(product(*settings))
    df = pd.DataFrame(df, columns=[f'v{i}' for i in range(len(settings))])
    df['fn'] = df.apply(lambda s: '_'.join(s), axis=1)
    df['accuracy_t'] = df['fn'].apply(lambda s: load_baselne_result(s)[-1])
    return df


def correlation(array1, array2):
    """ return (pearsonr, spearmanr, kendalltau)"""
    r1 = pearsonr(array1, array2)[0]
    r2 = spearmanr(array1, array2)[0]
    r3 = kendalltau(array1, array2)[0]
    return (r1, r2, r3)

def knnKL(Xs, Xt, k:int):
    Gs = KDTree(Xs, metric='euclidean')
    Gt = KDTree(Xt, metric='euclidean')
    
    rho, _ = Gs.query(Xs, k=k+1) # return index and distance
    rho = rho[:,-1].copy()
    
    k_buffer = 30 if k * 2 < 30 else k * 2
    nu_buffer, _ = Gt.query(Xs, k=k_buffer)
    n_zero = (nu_buffer == 0).sum(axis=1)
    if np.all(n_zero == 1): n_zero *= 0
    nu = np.array([_nu[k-1 + _nz] for _nu, _nz in zip(nu_buffer, n_zero)])
    
    n, d = Xs.shape
    m = Xt.shape[0] - n_zero
    div = np.log(nu / rho) + np.log(m / (n - 1))
    div = (d / n) * div.sum()
    
    return div

def marginal_knnKL(Xs, Xt, Ys, Yt, k:int, mode:str='KL'):
    
    xs_size, xt_size = Xs.shape[0], Xt.shape[0]
    if xs_size > xt_size:
        np.random.seed(12) 
        idx = np.random.choice(np.arange(xs_size), size=xt_size, replace=True)
        Xs = Xs[idx]
        Ys = Ys[idx]
    
    Y = np.hstack([Ys, Yt])

    dist = []
    for c in np.unique(Y):
        Py = (Y == c).sum() / Y.size
        
        s_idx = (Ys == c)
        t_idx = (Yt == c)

        _xs = Xs[s_idx].copy()
        _xt = Xt[t_idx].copy()
        
        if mode=='KL':
            _dist = knnKL(_xs, _xt, k)
        elif mode=='JS':
            _dist = knnJS(_xs, _xt, k)
            
        dist.append(_dist * Py)

    return np.sum(dist)


def knnJS(Xs, Xt, k:int, alpha=0.5):    
    Ns, Nt = Xs.shape[0], Xt.shape[0]
    M_size = min(Ns, Nt)
    
    np.random.seed(12) 
    prior = np.random.uniform(size=M_size)
    S_size = (prior < alpha).sum()
    T_size = M_size - S_size
    S_idx = np.random.choice(np.arange(Ns), size=S_size, replace=True)
    T_idx = np.random.choice(np.arange(Nt), size=T_size, replace=True)
    
    M = np.vstack([Xs[S_idx], 
                   Xt[T_idx]])
    
    return alpha * knnKL(Xs, M, k) + (1-alpha) * knnKL(Xt, M, k)


def compute_Ej(x_all, y_all, k:int):
    assert len(x_all) == 3
    xs, xi, xt = x_all
    ys, yi, yt = y_all
    
    # w intermediate
    forward = 0.5 * marginal_knnKL(xs, xi, ys, yi, k) 
    backward = 0.5 * marginal_knnKL(xt, xi, yt, yi, k)
    beta = forward + backward
    
    # w/o intermediate
    js = marginal_knnKL(xs, xt, ys, yt, k, mode='JS')
    
    # KL between source and target
    kl = marginal_knnKL(xs, xt, ys, yt, k, mode='KL')
    
    return (beta - js) / kl

# Gaussian Setting

In [None]:
CONDITIONS = ['1', '2', '3'] 
SEED = ['1', '2', '3', '4', '5']
result = make_summary(['gaussian_inter'], CONDITIONS, SEED)
result = result.drop(['v0'], axis=1).rename(columns={'v1':'inter_index', 'v2':'seed'})
summary = result.groupby(by=['inter_index'], as_index=False)
summary = summary.agg(mean=('accuracy_t','mean'), std=('accuracy_t','std'))

divergence = {}
for c in CONDITIONS:
    fn = f'gaussian_inter_{c}_1'
    x_all, y_all, args, _, _, _, _, _, _ = load_result(fn)
    clear_output()
    
    k = args.log_prob_param
    Ej = compute_Ej(x_all, y_all, k)
    
    divergence[c] = {'inter_index':c, 'Ej': Ej}

divergence = pd.DataFrame(divergence).T
summary = pd.merge(summary, divergence, on='inter_index').astype(float)
summary.round(3)

In [None]:
means1 = np.array([(3.0, 1.0),
                   (6.0, 3.0),
                   (8.0, 3.0),
                   (3.0, 3.0),
                   (3.0, 5.0)])

means2 = np.array([(-3.0, 1.0),
                   (-6.0, 3.0),
                   (-8.0, 3.0),
                   (-3.0, 3.0),
                   (-3.0, 5.0)])

domain = np.array(['source', 'inter1', 'inter2', 'inter3', 'target'])

means1 = pd.DataFrame(means1, columns=['x1', 'x2'])
means1['class'] = "1"
means1['domain'] = domain

means2 = pd.DataFrame(means2, columns=['x1', 'x2'])
means2['class'] = "2"
means2['domain'] = domain

df = pd.concat([means1, means2])

fig = px.scatter(df, x='x1', y='x2', color='class', 
                 labels={'x1':'first component of mean vector', 
                         'x2':'second component of mean vector'})
fig.add_annotation(x=-3, y=5, text="Target domain", xshift=90, showarrow=False)
fig.add_annotation(x=-3, y=3, text="Intermediate", xshift=90, showarrow=False)
fig.add_annotation(x=-3, y=1, text="Source domain", xshift=90, showarrow=False)
fig.update_traces(marker=dict(size=10))
fig.update_layout(width=650, height=500, **plt_settings)
fig.update_layout(legend_title_text='mean vector')
fig.data[0]['name'] = 'class label 1'
fig.data[1]['name'] = 'class label 2'
fig.show()

# DNF vs. CNF

In [None]:
file_name = 'moon_seed_1'

# CNF
x_all, y_all, args, cnf, prior = load_result(file_name)[:5]
inter = torch.tensor(x_all[1])
c_convert = C.visualize_trajectory_forward(cnf, inter, 1, .1)
subset = [0, 2, 6, 10]
c_convert = [c_convert[i].numpy() for i in subset]  # time -> 2.0, 1.8, 1.4, 1.0

# DNF
args.dims = [int(i) for i in args.dims.split('-')]
dnf = RealNVP(n_flows=len(args.dims), data_dim=args.n_dim, n_hidden=args.dims[0])
dnf = util.load_model(dnf, f'./result/state_realnvp_{file_name}.tar')
_ = dnf.forward(inter)
d_convert = dnf.inter_repr

clear_output()

In [None]:
titles = ['Intermediate', '', '', 'Source',
          't=2.0', 't=1.8', 't=1.4', 't=1.0',
          'Intermediate', 'Output of Block 1', 'Output of Block 2', 'Output of Block 3']
fig, pos = make_subplots_wrapper(rows=3, cols=4, horizontal_spacing=0.01, vertical_spacing=0.08, subplot_titles=titles)

# plot Ground Truth
for i in reversed(range(2)):
    x, y = x_all[i].copy(), y_all[i].copy()
    # color_plt = px.scatter(x=x[:,0], y=x[:,1], color=y.astype(str))
    color_plt = px.scatter(x=x[:,0], y=x[:,1])
    p = 3 if i == 0 else 0
    fig.add_traces(color_plt.data, rows=pos[p][0], cols=pos[p][1])
    
inter_y = y_all[1].copy().astype(str)
# plot CNF
for i, p in enumerate(pos[4:8]):
    x_hat = c_convert[i]
    # color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1], color=inter_y)
    color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1])
    fig.add_traces(color_plt.data, rows=p[0], cols=p[1])

# plot DNF
for i, p in enumerate(pos[8:]):
    x_hat = d_convert[i]
    # color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1], color=inter_y)
    color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1])
    fig.add_traces(color_plt.data, rows=p[0], cols=p[1])

# set layout
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(title_text='Ground Truth', row=1, col=1)
fig.update_yaxes(title_text='CNF', row=2, col=1)
fig.update_yaxes(title_text='DNF', row=3, col=1)
fig.update_layout(width=850, height=600, showlegend=False, **plt_settings)
fig.update_annotations(dict(font={'family': 'PTSerif', 'size': 21}))
fig.show()
fig.write_image('./fig/fig_realnvp.pdf')

# w/o intermediate

In [None]:
DATA = 'block'
CONDITIONS = ['seed', 'nointer']
DEBUG = True

if DATA == 'moon':
    titles = ['Target', '', 'Intermediate', '', 'Source']
    time = ['t=3.0', 't=2.5', 't=2.0', 't=1.5', 't=1.0']
    time_no_inter = ['t=2.0', 't=1.7', 't=1.5', 't=1.2', 't=1.0']
    share_axes = False
elif DATA == 'block':
    titles = ['Target', 'Intermediate1', 'Intermediate2', 'Source']
    time = ['t=4.0', 't=3.0', 't=2.0', 't=1.0']
    time_no_inter = ['t=2.0', 't=1.6', 't=1.4', 't=1.0']
    share_axes = True
    
subset = {'moon_seed_1': [0, 5, 10, 15, 20],
          'moon_nointer_1': [0, 3, 5, 8, 10],
          'block_seed_1': [0, 10, 20, 30],
          'block_nointer_1': [0, 4, 6, 10]}


titles = titles + time + time_no_inter
fig, pos = make_subplots_wrapper(rows=len(CONDITIONS)+1, cols=len(time), horizontal_spacing=0.01, vertical_spacing=0.06,
                                 subplot_titles=titles, shared_xaxes=share_axes, shared_yaxes=share_axes)

# plot Ground Truth
i = 0
file_name = f'{DATA}_seed_1'
x_all, y_all, args, cnf, prior, _, t_acc, _, _ = load_result(file_name)
for x, y in zip(x_all[::-1], y_all[::-1]):
    color_plt = px.scatter(x=x[:,0], y=x[:,1], color=(y+1).astype(str))
    if i > 0:
        for g in range(len(color_plt.data)):
            color_plt.data[g]['showlegend'] = False

    fig.add_traces(color_plt.data, rows=pos[i][0], cols=pos[i][1])
    
    if DATA == 'moon':
        i += 2
    elif DATA == 'block':
        i += 1

# plot CNF
i = 5 if DATA == 'moon' else 4
for c in CONDITIONS:
    file_name = f'{DATA}_{c}_1'
    x_all, y_all, args, cnf, prior, _, t_acc, _, _ = load_result(file_name)
    print(f'{file_name}, {t_acc}\n')  
    
    target_x = torch.tensor(x_all[-1], dtype=torch.float32)
    target_y = y_all[-1].copy() + 1 
    t0 = 1 if 'nointer' in file_name else len(x_all)-1
    c_convert = C.visualize_trajectory_forward(cnf, target_x, t0, .1)
    c_convert = [c_convert[i].numpy() for i in subset[file_name]]
    
    if DEBUG:
        time_for_debug = np.arange(0, t0+1+0.1, 0.1)[::-1]
        print(time_for_debug[subset[file_name]])
    else:
        clear_output()

    for x_hat in c_convert:
        color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1], color=target_y.astype(str))
        for g in range(len(color_plt.data)):
            color_plt.data[g]['showlegend'] = False
        fig.add_traces(color_plt.data, rows=pos[i][0], cols=pos[i][1])
        i += 1

# layout settings
fig.update_layout(width=800, height=700, **plt_settings)
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(title_text='Ground Truth', row=1, col=1)
fig.update_yaxes(title_text='w/ Intermediate', row=2, col=1)
fig.update_yaxes(title_text='w/o Intermediate', row=3, col=1)
fig.update_layout(legend=dict(y=-0.1, title_text='Class Label'))
fig.update_annotations(dict(font={'family': 'PTSerif', 'size': 22}))
fig.show()
fig.write_image(f'./fig/fig_no_inter_{DATA}.pdf')

# Gaussian Mixture estimators for the log-likelihood

In [None]:
convert = []
for file_name in ['moon_seed_1', 'moon_q30_1']:
    x_all, y_all, args, cnf, prior = load_result(file_name)[:5]
    target = torch.tensor(x_all[2])
    z_all = C.visualize_trajectory_forward(cnf, target, 2, .5)
    convert.append(z_all[:5]) # time -> 3.0, 2.5, 2.0, 1.5, 1.0
convert = sum(convert, [])
clear_output()

In [None]:
titles = ['Target', '', 'Intermediate', '', 'Source',
          't=3.0', 't=2.5', 't=2.0', 't=1.5', 't=1.0',
          't=3.0', 't=2.5', 't=2.0', 't=1.5', 't=1.0']
fig, pos = make_subplots_wrapper(rows=3, cols=5, horizontal_spacing=0.01, vertical_spacing=0.06, subplot_titles=titles)

# plot Ground Truth
p = 0
for i in reversed(range(3)):
    x, y = x_all[i].copy(), y_all[i].copy()
    # color_plt = px.scatter(x=x[:,0], y=x[:,1], color=y.astype(str))
    color_plt = px.scatter(x=x[:,0], y=x[:,1])
    fig.add_traces(color_plt.data, rows=pos[p][0], cols=pos[p][1])
    p += 2

# plot kNN and GMM
pos = pos[5:]
target_y = y_all[2].copy().astype(str)
for p, x_hat in zip(pos, convert):
    # color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1], color=target_y)
    color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1])
    fig.add_traces(color_plt.data, rows=p[0], cols=p[1])

# set layout
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(title_text='Ground Truth', row=1, col=1)
fig.update_yaxes(title_text='kNN', row=2, col=1)
fig.update_yaxes(title_text='fitted GM', row=3, col=1)
fig.update_layout(width=850, height=550, showlegend=False, **plt_settings)
fig.update_annotations(dict(font={'family': 'PTSerif', 'size': 22}))
fig.show()
fig.write_image('./fig/fig_fitted_gmm.pdf')

# OT vs. NF

In [None]:
# Settings
METHOD = {'gst':'GradualSelfTrain', 'goat':'GOAT', 'saux':'Sequential AuxSelfTrain'}
DATASET = {'moon':'Two Moon'}
CONDITIONS = ['seed']
SEED = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']

# Baselines
baselines = make_baseline_summary(METHOD, DATASET, SEED)
baselines = baselines.rename(columns={'v0':'method', 'v1':'data','v2':'seed'})
baselines['method'] = baselines['method'].map(METHOD)
baselines = baselines.groupby('method', as_index=False).agg(mean=('accuracy_t','mean'), std=('accuracy_t','std'))
baselines.round(3)

In [None]:
from datasets2 import make_gradual_data

x_all, y_all = make_gradual_data()
x_eval, y_eval = x_all.pop(), y_all.pop()

# generate by OT
ot = [x_all[0].copy()]
for i in range(len(x_all))[:-1]:
    xs, xt = x_all[i].copy(), x_all[i+1].copy()
    ys = y_all[0].copy() if i==0 else None
    ot += G.generate_domains(1, xs, xt, ys)
ot = ot[::-1]

# generated by AuxSelfTrain
aux = [x_all[0].copy()]
_, _aux = G.AuxSelfTrain(x_all[:2], y_all[:2], num_inter=3) # source -> inter
aux.append(_aux[0])
aux.append(x_all[1])
_, _aux = G.AuxSelfTrain(x_all[1:], y_all[1:], num_inter=3) # inter -> target
aux.append(_aux[0])
aux.append(x_all[2])
aux = aux[::-1]

X = x_all[::-1] + ot + aux

In [None]:
titles = ['Target', 'Generated', 'Intermediate', 'Generated', 'Source',]
fig, pos = make_subplots_wrapper(rows=3, cols=5, horizontal_spacing=0.01, vertical_spacing=0.05, subplot_titles=titles)
pos = np.delete(pos, [1, 3], axis=0)

for p, x in zip(pos, X):
    color_plt = px.scatter(x=x[:,0], y=x[:,1])
    fig.add_traces(color_plt.data, rows=p[0], cols=p[1])
    
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(title_text='Ground Truth', row=1, col=1)
fig.update_yaxes(title_text='OT', row=2, col=1)
fig.update_yaxes(title_text='Mixed Samples', row=3, col=1)
fig.update_layout(width=850, height=550, showlegend=False, **plt_settings)
fig.update_annotations(dict(font={'size':22}))
fig.write_image('./fig/fig_simple_interpolate.pdf')
fig.show()

# UMAP

In [None]:
DATASET = {'mnist':'Rotating MNIST', 'portraits':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1',
           'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors',}
CONDITIONS = ["4d", "32d"]

df = make_cv_summary(DATASET.keys(), CONDITIONS)
df = df.rename(columns=dict(v0='data', v1='dim'))
df['dim'] = df['dim'].str.replace('d', '').astype(int)
df['data'] = df['data'].map(DATASET)
df

# mean_r

## theoretical

In [None]:
def get_probability(n_dims:int, n_labels:int, mean_r:int):
    means = GaussianMixtureDA(n_dims=n_dims, n_labels=n_labels, mean_r=mean_r).get_means(mean_r=mean_r).numpy()
    d = np.mean(means[:2], axis=0)
    prior1 = multivariate_normal(means[0], np.eye(n_dims))
    prior2 = multivariate_normal(means[1], np.eye(n_dims))
    return max(prior1.pdf(d), prior2.pdf(d))

def draw_circle(mean, radius):
    theta = np.linspace(0, 2*np.pi, 1000)[:-1]
    x = mean[0] + radius * np.cos(theta)
    y = mean[1] + radius * np.sin(theta)
    return x, y

In [None]:
n_dims = 4
n_labels = 2
alpha = 1e-3
rs = np.arange(1, 21 ,1)

probs = np.array([get_probability(n_dims, n_labels, mean_r=r) for r in rs])
min_r = rs[probs < alpha].min()
print(min_r)

fig = px.scatter(x=rs, y=probs, labels={'x':'r', 'y':'probability'}, width=500, height=450)
fig.add_hline(y=alpha, line_width=2, line_dash="dash", line_color="LightSeaGreen")
fig.show()

if n_dims == 2:
    prior = GaussianMixtureDA(n_dims=n_dims, n_labels=n_labels, mean_r=min_r)    
    means = prior.get_means(mean_r=min_r).numpy()
    radius = np.linalg.norm(means[1]-means[0], ord=2) / 2
    x, _ = prior.sample(total_size=n_labels*1000)
    color = np.hstack([i*np.ones(1000, dtype=int) for i in range(n_labels)])
    
    fig = px.scatter(x=x[:,0], y=x[:,1], color=color.astype(str), labels={'x':'', 'y':''}, width=500, height=450)
    for mean in means:
        x, y = draw_circle(mean, radius)
        fig.add_shape(type="circle", xref="x", yref="y", x0=min(x), y0=min(y), x1=max(x), y1=max(y), line_color="LightSeaGreen",)
    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
    fig.show()

In [None]:
include_title = False
n_dims = 2
n_labels = [2, 5, 10]
rs = np.arange(1, 15, 0.5)

probs = []
for nl in n_labels:
    p = np.array([get_probability(n_dims=n_dims, n_labels=nl, mean_r=r) for r in rs])
    probs.append(p)
probs = pd.DataFrame(np.array(probs).T, columns=n_labels, index=rs)

title = '(a) Number of Dimensions is Fixed' if include_title else None

fig1 = px.scatter(probs, labels={'value':'U(r)', 'index':'r'}, title=title)
fig1.update_layout(plt_settings, legend_title_text='number of classes', title_x=0.5, xaxis=dict(dtick=3), yaxis=dict(dtick=0.05),
                   width=550, height=500, margin=dict(t=50))
fig1.write_image('./fig/fig_mean_r_dimension.pdf')
fig1.show()

In [None]:
n_dims = [2, 4, 8]
n_labels = 10
rs = np.arange(1, 15, 0.5)

probs = []
for nd in n_dims:
    p = np.array([get_probability(n_dims=nd, n_labels=n_labels, mean_r=r) for r in rs])
    probs.append(p)
probs = pd.DataFrame(np.array(probs).T, columns=n_dims, index=rs)

title = '(b) Number of Classes is Fixed' if include_title else None

fig2 = px.scatter(probs, labels={'value':'U(r)', 'index':'r'}, title=title)
fig2.update_layout(plt_settings, legend_title_text='number of dimensions', title_x=0.5, xaxis=dict(dtick=3), yaxis=dict(dtick=0.05),
                   width=550, height=500, margin=dict(t=50))
fig2.write_image('./fig/fig_mean_r_class.pdf')
fig2.show()

## experimental

In [None]:
DATASET = {'mnist':'Rotating MNIST', 'portraits':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1',
           'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors'}
CONDITIONS = ['r0.5', 'r1', 'r3', 'r5', 'r8', 'r10', 'r15','r20']

df = make_cv_summary(DATASET, CONDITIONS)
df = df.rename(columns=dict(v0='data', v1='r'))
df['r'] = df['r'].str.replace('r','').astype(float)
df['data'] = df['data'].map(DATASET)

Ur = []
alpha = 1e-3
rs = np.arange(1, 21 ,1)
for name in DATASET.keys():
    n_dims, n_labels = settings[name]
    probs = np.array([get_probability(n_dims, n_labels, mean_r=r) for r in rs])
    min_r = rs[probs < alpha].min()
    Ur.append(min_r)
    
facet_col_wrap = 4
positions = facet_col_warp_parser(facet_col_wrap, DATASET.keys())
fig = px.scatter(df, x='r', y='mean', facet_col='data', error_y='std', width=1200, height=600,
                 facet_col_wrap=facet_col_wrap, facet_row_spacing=0.15,
                 category_orders={'data':DATASET.values()}, labels={'mean':'Accuracy'})
for u, pos in zip(Ur, positions):
    fig.add_vline(u, row=pos[0], col=pos[1], line_width=1, line_dash="dash", line_color="red")
fig.update_layout(plt_settings)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(showticklabels=True).update_annotations(font=dict(size=22))
fig.write_image('./fig/fig_mean_r.pdf')
fig.show()

### Two Moon

In [None]:
x_all, y_all, args, cnf, prior, s_acc, _, _, _ = load_result('moon_sourceonly_r3_1')

x = x_all[0].copy()
y = (y_all[0] + 1).astype(str)

z = C.visualize_trajectory_forward(cnf, torch.tensor(x), 0, 1)
z = z[1]

org = np.array([0, 0 ])
mu1 = prior.gaussians[0].mean.numpy()
mu2 = prior.gaussians[1].mean.numpy()
r = np.vstack([org, mu1])

fig = px.scatter(x=x[:,0], y=x[:,1], color=y, labels={'x':'', 'y':''})
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_layout(plt_settings, font=dict(size=22), 
                  legend=dict(y=-0.13, title_text='Class Label'), margin=dict(l=30), width=550, height=500)
# fig.update_layout(title_text='Source Domain',title_x=0.5)
fig.write_image('./fig/fig_mean_moon_a.pdf')
fig.show()


my_colors = px.colors.qualitative.Plotly

fig = px.scatter(x=z[:,0], y=z[:,1], color=y, opacity=0.1, labels={'x':'', 'y':''})
fig = remove_unnecessary_layout(fig)
fig.add_scattergl(x=[mu2[0]], y=[mu2[1]], mode='markers', legendgroup=2, marker=dict(size=12, color=my_colors[1]), name='mean vector 2')
fig.add_scattergl(x=[mu1[0]], y=[mu1[1]], mode='markers', legendgroup=3,marker=dict(size=12, color=my_colors[0]), name='mean vector 1')
fig.add_scattergl(x=r[:,0], y=r[:, 1], mode='lines', legendgroup=4, line=dict(color='black', width=3, dash='dot'), name='r') # 'dash', 'dot', 'dashdot'
fig.data = fig.data[::-1]
fig.add_annotation(x=mu1[0]/2, y=mu1[1], text="<b>r</b>", showarrow=False, yshift=12, font=dict(size=22),)

fig.data[0]['showlegend'] = True
fig.data[1]['showlegend'] = True
fig.data[2]['showlegend'] = True
fig.data[3]['showlegend'] = False
fig.data[4]['showlegend'] = False

fig.update_xaxes(showticklabels=False, showgrid=False, range=[-6.5, 6.5])
fig.update_yaxes(showticklabels=False, showgrid=False, range=[-6.5, 6.5])
fig.update_layout(plt_settings, font=dict(size=22), margin=dict(l=30), legend=dict(y=-0.13,), width=550, height=500)
# fig.update_layout(title_text='Gaussian Mixture (r=3.0)', title_x=0.5)
fig.write_image('./fig/fig_mean_moon_b.pdf')
fig.show()



x_all, y_all, args, cnf, prior, s_acc, _, _, _ = load_result('moon_sourceonly_r0.5_1')

x = x_all[0].copy()
y = (y_all[0] + 1).astype(str)

z = C.visualize_trajectory_forward(cnf, torch.tensor(x), 0, 1)
z = z[1]

org = np.array([0, 0 ])
mu1 = prior.gaussians[0].mean.numpy()
mu2 = prior.gaussians[1].mean.numpy()
r = np.vstack([org, mu1])

my_colors = px.colors.qualitative.Plotly

fig = px.scatter(x=z[:,0], y=z[:,1], color=y, opacity=0.1, labels={'x':'', 'y':''})
fig = remove_unnecessary_layout(fig)
fig.add_scattergl(x=[mu2[0]], y=[mu2[1]], mode='markers', legendgroup=2, marker=dict(size=12, color=my_colors[1]), name='mean vector 2')
fig.add_scattergl(x=[mu1[0]], y=[mu1[1]], mode='markers', legendgroup=3,marker=dict(size=12, color=my_colors[0]), name='mean vector 1')
fig.add_scattergl(x=r[:,0], y=r[:, 1], mode='lines', legendgroup=4, line=dict(color='black', width=3, dash='dot'), name='r') # 'dash', 'dot', 'dashdot'
fig.data = fig.data[::-1]
fig.add_annotation(x=mu1[0]/2, y=mu1[1], text="<b>r</b>", showarrow=False, yshift=12, font=dict(size=22),)

fig.data[0]['showlegend'] = True
fig.data[1]['showlegend'] = True
fig.data[2]['showlegend'] = True
fig.data[3]['showlegend'] = False
fig.data[4]['showlegend'] = False

fig.update_xaxes(showticklabels=False, showgrid=False, range=[-6.5, 6.5])
fig.update_yaxes(showticklabels=False, showgrid=False, range=[-6.5, 6.5])
fig.update_layout(plt_settings, font=dict(size=22), margin=dict(l=30), legend=dict(y=-0.13,), width=550, height=500)
# fig.update_layout(title_text='Gaussian Mixture (r=0.5)', title_x=0.5)
fig.write_image('./fig/fig_mean_moon_c.pdf')
fig.show()

### Block

In [None]:
x_all, y_all, args, cnf, prior, s_acc, _, _, _ = load_result('block_sourceonly_r10_1')

x = x_all[0].copy()
y = (y_all[0] + 1).astype(str)

z = C.visualize_trajectory_forward(cnf, torch.tensor(x, dtype=torch.float32), 0, 1)
z = z[1]

org = np.array([0, 0 ])
mu1 = prior.gaussians[0].mean.numpy()
mu2 = prior.gaussians[1].mean.numpy()
mu3 = prior.gaussians[2].mean.numpy()
mu4 = prior.gaussians[3].mean.numpy()
mu5 = prior.gaussians[4].mean.numpy()
r = np.vstack([org, mu1])

fig = px.scatter(x=x[:,0], y=x[:,1], color=y, labels={'x':'', 'y':''})
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_layout(plt_settings, font=dict(size=22), 
                  legend=dict(y=-0.13, title_text='Class Label'), margin=dict(l=10), width=550, height=500)
fig.write_image('./fig/fig_mean_block_a.pdf')
fig.show()

my_colors = px.colors.qualitative.Plotly

fig = px.scatter(x=z[:,0], y=z[:,1], color=y, opacity=0.1, labels={'x':'', 'y':''})
fig = remove_unnecessary_layout(fig)
fig.add_scattergl(x=[mu1[0]], y=[mu1[1]], mode='markers', legendgroup=10,marker=dict(size=12, color=my_colors[0]), name='mean vector 1')
fig.add_scattergl(x=r[:,0], y=r[:, 1], mode='lines', legendgroup=11, line=dict(color='black', width=3, dash='dot'), name='r') # 'dash', 'dot', 'dashdot'
fig.data = fig.data[::-1]
fig.add_annotation(x=mu1[0]/2, y=mu1[1], text="<b>r</b>", showarrow=False, yshift=12, font=dict(size=22),)

fig.data[0]['showlegend'] = True
fig.data[1]['showlegend'] = True
fig.data[2]['showlegend'] = False
fig.data[3]['showlegend'] = False
fig.data[4]['showlegend'] = False
fig.data[5]['showlegend'] = False
fig.data[6]['showlegend'] = False

fig.update_xaxes(showticklabels=False, showgrid=False, range=[-12.5, 13.5])
fig.update_yaxes(showticklabels=False, showgrid=False, range=[-12.5, 13.5])
fig.update_layout(plt_settings, font=dict(size=22), margin=dict(l=10), legend=dict(y=-0.13,), width=550, height=500)
# fig.update_layout(title_text='Gaussian Mixture (r=3.0)', title_x=0.5)
fig.write_image('./fig/fig_mean_block_b.pdf')
fig.show()


x_all, y_all, args, cnf, prior, s_acc, _, _, _ = load_result('block_sourceonly_r2_1')

x = x_all[0].copy()
y = (y_all[0] + 1).astype(str)

z = C.visualize_trajectory_forward(cnf, torch.tensor(x, dtype=torch.float32), 0, 1)
z = z[1]

org = np.array([0, 0 ])
mu1 = prior.gaussians[0].mean.numpy()
mu2 = prior.gaussians[1].mean.numpy()
mu3 = prior.gaussians[2].mean.numpy()
mu4 = prior.gaussians[3].mean.numpy()
mu5 = prior.gaussians[4].mean.numpy()
r = np.vstack([org, mu1])

my_colors = px.colors.qualitative.Plotly

fig = px.scatter(x=z[:,0], y=z[:,1], color=y, opacity=0.1, labels={'x':'', 'y':''})
fig = remove_unnecessary_layout(fig)
fig.add_scattergl(x=[mu1[0]], y=[mu1[1]], mode='markers', legendgroup=10,marker=dict(size=12, color=my_colors[0]), name='mean vector 1')
fig.add_scattergl(x=r[:,0], y=r[:, 1], mode='lines', legendgroup=11, line=dict(color='black', width=3, dash='dot'), name='r') # 'dash', 'dot', 'dashdot'
fig.data = fig.data[::-1]
fig.add_annotation(x=mu1[0]/2, y=mu1[1], text="<b>r</b>", showarrow=False, yshift=12, font=dict(size=22),)

fig.data[0]['showlegend'] = True
fig.data[1]['showlegend'] = True
fig.data[2]['showlegend'] = False
fig.data[3]['showlegend'] = False
fig.data[4]['showlegend'] = False
fig.data[5]['showlegend'] = False
fig.data[6]['showlegend'] = False

fig.update_xaxes(showticklabels=False, showgrid=False, range=[-12.5, 13.5])
fig.update_yaxes(showticklabels=False, showgrid=False, range=[-12.5, 13.5])
fig.update_layout(plt_settings, font=dict(size=22), margin=dict(l=10), legend=dict(y=-0.13,), width=550, height=500)
# fig.update_layout(title_text='Gaussian Mixture (r=3.0)', title_x=0.5)
fig.write_image('./fig/fig_mean_block_c.pdf')
fig.show()

# kNN

In [None]:
DATASET = {'mnist':'Rotating MNIST', 'portraits':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1',
           'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors'}
CONDITIONS = ['k5', "k10", 'k15', "k20", 'k30'] 
SEED = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']

df = []
for d in DATASET:
    for c in CONDITIONS:
        file_name = f'./result/knn_summary_{d}_{c}.pkl'
        # print(file_name)
        if Path(file_name).exists():
            res = pd.read_pickle(file_name)
        else:
            res = make_summary([d], [c], SEED)
            pd.to_pickle(res, file_name)
        df.append(res)

df = pd.concat(df)
df = df.rename(columns=dict(v0="data", v1="k", v2="seed"))
df['k'] = df['k'].str.replace('k','').astype(int)
df['seed'] = df['seed'].astype(int)
df['data'] = df['data'].map(DATASET)
df = df.groupby(by=['data', 'k'], as_index=False)
df = df.agg(lm=('loss','mean'), ls=('loss','std'),
            sm=('accuracy_s','mean'), ss=('accuracy_s','std'),
            tm=('accuracy_t','mean'), ts=('accuracy_t','std'),)

In [None]:
facet_col_wrap = 4
positions = facet_col_warp_parser(facet_col_wrap, DATASET.keys())
fig = px.scatter(df, x='k', y='tm', facet_col='data', error_y='ts', width=1200, height=600,
                 facet_col_wrap=facet_col_wrap, facet_row_spacing=0.15, 
                 category_orders={'data':DATASET.values()}, labels={'tm':'Accuracy'})
fig.update_layout(plt_settings)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(showticklabels=True, nticks=10).update_annotations(font=dict(size=22))
fig.write_image('./fig/fig_knn.pdf')
fig.show()

## log-likelihood

In [None]:
DATASET = {'mnist':'Rotating MNIST', 'portraits':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1', 
           'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors'}

df = {}
for d in DATASET:
    ll_result = pd.read_pickle(f'./result/likelihood_{d}.pkl')
    mse = np.array([ks[np.argmin(res)] for res in ll_result['mse']])
    df[d] = mse
df = pd.DataFrame(df).T
df.columns = ['source', 'intermediate', 'target']
df.index = df.index.map(DATASET)
df

# Intermediate Index

In [None]:
def deg2radstr(deg:float):
    deg = float(deg)
    if deg == 0:
        return "0"
    denominator = np.pi / np.deg2rad(deg)
    deg_check = np.rad2deg(np.pi / denominator)
    if np.isclose(deg_check, deg):
        denominator = round(denominator, 2)
        return '$\pi/' + str(denominator) + '$'
    else:
        raise

In [None]:
ll_result = pd.read_pickle('./result/likelihood_mnist_index.pkl')
mse = np.array([ks[np.argmin(res)] for res in ll_result['mse']])
mse = pd.DataFrame(mse.reshape(-1,1).T, columns=ll_result['index'], index=['k'])
mse

In [None]:
num_inter_domain = 27
angles = np.linspace(0, 60, num_inter_domain+2)

CONDITIONS = ['4', "8", '12', "14", '16', '20', '24'] 
SEED = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']


inter = []
for c in CONDITIONS:
    file_name = f'./result/summary_mnist_index_{c}.pkl'
    print(file_name)
    if Path(file_name).exists():
        res = pd.read_pickle(file_name)
    else:
        res = make_summary(['mnist_index'], [c], ['seed'], SEED)
        pd.to_pickle(res, file_name)
    inter.append(res)

inter = pd.concat(inter)
inter = inter.drop(['v0', 'v2'], axis=1).rename(columns={'v1':'index_int', 'v3':'seed'})
inter['index_int'] = inter['index_int'].astype(int)
inter['index_int'] = inter['index_int'].apply(lambda i: angles[i])
inter['index_str'] = inter['index_int'].apply(deg2radstr)

for c, k in zip(CONDITIONS, mse.values[0]):
    file_name = '_'.join(['mnist', 'index', c, 'seed', '1'])    
    x_all, y_all, args, _, _, _, _, _, _ = load_result(file_name)
    # clear_output()
    idx = inter['fn'].str.contains(file_name[:-2])
    print(f'index: {c}, k={k}')
    inter.loc[idx, 'Ej'] = compute_Ej(x_all, y_all, k=k)

nointer = make_summary(['mnist_nointer_seed'], SEED)
nointer = nointer.drop(['v0'], axis=1).rename(columns={'v1':'seed'})
nointer['index_str'] = "$w/o \; Intermediate$"
    
df = pd.concat([inter, nointer])
df = df.groupby(by=['index_str'], as_index=False)
df = df.agg(index_int=('index_int','mean'), mean=('accuracy_t','mean'), std=('accuracy_t','std'), Ej=('Ej','mean'))

index_order = [deg2radstr(angles[int(c)]) for c in CONDITIONS]
index_order.append('$w/o \; Intermediate$')
index_order = {index_order[i]:i for i in range(len(index_order))}

df = df.sort_values('index_str', key=lambda col: col.map(index_order))
df.round(3)


In [None]:
x = df['index_str'].values[:-1]
y_mean = df['mean'].values[:-1]
y_std = df['std'].values[:-1]
Ej = df['Ej'].values[:-1]

fig, pos = make_subplots_wrapper(rows=2, cols=1, vertical_spacing=0.03, shared_xaxes=True)
fig.add_scatter(x=x, y=Ej, row=pos[0][0], col=pos[0][1], marker=dict(color="#636EFA"))
fig.add_scatter(x=x, y=y_mean, error_y=dict(type='data', array=y_std, visible=True), 
                row=pos[1][0], col=pos[1][1], marker=dict(color="#636EFA"))
fig.update_yaxes(title_text=r'$E(j=2)$', row=1, col=1)
fig.update_yaxes(title_text=r'$\text{Accuracy}$', row=2, col=1)
fig.update_layout(width=850, height=350, showlegend=False,**plt_settings)
fig.show()
fig.write_image('./fig/fig_index_mnist.pdf')

# Comparison with baseline methods

### box plot

In [None]:
# Settings
METHOD = {'sourceonly':'SourceOnly', 'gst':'GradualSelfTrain', 'goat':'GOAT',
          'gift':'GIFT', 'sgift':'Sequential GIFT', 'aux':'AuxSelfTrain', 'saux':'Sequential AuxSelfTrain', 'eaml':'EAML'}
DATASET = {'mnist':'Rotating MNIST', 'portraits':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1', 
           'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors'}
CONDITIONS = ['seed']
SEED = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']

# Ours
ours = []
for d in DATASET:
    file_name = f'./result/summary_{d}.pkl'
    if Path(file_name).exists():
        res = pd.read_pickle(file_name)
    else:
        res = make_summary([d], CONDITIONS, SEED)
        pd.to_pickle(res, file_name)
    ours.append(res)
ours = pd.concat(ours)
ours = ours.drop('v1', axis=1).rename(columns={'v0':'data','v2':'seed'})
ours['data'] = ours['data'].map(DATASET)
ours['method'] = 'Ours'

# Baselines
baselines = make_baseline_summary(METHOD, DATASET, SEED)
baselines = baselines.rename(columns={'v0':'method', 'v1':'data','v2':'seed'})
baselines['data'] = baselines['data'].map(DATASET)
baselines['method'] = baselines['method'].map(METHOD)

# Ours vs. Baselines
df = pd.concat([ours, baselines], ignore_index=True)

# make plot
fig = px.box(df, x='method', y='accuracy_t', color='method', facet_col='data', 
             facet_col_wrap=4, facet_row_spacing=0.1, labels={'accuracy_t':'Accuracy'})
fig.update_layout(plt_settings).update_layout(width=1000, height=600)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(visible=False)
fig.update_layout(font=dict(family="PTSerif", size=24), legend=dict(y=-0.2, x=0.5))
fig.write_image('./fig/fig_baseline_mr.pdf')
fig.show()

### Discrepancy between domains

In [None]:
OPT_K = {'mnist': 5, 'portraits': 5, 'shift15m': 5, 'rxrx1': 10, 'tox21a': 15, 'tox21b': 30, 'tox21c': 30}

nointer, result = [], {}
for d, k in OPT_K.items():
    # w/o intermediate
    file_name = f'./result/summary_nointer_{d}.pkl'
    print(file_name)
    if Path(file_name).exists():
        res = pd.read_pickle(file_name)
    else:
        res = make_summary([d], ['nointer_seed'], SEED)
        pd.to_pickle(res, file_name)
    nointer.append(res)
    
    # KL div
    fn = f'{d}_seed_1'
    x_all, y_all, args, _, _, _, _, _, _ = load_result(fn)
    clear_output()
    result[DATASET[d]] = compute_Ej(x_all, y_all, k)

    
nointer = pd.concat(nointer)
nointer = nointer.drop('v1', axis=1).rename(columns={'v0':'data','v2':'seed'})
nointer['data'] = nointer['data'].map(DATASET)
nointer['method'] = 'Ours_nointer'

nointer_agg = nointer.groupby('data').agg(without_inter_mean=('accuracy_t','mean'), without_inter_std=('accuracy_t','std'))
ours_agg = ours.groupby('data').agg(with_inter_mean=('accuracy_t','mean'), with_inter_std=('accuracy_t','std'))
df = pd.merge(ours_agg, nointer_agg, on='data')
df['diff'] = (df['with_inter_mean'] - df['without_inter_mean']) / df['without_inter_mean']
df['Ej'] = df.index.map(result)
df = df.reindex(index=DATASET.values())
df.round(3)

# VAE

In [None]:
file_name = 'mnist_vae_4D_1'

# load trained CNF
x_all, y_all, args, cnf, prior, s_acc, t_acc, fig, lh = load_result(file_name)
clear_output()

# load trained VAE
vae = VAE((1,28,28), args.n_dim)
vae = util.load_model(vae, f'./data/state_mnist_vae_{args.n_dim}D.tar')

# we choose 5 class for visualize
target_label = [1, 3, 4, 6, 9]
y = y_all[0].copy()

# select one sample from each class
np.random.seed(987)
idx = np.hstack([np.random.choice(np.where(y==l)[0], size=1, replace=False) for l in target_label])

# get original imgae
img_org = np.array([decode_data(vae, x[idx]) for x in x_all])

# get generate image
z = C.visualize_trajectory_forward(cnf, torch.tensor(x_all[0][idx]), 0)[-1]
logpz = prior.log_prob(z)[1]
x_hat = C.visualize_trajectory_backward(cnf, z, logpz, 3, 0.5)
x_hat = x_hat[2:]
_generate = [decode_data(vae, x) for x in x_hat]

# stack original image
original = []
for j in range(len(target_label)):
    s, i, t = img_org[:,j,:,:]
    blank = np.zeros_like(s)
    new_img = np.hstack([s, blank, i, blank, t])
    original.append(new_img)
original = np.vstack(original)

# stack generate image
generate = []
for i in range(len(target_label)):
    new_img = []
    for k in range(len(_generate)):
        new_img.append(_generate[k][i])
    generate.append(np.hstack(new_img))
generate = np.vstack(generate)

# show image
img = np.array([original, generate])

fig = px.imshow(img, facet_col=0, facet_col_wrap=1, facet_row_spacing=0.1)
fig.layout['annotations'][1]['text'] = '(a) Originals'
fig.layout['annotations'][1]['x'] = 0.48
fig.layout['annotations'][0]['text'] = '(b) Samples from CNF'
fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
fig.update_coloraxes(colorscale='gray', showscale=False)
fig.update_layout(plt_settings)
fig.update_layout(margin=dict(t=30, b=10, r=0, l=0), width=200, height=450)
fig.write_image(f'./fig/fig_vae_vertical.pdf')
fig.show()


fig = px.imshow(img, facet_col=0, facet_col_wrap=2, facet_row_spacing=0.1)
fig.layout['annotations'][0]['text'] = '(a) Originals'
fig.layout['annotations'][1]['text'] = '(b) Samples from CNF'
fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
fig.update_coloraxes(colorscale='gray', showscale=False)
fig.update_layout(plt_settings)
fig.update_layout(margin=dict(t=30, b=10, r=0, l=0), width=400, height=200)
fig.update_annotations(dict(font={"family":"PTSerif"}))
fig.update_layout(margin=dict(t=50, b=10, r=0, l=0), width=800, height=400)
fig.update_annotations(dict(font={"size":30}))
fig.write_image(f'./fig/fig_vae_horizontal.pdf')
fig.show()