In [None]:
import torch
import numpy as np
import pandas as pd
from itertools import product
import plotly.express as px
from plotly.subplots import make_subplots
from IPython.display import clear_output
from scipy.stats import pearsonr, spearmanr, kendalltau, multivariate_normal

import util
import ffjord as C
from Distributions import GaussianMixtureDA, Gaussian
from FitVAE import VAE, decode_data
from FitKNN import fit_knn_classifier
from RealNVP import RealNVP
import GradualSelfTrain as G
from mainCNF import settings

In [None]:
plt_settings = dict(margin=dict(t=30, b=30, r=30),
                    font=dict(family="PTSerif", size=18),
                    legend=dict(orientation="h", bordercolor="Black",
                                borderwidth=0.3, yanchor="bottom", y=-0.35, xanchor="center", x=0.5))

colors = [list(px.colors.hex_to_rgb(_hex)) for _hex in px.colors.qualitative.Plotly]

def make_subplots_wrapper(rows: int, cols:int, **kwargs):
    fig = make_subplots(rows=rows, cols=cols, **kwargs)
    pos = np.array(np.meshgrid(np.arange(rows)+1, np.arange(cols)+1)).T.reshape(-1,2)
    return fig, pos
make_subplots_wrapper.__doc__ = make_subplots.__doc__



def add_var_scatter_plot(fig, x, y, color, name=None, showlegend=True, **kwargs):
    """
    @param
    fig: go.Figure
    color: int, we prepare 10 colors, you can select the number 0 to 9.
    name: str, the name of plot
    """
    rgb = 'rgb' + str(tuple(colors[color]))
    rgba = 'rgba' + str(tuple(colors[color] + [0.3]))  # opacity = 0.3
    mean, std = np.nanmean(y, axis=1), np.nanstd(y, ddof=1, axis=1)
    fig.add_scatter(x=x, y=mean, name=name, mode='markers+lines', line=dict(color=rgb), showlegend=showlegend, **kwargs)
    fig.add_scatter(x=x, y=mean+std, mode='lines', line=dict(width=0), showlegend=False, hoverinfo='none', **kwargs)
    fig.add_scatter(x=x, y=mean-std, mode='lines', fill="tonexty", line=dict(width=0),
                    showlegend=False, hoverinfo='none', fillcolor=rgba, **kwargs)
    return fig


def add_box_plot(fig, y, color, **kwargs):
    """
    @param
    fig: go.Figure
    color: int, we prepare 10 colors, you can select the number 0 to 9.
    name: str, the name of plot
    """
    rgb = 'rgb' + str(tuple(colors[color]))
    black = 'rgb(0,0,0)'
    y = np.array(y).squeeze()
    fig.add_box(y=y, fillcolor=rgb, line=dict(color=black),  **kwargs)
    return fig


def remove_unnecessary_layout(fig):
    # remove the title of legend
    fig.layout['legend']['title']['text'] = ''
    # remove "=" on the title of each columns
    for layout in fig.layout['annotations']:
        layout['text'] = layout['text'].split('=')[1]
    return fig


def facet_col_warp_parser(facet_col_wrap:int, array:list) -> list:
    position = []
    count = 1
    row_count = 0
    col_count = 1
    for i in DATASET.keys():
        position.append((row_count, col_count))
        col_count += 1
        if count == facet_col_wrap:
            row_count += 1
            col_count = 1
            count = 0
        count += 1
    return position


def load_result(file_name:str):
    # load args
    args = pd.read_pickle(f'./result/args_{file_name}.pkl')
    # load data
    x_all, y_all = pd.read_pickle(f'./data/data_{args.dataset}.pkl')[args.n_dim]
    x_eval, y_eval = x_all.pop(), y_all.pop()
    if args.dataset == 'mnist_dense':
        given_domain = [0, args.inter_index, 28]
        x_all = [x_all[i].copy() for i in given_domain]
        y_all = [y_all[i].copy() for i in given_domain]
    if args.no_inter:
        x_all = [x_all[0].copy(), x_all[-1].copy()]
        y_all = [y_all[0].copy(), y_all[-1].copy()]
    # load model
    cnf = C.build_model_tabular(args, args.n_dim, None)
    cnf = util.load_model(cnf, f'./result/state_{file_name}.tar')
    cnf = util.torch_to(cnf)
    # load base distribution
    if args.base_distribution == 'normal':
        prior = Gaussian(args.n_dim, args.seed,)
    elif args.base_distribution == 'gmm':
        prior = GaussianMixtureDA(args.n_dim, args.n_class, args.seed, args.mean_r)
    # predict source and target(eval) data
    s_acc = C.predict_target(cnf, prior, x_all[0], y_all[0], 0)[-1]
    # t_acc = C.predict_target(cnf, prior, x_all[-1], y_all[-1], len(x_all)-2)[-1]
    t_acc = C.predict_target(cnf, prior, x_eval, y_eval, len(x_all)-1)[-1]
    # show loss history
    fig, lh = util.plot_loss_history(f'./result/lh_{file_name}.pkl')
    return x_all, y_all, args, cnf, prior, s_acc, t_acc, fig, lh


