In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from ToyTrajectoryNet.losses import MMD_loss, OT_loss, Density_loss, Local_density_loss
from ToyTrajectoryNet.utils import group_extract, sample, to_np, generate_steps
from ToyTrajectoryNet.models import ToyModel, make_model, Autoencoder
from ToyTrajectoryNet.plots import plot_comparision, plot_losses
from ToyTrajectoryNet.train import train, train_ae
from ToyTrajectoryNet.constants import ROOT_DIR, DATA_DIR, NTBK_DIR, IMGS_DIR, RES_DIR
from ToyTrajectoryNet.datasets import (
    make_diamonds, make_swiss_roll, make_tree, make_eb_data, 
    make_dyngen_data, relabel_data
)
from ToyTrajectoryNet.ode import NeuralODE, ODEF
from ToyTrajectoryNet.geo import GeoEmbedding, DiffusionDistance, old_DiffusionDistance
from ToyTrajectoryNet.exp import setup_exp
from ToyTrajectoryNet.eval import generate_plot_data

import os, pandas as pd, numpy as np, \
    seaborn as sns, matplotlib as mpl, matplotlib.pyplot as plt, \
    torch, torch.nn as nn, pickle
import random

from tqdm.notebook import tqdm
from phate import PHATE

# for geodesic learning
from sklearn.gaussian_process.kernels import RBF
from sklearn.manifold import MDS

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Run TrajectoryNet

**NOTE** here we are holding out one time point to see how well TJNet does interploating this missing timepoint

In [4]:
datasets = 'dyngen petals'.split()
dataset = datasets[1]

with open(os.path.expanduser(os.path.join('~/Downloads', f'{dataset}_df.pkl')), 'wb') as f:
    pickle.dump(df, f)

Here we create the datasets that are used by TJNet. Namely they are `npz` files with an `embedding_name` (here called `phate`) and another called `sample_labels` which are the time point labels

In [5]:
filepattern = lambda h: os.path.expanduser(os.path.join('~/Downloads', f'{dataset}_tjnet_ho_{int(h)}.npz'))

groups = sorted(df.samples.unique())

for hold_out in groups:
    df_ho = df.drop(df[df['samples']==hold_out].index, inplace=False)
    groups = sorted(df_ho.samples.unique())
    
    np.savez(
        filepattern(hold_out), 
        phate=df_ho.drop(columns='samples').values,
        sample_labels=df_ho.samples.astype(int).values.reshape(-1)
    )

In [7]:
for hold_out in groups:
    !python -m TrajectoryNet.main --dataset \
        ~/Downloads/{dataset}_tjnet_ho_{hold_out}.npz \
        --embedding_name "phate" \
        --max_dim 10 \
        --niter 1000 \
        --whiten \
        --save ~/Downloads/{dataset}_tjnet_ho_{hold_out}

    !python -m TrajectoryNet.eval --dataset \
        ~/Downloads/{dataset}_tjnet_ho_{hold_out}.npz \
        --embedding_name "phate" \
        --max_dim 10 \
        --niter 1000 \
        --vecint 1e-4 \
        --whiten \
        --save ~/Downloads/{dataset}_tjnet_ho_{hold_out}

True
/opt/anaconda/anaconda3/lib/python3.9/site-packages/TrajectoryNet/main.py
""" main.py

Learns ODE from scrna data

"""
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import time

import torch
import torch.nn.functional as F
import torch.optim as optim

from TrajectoryNet.lib.growth_net import GrowthNet
from TrajectoryNet.lib import utils
from TrajectoryNet.lib.visualize_flow import visualize_transform
from TrajectoryNet.lib.viz_scrna import (
    save_trajectory,
    trajectory_to_video,
    save_vectors,
)
from TrajectoryNet.lib.viz_scrna import save_trajectory_density


# from train_misc import standard_normal_logprob
from TrajectoryNet.train_misc import (
    set_cnf_options,
    count_nfe,
    count_parameters,
    count_total_time,
    add_spectral_norm,
    spectral_norm_power_iteration,
    create_regularization_fns,
    get_regularization,
    append_regularization_to_log,
    build_model_tabular,
)


SequentialFlow(
  (chain): ModuleList(
    (0): CNF(
      (odefunc): RegularizedODEfunc(
        (odefunc): ODEfunc(
          (diffeq): ODEnet(
            (layers): ModuleList(
              (0): ConcatSquashLinear(
                (_layer): Linear(in_features=2, out_features=64, bias=True)
                (_hyper_bias): Linear(in_features=1, out_features=64, bias=False)
                (_hyper_gate): Linear(in_features=1, out_features=64, bias=True)
              )
              (1): ConcatSquashLinear(
                (_layer): Linear(in_features=64, out_features=64, bias=True)
                (_hyper_bias): Linear(in_features=1, out_features=64, bias=False)
                (_hyper_gate): Linear(in_features=1, out_features=64, bias=True)
              )
              (2): ConcatSquashLinear(
                (_layer): Linear(in_features=64, out_features=64, bias=True)
                (_hyper_bias): Linear(in_features=1, out_features=64, bias=False)
                (_hyper_gate): Li

Iter 0063 | Time 1.0242(1.0134) | Loss 2.788877(2.828362) | NFE Forward 20(20.4) | NFE Backward 108(101.5)
Iter 0064 | Time 0.9920(1.0119) | Loss 2.796521(2.826133) | NFE Forward 20(20.3) | NFE Backward 102(101.5)
Iter 0065 | Time 1.0346(1.0135) | Loss 2.788632(2.823508) | NFE Forward 20(20.3) | NFE Backward 102(101.6)
Iter 0066 | Time 1.0416(1.0154) | Loss 2.785254(2.820830) | NFE Forward 20(20.3) | NFE Backward 108(102.0)
Iter 0067 | Time 1.0102(1.0151) | Loss 2.787621(2.818506) | NFE Forward 20(20.3) | NFE Backward 102(102.0)
Iter 0068 | Time 0.9626(1.0114) | Loss 2.793643(2.816765) | NFE Forward 20(20.3) | NFE Backward 96(101.6)
Iter 0069 | Time 1.0527(1.0143) | Loss 2.782822(2.814389) | NFE Forward 20(20.2) | NFE Backward 108(102.1)
Iter 0070 | Time 0.9751(1.0115) | Loss 2.793791(2.812947) | NFE Forward 20(20.2) | NFE Backward 96(101.6)
Iter 0071 | Time 0.9553(1.0076) | Loss 2.786801(2.811117) | NFE Forward 20(20.2) | NFE Backward 96(101.2)
Iter 0072 | Time 0.9834(1.0059) | Loss 2

Iter 0140 | Time 1.0468(1.0464) | Loss 2.717940(2.735407) | NFE Forward 26(22.1) | NFE Backward 108(106.7)
Iter 0141 | Time 1.0752(1.0484) | Loss 2.721359(2.734423) | NFE Forward 26(22.4) | NFE Backward 108(106.8)
Iter 0142 | Time 1.0613(1.0493) | Loss 2.712755(2.732906) | NFE Forward 20(22.2) | NFE Backward 108(106.9)
Iter 0143 | Time 1.0616(1.0502) | Loss 2.702561(2.730782) | NFE Forward 20(22.0) | NFE Backward 108(107.0)
Iter 0144 | Time 1.0312(1.0488) | Loss 2.705053(2.728981) | NFE Forward 20(21.9) | NFE Backward 108(107.1)
Iter 0145 | Time 1.0538(1.0492) | Loss 2.706103(2.727380) | NFE Forward 26(22.2) | NFE Backward 108(107.1)
Iter 0146 | Time 1.0850(1.0517) | Loss 2.707494(2.725988) | NFE Forward 26(22.5) | NFE Backward 108(107.2)
Iter 0147 | Time 1.0614(1.0524) | Loss 2.699121(2.724107) | NFE Forward 26(22.7) | NFE Backward 108(107.2)
Iter 0148 | Time 1.0556(1.0526) | Loss 2.696152(2.722150) | NFE Forward 26(22.9) | NFE Backward 108(107.3)
Iter 0149 | Time 1.1271(1.0578) | Los

Iter 0217 | Time 1.4469(1.3606) | Loss 2.411361(2.477059) | NFE Forward 32(31.6) | NFE Backward 144(132.7)
Iter 0218 | Time 1.5492(1.3738) | Loss 2.397303(2.471476) | NFE Forward 32(31.6) | NFE Backward 144(133.5)
Iter 0219 | Time 1.6055(1.3900) | Loss 2.394134(2.466062) | NFE Forward 32(31.6) | NFE Backward 144(134.2)
Iter 0220 | Time 1.4598(1.3949) | Loss 2.369069(2.459273) | NFE Forward 32(31.6) | NFE Backward 144(134.9)
Iter 0221 | Time 1.4566(1.3992) | Loss 2.365885(2.452736) | NFE Forward 32(31.7) | NFE Backward 144(135.6)
Iter 0222 | Time 1.4760(1.4046) | Loss 2.366292(2.446685) | NFE Forward 38(32.1) | NFE Backward 144(136.2)
Iter 0223 | Time 1.4157(1.4054) | Loss 2.350489(2.439951) | NFE Forward 32(32.1) | NFE Backward 138(136.3)
Iter 0224 | Time 1.4776(1.4104) | Loss 2.343921(2.433229) | NFE Forward 32(32.1) | NFE Backward 144(136.8)
Iter 0225 | Time 1.4804(1.4153) | Loss 2.338608(2.426605) | NFE Forward 32(32.1) | NFE Backward 144(137.3)
Iter 0226 | Time 1.4734(1.4194) | Los

Iter 0294 | Time 1.8313(1.8838) | Loss 1.764726(1.831344) | NFE Forward 44(44.0) | NFE Backward 174(178.9)
Iter 0295 | Time 1.9149(1.8860) | Loss 1.760401(1.826378) | NFE Forward 44(44.0) | NFE Backward 186(179.4)
Iter 0296 | Time 1.8583(1.8840) | Loss 1.740008(1.820332) | NFE Forward 44(44.0) | NFE Backward 180(179.5)
Iter 0297 | Time 1.8012(1.8782) | Loss 1.758944(1.816035) | NFE Forward 44(44.0) | NFE Backward 180(179.5)
Iter 0298 | Time 1.8768(1.8781) | Loss 1.738050(1.810576) | NFE Forward 44(44.0) | NFE Backward 180(179.5)
Iter 0299 | Time 1.8305(1.8748) | Loss 1.745162(1.805997) | NFE Forward 44(44.0) | NFE Backward 180(179.6)
Iter 0300 | Time 1.8178(1.8708) | Loss 1.747312(1.801889) | NFE Forward 44(44.0) | NFE Backward 180(179.6)
[TEST] Iter 0300 | Test Loss 1.745672 | NFE 44
Iter 0301 | Time 1.7687(1.8636) | Loss 1.743810(1.797824) | NFE Forward 44(44.0) | NFE Backward 174(179.2)
Iter 0302 | Time 1.8244(1.8609) | Loss 1.734364(1.793381) | NFE Forward 44(44.0) | NFE Backward 1

Iter 0371 | Time 2.0091(1.9276) | Loss 1.595850(1.592805) | NFE Forward 50(48.4) | NFE Backward 192(184.2)
Iter 0372 | Time 1.9470(1.9289) | Loss 1.572300(1.591370) | NFE Forward 50(48.5) | NFE Backward 180(183.9)
Iter 0373 | Time 1.9464(1.9302) | Loss 1.574687(1.590202) | NFE Forward 50(48.6) | NFE Backward 180(183.6)
Iter 0374 | Time 1.9522(1.9317) | Loss 1.556487(1.587842) | NFE Forward 50(48.7) | NFE Backward 180(183.3)
Iter 0375 | Time 1.9500(1.9330) | Loss 1.576347(1.587037) | NFE Forward 50(48.8) | NFE Backward 186(183.5)
Iter 0376 | Time 1.9903(1.9370) | Loss 1.534977(1.583393) | NFE Forward 50(48.9) | NFE Backward 186(183.7)
Iter 0377 | Time 1.9570(1.9384) | Loss 1.549225(1.581001) | NFE Forward 50(49.0) | NFE Backward 186(183.9)
Iter 0378 | Time 1.9319(1.9379) | Loss 1.545647(1.578527) | NFE Forward 50(49.1) | NFE Backward 180(183.6)
Iter 0379 | Time 1.9802(1.9409) | Loss 1.537723(1.575670) | NFE Forward 50(49.1) | NFE Backward 192(184.2)
Iter 0380 | Time 1.9978(1.9449) | Los

Iter 0448 | Time 1.9426(2.0152) | Loss 1.468722(1.476274) | NFE Forward 50(51.5) | NFE Backward 180(186.0)
Iter 0449 | Time 2.0704(2.0190) | Loss 1.463399(1.475373) | NFE Forward 50(51.4) | NFE Backward 186(186.0)
Iter 0450 | Time 2.0926(2.0242) | Loss 1.454046(1.473880) | NFE Forward 50(51.3) | NFE Backward 180(185.6)
Iter 0451 | Time 2.1352(2.0320) | Loss 1.468516(1.473505) | NFE Forward 56(51.7) | NFE Backward 180(185.2)
Iter 0452 | Time 2.2151(2.0448) | Loss 1.455264(1.472228) | NFE Forward 56(52.0) | NFE Backward 192(185.7)
Iter 0453 | Time 2.2056(2.0560) | Loss 1.465634(1.471766) | NFE Forward 56(52.3) | NFE Backward 198(186.6)
Iter 0454 | Time 2.0486(2.0555) | Loss 1.465962(1.471360) | NFE Forward 56(52.5) | NFE Backward 180(186.1)
Iter 0455 | Time 2.0552(2.0555) | Loss 1.450198(1.469879) | NFE Forward 56(52.8) | NFE Backward 186(186.1)
Iter 0456 | Time 2.1285(2.0606) | Loss 1.445982(1.468206) | NFE Forward 56(53.0) | NFE Backward 186(186.1)
Iter 0457 | Time 2.0488(2.0598) | Los

Iter 0525 | Time 1.9847(2.0036) | Loss 1.390614(1.402042) | NFE Forward 56(55.3) | NFE Backward 192(192.1)
Iter 0526 | Time 1.9782(2.0018) | Loss 1.403866(1.402170) | NFE Forward 56(55.3) | NFE Backward 192(192.1)
Iter 0527 | Time 2.0469(2.0049) | Loss 1.396151(1.401748) | NFE Forward 56(55.4) | NFE Backward 186(191.6)
Iter 0528 | Time 2.0284(2.0066) | Loss 1.392334(1.401089) | NFE Forward 56(55.4) | NFE Backward 198(192.1)
Iter 0529 | Time 1.8959(1.9988) | Loss 1.392372(1.400479) | NFE Forward 50(55.0) | NFE Backward 180(191.2)
Iter 0530 | Time 2.0656(2.0035) | Loss 1.410434(1.401176) | NFE Forward 56(55.1) | NFE Backward 192(191.3)
Iter 0531 | Time 2.0076(2.0038) | Loss 1.389933(1.400389) | NFE Forward 56(55.2) | NFE Backward 192(191.3)
Iter 0532 | Time 2.0312(2.0057) | Loss 1.383854(1.399231) | NFE Forward 56(55.2) | NFE Backward 192(191.4)
Iter 0533 | Time 2.0777(2.0108) | Loss 1.377591(1.397717) | NFE Forward 56(55.3) | NFE Backward 198(191.9)
Iter 0534 | Time 2.0284(2.0120) | Los

Iter 0602 | Time 2.1942(2.0547) | Loss 1.344355(1.350030) | NFE Forward 56(56.2) | NFE Backward 198(195.5)
Iter 0603 | Time 2.2345(2.0673) | Loss 1.332846(1.348828) | NFE Forward 56(56.1) | NFE Backward 198(195.7)
Iter 0604 | Time 2.2041(2.0769) | Loss 1.348023(1.348771) | NFE Forward 56(56.1) | NFE Backward 192(195.4)
Iter 0605 | Time 2.3305(2.0947) | Loss 1.327962(1.347315) | NFE Forward 56(56.1) | NFE Backward 204(196.0)
Iter 0606 | Time 2.3621(2.1134) | Loss 1.339742(1.346785) | NFE Forward 56(56.1) | NFE Backward 222(197.8)
Iter 0607 | Time 2.2680(2.1242) | Loss 1.364141(1.347999) | NFE Forward 56(56.1) | NFE Backward 192(197.4)
Iter 0608 | Time 2.2408(2.1324) | Loss 1.346485(1.347893) | NFE Forward 62(56.5) | NFE Backward 198(197.5)
Iter 0609 | Time 2.1930(2.1366) | Loss 1.327689(1.346479) | NFE Forward 56(56.5) | NFE Backward 198(197.5)
Iter 0610 | Time 2.1843(2.1399) | Loss 1.325276(1.344995) | NFE Forward 56(56.4) | NFE Backward 192(197.1)
Iter 0611 | Time 2.2501(2.1477) | Los

Iter 0679 | Time 2.3105(2.2959) | Loss 1.279804(1.302633) | NFE Forward 56(56.9) | NFE Backward 210(201.1)
Iter 0680 | Time 2.3582(2.3003) | Loss 1.308484(1.303043) | NFE Forward 56(56.8) | NFE Backward 204(201.3)
Iter 0681 | Time 2.2593(2.2974) | Loss 1.308234(1.303406) | NFE Forward 56(56.8) | NFE Backward 192(200.7)
Iter 0682 | Time 2.3449(2.3007) | Loss 1.287156(1.302269) | NFE Forward 62(57.1) | NFE Backward 198(200.5)
Iter 0683 | Time 2.3650(2.3052) | Loss 1.299724(1.302091) | NFE Forward 62(57.5) | NFE Backward 204(200.7)
Iter 0684 | Time 2.3712(2.3099) | Loss 1.285798(1.300950) | NFE Forward 56(57.4) | NFE Backward 198(200.5)
Iter 0685 | Time 2.3439(2.3122) | Loss 1.307103(1.301381) | NFE Forward 62(57.7) | NFE Backward 198(200.4)
Iter 0686 | Time 2.3958(2.3181) | Loss 1.311582(1.302095) | NFE Forward 62(58.0) | NFE Backward 216(201.5)
Iter 0687 | Time 2.3416(2.3197) | Loss 1.302940(1.302154) | NFE Forward 56(57.9) | NFE Backward 198(201.2)
Iter 0688 | Time 2.2914(2.3177) | Los

Iter 0756 | Time 2.1432(2.1675) | Loss 1.270433(1.271096) | NFE Forward 56(57.0) | NFE Backward 204(204.7)
Iter 0757 | Time 2.1034(2.1630) | Loss 1.279793(1.271705) | NFE Forward 56(56.9) | NFE Backward 198(204.2)
Iter 0758 | Time 2.1241(2.1603) | Loss 1.257949(1.270742) | NFE Forward 62(57.3) | NFE Backward 204(204.2)
Iter 0759 | Time 2.1284(2.1581) | Loss 1.282560(1.271569) | NFE Forward 56(57.2) | NFE Backward 204(204.2)
Iter 0760 | Time 2.0710(2.1520) | Loss 1.265557(1.271148) | NFE Forward 56(57.1) | NFE Backward 198(203.7)
Iter 0761 | Time 2.2465(2.1586) | Loss 1.241835(1.269096) | NFE Forward 56(57.0) | NFE Backward 222(205.0)
Iter 0762 | Time 2.1520(2.1581) | Loss 1.264356(1.268764) | NFE Forward 62(57.4) | NFE Backward 204(204.9)
Iter 0763 | Time 2.1381(2.1567) | Loss 1.277210(1.269356) | NFE Forward 56(57.3) | NFE Backward 204(204.9)
Iter 0764 | Time 2.1072(2.1533) | Loss 1.252516(1.268177) | NFE Forward 62(57.6) | NFE Backward 204(204.8)
Iter 0765 | Time 2.1508(2.1531) | Los

Iter 0833 | Time 2.2558(2.1879) | Loss 1.254084(1.240353) | NFE Forward 62(60.4) | NFE Backward 222(213.9)
Iter 0834 | Time 2.1403(2.1846) | Loss 1.232861(1.239829) | NFE Forward 62(60.6) | NFE Backward 210(213.6)
Iter 0835 | Time 2.4613(2.2040) | Loss 1.243619(1.240094) | NFE Forward 62(60.7) | NFE Backward 216(213.8)
Iter 0836 | Time 2.1938(2.2033) | Loss 1.230763(1.239441) | NFE Forward 56(60.3) | NFE Backward 210(213.5)
Iter 0837 | Time 2.2328(2.2053) | Loss 1.227434(1.238600) | NFE Forward 56(60.0) | NFE Backward 228(214.5)
Iter 0838 | Time 2.1688(2.2028) | Loss 1.250951(1.239465) | NFE Forward 56(59.7) | NFE Backward 216(214.6)
Iter 0839 | Time 2.1298(2.1977) | Loss 1.249742(1.240184) | NFE Forward 56(59.5) | NFE Backward 210(214.3)
Iter 0840 | Time 2.2548(2.2017) | Loss 1.218276(1.238651) | NFE Forward 62(59.7) | NFE Backward 222(214.9)
Iter 0841 | Time 2.2761(2.2069) | Loss 1.232519(1.238222) | NFE Forward 62(59.8) | NFE Backward 228(215.8)
Iter 0842 | Time 2.2560(2.2103) | Los

Iter 0910 | Time 2.3222(2.2363) | Loss 1.173867(1.210756) | NFE Forward 62(59.9) | NFE Backward 234(221.1)
Iter 0911 | Time 2.3124(2.2416) | Loss 1.205291(1.210374) | NFE Forward 62(60.0) | NFE Backward 234(222.0)
Iter 0912 | Time 2.3269(2.2476) | Loss 1.222002(1.211188) | NFE Forward 62(60.2) | NFE Backward 240(223.3)
Iter 0913 | Time 2.1820(2.2430) | Loss 1.227286(1.212314) | NFE Forward 62(60.3) | NFE Backward 210(222.4)
Iter 0914 | Time 2.2862(2.2460) | Loss 1.236725(1.214023) | NFE Forward 56(60.0) | NFE Backward 234(223.2)
Iter 0915 | Time 2.2296(2.2449) | Loss 1.212039(1.213884) | NFE Forward 68(60.6) | NFE Backward 222(223.1)
Iter 0916 | Time 2.1660(2.2393) | Loss 1.242872(1.215913) | NFE Forward 62(60.7) | NFE Backward 204(221.8)
Iter 0917 | Time 2.1384(2.2323) | Loss 1.209224(1.215445) | NFE Forward 56(60.3) | NFE Backward 204(220.5)
Iter 0918 | Time 2.2369(2.2326) | Loss 1.175334(1.212637) | NFE Forward 62(60.5) | NFE Backward 222(220.6)
Iter 0919 | Time 2.2434(2.2333) | Los

Iter 0987 | Time 2.3385(2.2562) | Loss 1.167503(1.186049) | NFE Forward 62(61.1) | NFE Backward 228(223.7)
Iter 0988 | Time 2.1985(2.2522) | Loss 1.187906(1.186179) | NFE Forward 62(61.1) | NFE Backward 216(223.2)
Iter 0989 | Time 2.3085(2.2561) | Loss 1.199634(1.187121) | NFE Forward 62(61.2) | NFE Backward 228(223.5)
Iter 0990 | Time 2.2150(2.2533) | Loss 1.172325(1.186085) | NFE Forward 62(61.2) | NFE Backward 222(223.4)
Iter 0991 | Time 2.2567(2.2535) | Loss 1.183592(1.185911) | NFE Forward 62(61.3) | NFE Backward 222(223.3)
Iter 0992 | Time 2.2259(2.2516) | Loss 1.177768(1.185341) | NFE Forward 62(61.3) | NFE Backward 222(223.2)
Iter 0993 | Time 2.1875(2.2471) | Loss 1.185739(1.185369) | NFE Forward 62(61.4) | NFE Backward 216(222.7)
Iter 0994 | Time 2.2867(2.2499) | Loss 1.183351(1.185227) | NFE Forward 62(61.4) | NFE Backward 228(223.1)
Iter 0995 | Time 2.2949(2.2530) | Loss 1.185099(1.185218) | NFE Forward 62(61.5) | NFE Backward 228(223.4)
Iter 0996 | Time 2.2943(2.2559) | Los