This repository contains the Protein Set Transformer (PST) framework for contextualizing protein language model embeddings at genome-scale to produce genome embeddings. You can use this code to train your own models. Using our foundation model pre-trained on viruses (vPST), you can also generate genome embeddings for input viruses.
For more information, see our manuscript:
Protein Set Transformer: A protein-based genome language model to power high diversity viromics.
Cody Martin, Anthony Gitter, and Karthik Anantharaman.
bioRxiv, 2024, doi: 10.1101/2024.07.26.605391.
We highly recommend using uv for installation, since it will be significantly faster to solve the dependencies and install everything.
If you don't have the ability to install
uv, then just removeuvfrom the following commands.
If you wish, you can setup a virtual environment to install the PST dependencies into using, for example, conda, mamba, or pyenv:
mamba create -n pst -c conda-forge 'python>=3.9'Just make sure you activate your virtual environment before proceeding with the installation.
You can try simply doing:
uv pip install torch
uv pip install ptn-set-transformer --no-build-isolationThis will do 2 things:
- Install the latest version of
PyTorchwith the defaultCUDAruntime, even if your system does not have GPUs- This will run fine on CPU-only systems, but the install will be larger
- Install the
PSTlibrary and force some of the requiredPyTorchextension libraries (specificallyPyTorch-Scatter) to build on your target machine. This will take a few minutes.
For most use cases, this should work fine.
Optional Note: If you would like to install the latest release from this repository, you will likely need to link your git command line interface with an online github account. Follow this link for help setting up git at the command line.
If you would like to proceed further for a more advanced setup or ran into issues, then try the more manual setups below.
You can check the installation pages for PyTorch and PyTorch-Geometric.
Since PyTorch > 2.5.0, conda is no longer an option to install PyTorch, and we have uncapped the PyTorch version since the minor updates will not affect PST. Thus, these examples will show pip/uv pip.
uv pip install torch --index-url https://download.pytorch.org/whl/cpuThen, depending on the version of PyTorch installed, the following command needs to be updated to install the PyTorch extension libraries (PyTorch-Geometric, PyTorch-Scatter, PyTorch-Sparse):
uv pip install torch_geometric torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-{TORCHVERSION}+cpu.htmlwhere {TORCHVERSION} is replaced with the specific PyTorch version (ie: 2.8.0)
uv pip install ptn-set-transformeruv pip install torchThis will install the PyTorch library with the default CUDA runtime.
If you wish to download a precompiled PyTorch library with a different CUDA runtime, then you can adjust the command to be:
uv pip install torch --index-url https://download.pytorch.org/whl/{CUDAVERSION}where {CUDAVERSION} is a supported CUDA version of the latest PyTorch release (i.e. cu129).
Note the {CUDAVERSION} of the PyTorch library installed since it will be needed for the next step.
Then, depending on the version of PyTorch installed, the following command needs to be updated to install the PyTorch extension libraries (PyTorch-Geometric, PyTorch-Scatter, PyTorch-Sparse):
uv pip install torch_geometric torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-{TORCHVERSION}+{CUDAVERSION}.htmlwhere {TORCHVERSION} is replaced with the specific PyTorch version (ie: 2.8.0) and {CUDAVERSION} is the CUDA version of the installed PyTorch release (i.e. cu129)
uv pip install ptn-set-transformerDue to the various PyTorch dependencies, which are typically shipped as precompiled binaries for specific Python/CUDA/GCC/Linux/etc versions, there can sometimes be version conflict issues that can be hard to resolve.
We have primarily encountered these errors when installing the PyTorch extension libaries, so we will focus on how to resolve issues installing torch_geometric, torch_scatter, and torch_sparse.
If PyTorch was installed CPU-only, then the extension libraries also need to be installed CPU-only.
Ensure that they are installed like this:
uv pip install torch_geometric torch_scatter torch_sparse -f https://data.pyg.org/whl/torch-{TORCHVERSION}+cpu.htmlwhere {TORCHVERSION} is replaced with the specific PyTorch version (ie: 2.8.0)
Conversely, if PyTorch is installed with GPU support, then the extension libraries also need to be installed with GPU support corresponding to the same PyTorch version and CUDA runtime version. If you are unsure about this, then you can obtain this information like this:
python -c 'import torch; print(torch.__version__)'which will return a string such as 2.8.0+cu126 or 2.8.0+cpu.
The precompiled binaries are compiled with specific versions of your system's C compiler, which may not be present on your system. You could update your C compiler/C lib or install a version that is compatible with the precompiled binaries. However, it is much simpler to recompile these libraries for your target system:
uv pip install torch_geometric torch_scatter torch_sparse --verbose --no-build-isolationThis will take several minutes.
We implemented a hyperparameter tuning cross validation workflow implemented using Lightning Fabric in a base library called lightning-crossval. Part of our specific implementation for hyperparameter tuning is also implemented in the PST library.
The latest versions of PST requires the tuning dependencies, so there is no additional required installed.
Upon successful installation, you will have the pst executable to train, tune, and predict. There are also other modules included as utilties that you can see using pst -h.
You will need to first download a trained vPST model:
pst download --model.choices="[PST-TL-P__small]"This will download both vPST models into ./pstdata, but you can change the download location using --outdir.
You can use the test data for a test prediction run:
pst predict \
--file examples/sample_data.graphfmt.h5 \
--predict.checkpoint pstdata/PST-TL-P__small.ckpt \
--predict.output PST-TL-P__small_predictions.h5The results from the above command are available at examples/PST-TL-P__small_predictions.h5. This test run takes fewer than 1 minute using a single CPU.
If you are unfamiliar with .h5 files, you can use pytables (installed with PST as a dependency) to inspect .h5 files in python, or you can install hdf5 and use the h5ls to inspect the fields in the output file.
There should be 3 fields in the prediciton file:
-
attnwhich contains the per-protein attention values (shape:$N_{prot} \times N_{heads}$ ) -
ctx_ptnwhich contains the contextualized PST protein embeddings (shape:$N_{prot} \times D$ ) -
genomewhich contains the PST genome embeddings (shape:$N_{genome} \times D$ )- Prior to version
1.2.0, this was calleddata.
- Prior to version
We have provided a PST inference notebook that can be used within a Google Colab runtime environment. You can use free (although less powerful and lower memory) GPUs for inference of relatively small datasets (ie <10k genomes encoding <250k proteins).
All data associated with the initial training model training can be found here: https://doi.org/10.5061/dryad.d7wm37q8w
We have provided the README to the DRYAD data repository to render here. Additionally, we have provided a programmatic way to access the data from the command line using pst download:
NOTE: we have recently changed the DRYAD repository corresponding to manuscript resubmission, so these commands will not work at the moment. However, the latest dataset will be available to download directly through DRYAD soon.
usage: pst [options] download [-h] [--config CONFIG] [--print_config[=flags]] [--manuscript CONFIG]
[--manuscript.choices CHOICES] [--cluster CONFIG] [--cluster.choices CHOICES]
[--model CONFIG] [--model.choices CHOICES] [--embeddings CONFIG]
[--embeddings.choices CHOICES] [--all {true,false}] [--outdir OUTDIR]
Download mode to download data and trained models from DRYAD. Example usage: pst download
--manuscript.choices="[source_data, supplementary_data]"
options:
-h, --help Show this help message and exit.
--config CONFIG Path to a configuration file.
--print_config[=flags]
Print the configuration after applying all other arguments and exit. The optional flags
customizes the output and are one or more keywords separated by comma. The supported
flags are: comments, skip_default, skip_null.
--all {true,false} Download all data from the DRYAD repository. (type: bool, default: False)
--outdir OUTDIR Output directory to save files. (type: <class 'Path'>, default: pstdata)
MANUSCRIPT DATA:
--manuscript CONFIG Path to a configuration file.
--manuscript.choices CHOICES, --manuscript.choices+ CHOICES
Download manuscript-specific data. Defaults to only the README. (type:
list[Literal['source_data', 'supplementary_data', 'supplementary_tables',
'host_prediction', 'fasta', 'foldseek_databases', 'README']] | None, default: null)
CLUSTER DATA:
--cluster CONFIG Path to a configuration file.
--cluster.choices CHOICES, --cluster.choices+ CHOICES
Download genome or protein clusters. (type: list[Literal['genome', 'protein']] | None,
default: null)
TRAINED MODELS:
--model CONFIG Path to a configuration file.
--model.choices CHOICES, --model.choices+ CHOICES
Download pretrained models. (type: list[Literal['PST-TL-P__small', 'PST-TL-P__large',
'PST-TL-T__small', 'PST-TL-T__large', 'PST-MLM']] | None, default: null)
EMBEDDINGS:
--embeddings CONFIG Path to a configuration file.
--embeddings.choices CHOICES, --embeddings.choices+ CHOICES
Download embedding files. (type: list[Literal['esm2', 'IMGVR_PST-TL-P__large',
'IMGVR_PST-TL-P__small', 'IMGVR_PST-TL-T__large', 'IMGVR_PST-TL-T__small', 'MGnify_PST-
TL-P__large', 'MGnify_PST-TL-P__small', 'MGnify_PST-TL-T__large', 'MGnify_PST-TL-
T__small', 'genslm_ORF', 'train_PST-TL-P__large', 'train_PST-TL-P__small', 'train_PST-
TL-T__large', 'train_PST-TL-T__small', 'PST-TL_genome', 'other_genome']] | None,
default: null)Example Usage:
You need to write your arguments like a Python list, all in quotes, which enables downloading multiple files at a time.
You still need to write the command this way even if you download 1 file.
pst download \
--model.choices="[PST-TL-P__small, PST-TL-P__large]" \
--manuscript.choices="[supplementary_tables]"The DRYAD repository contains all PST models pretrained on our viral genome dataset. Each model was trained with the same input data.
The training and test data are also available in the above data repository.
Here is a summary of each model:
| Model | # Encoder layers | # Attention heads | # Params | Embedding dim |
|---|---|---|---|---|
PST-TL-T__small |
5 | 4 | 5.4M | 400 |
PST-TL-T__large |
20 | 32 | 177.9M | 1280 |
PST-TL-P__small |
5 | 4 | 5.4M | 400 |
PST-TL-P__large |
5 | 4 | 21.3M | 800 |
PST-MLM-T__small |
5 | 4 | 23.8M | 960 |
PST-MLM-T__large |
5 | 4 | 93.6M | 1920 |
PST-MLM-P__small |
30 | 32 | 93M | 960 |
PST-MLM-P__large |
10 | 8 | 185.8M | 1920 |
The model name follows this format: PST-OBJECTIVE-CV__ESMsize, where:
OBJECTIVErefers to the training objectiveTL= triplet lossMLM= masked language modeling
CVrefers to how the cross validation groups were definedP= non overlapping protein diversityT= viral taxonomic realm
ESMsizerefers to the relative size of ESM2 embeddings used to train each model, not the size fo the PST model itselflarge=esm2_t30_150M(640 dim)small=esm2_t6_8M(320 dim)
Please read the wiki for more information about how to use these models, extend them for finetuning and transfer learning, and the specific model API to integrate new models into your own workflows. Note: This is still a work in progress. There is an example Jupyter notebook provided
The expected runtime for training the final models after hyperparameter tuning can be found in Supplementary Table 4 and ranged from 3.9-33.7h on 1 A100 GPU.
These are estimates of inference times for a dataset composed of ~12k viral genomes encoding ~140k proteins (such as the MGnify test dataset):
| Model Size | Accelerator | ESM2 embedding* | PST inference | Total Time |
|---|---|---|---|---|
| Large | 1 A100 GPU | 18 min | <1 min | ~19 min |
| Large | 128 CPUs | 6h | <1 min | ~6h |
| Large | 8 CPUs | 96h | 11 min | ~96h |
| Small | 1 A100 GPU | 9 min | <1 min | ~9 min |
| Small | 128 CPUs | 3h | <1 min | ~3h |
| Small | 8 CPUs | 48h | 6 min | ~48h |
* ESM2 embeddings are computed independently for each protein, so input FASTA files can be split into equal batches and processed in parallel with as many GPUs as available.
- These will need to be concatenated in the same order as the original FASTA file.
Memory usage should be negligible for inference, especially if using a LazyGenomeDataset. Less than 4GB of memory is needed for inference.
We have provided code for all analyses associated with the manuscript in the manuscript folder. The README in that folder links each method section from the manuscript to a specific Jupyter notebook code implementation.
There are several other repositories associated with the model code and the manuscript:
| Repository | Description |
|---|---|
| lightning-crossval | Our fold-synchronized cross validation strategy implemented with Lightning Fabric |
| esm_embed | Our user-friendly way of embedding proteins from a FASTA file with ESM2 models |
| genslm_embed | Code to generate GenSLM ORF and genome embeddings |
| hyena-dna-embed | Code to generate Hyena-DNA genome embeddings |
| PST_host_prediction | Model and evaluation code for our host prediction proof of concept analysis |
Please cite our preprint if you find our work useful:
Martin C, Gitter A, Anantharaman K. (2024) "Protein Set Transformer: A protein-based genome language model to power high diversity viromics."
@article {
author = {Cody Martin and Anthony Gitter and Karthik Anantharaman},
title = {Protein Set Transformer: A protein-based genome language model to power high diversity viromics},
elocation-id = {2024.07.26.605391},
year = {2024},
doi = {10.1101/2024.07.26.605391},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/10.1101/2024.07.26.605391v1},
eprint = {https://www.biorxiv.org/content/10.1101/2024.07.26.605391v1.full.pdf}
journal = {bioRxiv},
}