def make_summary(*settings):
    
    def get_accuracy_and_loss(file_name):
        res = load_result(file_name)
        loss_min = np.sum(res[-1], axis=0).min()
        return pd.Series([res[5], res[6], loss_min])  # source accuracy, target accuracy, min loss
    
    df = list(product(*settings))
    df = pd.DataFrame(df, columns=[f'v{i}' for i in range(len(settings))])
    df['fn'] = df.apply(lambda s: '_'.join(s), axis=1)
    df[['accuracy_s', 'accuracy_t', 'loss']] = df['fn'].apply(get_accuracy_and_loss)
    clear_output()
    return df


def make_cv_summary(*settings):
    df = list(product(*settings))
    df = pd.DataFrame(df, columns=[f'v{i}' for i in range(len(settings))])
    df['fn'] = df.apply(lambda s: '_'.join(s), axis=1)
    cv_result = df['fn'].apply(lambda fn: pd.Series(pd.read_pickle(f'./result/cv_{fn}.pkl')[1]))
    df['mean'] = cv_result.mean(axis=1)
    df['std'] = cv_result.std(axis=1)
    return df


def load_baselne_result(file_name:str):
    # load args
    args = pd.read_pickle(f'./result/args_{file_name}.pkl')
    if 'eaml' in file_name:
        return None, None, args, None, args.accuracy_score
    # load data
    x_all, y_all = pd.read_pickle(f'./data/data_{args.dataset}.pkl')[args.n_dim]
    x_eval, y_eval = x_all.pop(), y_all.pop()
    # load model
    try:
        model = G.MLP(num_labels=args.n_class, input_dim=args.n_dim, hidden_dim=args.hidden_dim)
        model = util.load_model(model, f'./result/state_{file_name}.tar')
    except:
        print(f'param mismatch: {file_name}')
        map_location = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
        param = torch.load(f'./result/state_{file_name}.tar', map_location=map_location)
        n_labels = param[list(param.keys())[-1]].size()[0]
        model = G.MLP(num_labels=n_labels, input_dim=args.n_dim, hidden_dim=args.hidden_dim)
        model = util.load_model(model, f'./result/state_{file_name}.tar')
    # predict target(eval) data
    t_acc = G.calc_accuracy(model, x_eval, y_eval)
    return x_all, y_all, args, model, t_acc


def make_baseline_summary(*settings):
    df = list(product(*settings))
    df = pd.DataFrame(df, columns=[f'v{i}' for i in range(len(settings))])
    df['fn'] = df.apply(lambda s: '_'.join(s), axis=1)
    df['accuracy_t'] = df['fn'].apply(lambda s: load_baselne_result(s)[-1])
    return df


def correlation(array1, array2):
    """ return (pearsonr, spearmanr, kendalltau)"""
    r1 = pearsonr(array1, array2)[0]
    r2 = spearmanr(array1, array2)[0]
    r3 = kendalltau(array1, array2)[0]
    return (r1, r2, r3)

# DNF vs. CNF

In [None]:
file_name = 'moon_seed_1'

# CNF
x_all, y_all, args, cnf, prior = load_result(file_name)[:5]
inter = torch.tensor(x_all[1])
c_convert = C.visualize_trajectory_forward(cnf, inter, 1, .1)
subset = [0, 2, 6, 10]
c_convert = [c_convert[i].numpy() for i in subset]  # time -> 2.0, 1.8, 1.4, 1.0

# DNF
args.dims = [int(i) for i in args.dims.split('-')]
dnf = RealNVP(n_flows=len(args.dims), data_dim=args.n_dim, n_hidden=args.dims[0])
dnf = util.load_model(dnf, f'./result/state_realnvp_{file_name}.tar')
_ = dnf.forward(inter)
d_convert = dnf.inter_repr

clear_output()

In [None]:
titles = ['Intermediate', '', '', 'Source', 't=2.0', 't=1.8', 't=1.4', 't=1.0', 'Intermediate', 'Step 1', 'Step 2', 'Step 3']
fig, pos = make_subplots_wrapper(rows=3, cols=4, horizontal_spacing=0.01, vertical_spacing=0.08, subplot_titles=titles)

# plot Ground Truth
for i in reversed(range(2)):
    x, y = x_all[i].copy(), y_all[i].copy()
    color_plt = px.scatter(x=x[:,0], y=x[:,1], color=y.astype(str))
    p = 3 if i == 0 else 0
    fig.add_traces(color_plt.data, rows=pos[p][0], cols=pos[p][1])
    
