Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up data loading process #376

Merged
merged 12 commits into from
Dec 11, 2023
47 changes: 27 additions & 20 deletions openfold/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,17 @@
from multiprocessing import cpu_count
import tempfile
from typing import Mapping, Optional, Sequence, Any, MutableMapping, Union

import subprocess
import numpy as np
import torch

import pickle
from openfold.data import templates, parsers, mmcif_parsing, msa_identifiers, msa_pairing, feature_processing_multimer
from openfold.data.templates import get_custom_template_features, empty_template_feats
from openfold.data.tools import jackhmmer, hhblits, hhsearch, hmmsearch
from openfold.data.tools.utils import to_date
from openfold.np import residue_constants, protein

import concurrent
from concurrent.futures import ThreadPoolExecutor

FeatureDict = MutableMapping[str, np.ndarray]
TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch]
Expand Down Expand Up @@ -735,22 +736,11 @@ def read_msa(start, size):

fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
filename, ext = os.path.splitext(f)

if(ext == ".a3m"):
with open(path, "r") as fp:
msa = parsers.parse_a3m(fp.read())
elif(ext == ".sto" and filename not in ["uniprot_hits", "hmm_output"]):
with open(path, "r") as fp:
msa = parsers.parse_stockholm(
fp.read()
)
else:
continue

msa_data[f] = msa
# Now will split the following steps into multiple processes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we already generated the pkl file, then we should check that it exists before re-parsing the msas. Or does it get removed somewhere?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh also, is there reason we couldn't just call a function to do this instead of running the script with subprocess?

current_directory = os.path.dirname(os.path.abspath(__file__))
cmd = f"{current_directory}/tools/parse_msa_files.py"
msa_data = subprocess.run(['python',cmd, f"--alignment_dir={alignment_dir}"],capture_output=True, text=True)
msa_data = pickle.load((open(msa_data.stdout.lstrip().rstrip(),'rb')))

return msa_data

Expand Down Expand Up @@ -826,6 +816,7 @@ def _process_msa_feats(
input_sequence: Optional[str] = None,
alignment_index: Optional[str] = None
) -> Mapping[str, Any]:

msas = self._get_msas(
alignment_dir, input_sequence, alignment_index
)
Expand Down Expand Up @@ -1216,8 +1207,10 @@ def process_fasta(self,
with open(fasta_path) as f:
input_fasta_str = f.read()


input_seqs, input_descs = parsers.parse_fasta(input_fasta_str)


all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(input_seqs)) == 1
Expand All @@ -1228,6 +1221,7 @@ def process_fasta(self,
)
continue


chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
Expand All @@ -1236,24 +1230,28 @@ def process_fasta(self,
is_homomer_or_monomer=is_homomer_or_monomer
)


chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features


all_chain_features = add_assembly_features(all_chain_features)


np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)


# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)

return np_example

def get_mmcif_features(
self, mmcif_object: mmcif_parsing.MmcifObject, chain_id: str
) -> FeatureDict:
Expand Down Expand Up @@ -1284,18 +1282,21 @@ def process_mmcif(
alignment_index: Optional[str] = None,
) -> FeatureDict:


all_chain_features = {}
sequence_features = {}
is_homomer_or_monomer = len(set(list(mmcif.chain_to_seqres.values()))) == 1
for chain_id, seq in mmcif.chain_to_seqres.items():
desc= "_".join([mmcif.file_id, chain_id])


if seq in sequence_features:
all_chain_features[desc] = copy.deepcopy(
sequence_features[seq]
)
continue


chain_features = self._process_single_chain(
chain_id=desc,
sequence=seq,
Expand All @@ -1304,23 +1305,29 @@ def process_mmcif(
is_homomer_or_monomer=is_homomer_or_monomer
)


chain_features = convert_monomer_features(
chain_features,
chain_id=desc
)


mmcif_feats = self.get_mmcif_features(mmcif, chain_id)
chain_features.update(mmcif_feats)
all_chain_features[desc] = chain_features
sequence_features[seq] = chain_features


all_chain_features = add_assembly_features(all_chain_features)


np_example = feature_processing_multimer.pair_and_merge(
all_chain_features=all_chain_features,
)


# Pad MSA to avoid zero-sized extra_msa.
np_example = pad_msa(np_example, 512)


return np_example
51 changes: 51 additions & 0 deletions openfold/data/tools/parse_msa_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os, argparse, pickle, tempfile, concurrent
from openfold.data import parsers
from concurrent.futures import ProcessPoolExecutor

def parse_stockholm_file(alignment_dir: str, stockholm_file: str):
path = os.path.join(alignment_dir, stockholm_file)
file_name,_ = os.path.splitext(stockholm_file)
with open(path, "r") as infile:
msa = parsers.parse_stockholm(infile.read())
infile.close()
return {file_name: msa}

def parse_a3m_file(alignment_dir: str, a3m_file: str):
path = os.path.join(alignment_dir, a3m_file)
file_name,_ = os.path.splitext(a3m_file)
with open(path, "r") as infile:
msa = parsers.parse_a3m(infile.read())
infile.close()
return {file_name: msa}

def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str):
# Number of workers based on the tasks
msa_results={}
a3m_tasks = [(alignment_dir, f) for f in a3m_files]
sto_tasks = [(alignment_dir, f) for f in stockholm_files]
with ProcessPoolExecutor(max_workers = len(a3m_tasks) + len(sto_tasks)) as executor:
a3m_futures = {executor.submit(parse_a3m_file, *task): task for task in a3m_tasks}
sto_futures = {executor.submit(parse_stockholm_file, *task): task for task in sto_tasks}

for future in concurrent.futures.as_completed(a3m_futures | sto_futures):
try:
result = future.result()
msa_results.update(result)
except Exception as exc:
print(f'Task generated an exception: {exc}')
return msa_results

def main():
parser = argparse.ArgumentParser(description='Process msa files in parallel')
parser.add_argument('--alignment_dir', type=str, help='path to alignment dir')
args = parser.parse_args()
alignment_dir = args.alignment_dir
stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here can you add an exclusion "uniprot_hits" as well? I changed this recently, it is only used for msa pairing.

a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')]
msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir)
with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile:
pickle.dump(msa_data, outfile)
print(outfile.name)

if __name__ == "__main__":
main()