Code to run and validate VAEs fit to protein MSA data. This implementation provides an abstract base class to help with implementation of new VAE architectures, as well as plotting functionality. This code was developed for publication [1] and was use to make the plots.
This implementation builds on the code from [2] hosted here, which itself appears related to the example VAE code provided with the Keras library provided here. Both previous implementations are under MIT license. The Keras code has since been updated, hosted here.
Related repositories:
- https://github.com/ahaldane/Mi3-GPU provides the implementation for the Potts model used in [1], as well as helper scripts useful to process MSA data.
- https://github.com/ahaldane/HOM_r20 provides a script used to make the "r20" plots used in [1] based on the VAE output.
- https://github.com/alagauche/generative_capacity_metrics provides code used to make other plots in [1] based on the VAE output.
[1] TBD, Upcoming.
[2] Sam Sinai, Eric Kelsic, George M. Church, and Martin A. Nowak. Variational auto-encoding of protein sequences. [arXiv:1712.03346] [cs, q-bio], January 2018.
Requirements:
- Python modules: Tensorflow 2.0.0 (Keras 2.3.1), scipy 1.5.0, maplotlib 3.2.2.
- C compiler
Optional: A fast GPU. It is recommended to install a GPU-aware version of Tensorflow, as running on GPUs will greatly speed up the software.
Run "make" to compile the helper module for loading sequence files. After compiling, a file seqtools.xxx.so should have been created.
The example directory shows an example run.sh script with example data. To run: Go to the example
directory, and on a system with a GPU run the run_svae.sh
script. It should take a few minutes total, and produce a number of intermediate files and plots, including:
l8_10K_train_log.csv
- tensorflow training logl8_10K_param.pkl
,l8_10K_vae.k
- model parametersTraining_loss_l8_10K.png
- plot of the training loss per epochLatentTraining_1d_l8_10K.png
- 1d plots of training sequences in latent spaceLatentTraining_l8_10K.png
- 2d plots of training sequences in latent spaceTVD_l8_10K.png
- plot of Total Variation Distance from target distribution per epochgen_l8_10K_100K
- 100K sequneces generated by the model
The plots should look similar to the example plots in the expected_output
directory.
Then you can analyze the resulting data by running run_analyze.sh
, which will create a plot of the covariation scores.
The vaes.py script is the main tool, and will output usage info when run. It can be run in one of the following ways:
$ vaes.py my_name train Church_VAE trainingMSA 2 250
"my_name" is a name for use as prefix for output files (pickled VAE). "train" tells the script to run in training mode. "Church_VAE" tells the script to use the Church VAE (can also be Deep_VAE, test_VAE) "trainingMSA" is the file (MSA) to train on. The MSA should be formatted as one sequence per line, without any header lines or other info (see the Mi3-GPU docs). "2" is the number of latent dimensions. "250" is the size of the encoder/decoder layers for the church VAE (only necessary for Church_VAE, other vaes have different options here).
Optional arg: "--TVDseqs msafile" will track the TVD of the hamming distance distribution on each epock, and make a plot.
In this mode the script will create create plots of the loss function over epochs, and a csv file containing the loss over epochs.
$ vaes.py my_name gen 100000 -o gen100K
"my_name" is the name prefix for pickled VAE & output files. "gen" tells the script to run in sequence generation mode. 1000000 is the # of sequences to generate. The "-o" option specifies the output file.
$ vaes.py my_name energy MSAfile --ref_energy E.npy
"my_name" is the name prefix for pickled VAE & output files. "energy" tells the script to run in energy computation mode. "MSAfile" is the MSA to compute energies for --ref_energy (optional) is reference energies to include in comparison plot
$VAE my_name plot_latent trainingMSA
Will create some png files of the latent space.
See the example "run_svae.sh" file. Can also compute TVD plot, C correlation plot, and more.