inter_y = y_all[1].copy().astype(str)
# plot CNF
for i, p in enumerate(pos[4:8]):
    x_hat = c_convert[i]
    color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1], color=inter_y)
    fig.add_traces(color_plt.data, rows=p[0], cols=p[1])

# plot DNF
for i, p in enumerate(pos[8:]):
    x_hat = d_convert[i]
    color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1], color=inter_y)
    fig.add_traces(color_plt.data, rows=p[0], cols=p[1])

# set layout
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(title_text='Ground Truth', row=1, col=1)
fig.update_yaxes(title_text='CNF', row=2, col=1)
fig.update_yaxes(title_text='DNF', row=3, col=1)
fig.update_layout(width=750, height=550, showlegend=False, **plt_settings)
fig.update_annotations(dict(font={'family': 'PTSerif', 'size': 21}))
fig.show()
fig.write_image('./fig/fig_realnvp.pdf')

# w/o intermediate

In [None]:
DATA = 'block'
CONDITIONS = ['seed', 'nointer']
DEBUG = False

if DATA == 'moon':
    titles = ['Target', '', 'Intermediate', '', 'Source']
    time = ['t=3.0', 't=2.5', 't=2.0', 't=1.5', 't=1.0']
    time_no_inter = ['t=2.0', 't=1.7', 't=1.5', 't=1.2', 't=1.0']
    share_axes = False
elif DATA == 'block':
    titles = ['Target', 'Intermediate1', 'Intermediate2', 'Source']
    time = ['t=4.0', 't=3.0', 't=2.0', 't=1.0']
    time_no_inter = ['t=2.0', 't=1.6', 't=1.4', 't=1.0']
    share_axes = True
    
subset = {'moon_seed_1': [0, 5, 10, 15, 20],
          'moon_nointer_1': [0, 3, 5, 8, 10],
          'block_seed_1': [0, 10, 20, 30],
          'block_nointer_1': [0, 4, 6, 10]}


titles = titles + time + time_no_inter
fig, pos = make_subplots_wrapper(rows=len(CONDITIONS)+1, cols=len(time), horizontal_spacing=0.01, vertical_spacing=0.06,
                                 subplot_titles=titles, shared_xaxes=share_axes, shared_yaxes=share_axes)

# plot Ground Truth
i = 0
file_name = f'{DATA}_seed_1'
x_all, y_all, args, cnf, prior, _, t_acc, _, _ = load_result(file_name)
for x, y in zip(x_all[::-1], y_all[::-1]):
    color_plt = px.scatter(x=x[:,0], y=x[:,1], color=y.astype(str))
    fig.add_traces(color_plt.data, rows=pos[i][0], cols=pos[i][1])
    if DATA == 'moon':
        i += 2
    elif DATA == 'block':
        i += 1

# plot CNF
i = 5 if DATA == 'moon' else 4
for c in CONDITIONS:
    file_name = f'{DATA}_{c}_1'
    # x_all, y_all, args, cnf, prior = load_result(file_name)[:5]
    x_all, y_all, args, cnf, prior, _, t_acc, _, _ = load_result(file_name)
    # x_eval, y_eval = x_all.pop(), y_all.pop()
    print(f'{file_name}, {t_acc}')
    target_x = torch.tensor(x_all[-1], dtype=torch.float32)
    target_y = y_all[-1].copy().astype(str)
    t0 = 1 if 'nointer' in file_name else len(x_all)-1
    c_convert = C.visualize_trajectory_forward(cnf, target_x, t0, .1)
    c_convert = [c_convert[i].numpy() for i in subset[file_name]]
    
    if DEBUG:
        time_for_debug = np.arange(0, t0+1+0.1, 0.1)[::-1]
        print(time_for_debug[subset[file_name]])
    else:
        clear_output()

    for x_hat in c_convert:
        color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1], color=target_y)
        fig.add_traces(color_plt.data, rows=pos[i][0], cols=pos[i][1])
        i += 1

# layout settings
fig.update_layout(width=800, height=700, showlegend=False, **plt_settings)
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(title_text='Ground Truth', row=1, col=1)
fig.update_yaxes(title_text='w/ Intermediate', row=2, col=1)
fig.update_yaxes(title_text='w/o Intermediate', row=3, col=1)
fig.update_annotations(dict(font={'family': 'PTSerif', 'size': 22}))
fig.show()
fig.write_image(f'./fig/fig_no_inter_{DATA}.pdf')

# Intermediate Index

In [None]:
def deg2radstr(deg:float):
    deg = float(deg)
    if deg == 0:
        return "0"
    denominator = np.pi / np.deg2rad(deg)
    deg_check = np.rad2deg(np.pi / denominator)
    if np.isclose(deg_check, deg):
        denominator = round(denominator, 2)
        return '$\pi/' + str(denominator) + '$'
    else:
        raise

