### Step by step run of run_pretrained_openfold.py
investigating new conformations in latent space

In [None]:
# imports

import argparse
import logging
import math
import numpy as np
import os

from openfold.utils.script_utils import load_models_from_command_line, parse_fasta, run_model, prep_output, \
    update_timings, relax_protein

logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)

import pickle

import random
import time
import torch

torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
if(
    torch_major_version > 1 or 
    (torch_major_version == 1 and torch_minor_version >= 12)
):
    # Gives a large speedup on Ampere-class GPUs
    torch.set_float32_matmul_precision("high")

torch.set_grad_enabled(False)

from openfold.config import model_config
from openfold.data import templates, feature_pipeline, data_pipeline
from openfold.np import residue_constants, protein
import openfold.np.relax.relax as relax

from openfold.utils.tensor_utils import (
    tensor_tree_map,
)
from openfold.utils.trace_utils import (
    pad_feature_dict_seq,
    trace_model_,
)
from scripts.utils import add_data_args


TRACING_INTERVAL = 50

In [None]:
# tools

def precompute_alignments(tags, seqs, alignment_dir, args):
    for tag, seq in zip(tags, seqs):
        tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        if(args.use_precomputed_alignments is None and not os.path.isdir(local_alignment_dir)):
            logger.info(f"Generating alignments for {tag}...")
                
            os.makedirs(local_alignment_dir)

            alignment_runner = data_pipeline.AlignmentRunner(
                jackhmmer_binary_path=args.jackhmmer_binary_path,
                hhblits_binary_path=args.hhblits_binary_path,
                hhsearch_binary_path=args.hhsearch_binary_path,
                uniref90_database_path=args.uniref90_database_path,
                mgnify_database_path=args.mgnify_database_path,
                bfd_database_path=args.bfd_database_path,
                uniclust30_database_path=args.uniclust30_database_path,
                pdb70_database_path=args.pdb70_database_path,
                no_cpus=args.cpus,
            )
            alignment_runner.run(
                tmp_fasta_path, local_alignment_dir
            )
        else:
            logger.info(
                f"Using precomputed alignments for {tag} at {alignment_dir}..."
            )

        # Remove temporary FASTA file
        os.remove(tmp_fasta_path)


def round_up_seqlen(seqlen):
    return int(math.ceil(seqlen / TRACING_INTERVAL)) * TRACING_INTERVAL


def generate_feature_dict(
    tags,
    seqs,
    alignment_dir,
    data_processor,
    args,
):
    tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
    if len(seqs) == 1:
        tag = tags[0]
        seq = seqs[0]
        with open(tmp_fasta_path, "w") as fp:
            fp.write(f">{tag}\n{seq}")

        local_alignment_dir = os.path.join(alignment_dir, tag)
        feature_dict = data_processor.process_fasta(
            fasta_path=tmp_fasta_path, alignment_dir=local_alignment_dir
        )
    else:
        with open(tmp_fasta_path, "w") as fp:
            fp.write(
                '\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
            )
        feature_dict = data_processor.process_multiseq_fasta(
            fasta_path=tmp_fasta_path, super_alignment_dir=alignment_dir,
        )

    # Remove temporary FASTA file
    os.remove(tmp_fasta_path)

    return feature_dict

def list_files_with_extensions(dir, extensions):
    return [f for f in os.listdir(dir) if f.endswith(extensions)]



In [None]:
# argparser emulation
class pseudo_parser:
    
