Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single-sequence model #354

Merged
merged 38 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
43e1e5c
Added embedder for handling single-sequence embeddings.
sachinkadyan7 Oct 11, 2022
062a3f0
Added sequence-embedding mode config.
sachinkadyan7 Oct 14, 2022
1e42b70
Added dummy MSA generation for seq-emb mode.
sachinkadyan7 Oct 14, 2022
e6dec86
Added switch in inference flow for using sequence embedding instead o…
sachinkadyan7 Oct 14, 2022
7663b70
Added loading of sequence embeddings in inference flow when in seq_em…
sachinkadyan7 Oct 14, 2022
a718ceb
Added single seq mode in inference script and forwarded to the FASTA …
sachinkadyan7 Oct 14, 2022
ab8ccf2
Added switch for using the single sequence embedder when using the mo…
sachinkadyan7 Oct 18, 2022
432f8c8
Added configuration options for the new PreembeddingEmbedder.
sachinkadyan7 Oct 18, 2022
01c3e20
Added switching off of column attention in evoformer when using seque…
sachinkadyan7 Oct 18, 2022
d4acab8
Added switch in the MMCIF processing pipeline for using sequence embe…
sachinkadyan7 Oct 18, 2022
2e5073d
Added passing of sequence embedding mode flag from `data_modules` to …
sachinkadyan7 Oct 18, 2022
518557a
Added training preset for sequence embedding initial training.
sachinkadyan7 Oct 18, 2022
1ab1004
Added training preset for sequence embedding finetuning training.
sachinkadyan7 Oct 18, 2022
63c5a24
[BUGFIX] Fix an import bug in `data_pipeline.py`
sachinkadyan7 Oct 19, 2022
3e80bbb
Optimized type-changing of features from numpy to torch
sachinkadyan7 Oct 19, 2022
c058b7b
Changed the seq embedding tensor passed to the data pipeline to be a …
sachinkadyan7 Oct 21, 2022
d542dc6
Added the seq_emb features to the list of features to be processed by…
sachinkadyan7 Oct 28, 2022
aceb092
Added a separate AlignmentRunner for handling seq_emb mode.
sachinkadyan7 Oct 28, 2022
ca72982
Added inference model preset for seqemb mode.
sachinkadyan7 Oct 28, 2022
a3fe6c9
Added documentation for some sequence embedding model changes.
sachinkadyan7 Oct 28, 2022
c2c994c
Added switch for sequence embedding mode to the PDB file pipeline.
sachinkadyan7 Oct 29, 2022
0ac23e4
Fix for a bug in data_transforms which wouldn't allow creation of MSA…
sachinkadyan7 Nov 1, 2022
c4aded6
Added config presets for esm1b model inference
sachinkadyan7 Sep 12, 2023
57bf182
Properly reading the embedding file
sachinkadyan7 Sep 12, 2023
f612689
Renamed `preembedding_embedder` to `input_embedder`
sachinkadyan7 Sep 12, 2023
6c9aaf2
Bugfix for timings.json - now store timings per tag.
sachinkadyan7 Sep 13, 2023
ae9bbaa
Added flag in training script for using sequence embeddings
sachinkadyan7 Sep 15, 2023
2c50816
Add sequence embedding mode option to .core file parser
sachinkadyan7 Sep 15, 2023
e7f713e
Default value for --use_single_seq_mode arg
sachinkadyan7 Sep 15, 2023
5047ca4
Added test for PreembeddingEmbedder
sachinkadyan7 Sep 18, 2023
b7e50a1
Added sequence embedding mode test for `model`.
sachinkadyan7 Sep 19, 2023
3be83e8
Added test for no column attention Evoformer
sachinkadyan7 Sep 19, 2023
05a7284
Updated README: Running seqemb model inference
sachinkadyan7 Sep 27, 2023
55fd315
Fix typos
gahdritz Sep 29, 2023
f8d517b
Reduce redundancy in seq embedding config presets
sachinkadyan7 Oct 2, 2023
3162e91
Limit the MSA distillation clusters to 1 in seq mode
sachinkadyan7 Oct 3, 2023
f14e599
Separate out the seq mode configs from vanilla OF config
sachinkadyan7 Oct 3, 2023
e5a44aa
Improved UX: Automatically set the single seq mode flag
sachinkadyan7 Oct 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,28 @@ efficent AlphaFold-Multimer more than double the time. Use the
at once. The `run_pretrained_openfold.py` script can enable this config option with the
`--long_sequence_inference` command line option

