In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from MIOFlow.losses import MMD_loss, OT_loss, Density_loss, Local_density_loss
from MIOFlow.utils import group_extract, sample, to_np, generate_steps
from MIOFlow.models import ToyModel, make_model, Autoencoder
from MIOFlow.plots import plot_comparision, plot_losses
from MIOFlow.train import train, train_ae
from MIOFlow.constants import ROOT_DIR, DATA_DIR, NTBK_DIR, IMGS_DIR, RES_DIR
from MIOFlow.datasets import (
    make_diamonds, make_swiss_roll, make_tree, make_eb_data, 
    make_dyngen_data, relabel_data
)
from MIOFlow.ode import NeuralODE, ODEF
from MIOFlow.geo import GeoEmbedding, DiffusionDistance, old_DiffusionDistance
from MIOFlow.exp import setup_exp
from MIOFlow.eval import generate_plot_data

import os, pandas as pd, numpy as np, \
    seaborn as sns, matplotlib as mpl, matplotlib.pyplot as plt, \
    torch, torch.nn as nn, pickle
import random

from tqdm.notebook import tqdm
from phate import PHATE

# for geodesic learning
from sklearn.gaussian_process.kernels import RBF
from sklearn.manifold import MDS

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Run TrajectoryNet

**NOTE** here we are holding out one time point to see how well TJNet does interploating this missing timepoint

In [None]:
datasets = 'dyngen petals'.split()
dataset = datasets[1]

with open(os.path.expanduser(os.path.join('~/Downloads', f'{dataset}_df.pkl')), 'wb') as f:
    pickle.dump(df, f)

Here we create the datasets that are used by TJNet. Namely they are `npz` files with an `embedding_name` (here called `phate`) and another called `sample_labels` which are the time point labels

In [None]:
filepattern = lambda h: os.path.expanduser(os.path.join('~/Downloads', f'{dataset}_tjnet_ho_{int(h)}.npz'))

groups = sorted(df.samples.unique())

for hold_out in groups:
    df_ho = df.drop(df[df['samples']==hold_out].index, inplace=False)
    groups = sorted(df_ho.samples.unique())
    
    np.savez(
        filepattern(hold_out), 
        phate=df_ho.drop(columns='samples').values,
        sample_labels=df_ho.samples.astype(int).values.reshape(-1)
    )

In [None]:
for hold_out in groups:
    !python -m TrajectoryNet.main --dataset \
        ~/Downloads/{dataset}_tjnet_ho_{hold_out}.npz \
        --embedding_name "phate" \
        --max_dim 10 \
        --niter 1000 \
        --whiten \
        --save ~/Downloads/{dataset}_tjnet_ho_{hold_out}

    !python -m TrajectoryNet.eval --dataset \
        ~/Downloads/{dataset}_tjnet_ho_{hold_out}.npz \
        --embedding_name "phate" \
        --max_dim 10 \
        --niter 1000 \
        --vecint 1e-4 \
        --whiten \
        --save ~/Downloads/{dataset}_tjnet_ho_{hold_out}