# Predict Structures for Monomers using OmegaFold

## References

1.   OmegaFold [[Pre-Print](https://www.biorxiv.org/content/10.1101/2022.07.21.500999v1)], [[GitHub](https://github.com/HeliXonProtein/OmegaFold)]
2.   ColabFold [[Paper](https://www.nature.com/articles/s41592-022-01488-1)], [GitHub](https://github.com/sokrypton/ColabFold)
3. [The ColabFold Implementation of OmegaFold](https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/beta/omegafold.ipynb)
4. [Why should one use OmegaFold over AlphaFold or ESMFold?](https://310.ai/2023/05/17/benchmarking-machine-learning-methods-for-protein-folding-a-comparative-study-of-esmfold-omegafold-and-alphafold/)

TLDR:

*   OmegaFold is faster than methods that rely on MSA (AlphaFold, RoseTTAFold)
*   OmegaFold's accuracy is competitive with AlphaFold and RoseTTAFold, all of which beat ESMFold.
* ESMFold's advantage is speed.





## Imports
* Had to downgrade biopython to 1.81

In [None]:
import os,sys,re
from IPython.utils import io
if "SETUP_DONE" not in dir():
  import torch
  device = "cuda" if torch.cuda.is_available() else "cpu"
  with io.capture_output() as captured:
    if not os.path.isdir("OmegaFold"):
      %shell git clone --branch beta --quiet https://github.com/sokrypton/OmegaFold.git
      # %shell cd OmegaFold; pip -q install -r requirements.txt
      %shell pip -q install py3Dmol biopython==1.81
      %shell apt-get install aria2 -qq > /dev/null
      %shell aria2c -q -x 16 https://helixon.s3.amazonaws.com/release1.pt
      %shell mkdir -p ~/.cache/omegafold_ckpt
      %shell mv release1.pt ~/.cache/omegafold_ckpt/model.pt
  SETUP_DONE = True

In [None]:
import gc
from Bio import SeqIO
from google.colab import drive
import os, sys, re, hashlib
from IPython.utils import io
from string import ascii_uppercase, ascii_lowercase

In [None]:
# Mount Google Drive
drive.mount('/content/drive')

## Helper Functions

In [None]:
# Function to generate job ID and hash
def get_hash(x):
    return hashlib.sha1(x.encode()).hexdigest()

def get_subbatch_size(L):
    if L <  500: return 500
    if L < 1000: return 200
    return 150

def renum_pdb_str(pdb_str, Ls=None, renum=True, offset=1):
    if Ls is not None:
        L_init = 0
        new_chain = {}
        for L, c in zip(Ls, ascii_uppercase + ascii_lowercase):
            new_chain.update({i: c for i in range(L_init, L_init + L)})
            L_init += L

    n, num, pdb_out = 0, offset, []
    resnum_ = None
    chain_ = None
    new_chain_ = new_chain[0]
    for line in pdb_str.split("\n"):
        if line[:4] == "ATOM":
            chain = line[21:22]
            resnum = int(line[22:26])
            if resnum_ is None: resnum_ = resnum
            if chain_ is None: chain_ = chain
            if resnum != resnum_ or chain != chain_:
                num += (resnum - resnum_)
                n += 1
                resnum_, chain_ = resnum, chain
            if Ls is not None:
                if new_chain[n] != new_chain_:
                    num = offset
                    new_chain_ = new_chain[n]
            N = num if renum else resnum
            pdb_out.append("%s%s%4i%s" % (line[:21], new_chain[n] if Ls is not None else '', N, line[26:]))
    return "\n".join(pdb_out)

## Main Function to get prediction

In [None]:
def process_sequence(jobname, sequence, num_cycle, offset_rope, device, output_dir):
    ID = jobname + "_" + get_hash(sequence)[:5]
    seqs = sequence.split(":")
    lengths = [len(s) for s in seqs]
    subbatch_size = get_subbatch_size(sum(lengths))

    with open(f"{ID}.fasta", "w") as out:
        out.write(f">{ID}\n{sequence}\n")

    # Running OmegaFold prediction
    %shell python OmegaFold/main.py --offset_rope={offset_rope} --device={device} --subbatch_size={subbatch_size} --num_cycle={num_cycle} {ID}.fasta .

    # Renumber PDB file
    pdb_file_path = os.path.join(output_dir, f"{jobname}.pdb")
    pdb_str = renum_pdb_str(open(f"{ID}.pdb", 'r').read(), Ls=lengths)
    with open(pdb_file_path, "w") as out:
        out.write(pdb_str)

    # Clear memory
    del pdb_str
    torch.cuda.empty_cache()
    gc.collect()

## Read Files and define paths

In [None]:
# Directory in Google Drive to save PDB files and checkpoints
output_dir = '/content/drive/MyDrive/omegafold_predicted_structures'
os.makedirs(output_dir, exist_ok=True)

Added a provision for checkpointing to start from where you left off in case you run out of memory or runtime expires.

In [None]:
# Checkpoint file to keep track of processed sequences
checkpoint_file = os.path.join(output_dir, 'checkpoint.txt')

# Read the checkpoint
start_from = 0
if os.path.exists(checkpoint_file):
    with open(checkpoint_file, 'r') as file:
        start_from = int(file.read().strip())

It is assumed that the input fasta file is of the format shown below. Tweak it according to your use case.

\>Sequence0 TARGET=1.0 SET=test VALIDATION=False
MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLC

\>Sequence1 TARGET=1.4459050863 SET=test VALIDATION=False
MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGIDGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLC

In [None]:
fasta_file = '/content/drive/MyDrive/input_sequences.fasta'  # Replace with your file path
sequences = []
sequence_names = []
for record in SeqIO.parse(fasta_file, "fasta"):
    # This code is for the following fasta file format. One example shown here
    # >Sequence0 TARGET=1.0 SET=test VALIDATION=False
    # MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLCEVARKLGTDDREVVLFLLNVFIPQPTLAQLIGALRALK
    sequence_names.append(record.id.split()[0])
    sequences.append(str(record.seq))

OmegaFold uses [Rotary Position Embeddings (RoPE)](https://arxiv.org/abs/2104.09864) You can use the `offset_rope` parameter to introduce bias in attention formulation probably as a form of regularization or more so to emphasize the impact of relative position information.

In [None]:
num_cycle = 4
offset_rope = False #OmegaFold uses RoPE

## Make your predictions!

In [None]:
# Process each sequence and manage memory
for i, (name, seq) in enumerate(zip(sequence_names, sequences)):
    if i<=start_from:
        continue
    #print(f"Processing sequence {i}: {name}")
    jobname = name
    process_sequence(jobname, seq, num_cycle, offset_rope, device, output_dir)

    # Update checkpoint
    with open(checkpoint_file, 'w') as file:
        file.write(str(i))

    # Optional: Free up memory after every sequence
    torch.cuda.empty_cache()
    gc.collect()