self.fasta_dir
parser.add_argument(
    "fasta_dir", type=str,
    help="Path to directory containing FASTA files, one sequence per file"
)
self.template_mmcif_dir
parser.add_argument(
    "template_mmcif_dir", type=str,
)
self.use_precomputed_alignments
parser.add_argument(
    "--use_precomputed_alignments", type=str, default=None,
    help="""Path to alignment directory. If provided, alignment computation 
            is skipped and database path arguments are ignored."""
)
self.output_dir
parser.add_argument(
    "--output_dir", type=str, default=os.getcwd(),
    help="""Name of the directory in which to output the prediction""",
)
self.model_device
parser.add_argument(
    "--model_device", type=str, default="cpu",
    help="""Name of the device on which to run the model. Any valid torch
            device name is accepted (e.g. "cpu", "cuda:0")"""
)
self.config_preset
parser.add_argument(
    "--config_preset", type=str, default="model_1",
    help="""Name of a model config preset defined in openfold/config.py"""
)
self.jax_param_path
parser.add_argument(
    "--jax_param_path", type=str, default=None,
    help="""Path to JAX model parameters. If None, and openfold_checkpoint_path
            is also None, parameters are selected automatically according to 
            the model name from openfold/resources/params"""
)
self.openfold_checkpoint_path
parser.add_argument(
    "--openfold_checkpoint_path", type=str, default=None,
    help="""Path to OpenFold checkpoint. Can be either a DeepSpeed 
            checkpoint directory or a .pt file"""
)
self.save_outputs
parser.add_argument(
    "--save_outputs", action="store_true", default=False,
    help="Whether to save all model outputs, including embeddings, etc."
)
self.cpus
parser.add_argument(
    "--cpus", type=int, default=4,
    help="""Number of CPUs with which to run alignment tools"""
)
self.preset
parser.add_argument(
    "--preset", type=str, default='full_dbs',
    choices=('reduced_dbs', 'full_dbs')
)
self.output_postfix
parser.add_argument(
    "--output_postfix", type=str, default=None,
    help="""Postfix for output prediction filenames"""
)
self.data_random_seed
parser.add_argument(
    "--data_random_seed", type=str, default=None
)
self.skip_relaxation
parser.add_argument(
    "--skip_relaxation", action="store_true", default=False,
)
self.multimer_ri_gap
parser.add_argument(
    "--multimer_ri_gap", type=int, default=200,
    help="""Residue index offset between multiple sequences, if provided"""
)
self.trace_model
parser.add_argument(
    "--trace_model", action="store_true", default=False,
    help="""Whether to convert parts of each model to TorchScript.
            Significantly improves runtime at the cost of lengthy
            'compilation.' Useful for large batch jobs."""
)
self.subtract_plddt
parser.add_argument(
    "--subtract_plddt", action="store_true", default=False,
    help=""""Whether to output (100 - pLDDT) in the B-factor column instead
                of the pLDDT itself"""
)
self.p = True
parser.add_argument(
    '-p', action='store_true', default=False,
    help="""Trigger print statements during runtime. Print statements 
    describe each step in the process"""
)



self.uniref90_database_path
parser.add_argument(
    '--uniref90_database_path', type=str, default=None,
)
self.mgnify_database_path
parser.add_argument(
    '--mgnify_database_path', type=str, default=None,
)
self.pdb70_database_path
parser.add_argument(
    '--pdb70_database_path', type=str, default=None,
)
self.uniclust30_database_path
parser.add_argument(
    '--uniclust30_database_path', type=str, default=None,
)
self.bfd_database_path
parser.add_argument(
    '--bfd_database_path', type=str, default=None,
)
self.jackhmmer_binary_path
parser.add_argument(
    '--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
)
self.hhblits_binary_path
parser.add_argument(
    '--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
)
self.hhsearch_binary_path
parser.add_argument(
    '--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
)
self.kalign_binary_path
parser.add_argument(
    '--kalign_binary_path', type=str, default='/usr/bin/kalign'
)
self.max_template_date
parser.add_argument(
    '--max_template_date', type=str,
    default=date.today().strftime("%Y-%m-%d"),
)
self.obsolete_pdbs_path
parser.add_argument(
    '--obsolete_pdbs_path', type=str, default=None
)
self.release_dates_path
parser.add_argument(
    '--release_dates_path', type=str, default=None
)

args = pseudo_parser()

if(args.jax_param_path is None and args.openfold_checkpoint_path is None):
    args.jax_param_path = os.path.join(
        "openfold", "resources", "params", 
        "params_" + args.config_preset + ".npz"
    )

if(args.model_device == "cpu" and torch.cuda.is_available()):
    logging.warning(
        """The model is being run on CPU. Consider specifying 
        --model_device for better performance"""
    )

#main(args)

In [None]:
args.template_mmcif_dir

# main