Skip to content

MurrellGroup/SideChainAutoencoders

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 

Repository files navigation

Protein Sidechain Sliced Wasserstein Autoencoder (SWAE)

This repository contains a Julia implementation of a Sliced Wasserstein Autoencoder (SWAE) designed to learn a generative latent space for protein sidechain conformations. The model maps the 3D coordinates of sidechain atoms (atoms 4-14 in the standard atom14 representation) into a low-dimensional latent space.

Model Overview

The model is a deterministic autoencoder regularized using the Sliced Wasserstein distance to force the latent distribution toward a standard multivariate normal distribution $N(0, I)$. This allows for efficient sampling and meaningful latent space interpolations.

Key Features:

  • Latent Dimension: 8-dimensional latent space.
  • Architecture: Uses StarGLU layers and LayerNorm for robust feature extraction and reconstruction.
  • Loss Functions:
    • Reconstruction Loss: Gaussian Negative Log-Likelihood with Maximum Likelihood variance estimation (gaussian_nll_mlvar).
    • SWAE Regularization: Sliced Wasserstein distance between projected latent codes and standard normal quantiles.
    • Inverse Consistency Loss: Ensures that $E(D(z)) \approx z$ for random $z \sim N(0, I)$, improving the generative quality of the latent space.
  • Training: Optimized using the Muon optimizer with a custom burn-in and decay learning rate schedule.

Implementation Details

Data Representation

The model operates on "local" coordinates. Sidechain atom positions are transformed into a local frame defined by the backbone atoms (N, CA, C). The 11 sidechain atoms (atoms 4 through 14) are flattened into a 33-feature vector (3 coordinates $\times$ 11 atoms).

Model Components

  • Encoder: Maps 33 input features to 8 latent dimensions.
  • Decoder: Maps 8 latent dimensions back to 33 features (reconstructed local coordinates).
  • StarGLU: A custom Gated Linear Unit variant used in the dense blocks for better gradient flow.

Dependencies

The project uses the following Julia packages:

  • Flux.jl & Zygote.jl: Deep learning and auto-differentiation.
  • CUDA.jl: GPU acceleration.
  • DLProteinFormats.jl: For protein structure handling and featurization.
  • ProteinChains.jl: For PDB generation and coordinate manipulation.
  • CannotWaitForTheseOptimisers.jl: For the Muon optimizer.
  • LearningSchedules.jl: For learning rate management.
  • JLD2.jl: For saving and loading model states.

Usage

Training

To train the model, ensure you have the required data file pdb-atom14.jld2 (from here) in the root directory and just paste the contents of swae.jl into the Julia REPL.

Outputs

The script generates several outputs:

  1. Model Checkpoints: Saved as .jld files (e.g., sidechain_SWAE_dim8_model256_ns0.1.jld).
  2. Visualizations: A PDF panel (pdbs/latent_distributions_panel.pdf) showing the marginal distributions of the latent dimensions.
  3. PDB Files: Sampled and reconstructed protein structures in the pdbs/ directory:
    • original.pdb: The ground truth structure.
    • recon_nonoise.pdb: Reconstruction from the mean latent code.
    • recon_noise_i.pdb: Reconstructions with added latent noise for robustness checks.
    • pure_noise_i.pdb: Sidechains generated by sampling directly from the $N(0, I)$ prior.

Coordinate System

Coordinates are handled in nanometers (nm) for internal calculations and converted to Angstroms (Å) for PDB output to maintain compatibility with standard visualization tools like PyMOL or ChimeraX.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages