Skip to content

ahaldane/MSA_VAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Description

Code to run and validate VAEs fit to protein MSA data. This implementation provides an abstract base class to help with implementation of new VAE architectures, as well as plotting functionality. This code was developed for publication [1] and was use to make the plots.

This implementation builds on the code from [2] hosted here, which itself appears related to the example VAE code provided with the Keras library provided here. Both previous implementations are under MIT license. The Keras code has since been updated, hosted here.

Related repositories:

[1] TBD, Upcoming.

[2] Sam Sinai, Eric Kelsic, George M. Church, and Martin A. Nowak. Variational auto-encoding of protein sequences. [arXiv:1712.03346] [cs, q-bio], January 2018.

Setup

Requirements:

  • Python modules: Tensorflow 2.0.0 (Keras 2.3.1), scipy 1.5.0, maplotlib 3.2.2.
  • C compiler

Optional: A fast GPU. It is recommended to install a GPU-aware version of Tensorflow, as running on GPUs will greatly speed up the software.

Run "make" to compile the helper module for loading sequence files. After compiling, a file seqtools.xxx.so should have been created.

Example Run

The example directory shows an example run.sh script with example data. To run: Go to the example directory, and on a system with a GPU run the run_svae.sh script. It should take a few minutes total, and produce a number of intermediate files and plots, including:

  • l8_10K_train_log.csv - tensorflow training log
  • l8_10K_param.pkl, l8_10K_vae.k - model parameters
  • Training_loss_l8_10K.png - plot of the training loss per epoch
  • LatentTraining_1d_l8_10K.png - 1d plots of training sequences in latent space
  • LatentTraining_l8_10K.png - 2d plots of training sequences in latent space
  • TVD_l8_10K.png - plot of Total Variation Distance from target distribution per epoch
  • gen_l8_10K_100K - 100K sequneces generated by the model

The plots should look similar to the example plots in the expected_output directory.

Then you can analyze the resulting data by running run_analyze.sh, which will create a plot of the covariation scores.

Usage

The vaes.py script is the main tool, and will output usage info when run. It can be run in one of the following ways:

To train a new model:

$ vaes.py my_name train Church_VAE trainingMSA 2 250

"my_name" is a name for use as prefix for output files (pickled VAE). "train" tells the script to run in training mode. "Church_VAE" tells the script to use the Church VAE (can also be Deep_VAE, test_VAE) "trainingMSA" is the file (MSA) to train on. The MSA should be formatted as one sequence per line, without any header lines or other info (see the Mi3-GPU docs). "2" is the number of latent dimensions. "250" is the size of the encoder/decoder layers for the church VAE (only necessary for Church_VAE, other vaes have different options here).

Optional arg: "--TVDseqs msafile" will track the TVD of the hamming distance distribution on each epock, and make a plot.

In this mode the script will create create plots of the loss function over epochs, and a csv file containing the loss over epochs.

To generate sequences:

$ vaes.py my_name gen 100000 -o gen100K

"my_name" is the name prefix for pickled VAE & output files. "gen" tells the script to run in sequence generation mode. 1000000 is the # of sequences to generate. The "-o" option specifies the output file.

To compute energies

$ vaes.py my_name energy MSAfile --ref_energy E.npy

"my_name" is the name prefix for pickled VAE & output files. "energy" tells the script to run in energy computation mode. "MSAfile" is the MSA to compute energies for --ref_energy (optional) is reference energies to include in comparison plot

To Plot the Latent Space

$VAE my_name plot_latent trainingMSA

Will create some png files of the latent space.

Other computations:

See the example "run_svae.sh" file. Can also compute TVD plot, C correlation plot, and more.