-
Notifications
You must be signed in to change notification settings - Fork 460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed up data loading process #376
Merged
christinaflo
merged 12 commits into
aqlaboratory:multimer
from
dingquanyu:speedup-dataloader
Dec 11, 2023
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
aec1276
added timing steps
6f3e0c0
now used asynchronised version in parse_msa_data
2e1941a
now using multiprocessing style
c3c627e
now run in a subprocess
2204bbb
fixed errors when running in subprocess
4e58a6a
now use ThreadPoolExecutor
53c03a6
update config.py for the development for now
e72e4e6
remove debugging statement
28b9e2b
moved pase_msa_file into tools subfolder
08bfb1f
reverse back to multimer branch version
78ecfc6
remove unnecessary imports and statements
6f26b0a
Merge branch 'multimer' into speedup-dataloader
dingquanyu File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os, argparse, pickle, tempfile, concurrent | ||
from openfold.data import parsers | ||
from concurrent.futures import ProcessPoolExecutor | ||
|
||
def parse_stockholm_file(alignment_dir: str, stockholm_file: str): | ||
path = os.path.join(alignment_dir, stockholm_file) | ||
file_name,_ = os.path.splitext(stockholm_file) | ||
with open(path, "r") as infile: | ||
msa = parsers.parse_stockholm(infile.read()) | ||
infile.close() | ||
return {file_name: msa} | ||
|
||
def parse_a3m_file(alignment_dir: str, a3m_file: str): | ||
path = os.path.join(alignment_dir, a3m_file) | ||
file_name,_ = os.path.splitext(a3m_file) | ||
with open(path, "r") as infile: | ||
msa = parsers.parse_a3m(infile.read()) | ||
infile.close() | ||
return {file_name: msa} | ||
|
||
def run_parse_all_msa_files_multiprocessing(stockholm_files: list, a3m_files: list, alignment_dir:str): | ||
# Number of workers based on the tasks | ||
msa_results={} | ||
a3m_tasks = [(alignment_dir, f) for f in a3m_files] | ||
sto_tasks = [(alignment_dir, f) for f in stockholm_files] | ||
with ProcessPoolExecutor(max_workers = len(a3m_tasks) + len(sto_tasks)) as executor: | ||
a3m_futures = {executor.submit(parse_a3m_file, *task): task for task in a3m_tasks} | ||
sto_futures = {executor.submit(parse_stockholm_file, *task): task for task in sto_tasks} | ||
|
||
for future in concurrent.futures.as_completed(a3m_futures | sto_futures): | ||
try: | ||
result = future.result() | ||
msa_results.update(result) | ||
except Exception as exc: | ||
print(f'Task generated an exception: {exc}') | ||
return msa_results | ||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description='Process msa files in parallel') | ||
parser.add_argument('--alignment_dir', type=str, help='path to alignment dir') | ||
args = parser.parse_args() | ||
alignment_dir = args.alignment_dir | ||
stockholm_files = [i for i in os.listdir(alignment_dir) if (i.endswith('.sto') and ("hmm_output" not in i))] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here can you add an exclusion "uniprot_hits" as well? I changed this recently, it is only used for msa pairing. |
||
a3m_files = [i for i in os.listdir(alignment_dir) if i.endswith('.a3m')] | ||
msa_data = run_parse_all_msa_files_multiprocessing(stockholm_files, a3m_files, alignment_dir) | ||
with tempfile.NamedTemporaryFile('wb', suffix='.pkl', delete=False) as outfile: | ||
pickle.dump(msa_data, outfile) | ||
print(outfile.name) | ||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we already generated the pkl file, then we should check that it exists before re-parsing the msas. Or does it get removed somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh also, is there reason we couldn't just call a function to do this instead of running the script with subprocess?