#### Single-Sequence Model Inference
To run inference for a sequence using the single-sequence model, first you would need the ESM-1b embedding for the sequence. For this you need to set up the ESM model on your system ([ESM](https://www.github.com/facebookresearch/esm.git)). Once you have the the setup ready, use the following command in the ESM model directory to generate an embedding:

```bash
cd <esm_dir>
python scripts/extract.py esm1b_t33_650M_UR50S <fasta> output_dir --include per_tok
```

Once you have the `*.pt` embedding file, you can place it in that sequence's alignments directory (same as that used by the MSA model of OF). That is, inside the top-level alignments directory, there will be one subdirectory for each sequence you want to run inference on, like so: `alignments_dir/{sequence_id}/{sequence_id}.pt`. You can also place a `*.hhr` file in the same directory, which can contain the details about the structures that you want to use as templates.

Now, you are ready to run inference:
```bash
python run_pretrained_openfold.py \
fasta_dir \
data/pdb_mmcif/mmcif_files/ \
--use_precomputed_alignments alignments_dir \
--output_dir ./ \
--model_device "cuda:0" \
--config_preset "seq_model_esm1b" \
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This preset should be seq_model_esm1b_ptm

--openfold_checkpoint_path openfold/resources/openfold_params/seq_model_esm1b.pt
```

### Training

To train the model, you will first need to precompute protein alignments.
Expand Down
71 changes: 71 additions & 0 deletions openfold/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,42 @@ def model_config(
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
# SINGLE SEQUENCE EMBEDDING PRESETS
elif name == "seqemb_initial_training":
sachinkadyan7 marked this conversation as resolved.
Show resolved Hide resolved
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.max_distillation_msa_clusters = 1
elif name == "seqemb_finetuning":
c.data.train.max_msa_clusters = 1
c.data.eval.max_msa_clusters = 1
c.data.train.max_distillation_msa_clusters = 1
c.data.train.crop_size = 384
c.loss.violation.weight = 1.
c.loss.experimentally_resolved.weight = 0.01
elif name == "seq_model_esm1b":
c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
elif name == "seq_model_esm1b_ptm":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between seq_model_esm1b and seq_model_esm1b_ptm? Could we include this information as a comment here perhaps?

c.data.common.use_templates = True
c.data.common.use_template_torsion_angles = True
c.model.template.enabled = True
c.data.predict.max_msa_clusters = 1
c.model.heads.tm.enabled = True
c.loss.tm.weight = 0.1
else:
raise ValueError("Invalid model name")

if name.startswith("seq"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we change this to be name.startswith("seqemb"), and change the two other presets seq_model_esm1b and seq_model_esm1b_ptm to begin with seq_emb?

This will help in case future presets unrelated to solo seq also begin with 'seq'

# Tell the data pipeline that we will use sequence embeddings instead of MSAs.
c.data.seqemb_mode.enabled = True
c.globals.seqemb_mode_enabled = True
# In seqemb mode, we turn off the ExtraMSAStack and Evoformer's column attention.
c.model.extra_msa.enabled = False
c.model.evoformer_stack.no_column_attention = True
c.update(seq_mode_config.copy_and_resolve_references())

if long_sequence_inference:
assert(not train)
c.globals.offload_inference = True
Expand Down Expand Up @@ -189,6 +222,11 @@ def model_config(
c_t = mlc.FieldReference(64, field_type=int)
c_e = mlc.FieldReference(64, field_type=int)
c_s = mlc.FieldReference(384, field_type=int)

# For seqemb mode, dimension size of the per-residue sequence embedding passed to the model
# In current model, the dimension size is the ESM-1b dimension size i.e. 1280.
preemb_dim_size = mlc.FieldReference(1280, field_type=int)

blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
Expand Down Expand Up @@ -301,6 +339,9 @@ def model_config(
"use_templates": templates_enabled,
"use_template_torsion_angles": embed_template_torsion_angles,
},
"seqemb_mode": { # Configuration for sequence embedding mode
"enabled": False, # If True, use seq emb instead of MSA
},
"supervised": {
"clamp_prob": 0.9,
"supervised_features": [
Expand Down Expand Up @@ -365,6 +406,7 @@ def model_config(
},
# Recurring FieldReferences that can be changed globally here
"globals": {
"seqemb_mode_enabled": False, # Global flag for enabling seq emb mode
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
# Use Staats & Rabe's low-memory attention algorithm. Mutually
Expand Down Expand Up @@ -493,6 +535,7 @@ def model_config(
"transition_n": 4,
"msa_dropout": 0.15,
"pair_dropout": 0.25,
"no_column_attention": False,
"blocks_per_ckpt": blocks_per_ckpt,
"clear_cache_between_blocks": False,
"tune_chunk_size": tune_chunk_size,
Expand Down Expand Up @@ -614,3 +657,31 @@ def model_config(
"ema": {"decay": 0.999},
}
)

seq_mode_config = mlc.ConfigDict({
"data": {
"common": {
"feat": {
"seq_embedding": [NUM_RES, None],
},
"seqemb_features": [ # List of features to be generated in seqemb mode
"seq_embedding"
],
},
"seqemb_mode": { # Configuration for sequence embedding mode
"enabled": True, # If True, use seq emb instead of MSA
},
},
"globals": {
"seqemb_mode_enabled": True,
},
"model": {
"preembedding_embedder": { # Used in sequence embedding mode
"tf_dim": 22,
"preembedding_dim": preemb_dim_size,
"c_z": c_z,
"c_m": c_m,
"relpos_k": 32,
},
}
})
6 changes: 5 additions & 1 deletion openfold/data/data_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
alignment_index=alignment_index
alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled
)

return data
Expand Down Expand Up @@ -239,6 +240,7 @@ def __getitem__(self, idx):
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled,
)
elif(ext == ".pdb"):
structure_index = None
Expand All @@ -251,6 +253,7 @@ def __getitem__(self, idx):
chain_id=chain_id,
alignment_index=alignment_index,
_structure_index=structure_index,
seqemb_mode=self.config.seqemb_mode.enabled,
)
else:
raise ValueError("Extension branch missing")
Expand All @@ -260,6 +263,7 @@ def __getitem__(self, idx):
fasta_path=path,
alignment_dir=alignment_dir,
alignment_index=alignment_index,
seqemb_mode=self.config.seqemb_mode.enabled,
)

if(self._output_raw):
Expand Down
76 changes: 67 additions & 9 deletions openfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Mapping, Optional, Sequence, Any

import numpy as np
import torch

from openfold.data import templates, parsers, mmcif_parsing
from openfold.data.templates import get_custom_template_features
Expand Down Expand Up @@ -260,6 +261,18 @@ def make_msa_features(
return features


# Generate 1-sequence MSA features having only the input sequence
def make_dummy_msa_feats(input_sequence):
msas = [[input_sequence]]
deletion_matrices = [[[0 for _ in input_sequence]]]
msa_features = make_msa_features(
msas=msas,
deletion_matrices=deletion_matrices,
)

return msa_features


def make_sequence_features_with_custom_template(
sequence: str,
mmcif_path: str,
Expand Down Expand Up @@ -627,11 +640,28 @@ def _process_msa_feats(

return msa_features

# Load and process sequence embedding features
def _process_seqemb_features(self,
alignment_dir: str,
) -> Mapping[str, Any]:
seqemb_features = {}
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]

if (ext == ".pt"):
# Load embedding file
seqemb_data = torch.load(path)
seqemb_features["seq_embedding"] = seqemb_data["representations"][33]

return seqemb_features

def process_fasta(
self,
fasta_path: str,
alignment_dir: str,
alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
Expand All @@ -658,12 +688,19 @@ def process_fasta(
num_res=num_res,
)

msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)

return {
**sequence_features,
**msa_features,
**template_features
**template_features,
**sequence_embedding_features,
}

def process_mmcif(
Expand All @@ -672,6 +709,7 @@ def process_mmcif(
alignment_dir: str,
chain_id: Optional[str] = None,
alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
Expand All @@ -696,10 +734,16 @@ def process_mmcif(
self.template_featurizer,
query_release_date=to_date(mmcif.header["release_date"])
)

msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)

return {**mmcif_feats, **template_features, **msa_features}
sequence_embedding_features = {}
# If using seqemb mode, generate a dummy MSA features using just the sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)

return {**mmcif_feats, **template_features, **msa_features, **sequence_embedding_features}

def process_pdb(
self,
Expand All @@ -709,6 +753,7 @@ def process_pdb(
chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
Expand Down Expand Up @@ -742,15 +787,22 @@ def process_pdb(
self.template_featurizer,
)

msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
sequence_embedding_features = {}
# If in sequence embedding mode, generate dummy MSA features using just the input sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)

return {**pdb_feats, **template_features, **msa_features}
return {**pdb_feats, **template_features, **msa_features, **sequence_embedding_features}

def process_core(
self,
core_path: str,
alignment_dir: str,
alignment_index: Optional[str] = None,
seqemb_mode: bool = False,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
Expand All @@ -770,9 +822,15 @@ def process_core(
self.template_featurizer,
)

msa_features = self._process_msa_feats(alignment_dir, input_sequence)
sequence_embedding_features = {}
# If in sequence embedding mode, generate dummy MSA features using just the input sequence
if seqemb_mode:
msa_features = make_dummy_msa_feats(input_sequence)
sequence_embedding_features = self._process_seqemb_features(alignment_dir)
else:
msa_features = self._process_msa_feats(alignment_dir, input_sequence)

return {**core_feats, **template_features, **msa_features}
return {**core_feats, **template_features, **msa_features, **sequence_embedding_features}

def process_multiseq_fasta(self,
fasta_path: str,
Expand Down
10 changes: 8 additions & 2 deletions openfold/data/feature_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ def np_to_tensor_dict(
Returns:
A dictionary of features mapping feature names to features. Only the given
features are returned, all other ones are filtered out.
"""
"""
# torch generates warnings if feature is already a torch Tensor
to_tensor = lambda t: torch.tensor(t) if type(t) != torch.Tensor else t.clone().detach()
tensor_dict = {
k: torch.tensor(v) for k, v in np_example.items() if k in features
k: to_tensor(v) for k, v in np_example.items() if k in features
}

return tensor_dict
Expand All @@ -61,6 +63,10 @@ def make_data_config(

feature_names = cfg.common.unsupervised_features

# Add seqemb related features if using seqemb mode.
if cfg.seqemb_mode.enabled:
feature_names += cfg.common.seqemb_features

if cfg.common.use_templates:
feature_names += cfg.common.template_features

Expand Down