In [None]:
num_inter_domain = 27
angles = np.linspace(0, 60, num_inter_domain+2)

CONDITIONS = ['4', "8", '12', "14", '16', '20', '24'] 
SEED = ['1', '2', '3', '4', '5']

inter = make_summary(['mnist_index'], CONDITIONS, ['seed'], SEED)
inter = inter.drop(['v0', 'v2'], axis=1).rename(columns={'v1':'index', 'v3':'seed'})
inter['index'] = inter['index'].astype(int)
inter['index'] = inter['index'].apply(lambda i: angles[i])
inter['index'] = inter['index'].apply(deg2radstr)

nointer = make_summary(['mnist_nointer_seed'], SEED)
nointer = nointer.drop(['v0'], axis=1).rename(columns={'v1':'seed'})
nointer['index'] = "$w/o \; Intermediate$"

In [None]:
df = inter.copy()
df = df.groupby(by=['index'], as_index=False)
df = df.agg(mean=('accuracy_t','mean'), std=('accuracy_t','std'),)

index_order = [deg2radstr(angles[int(c)]) for c in CONDITIONS]

fig = px.scatter(df, x='index', y='mean', error_y='std',
                 labels={'mean':'Accuracy', 'index':'Rotation angle of intermediate domain'},
                 category_orders={'index':index_order},)
fig.update_layout(plt_settings).update_layout(width=700, height=300)
fig.update_layout(yaxis=dict(range=(0.55, 1.0)))
fig.write_image('./fig/fig_index_mnist.pdf')

In [None]:
worst = df.set_index('index')['mean'].idxmin()
print(worst)
df = pd.concat([inter.query('index==@worst'), nointer])
df = df.groupby(by=['index'], as_index=False)
df = df.agg(mean=('accuracy_t','mean'), std=('accuracy_t','std'),)
# df.loc[0, 'index'] = "$w \; Intermediate(\pi/3.5)$"
df = df.sort_index()

fig = px.scatter(df, x='index', y='mean', error_y='std',
                 labels={'mean':'Accuracy', 'index':''},)
fig.update_layout(plt_settings)
fig.update_layout(width=400, height=300)
fig.update_layout(xaxis=dict(range=(-0.5, 1.5)), yaxis=dict(range=(0.55, 1.0)))
fig.write_image('./fig/fig_no_inter_mnist.pdf')

# Gaussian Mixture estimators for the log-likelihood

In [None]:
convert = []
for file_name in ['moon_seed_1', 'moon_q30_1']:
    x_all, y_all, args, cnf, prior = load_result(file_name)[:5]
    target = torch.tensor(x_all[2])
    z_all = C.visualize_trajectory_forward(cnf, target, 2, .5)
    convert.append(z_all[:5]) # time -> 3.0, 2.5, 2.0, 1.5, 1.0
convert = sum(convert, [])
clear_output()

In [None]:
titles = ['Target', '', 'Intermediate', '', 'Source',
          't=3.0', 't=2.5', 't=2.0', 't=1.5', 't=1.0',
          't=3.0', 't=2.5', 't=2.0', 't=1.5', 't=1.0']
fig, pos = make_subplots_wrapper(rows=3, cols=5, horizontal_spacing=0.01, vertical_spacing=0.05, subplot_titles=titles)

# plot Ground Truth
p = 0
for i in reversed(range(3)):
    x, y = x_all[i].copy(), y_all[i].copy()
    color_plt = px.scatter(x=x[:,0], y=x[:,1], color=y.astype(str))
    fig.add_traces(color_plt.data, rows=pos[p][0], cols=pos[p][1])
    p += 2

# plot kNN and GMM
pos = pos[5:]
target_y = y_all[2].copy().astype(str)
for p, x_hat in zip(pos, convert):
    color_plt = px.scatter(x=x_hat[:,0], y=x_hat[:,1], color=target_y)
    fig.add_traces(color_plt.data, rows=p[0], cols=p[1])

# set layout
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(title_text='Ground Truth', row=1, col=1)
fig.update_yaxes(title_text='kNN', row=2, col=1)
fig.update_yaxes(title_text='fitted GM', row=3, col=1)
fig.update_layout(width=850, height=550, showlegend=False, **plt_settings)
fig.update_annotations(dict(font={'family': 'PTSerif', 'size': 22}))
fig.show()
fig.write_image('./fig/fig_fitted_gmm.pdf')

# OT vs. NF

In [None]:
from datasets2 import make_gradual_data

x_all, y_all = make_gradual_data()
x_eval, y_eval = x_all.pop(), y_all.pop()

# generate by NF
_, _, args, cnf, prior = load_result('moon_seed_1')[:5]
target = torch.tensor(x_all[2])
nf = C.visualize_trajectory_forward(cnf, target, 2, .5)
nf = nf[:5] # time -> 3.0, 2.5, 2.0, 1.5, 1.0
nf = [x.numpy() for x in nf]
# set ground truth
nf[0] = x_all[2].copy()
nf[2] = x_all[1].copy()
nf[4] = x_all[0].copy()

