This repository contains a Julia implementation of a Sliced Wasserstein Autoencoder (SWAE) designed to learn a generative latent space for protein sidechain conformations. The model maps the 3D coordinates of sidechain atoms (atoms 4-14 in the standard atom14 representation) into a low-dimensional latent space.
The model is a deterministic autoencoder regularized using the Sliced Wasserstein distance to force the latent distribution toward a standard multivariate normal distribution
- Latent Dimension: 8-dimensional latent space.
-
Architecture: Uses
StarGLUlayers andLayerNormfor robust feature extraction and reconstruction. -
Loss Functions:
-
Reconstruction Loss: Gaussian Negative Log-Likelihood with Maximum Likelihood variance estimation (
gaussian_nll_mlvar). - SWAE Regularization: Sliced Wasserstein distance between projected latent codes and standard normal quantiles.
-
Inverse Consistency Loss: Ensures that
$E(D(z)) \approx z$ for random$z \sim N(0, I)$ , improving the generative quality of the latent space.
-
Reconstruction Loss: Gaussian Negative Log-Likelihood with Maximum Likelihood variance estimation (
-
Training: Optimized using the
Muonoptimizer with a custom burn-in and decay learning rate schedule.
The model operates on "local" coordinates. Sidechain atom positions are transformed into a local frame defined by the backbone atoms (N, CA, C). The 11 sidechain atoms (atoms 4 through 14) are flattened into a 33-feature vector (3 coordinates
- Encoder: Maps 33 input features to 8 latent dimensions.
- Decoder: Maps 8 latent dimensions back to 33 features (reconstructed local coordinates).
- StarGLU: A custom Gated Linear Unit variant used in the dense blocks for better gradient flow.
The project uses the following Julia packages:
Flux.jl&Zygote.jl: Deep learning and auto-differentiation.CUDA.jl: GPU acceleration.DLProteinFormats.jl: For protein structure handling and featurization.ProteinChains.jl: For PDB generation and coordinate manipulation.CannotWaitForTheseOptimisers.jl: For theMuonoptimizer.LearningSchedules.jl: For learning rate management.JLD2.jl: For saving and loading model states.
To train the model, ensure you have the required data file pdb-atom14.jld2 (from here) in the root directory and just paste the contents of swae.jl into the Julia REPL.
The script generates several outputs:
-
Model Checkpoints: Saved as
.jldfiles (e.g.,sidechain_SWAE_dim8_model256_ns0.1.jld). -
Visualizations: A PDF panel (
pdbs/latent_distributions_panel.pdf) showing the marginal distributions of the latent dimensions. -
PDB Files: Sampled and reconstructed protein structures in the
pdbs/directory:-
original.pdb: The ground truth structure. -
recon_nonoise.pdb: Reconstruction from the mean latent code. -
recon_noise_i.pdb: Reconstructions with added latent noise for robustness checks. -
pure_noise_i.pdb: Sidechains generated by sampling directly from the$N(0, I)$ prior.
-
Coordinates are handled in nanometers (nm) for internal calculations and converted to Angstroms (Å) for PDB output to maintain compatibility with standard visualization tools like PyMOL or ChimeraX.