Pytroch implementation of Neurips 22 workshop paper Latent GP-ODEs with Informative Priors by Ilze Amanda Auzina, Cagatay Yildiz and Efstratios Gavves. Paper
We tackle the exisitng limitations of GP-ODE models:
-
applicable only on low-dimensional data settings. The input is either some low dimension observations or simulation data of the actual ODE system.
-
no the use of prior knowledge of the system, an integral part of GP modeling.
We propose VAE-GP-ODE: a probabilistic dynamic model that extends previous work by learning dynamics from high-dimensional data with a structured GP prior. Our model is trained end-to-end using variational inference.
The code was developed and tested with python3.8
and Pytorch 1.13
.
Fixed Initial Angle : the rotating mnist dataset with fixed initial angle can be donwloaded from here. The data should be placed in experiments/data/rot_mnist
directory.
Random Initial Angle : the rotating mnist dataset with random intial angle can be created by running the code with flag --rotrand True
, the code with automatically shuffle the existing data in the correct manner.
1st ODE The experiments can be run from the command line with passing arguments to the relevant experimental setup. If the current directory is experiments then an example command is as follows
python main.py --ode 1 --kernel RBF --D_in 6 --D_out 6 --latent_dim 6 --Nepoch 5000 --lr 0.001 --variance 1.0 --lengthscale 2.0 --rotrand True
The above command will run a first order ODE model with an RBF kernel with VAE latent space of 6. For experiment with divergence free kernel the above command can be adjust to --kernel DF
.
2nd ODE For a second orde ODE system the above command should be changed to:
python main.py --ode 2 --kernel RBF --D_in 6 --D_out 3 --latent_dim 3 --Nepoch 5000 --lr 0.001 --variance 1.0 --lengthscale 2.0 --rotrand True
The main different is that the ouput and latent dimensionality is reduced to 3 as there are 2 encoders now, one for the state and one for the velocity.
Pretrained VAE To train a decoupled model (VAE training separate form GP-ODE training). You can run the main_vae.py
file that will train a VAE, for example to train a model on dataset with 16 rotation angles (T=16) and with latent space dimensionality of 6 run the following
python main_vae.py --n_angle 16 --latent_dim 6
Runnin the above will create a new data repository 'data/moving_mnist' where the training data for the VAE training will be stored. The rotation angle can be increased to also 64, meaning that there will be 64 timesteps for a full rotation.
Subsequently the trained VAE can be used in the GP-ODE training code as
python main.py --ode 1 --rotrand True --pretrained True --vae_path 'specify your model path'
All results will be stores in directotory results with a corresponding timestamp as folder name. Each folder includes a log file with all experimental details and the training performance report.
All trained models reported in the final paper can be downloaded from here
In the plot_dynamics.ipynb
you can see the performance results of the trained models as reported in the final paper.
In the plot_dynamics_extended.ipynb
there are additional models reported for further investigation of the models sensitivty to the training setup.
Figure 1. Reconstructed test sequences. Top row: ground truth. Second row: Pretrained VAE (model according to \cite{solin2021scalable}). Bottom 3 rows: VAE-GP-ODE with varying order differential equation ($1^\text{st}$ or $2^\text{nd}$ order), and without (RBF kernel) or with (DF kernel) informative prior. \ Conditioning: input for the encoder ($1^\text{st}$ ODE only $\mathbf{x}1$, $2^\text{nd}$ ODE $\mathbf{x}{1:5}$ frames); Extrapolation: model's prediction within the training data sequence length ($T_{max}=16$); Forecasting: model's prediction outside the training data sequence length ($T>16$).
Figure 2. Latent space for first-order ODE, DF kernel. Each color corresponds to a latent trajectory associated with a distinct data sample. A circle indicates the start of the trajectory and the subsequent stars a subsequent time point up until T=32.