# generate by OT
ot = [x_all[0].copy()]
for i in range(len(x_all))[:-1]:
    xs, xt = x_all[i].copy(), x_all[i+1].copy()
    ys = y_all[0].copy() if i==0 else None
    ot += G.generate_domains(1, xs, xt, ys)
ot = ot[::-1]

clear_output()

In [None]:
titles = ['Target', 'Generated', 'Intermediate', 'Generated', 'Source',]
fig, pos = make_subplots_wrapper(rows=2, cols=5, horizontal_spacing=0.01, vertical_spacing=0.05, subplot_titles=titles)

# plot NF
for p, x in zip(pos[:5], nf):
    color_plt = px.scatter(x=x[:,0], y=x[:,1])
    fig.add_traces(color_plt.data, rows=p[0], cols=p[1])

# plot ot
for p, x in zip(pos[5:], ot):
    color_plt = px.scatter(x=x[:,0], y=x[:,1])
    fig.add_traces(color_plt.data, rows=p[0], cols=p[1])

    
# set layout
fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(title_text='CNF', row=1, col=1)
fig.update_yaxes(title_text='OT', row=2, col=1)
fig.update_layout(width=1200, height=500, showlegend=False, **plt_settings)
fig.update_annotations(dict(font={'size':22}))
fig.write_image('./fig/fig_ot.pdf')
fig.show()


# UMAP

In [None]:
DATASET = {'mnist':'Rotating MNIST', 'portraits2':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1',
           'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors',}
CONDITIONS = ["2d", "4d", "8d", "16d", "32d"]

df = make_cv_summary(DATASET.keys(), CONDITIONS)
df = df.rename(columns=dict(v0='data', v1='dim'))
df['dim'] = df['dim'].str.replace('d', '').astype(int)
df['data'] = df['data'].map(DATASET)

for d in DATASET.values():
    subset = df.query('data== @d')
    idx = subset.set_index('dim')['mean'].idxmax()
    print(f'{d}, {idx}')

fig = px.scatter(df, x='dim', y='mean', error_y='std', facet_col='data', facet_col_wrap=4, width=1500, height=800,
                 labels={'mean':'Accuracy', 'dim':'Number of Embedding Dimension'})
fig.update_layout(plt_settings)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(showticklabels=True)
fig.write_image('./fig/fig_umap_dim.pdf')
fig.show()

# mean_r

## theoretical

In [None]:
def get_probability(n_dims:int, n_labels:int, mean_r:int):
    means = GaussianMixtureDA(n_dims=n_dims, n_labels=n_labels, mean_r=mean_r).get_means(mean_r=mean_r).numpy()
    d = np.mean(means[:2], axis=0)
    prior1 = multivariate_normal(means[0], np.eye(n_dims))
    prior2 = multivariate_normal(means[1], np.eye(n_dims))
    return max(prior1.pdf(d), prior2.pdf(d))

def draw_circle(mean, radius):
    theta = np.linspace(0, 2*np.pi, 1000)[:-1]
    x = mean[0] + radius * np.cos(theta)
    y = mean[1] + radius * np.sin(theta)
    return x, y

In [None]:
n_dims = 4
n_labels = 2
alpha = 1e-3
rs = np.arange(1, 21 ,1)

probs = np.array([get_probability(n_dims, n_labels, mean_r=r) for r in rs])
min_r = rs[probs < alpha].min()
print(min_r)

fig = px.scatter(x=rs, y=probs, labels={'x':'r', 'y':'probability'}, width=500, height=450)
fig.add_hline(y=alpha, line_width=2, line_dash="dash", line_color="LightSeaGreen")
fig.show()

if n_dims == 2:
    prior = GaussianMixtureDA(n_dims=n_dims, n_labels=n_labels, mean_r=min_r)    
    means = prior.get_means(mean_r=min_r).numpy()
    radius = np.linalg.norm(means[1]-means[0], ord=2) / 2
    x, _ = prior.sample(total_size=n_labels*1000)
    color = np.hstack([i*np.ones(1000, dtype=int) for i in range(n_labels)])
    
    fig = px.scatter(x=x[:,0], y=x[:,1], color=color.astype(str), labels={'x':'', 'y':''}, width=500, height=450)
    for mean in means:
        x, y = draw_circle(mean, radius)
        fig.add_shape(type="circle", xref="x", yref="y", x0=min(x), y0=min(y), x1=max(x), y1=max(y), line_color="LightSeaGreen",)
    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
    fig.show()

In [None]:
include_title = False
n_dims = 2
n_labels = [2, 5, 10]
rs = np.arange(1, 15, 0.5)

probs = []
for nl in n_labels:
    p = np.array([get_probability(n_dims=n_dims, n_labels=nl, mean_r=r) for r in rs])
    probs.append(p)
probs = pd.DataFrame(np.array(probs).T, columns=n_labels, index=rs)

title = '(a) Number of Dimensions is Fixed' if include_title else None

fig1 = px.scatter(probs, labels={'value':'U(r)', 'index':'r'}, title=title)
fig1.update_layout(plt_settings, legend_title_text='number of classes', title_x=0.5, xaxis=dict(dtick=3), yaxis=dict(dtick=0.05),
                   width=550, height=500, margin=dict(t=50))
# fig1.write_image('./fig/fig_mean_r_class.png')
fig1.write_image('./fig/fig_mean_r_dimension.pdf')
fig1.show()

In [None]:
n_dims = [2, 4, 8]
n_labels = 10
rs = np.arange(1, 15, 0.5)

probs = []
for nd in n_dims:
    p = np.array([get_probability(n_dims=nd, n_labels=n_labels, mean_r=r) for r in rs])
    probs.append(p)
probs = pd.DataFrame(np.array(probs).T, columns=n_dims, index=rs)

title = '(b) Number of Classes is Fixed' if include_title else None

fig2 = px.scatter(probs, labels={'value':'U(r)', 'index':'r'}, title=title)
fig2.update_layout(plt_settings, legend_title_text='number of dimensions', title_x=0.5, xaxis=dict(dtick=3), yaxis=dict(dtick=0.05),
                   width=550, height=500, margin=dict(t=50))
# fig2.write_image('./fig/fig_mean_r_dimension.png')
fig2.write_image('./fig/fig_mean_r_class.pdf')
fig2.show()

## experimental

In [None]:
DATASET = {'mnist':'Rotating MNIST', 'portraits':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1',
           'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors'}
CONDITIONS = ['r0.5', 'r1', 'r3', 'r5', 'r8', 'r10', 'r15','r20']

df = make_cv_summary(DATASET, CONDITIONS)
df = df.rename(columns=dict(v0='data', v1='r'))
df['r'] = df['r'].str.replace('r','').astype(float)
df['data'] = df['data'].map(DATASET)

In [None]:
Ur = []
alpha = 1e-3
rs = np.arange(1, 21 ,1)
for name in DATASET.keys():
    n_dims, n_labels = settings[name]
    probs = np.array([get_probability(n_dims, n_labels, mean_r=r) for r in rs])
    min_r = rs[probs < alpha].min()
    Ur.append(min_r)

In [None]:
facet_col_wrap = 4
positions = facet_col_warp_parser(facet_col_wrap, DATASET.keys())
fig = px.scatter(df, x='r', y='mean', facet_col='data', error_y='std', width=1200, height=600,
                 facet_col_wrap=facet_col_wrap, facet_row_spacing=0.15,
                 category_orders={'data':DATASET.values()}, labels={'mean':'Accuracy'})
for u, pos in zip(Ur, positions):
    fig.add_vline(u, row=pos[0], col=pos[1], line_width=1, line_dash="dash", line_color="red")
fig.update_layout(plt_settings)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(showticklabels=True).update_annotations(font=dict(size=22))
fig.write_image('./fig/fig_mean_r.pdf')
fig.show()

# kNN

In [None]:
DATASET = {#'moon':'Two Moon', 'block':'Block',
           'mnist':'Rotating MNIST', 'portraits':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1',
           'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors'}
CONDITIONS = ['k5', "k10", 'k15', "k20", 'k30'] 
SEED = ['1', '2', '3']

df = make_summary(DATASET, CONDITIONS, SEED)
df = df.rename(columns=dict(v0="data", v1="k", v2="seed"))
df['k'] = df['k'].str.replace('k','').astype(int)
df['seed'] = df['seed'].astype(int)
df['data'] = df['data'].map(DATASET)
df = df.groupby(by=['data', 'k'], as_index=False)
df = df.agg(lm=('loss','mean'), ls=('loss','std'),
            sm=('accuracy_s','mean'), ss=('accuracy_s','std'),
            tm=('accuracy_t','mean'), ts=('accuracy_t','std'),)

cv = []
for key, name in DATASET.items():
    res = fit_knn_classifier(key)
    cv.append(res.best_params_['n_neighbors'])

In [None]:
facet_col_wrap = 4
positions = facet_col_warp_parser(facet_col_wrap, DATASET.keys())
fig = px.scatter(df, x='k', y='tm', facet_col='data', error_y='ts', width=1200, height=600,
                 facet_col_wrap=facet_col_wrap, facet_row_spacing=0.15, 
                 category_orders={'data':DATASET.values()}, labels={'tm':'Accuracy'})
for x, pos in zip(cv, positions):
    fig.add_vline(x, row=pos[0], col=pos[1], line_width=1, line_dash="dash", line_color="red")
fig.update_layout(plt_settings)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(showticklabels=True, nticks=10).update_annotations(font=dict(size=22))
fig.write_image('./fig/fig_knn.pdf')
fig.show()

# Comparison with baseline methods

## GST, GOAT

In [None]:
# Settings
METHOD = {'sourceonly':'SourceOnly', 'gst':'GradualSelfTrain', 'goat':'GOAT'}
DATASET = {'mnist':'Rotating MNIST', 'portraits':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1', 'tox21a':'Tox21'}
CONDITIONS = ['seed']
SEED = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']

# Ours
ours = make_summary(DATASET, CONDITIONS, SEED)
ours = ours.drop('v1', axis=1).rename(columns={'v0':'data','v2':'seed'})
ours['data'] = ours['data'].map(DATASET)
ours['method'] = 'Ours'

# Baselines
baselines = make_baseline_summary(METHOD, DATASET, SEED)
baselines = baselines.rename(columns={'v0':'method', 'v1':'data','v2':'seed'})
baselines['data'] = baselines['data'].map(DATASET)
baselines['method'] = baselines['method'].map(METHOD)

# Ours vs. Baselines
df = pd.concat([ours, baselines], ignore_index=True)

In [None]:
fig = px.box(df, x='method', y='accuracy_t', color='method', facet_col='data', facet_col_wrap=5, labels={'accuracy_t':'Accuracy'})
fig.update_layout(plt_settings).update_layout(width=1200, height=300)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(visible=False)
fig.update_layout(dict(legend={"y":-0.25}))
# fig.write_image('./fig/fig_baseline.pdf')
fig.show()


fig = px.box(df, x='method', y='accuracy_t', color='method', facet_col='data', facet_col_wrap=3, labels={'accuracy_t':'Accuracy'})
fig.update_layout(plt_settings).update_layout(width=800, height=500)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(visible=False)
fig.update_layout(dict(legend={"orientation":"v", "x": 0.85, "y":0.1}, margin=dict(t=35, b=20), height=520, font=dict(family="PTSerif", size=24)))
# fig.write_image('./fig/fig_baseline_mr.pdf')
fig.show()

## GIFT, AuxSelfTrain

In [None]:
# Settings
METHOD = {'gift':'GIFT', 'sgift':'Sequential GIFT', 'aux':'AuxSelfTrain', 'saux':'Sequential AuxSelfTrain'}

# Baselines
baselines = make_baseline_summary(METHOD, DATASET, SEED)
baselines = baselines.rename(columns={'v0':'method', 'v1':'data','v2':'seed'})
baselines['data'] = baselines['data'].map(DATASET)
baselines['method'] = baselines['method'].map(METHOD)

# Ours vs. Baselines
df = pd.concat([ours, baselines], ignore_index=True)

In [None]:
fig = px.box(df, x='method', y='accuracy_t', color='method', facet_col='data', facet_col_wrap=5, labels={'accuracy_t':'Accuracy'})
fig.update_layout(plt_settings).update_layout(width=1200, height=300)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(visible=False)
fig.update_layout(dict(legend={"y":-0.25}))
# fig.write_image('./fig/fig_baseline_appendix.pdf')
fig.show()

## Tox21

In [None]:
# Settings
METHOD = {'sourceonly':'Source Only', 'gst':'Gradual Self Train', 'goat':'GOAT'}
DATASET = {'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors'}
CONDITIONS = ['seed']
SEED = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']

# Ours
ours = make_summary(DATASET, CONDITIONS, SEED)
ours = ours.drop('v1', axis=1).rename(columns={'v0':'data','v2':'seed'})
ours['data'] = ours['data'].map(DATASET)
ours['method'] = 'Ours'

# Baselines
torch.manual_seed(1234)
baselines = make_baseline_summary(METHOD, DATASET, SEED)
baselines = baselines.rename(columns={'v0':'method', 'v1':'data','v2':'seed'})
baselines['data'] = baselines['data'].map(DATASET)
baselines['method'] = baselines['method'].map(METHOD)

# Ours vs. Baselines
df = pd.concat([ours, baselines], ignore_index=True)

In [None]:
fig = px.box(df, x='method', y='accuracy_t', color='method', facet_col='data', facet_col_wrap=3, labels={'accuracy_t':'Accuracy'})
fig.update_layout(plt_settings).update_layout(width=800, height=300)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(visible=False)
fig.update_layout(dict(legend={"y":-0.25}))
# fig.write_image('./fig/fig_baseline_tox21.pdf')
fig.show()

## ALL Ver.

In [None]:
# Settings
METHOD = {'sourceonly':'SourceOnly', 'gst':'GradualSelfTrain', 'goat':'GOAT',
          'gift':'GIFT', 'sgift':'Sequential GIFT', 'aux':'AuxSelfTrain', 'saux':'Sequential AuxSelfTrain', 'eaml':'EAML'}
DATASET = {#'moon':'Two Moon', 'block':'Block',
           'mnist':'Rotating MNIST', 'portraits':'Portraits', 'shift15m':'SHIFT15M', 'rxrx1':'RxRx1', 
           'tox21a':'Tox21 NHOHCount', 'tox21b':'Tox21 RingCount', 'tox21c':'Tox21 NumHDonors'}
CONDITIONS = ['seed']
SEED = ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10']

# Ours
ours = make_summary(DATASET, CONDITIONS, SEED)
ours = ours.drop('v1', axis=1).rename(columns={'v0':'data','v2':'seed'})
ours['data'] = ours['data'].map(DATASET)
ours['method'] = 'Ours'

# Baselines
baselines = make_baseline_summary(METHOD, DATASET, SEED)
baselines = baselines.rename(columns={'v0':'method', 'v1':'data','v2':'seed'})
baselines['data'] = baselines['data'].map(DATASET)
baselines['method'] = baselines['method'].map(METHOD)

# Ours vs. Baselines
df = pd.concat([ours, baselines], ignore_index=True)

In [None]:
fig = px.box(df, x='method', y='accuracy_t', color='method', facet_col='data', 
             facet_col_wrap=4, facet_row_spacing=0.1, labels={'accuracy_t':'Accuracy'})
fig.update_layout(plt_settings).update_layout(width=1200, height=600)
fig = remove_unnecessary_layout(fig)
fig.update_xaxes(visible=False)
fig.update_layout(font=dict(family="PTSerif", size=24), legend=dict(y=-0.2, x=0.5))
fig.write_image('./fig/fig_baseline_mr.pdf')
fig.show()

# VAE

In [None]:
file_name = 'mnist_vae_4D_1'

# load trained CNF
x_all, y_all, args, cnf, prior, s_acc, t_acc, fig, lh = load_result(file_name)
clear_output()

# load trained VAE
vae = VAE((1,28,28), args.n_dim)
vae = util.load_model(vae, f'./data/state_mnist_vae_{args.n_dim}D.tar')

# we choose 5 class for visualize
target_label = [1, 3, 4, 6, 9]
y = y_all[0].copy()

# select one sample from each class
np.random.seed(987)
idx = np.hstack([np.random.choice(np.where(y==l)[0], size=1, replace=False) for l in target_label])

# get original imgae
img_org = np.array([decode_data(vae, x[idx]) for x in x_all])

# get generate image
z = C.visualize_trajectory_forward(cnf, torch.tensor(x_all[0][idx]), 0)[-1]
logpz = prior.log_prob(z)[1]
x_hat = C.visualize_trajectory_backward(cnf, z, logpz, 3, 0.5)
x_hat = x_hat[2:]
_generate = [decode_data(vae, x) for x in x_hat]

# stack original image
original = []
for j in range(len(target_label)):
    s, i, t = img_org[:,j,:,:]
    blank = np.zeros_like(s)
    new_img = np.hstack([s, blank, i, blank, t])
    original.append(new_img)
original = np.vstack(original)

# stack generate image
generate = []
for i in range(len(target_label)):
    new_img = []
    for k in range(len(_generate)):
        new_img.append(_generate[k][i])
    generate.append(np.hstack(new_img))
generate = np.vstack(generate)

# show image
img = np.array([original, generate])

fig = px.imshow(img, facet_col=0, facet_col_wrap=1, facet_row_spacing=0.1)
fig.layout['annotations'][1]['text'] = '(a) Originals'
fig.layout['annotations'][1]['x'] = 0.48
fig.layout['annotations'][0]['text'] = '(b) Samples from CNF'
fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
fig.update_coloraxes(colorscale='gray', showscale=False)
fig.update_layout(plt_settings)
fig.update_layout(margin=dict(t=30, b=10, r=0, l=0), width=200, height=450)
fig.write_image(f'./fig/fig_vae_vertical.pdf')
fig.show()


fig = px.imshow(img, facet_col=0, facet_col_wrap=2, facet_row_spacing=0.1)
fig.layout['annotations'][0]['text'] = '(a) Originals'
fig.layout['annotations'][1]['text'] = '(b) Samples from CNF'
fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
fig.update_coloraxes(colorscale='gray', showscale=False)
fig.update_layout(plt_settings)
fig.update_layout(margin=dict(t=30, b=10, r=0, l=0), width=400, height=200)
fig.update_annotations(dict(font={"family":"PTSerif"}))
fig.update_layout(margin=dict(t=50, b=10, r=0, l=0), width=800, height=400)
fig.update_annotations(dict(font={"size":30}))
fig.write_image(f'./fig/fig_vae_horizontal.pdf')
fig.show()