In [126]:
import os, time, signal
import sys, random, string, re
import subprocess
import shutil
from tqdm import tqdm
from typing import Optional

from Bio.PDB import PDBIO, PDBParser

import glob
import numpy as np
import pandas as pd

from colabdesign.rf.utils import get_ca
from colabdesign.rf.utils import fix_contigs, fix_partial_contigs, fix_pdb, sym_it
from colabdesign.shared.protein import pdb_to_string
from colabdesign.shared.plot import plot_pseudo_3D

In [127]:
#input_pdb = "/home/tsatler/RFdif/ClusterProteinDesign/scripts/binder_design/examples/partial_diff/5fmv_domain3-4_5_3H_13_4_2_5_24_6.pdb"

input = "TEVp-test"
out_dir= f"output/{input}/partial_diff"

### Functions

In [128]:
# dictionaries
num2aa=[
    'ALA','ARG','ASN','ASP','CYS',
    'GLN','GLU','GLY','HIS','ILE',
    'LEU','LYS','MET','PHE','PRO',
    'SER','THR','TRP','TYR','VAL',
    'UNK','MAS',
    ]

one_letter = ["A", "R", "N", "D", "C", \
             "Q", "E", "G", "H", "I", \
             "L", "K", "M", "F", "P", \
             "S", "T", "W", "Y", "V", "?", "-"]

aa2long=[
    (" N  "," CA "," C  "," O  "," CB ",  None,  None,  None,  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","3HB ",  None,  None,  None,  None,  None,  None,  None,  None), # ala
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD "," NE "," CZ "," NH1"," NH2",  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD "," HE ","1HH1","2HH1","1HH2","2HH2"), # arg
    (" N  "," CA "," C  "," O  "," CB "," CG "," OD1"," ND2",  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HD2","2HD2",  None,  None,  None,  None,  None,  None,  None), # asn
    (" N  "," CA "," C  "," O  "," CB "," CG "," OD1"," OD2",  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ",  None,  None,  None,  None,  None,  None,  None,  None,  None), # asp
    (" N  "," CA "," C  "," O  "," CB "," SG ",  None,  None,  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB "," HG ",  None,  None,  None,  None,  None,  None,  None,  None), # cys
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD "," OE1"," NE2",  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ","1HE2","2HE2",  None,  None,  None,  None,  None), # gln
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD "," OE1"," OE2",  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ",  None,  None,  None,  None,  None,  None,  None), # glu
    (" N  "," CA "," C  "," O  ",  None,  None,  None,  None,  None,  None,  None,  None,  None,  None," H  ","1HA ","2HA ",  None,  None,  None,  None,  None,  None,  None,  None,  None,  None), # gly
    (" N  "," CA "," C  "," O  "," CB "," CG "," ND1"," CD2"," CE1"," NE2",  None,  None,  None,  None," H  "," HA ","1HB ","2HB "," HD2"," HE1"," HE2",  None,  None,  None,  None,  None,  None), # his
    (" N  "," CA "," C  "," O  "," CB "," CG1"," CG2"," CD1",  None,  None,  None,  None,  None,  None," H  "," HA "," HB ","1HG2","2HG2","3HG2","1HG1","2HG1","1HD1","2HD1","3HD1",  None,  None), # ile
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD1"," CD2",  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB "," HG ","1HD1","2HD1","3HD1","1HD2","2HD2","3HD2",  None,  None), # leu
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD "," CE "," NZ ",  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ","1HE ","2HE ","1HZ ","2HZ ","3HZ "), # lys
    (" N  "," CA "," C  "," O  "," CB "," CG "," SD "," CE ",  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","1HG ","2HG ","1HE ","2HE ","3HE ",  None,  None,  None,  None), # met
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ ",  None,  None,  None," H  "," HA ","1HB ","2HB "," HD1"," HD2"," HE1"," HE2"," HZ ",  None,  None,  None,  None), # phe
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD ",  None,  None,  None,  None,  None,  None,  None," HA ","1HB ","2HB ","1HG ","2HG ","1HD ","2HD ",  None,  None,  None,  None,  None,  None), # pro
    (" N  "," CA "," C  "," O  "," CB "," OG ",  None,  None,  None,  None,  None,  None,  None,  None," H  "," HG "," HA ","1HB ","2HB ",  None,  None,  None,  None,  None,  None,  None,  None), # ser
    (" N  "," CA "," C  "," O  "," CB "," OG1"," CG2",  None,  None,  None,  None,  None,  None,  None," H  "," HG1"," HA "," HB ","1HG2","2HG2","3HG2",  None,  None,  None,  None,  None,  None), # thr
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD1"," CD2"," NE1"," CE2"," CE3"," CZ2"," CZ3"," CH2"," H  "," HA ","1HB ","2HB "," HD1"," HE1"," HZ2"," HH2"," HZ3"," HE3",  None,  None,  None), # trp
    (" N  "," CA "," C  "," O  "," CB "," CG "," CD1"," CD2"," CE1"," CE2"," CZ "," OH ",  None,  None," H  "," HA ","1HB ","2HB "," HD1"," HE1"," HE2"," HD2"," HH ",  None,  None,  None,  None), # tyr
    (" N  "," CA "," C  "," O  "," CB "," CG1"," CG2",  None,  None,  None,  None,  None,  None,  None," H  "," HA "," HB ","1HG1","2HG1","3HG1","1HG2","2HG2","3HG2",  None,  None,  None,  None), # val
    (" N  "," CA "," C  "," O  "," CB ",  None,  None,  None,  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","3HB ",  None,  None,  None,  None,  None,  None,  None,  None), # unk
    (" N  "," CA "," C  "," O  "," CB ",  None,  None,  None,  None,  None,  None,  None,  None,  None," H  "," HA ","1HB ","2HB ","3HB ",  None,  None,  None,  None,  None,  None,  None,  None), # mask
]

aa2num= {x:i for i,x in enumerate(num2aa)}

aa_321 = {a:b for a,b in zip(num2aa,one_letter)}
aa_123 = {val:key for key,val in aa_321.items()}

# Sokrypton's code
MODRES = {'MSE':'MET','MLY':'LYS','FME':'MET','HYP':'PRO',
          'TPO':'THR','CSO':'CYS','SEP':'SER','M3L':'LYS',
          'HSK':'HIS','SAC':'SER','PCA':'GLU','DAL':'ALA',
          'CME':'CYS','CSD':'CYS','OCS':'CYS','DPR':'PRO',
          'B3K':'LYS','ALY':'LYS','YCM':'CYS','MLZ':'LYS',
          '4BF':'TYR','KCX':'LYS','B3E':'GLU','B3D':'ASP',
          'HZP':'PRO','CSX':'CYS','BAL':'ALA','HIC':'HIS',
          'DBZ':'ALA','DCY':'CYS','DVA':'VAL','NLE':'LEU',
          'SMC':'CYS','AGM':'ARG','B3A':'ALA','DAS':'ASP',
          'DLY':'LYS','DSN':'SER','DTH':'THR','GL3':'GLY',
          'HY3':'PRO','LLP':'LYS','MGN':'GLN','MHS':'HIS',
          'TRQ':'TRP','B3Y':'TYR','PHI':'PHE','PTR':'TYR',
          'TYS':'TYR','IAS':'ASP','GPL':'LYS','KYN':'TRP',
          'CSD':'CYS','SEC':'CYS'}

restype_1to3 = {
    'A': 'ALA',
    'R': 'ARG',
    'N': 'ASN',
    'D': 'ASP',
    'C': 'CYS',
    'Q': 'GLN',
    'E': 'GLU',
    'G': 'GLY',
    'H': 'HIS',
    'I': 'ILE',
    'L': 'LEU',
    'K': 'LYS',
    'M': 'MET',
    'F': 'PHE',
    'P': 'PRO',
    'S': 'SER',
    'T': 'THR',
    'W': 'TRP',
    'Y': 'TYR',
    'V': 'VAL',
}

restype_3to1 = {v: k for k, v in restype_1to3.items()}

def pdb_to_string(
        pdb_file: str, 
        chains: Optional[str] = None, 
        models: Optional[list] = None
    ) -> str:
  '''read pdb file and return as string'''

  if chains is not None:
    if "," in chains: chains = chains.split(",")
    if not isinstance(chains,list): chains = [chains]
  if models is not None:
    if not isinstance(models,list): models = [models]

  modres = {**MODRES}
  lines = []
  seen = []
  model = 1

  if "\n" in pdb_file:
    old_lines = pdb_file.split("\n")
  else:
    with open(pdb_file,"rb") as f:
      old_lines = [line.decode("utf-8","ignore").rstrip() for line in f]  
  for line in old_lines:
    if line[:5] == "MODEL":
      model = int(line[5:])
    if models is None or model in models:
      if line[:6] == "MODRES":
        k = line[12:15]
        v = line[24:27]
        if k not in modres and v in restype_3to1:
          modres[k] = v
      if line[:6] == "HETATM":
        k = line[17:20]
        if k in modres:
          line = "ATOM  "+line[6:17]+modres[k]+line[20:]
      if line[:4] == "ATOM":
        chain = line[21:22]
        if chains is None or chain in chains:
          atom = line[12:12+4].strip()
          resi = line[17:17+3]
          resn = line[22:22+5].strip()
          if resn[-1].isalpha(): # alternative atom
            resn = resn[:-1]
            line = line[:26]+" "+line[27:]
          key = f"{model}_{chain}_{resn}_{resi}_{atom}"
          if key not in seen: # skip alternative placements
            lines.append(line)
            seen.append(key)
      if line[:5] == "MODEL" or line[:3] == "TER" or line[:6] == "ENDMDL":
        lines.append(line)
  return "\n".join(lines)

# RFdif
def parse_pdb(filename, **kwargs):
  """extract xyz coords for all heavy atoms"""
  lines = open(filename, "r").readlines()
  return parse_pdb_lines(lines, **kwargs)


def get_pdb_seq(lines) -> str:
    """get the sequence from a pdb file"""
    seq = ""
    for line in lines:
        if line[:4] == "ATOM":
            if line[13:15] == "CA":
                seq += line[17:20]
    return seq

def parse_pdb_lines(lines, parse_hetatom=False, ignore_het_h=True):
    # indices of residues observed in the structure
    res, pdb_idx = [],[]
    for l in lines:
      if l[:4] == "ATOM" and l[12:16].strip() == "CA":
        res.append((l[22:26], l[17:20]))
        # chain letter, res num
        pdb_idx.append((l[21:22].strip(), int(l[22:26].strip())))
    seq = [aa2num[r[1]] if r[1] in aa2num.keys() else 20 for r in res]

    # 4 BB + up to 10 SC atoms
    xyz = np.full((len(res), 14, 3), np.nan, dtype=np.float32)
    for l in lines:
        if l[:4] != "ATOM":
            continue
        chain, resNo, atom, aa = (
            l[21:22],
            int(l[22:26]),
            " " + l[12:16].strip().ljust(3),
            l[17:20],
        )
        if (chain,resNo) in pdb_idx:
          idx = pdb_idx.index((chain, resNo))
          # for i_atm, tgtatm in enumerate(util.aa2long[util.aa2num[aa]]):
          for i_atm, tgtatm in enumerate(
              aa2long[aa2num[aa]][:14]
          ):  # Nate's proposed change
              if (
                  tgtatm is not None and tgtatm.strip() == atom.strip()
              ):  # ignore whitespace
                  xyz[idx, i_atm, :] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
                  break

    # save atom mask
    mask = np.logical_not(np.isnan(xyz[..., 0]))
    xyz[np.isnan(xyz[..., 0])] = 0.0

    # remove duplicated (chain, resi)
    new_idx = []
    i_unique = []
    for i, idx in enumerate(pdb_idx):
        if idx not in new_idx:
            new_idx.append(idx)
            i_unique.append(i)

    pdb_idx = new_idx
    xyz = xyz[i_unique]
    mask = mask[i_unique]

    seq = np.array(seq)[i_unique]

    out = {
        "xyz": xyz,  # cartesian coordinates, [Lx14]
        "mask": mask,  # mask showing which atoms are present in the PDB file, [Lx14]
        "idx": np.array(
            [i[1] for i in pdb_idx]
        ),  # residue numbers in the PDB file, [L]
        "seq": np.array(seq),  # amino acid sequence, [L]
        "pdb_idx": pdb_idx,  # list of (chain letter, residue number) in the pdb file, [L]
    }

    # heteroatoms (ligands, etc)
    if parse_hetatom:
        xyz_het, info_het = [], []
        for l in lines:
            if l[:6] == "HETATM" and not (ignore_het_h and l[77] == "H"):
                info_het.append(
                    dict(
                        idx=int(l[7:11]),
                        atom_id=l[12:16],
                        atom_type=l[77],
                        name=l[16:20],
                    )
                )
                xyz_het.append([float(l[30:38]), float(l[38:46]), float(l[46:54])])

        out["xyz_het"] = np.array(xyz_het)
        out["info_het"] = info_het

    return out


def parse_pdb_sequences(file_path):
  sequences = {}
  with open(file_path, 'r') as pdb_file:
      current_chain_id = None
      current_sequence = ''
      for line in pdb_file:
          if line[:4] == "ATOM":
            if line[13:15] == "CA":
                #seq += line[17:20]
                chain_id = line[21]
                if chain_id != current_chain_id:
                    if current_chain_id is not None:
                        sequences[current_chain_id] = current_sequence
                    current_chain_id = chain_id
                    current_sequence = ''
                amino_acid = line[17:20].strip()
                aa = restype_3to1[amino_acid] if amino_acid in restype_3to1 else "X"
                current_sequence += aa
      # Add the last chain sequence
      if current_chain_id is not None:
          sequences[current_chain_id] = current_sequence
  return sequences

def swap_chains(pdb_file):
    with open(pdb_file, 'r') as f:
        lines = f.readlines()

    modified_lines_before = []
    modified_lines_A = []
    modified_lines_B = []
    modified_lines_after = []

    for line in lines:
        if line.startswith('ATOM') or line.startswith('HETATM') or line.startswith('TER'):
            chain_id = line[21]
            if chain_id == 'A':
                modified_lines_B.append(line[:21] + 'B' + line[22:])
            elif chain_id == 'B':
                modified_lines_A.append(line[:21] + 'A' + line[22:])
            else:
               return ValueError('Chain ID is not A or B')
        elif line.startswith('MODEL'):
            modified_lines_before.append(line)
        else:
            modified_lines_after.append(line)
    
    return modified_lines_before + modified_lines_A + modified_lines_B + modified_lines_after

    # with open('modified_' + pdb_file, 'w') as f:
    #     f.writelines(modified_lines)


def swap_chain_names(pdb_file, renames):
    """
    Swap chain names in a PDB file according to the provided mapping.
    
    Args:
    - pdb_file (str): Path to the input PDB file.
    - renames (dict): Mapping of old chain names to new chain names.
    
    Returns:
    - str: Path to the modified PDB file.
    """
    # Load the PDB structure
    parser = PDBParser()
    structure = parser.get_structure('structure', pdb_file)

    # First rename
    for model in structure:
        for chain in model:
            old_name = chain.id
            new_name = renames.get(old_name)
            if new_name:
                chain.id = new_name
    
    # Second rename
    for model in structure:
        for chain in model:
            old_name = chain.id
            new_name = renames.get(old_name)
            if new_name:
                chain.id = new_name

    # Save modified PDB file
    output_file = pdb_file.replace('.pdb', '_modif.pdb')
    io = PDBIO()
    io.set_structure(structure)
    io.save(output_file)

    return output_file





### Use partial diffusion to generate more and better scaffolds
Currently partial diffusion only works if we have binders as chain A...

### Get best initial models from first notebook

In [129]:
best_pdbs = pd.read_csv(f"output/{input}/af2_best.csv")
best_pdbs

Unnamed: 0,mpnn,plddt,i_ptm,i_pae,rmsd,seq,model_path,binder-rmsd,input_pdb
0,0.820425,0.939712,0.892487,5.746988,2.045719,PRDYNPISSTICHLTNESDGHTTSLYGIGFGPFIITNKHLFRRNNG...,output/TEVp-test/mpnn_af2/-test_1/af2/-test_1_...,,output/TEVp-test/rf_dock/TEVp-test_1/-test_1_0...
1,0.817176,0.946283,0.881043,5.657455,2.209037,PRDYNPISSTICHLTNESDGHTTSLYGIGFGPFIITNKHLFRRNNG...,output/TEVp-test/mpnn_af2/-test_1/af2/-test_1_...,,output/TEVp-test/rf_dock/TEVp-test_1/-test_1_0...


###

In [130]:
filtered = best_pdbs[(best_pdbs["plddt"] > 0.9) & (best_pdbs["rmsd"] < 3)]
filtered

Unnamed: 0,mpnn,plddt,i_ptm,i_pae,rmsd,seq,model_path,binder-rmsd,input_pdb
0,0.820425,0.939712,0.892487,5.746988,2.045719,PRDYNPISSTICHLTNESDGHTTSLYGIGFGPFIITNKHLFRRNNG...,output/TEVp-test/mpnn_af2/-test_1/af2/-test_1_...,,output/TEVp-test/rf_dock/TEVp-test_1/-test_1_0...
1,0.817176,0.946283,0.881043,5.657455,2.209037,PRDYNPISSTICHLTNESDGHTTSLYGIGFGPFIITNKHLFRRNNG...,output/TEVp-test/mpnn_af2/-test_1/af2/-test_1_...,,output/TEVp-test/rf_dock/TEVp-test_1/-test_1_0...


## Prepare models for diffusion

In [131]:
# Copy best models to a new directory
initial_dir = f"{out_dir}/model_picks"
os.makedirs(initial_dir, exist_ok=True)

for i, row in filtered.iterrows():
    pdb_file = row["model_path"]
    shutil.copy(pdb_file, initial_dir)

# Swapping chain names
swapped_chains_dir = f"{out_dir}/model_picks/swapped_chains"
os.makedirs(swapped_chains_dir, exist_ok=True)

for pdb_file in glob.glob(f"{initial_dir}/*.pdb"):
    pdb_string = swap_chains(pdb_file)
    with open(f"{swapped_chains_dir}/{os.path.basename(pdb_file)}", 'w') as f:
        f.writelines(pdb_string)

# Random hex strings
pdbs = glob.glob(swapped_chains_dir + "/*.pdb")
array_len = len(pdbs) 
random_names = []
for pdb in pdbs:
    random_names.append(''.join(random.choices(string.ascii_uppercase + string.digits, k=4)))
        

#### Parameters

In [136]:
# RFdiffusion
number_of_diffusions = 5 # per step
diffusion_steps = [15,20,25,30,35,40] # different partial diffusions steps
diffusion_steps = [20,25,30] # different partial diffusions steps

# ProteinMPNN / AF2
mpnn_seqs_per_protein = 16
sampling_temp = 0.1

num_recycles = 3

# Slurm
num_separate_array_jobs = 1
batch_size = array_len // num_separate_array_jobs + 1

array_limit_per_job = 2

print(f"🚀 Initiating array job for {array_len} models with {number_of_diffusions * len(diffusion_steps)} diffusions each...")
print(f"Total diffusions to be generated: {array_len * number_of_diffusions * len(diffusion_steps)}")
print(f"Adjusting batch size to {batch_size}")
print(f"This will execute {num_separate_array_jobs} separate array jobs, each with an array limit of {array_limit_per_job}. Thus, totaling {num_separate_array_jobs * array_limit_per_job} jobs.")


🚀 Initiating array job for 2 models with 15 diffusions each...
Total diffusions to be generated: 30
Adjusting batch size to 3
This will execute 1 separate array jobs, each with an array limit of 2. Thus, totaling 2 jobs.


#### Split model into batches and prepare inputs

In [137]:
batches = [pdbs[i:i + batch_size] for i in range(0, len(pdbs), batch_size)]

for i, batch in enumerate(batches):

    print(f"Batch {i+1}: {len(batch)} models")
    # For each batch prepare a separate folder and in there a folder for inputs and outputs
    diffusion_out_dir = f"{out_dir}/diffusions/batch_{i+1}"
    batch_dir = f"{diffusion_out_dir}/inputs"
    os.makedirs(batch_dir, exist_ok=True)


    # for each pdb within the batch, save commands to inputs folder
    for i, pdb in enumerate(batch):
        pdb_seqs = parse_pdb_sequences(pdb)
        binder_len = len(pdb_seqs["A"])
        target_len = len(pdb_seqs["B"])
        pdb_basename = pdb.split("/")[-1].split(".")[0]
        random_name = random_names[i]

        pdb_command_file = f"{batch_dir}/{random_name}_{pdb_basename}.sh"

        protein_dir = f"{diffusion_out_dir}/{random_name}_{pdb_basename}"
        # os.makedirs(protein_dir, exist_ok=True)

        all_cmds = []
        for steps in diffusion_steps:

            cmd = f"""source /home/tsatler/anaconda3/etc/profile.d/conda.sh
conda activate SE3nv

# Run RFdiffusion
echo "Running RFDiffusion for {pdb_basename} with {steps} steps"

mkdir -p {protein_dir}

python /home/tsatler/RFdif/RFdiffusion/scripts/run_inference.py \
inference.output_prefix={protein_dir}/{random_name}_pd_{steps}_{pdb_basename} \
inference.input_pdb={pdb} \
'contigmap.contigs=[{binder_len}-{binder_len}/0 B1-{target_len}]' \
inference.num_designs={number_of_diffusions} \
diffuser.partial_T={steps}

# delete traj and trp files
if [ -e "{protein_dir}/{random_name}_pd_{steps}_{pdb_basename}*.trb" ]; then
    rm {protein_dir}/{random_name}_pd_{steps}_{pdb_basename}*.trb # delete trb files
fi
if [ -d "{protein_dir}/traj" ]; then
    rm -r "{protein_dir}/traj" # delete trajectories
fi

source /home/tsatler/anaconda3/etc/profile.d/conda.sh
conda activate colabthread

# Run ProteinMPNN and AF2
echo Running ProteinMPNN and AF2

input_files=("{protein_dir}/{random_name}_pd_{steps}_{pdb_basename}"*.pdb)
script="helper_scripts/partial_af2.py"

mkdir -p {protein_dir}/mpnn_af2

for ((i=0; i<${{#input_files[@]}}; i++)); do
  pdb_file=${{input_files[$i]}}
  echo "Running $pdb_file ..."
  af_out={protein_dir}/mpnn_af2/{steps}

  python $script $pdb_file $af_out B A --sampling_temp {sampling_temp} \
  --num_recycles {num_recycles} --num_seqs {mpnn_seqs_per_protein} --num_filt_seq {mpnn_seqs_per_protein} \
  --results_dataframe {out_dir}/diffusions --save_best_only
done
"""
            all_cmds.append(cmd)
        # print(all_cmds)
        with open(pdb_command_file, 'w') as f:
            f.writelines(all_cmds)


Batch 1: 2 models


#### Prepare bash commands for each batch

In [138]:
for i, batch in enumerate(batches):
    num_models = len(batch)
    print(f"Batch {i+1}: {num_models} models")
    batch_script = f"{out_dir}/diffusions/batch_{i+1}_script.sh"

    batch = f"""#!/bin/bash
#SBATCH --partition=gpu
#SBATCH --gres=gpu:A40:1
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=2
#SBATCH --array=0-{num_models-1}%{array_limit_per_job}

set -e

input_commands=({out_dir}/diffusions/batch_{i+1}/inputs/*.sh)
command_file=${{input_commands[$SLURM_ARRAY_TASK_ID]}}
# mapfile -t commands < "$command_file"

echo "Running command file: $command_file"                   
bash $command_file


"""
    
    with open(batch_script, 'w') as f:
        f.writelines(batch)

Batch 1: 2 models


#### Run bash scripts

In [139]:
print("🚀 Launching array bathc job...")
for i, batch in enumerate(batches):
    batch_script = f"{out_dir}/diffusions/batch_{i+1}_script.sh"
    os.system(f"sbatch {batch_script}")

🚀 Launching array bathc job...
Submitted batch job 463650


### Filter results and store best for cycling / metrics

In [5]:
store_path = f"output/best_cd5"
dataframes = glob.glob("output/2ja4*/af2/af2_best.csv")
print(dataframes)
len(dataframes)

['output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/af2/af2_best.csv', 'output/2ja4_cd5_43_3H_17_8_3_mod/af2/af2_best.csv', 'output/2ja4_cd5_30_3H_16_0_4_mod/af2/af2_best.csv', 'output/2ja4_cd5_28_3H_15_3_3_mod/af2/af2_best.csv', 'output/2ja4_cd5_20_2H4E_14_4_2_mod/af2/af2_best.csv', 'output/2ja4_cd5_43_lcb_55_8_3_mod/af2/af2_best.csv', 'output/2ja4_cd5_23_4H_19_6_2_mod/af2/af2_best.csv', 'output/2ja4_cd5_4-7rdh_3H_14_3_2_mod/af2/af2_best.csv', 'output/2ja4_cd5_21_2H4E_17_5_3_mod/af2/af2_best.csv', 'output/2ja4_cd5_1-lcb3_3H_19_3_0_mod/af2/af2_best.csv', 'output/2ja4_cd5_25_4H_20_4_1_mod/af2/af2_best.csv', 'output/2ja4_cd5_16_3H_19_2_1_mod/af2/af2_best.csv', 'output/2ja4_cd5_14_3H_20_1_2_mod/af2/af2_best.csv', 'output/2ja4_cd5_42_3H_13_8_4_mod/af2/af2_best.csv', 'output/2ja4_cd5_15_3H_11_0_0_mod/af2/af2_best.csv', 'output/2ja4_cd5_29_3H_20_2_4_mod/af2/af2_best.csv', 'output/2ja4_cd5_27_3H_20_0_0_mod/af2/af2_best.csv', 'output/2ja4_cd5_24_4H_13_6_0_mod/af2/af2_best.csv']


18

In [6]:
# merge all dataframes in dataframes list
for df_path in dataframes:
    df = pd.read_csv(df_path)
    if df_path == dataframes[0]:
        merged = df
    else:
        merged = pd.concat([merged, df])

merged

Unnamed: 0,mpnn,plddt,i_ptm,i_pae,rmsd,seq,model_path,binder-rmsd,input_pdb
0,1.107793,0.904767,0.731257,7.115674,2.100018,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/af2/af2/...,,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/0ZZ6_pd_...
1,1.102261,0.897332,0.675108,8.255573,2.515133,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/af2/af2/...,,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/0ZZ6_pd_...
2,1.080848,0.898672,0.719799,7.383804,2.020807,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/af2/af2/...,,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/0ZZ6_pd_...
3,1.074102,0.886665,0.653436,8.271711,1.190259,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/af2/af2/...,,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/0ZZ6_pd_...
4,1.074048,0.902687,0.718427,7.230995,1.226888,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/af2/af2/...,,output/2ja4_cd5_2-8sk7_2H3E_9_9_2_mod/0ZZ6_pd_...
...,...,...,...,...,...,...,...,...,...
1598,1.208489,0.784615,0.523055,10.516017,2.563221,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_24_4H_13_6_0_mod/af2/af2/YLEB_...,,output/2ja4_cd5_24_4H_13_6_0_mod/YLEB_pd_20_8.pdb
1599,1.181706,0.782212,0.477208,10.761100,2.627660,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_24_4H_13_6_0_mod/af2/af2/YLEB_...,,output/2ja4_cd5_24_4H_13_6_0_mod/YLEB_pd_20_8.pdb
1600,1.175341,0.776156,0.543966,9.265256,2.996432,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_24_4H_13_6_0_mod/af2/af2/YLEB_...,,output/2ja4_cd5_24_4H_13_6_0_mod/YLEB_pd_20_8.pdb
1601,1.172967,0.752545,0.418754,12.291875,2.388027,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_24_4H_13_6_0_mod/af2/af2/YLEB_...,,output/2ja4_cd5_24_4H_13_6_0_mod/YLEB_pd_20_8.pdb


In [7]:
merged.describe()

Unnamed: 0,mpnn,plddt,i_ptm,i_pae,rmsd,binder-rmsd
count,49687.0,49687.0,49687.0,49687.0,49687.0,0.0
mean,1.137604,0.841529,0.606492,8.86843,1.893695,
std,0.081687,0.074829,0.171629,3.71233,0.575526,
min,0.835833,0.439428,0.069751,4.141545,0.408412,
25%,1.085944,0.803511,0.514278,6.165107,1.442757,
50%,1.13945,0.860757,0.655632,7.629819,1.870749,
75%,1.192706,0.896616,0.738203,10.437557,2.353969,
max,1.477788,0.958538,0.86632,25.382884,2.999982,


In [24]:
filtered = merged[(merged["plddt"] > 0.9)&
                  (merged["i_pae"] < 5.5)&
                  (merged["rmsd"] < 2)
                  ]
filtered = filtered.sort_values("i_pae", ascending=True)
os.makedirs(store_path, exist_ok=True)
filtered.to_csv(store_path + "/cd5_bestpartial.csv", index=False)
filtered

Unnamed: 0,mpnn,plddt,i_ptm,i_pae,rmsd,seq,model_path,binder-rmsd,input_pdb
696,1.087203,0.949473,0.857809,4.141545,1.288186,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_27_3H_20_0_0_mod/af2/af2/42ML_...,,output/2ja4_cd5_27_3H_20_0_0_mod/42ML_pd_20_22...
2084,1.135136,0.957026,0.846879,4.180467,1.486491,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_43_3H_17_8_3_mod/af2/af2/QY63_...,,output/2ja4_cd5_43_3H_17_8_3_mod/QY63_pd_25_37...
750,1.023279,0.952732,0.838807,4.238931,1.343033,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_43_3H_17_8_3_mod/af2/af2/OC25_...,,output/2ja4_cd5_43_3H_17_8_3_mod/OC25_pd_20_33...
888,1.045610,0.954768,0.853614,4.253094,1.230778,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_43_3H_17_8_3_mod/af2/af2/OC25_...,,output/2ja4_cd5_43_3H_17_8_3_mod/OC25_pd_20_37...
2257,1.049232,0.949300,0.849913,4.254208,1.212411,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_43_3H_17_8_3_mod/af2/af2/QY63_...,,output/2ja4_cd5_43_3H_17_8_3_mod/QY63_pd_25_43...
...,...,...,...,...,...,...,...,...,...
527,1.147216,0.909553,0.780373,5.499053,1.946588,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_42_3H_13_8_4_mod/af2/af2/60ZU_...,,output/2ja4_cd5_42_3H_13_8_4_mod/60ZU_pd_25_18...
2772,1.140657,0.923014,0.771877,5.499333,1.271049,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_1-lcb3_3H_19_3_0_mod/af2/af2/O...,,output/2ja4_cd5_1-lcb3_3H_19_3_0_mod/OZT9_pd_2...
2329,1.107046,0.931662,0.780930,5.499363,1.493340,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_29_3H_20_2_4_mod/af2/af2/V1I3_...,,output/2ja4_cd5_29_3H_20_2_4_mod/V1I3_pd_25_26...
1649,1.238068,0.904065,0.766409,5.499387,1.067629,AFQPKVQSRLVGGSSICEGTVEVRQGAQWAALCDSSSARSSLRWEE...,output/2ja4_cd5_27_3H_20_0_0_mod/af2/af2/42ML_...,,output/2ja4_cd5_27_3H_20_0_0_mod/42ML_pd_20_4.pdb


In [None]:
.csv", index=False)
filtered

In [19]:
grouped = filtered.groupby("input_pdb")
len(grouped)

510

In [36]:
out= f"{store_path}/5fmv_domain4_2_lcb3"
os.makedirs(out, exist_ok=True)

for i, row in filt2.iterrows():
    model = row["model_path"]
    shutil.copy(model, f"{out}/{model.split('/')[-1]}")

In [40]:
df3 = pd.read_csv(dataframes[2])

filt3 = df3[
    (df3["plddt"] > 0.9) & 
    (df3["rmsd"] < 1.5) &
    (df3["i_pae"] < 8)
]

filt3 = filt3.drop_duplicates(["input_pdb"])

filt3

Unnamed: 0,mpnn,plddt,i_ptm,i_pae,rmsd,seq,model_path,binder-rmsd,input_pdb
21,1.115748,0.905856,0.834014,7.466749,1.497733,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
32,1.244384,0.90851,0.855917,6.967074,1.316984,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
105,1.208348,0.907436,0.860606,6.892799,1.45069,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
195,1.124334,0.925211,0.791968,7.97795,1.444799,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
327,1.139189,0.901043,0.812098,7.809068,1.358346,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
340,1.299763,0.901468,0.811883,7.606027,1.202947,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
429,1.14625,0.905463,0.824724,7.656832,1.399906,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
524,1.16879,0.901929,0.848107,7.318871,1.406448,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
547,1.266944,0.902509,0.848655,7.24848,0.556795,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
558,1.220614,0.900082,0.821461,7.664502,1.435184,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...


In [41]:
out= f"{store_path}/5fmv_domain3-4_5_3H"
os.makedirs(out, exist_ok=True)

for i, row in filt3.iterrows():
    model = row["model_path"]
    shutil.copy(model, f"{out}/{model.split('/')[-1]}")

In [45]:
df4 = pd.read_csv(dataframes[3])

filt4 = df4[
    (df4["plddt"] > 0.9) & 
    (df4["rmsd"] < 1.5) &
    (df4["i_pae"] < 7)
]

filt4 = filt4.drop_duplicates(["input_pdb"])

filt4

Unnamed: 0,mpnn,plddt,i_ptm,i_pae,rmsd,seq,model_path,binder-rmsd,input_pdb
241,0.967114,0.919704,0.848756,6.872086,1.363722,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/cd45binder_otherscaff_18_unrelaxed_rank...,,output/cd45binder_otherscaff_18_unrelaxed_rank...


In [47]:
# merge all dataframes in dataframes list
for df_path in dataframes:
    df = pd.read_csv(df_path)
    if df_path == dataframes[0]:
        merged = df
    else:
        merged = pd.concat([merged, df])

merged

Unnamed: 0,mpnn,plddt,i_ptm,i_pae,rmsd,seq,model_path,binder-rmsd,input_pdb
0,1.277649,0.631686,0.551593,13.829327,1.747557,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/a...,,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/6...
1,1.275846,0.683484,0.612635,12.576947,1.761258,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/a...,,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/6...
2,1.221064,0.799177,0.775171,9.504641,2.358474,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/a...,,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/6...
3,1.219265,0.765851,0.719844,10.546119,2.316575,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/a...,,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/6...
4,1.216010,0.757928,0.706146,10.792172,2.208929,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/a...,,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/6...
...,...,...,...,...,...,...,...,...,...
2328,1.091968,0.823915,0.530615,11.541245,2.955819,PSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVR...,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/a...,,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/H...
2329,1.089464,0.803525,0.543450,10.610820,1.362002,PSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVR...,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/a...,,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/H...
2330,1.087865,0.762951,0.300635,16.463484,1.910863,PSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVR...,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/a...,,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/H...
2331,1.120579,0.611589,0.541237,13.655488,2.748455,PSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVR...,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/a...,,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/H...


In [51]:
# filter merged
filtered = merged[
    (merged["plddt"] > 0.9) & 
    (merged["rmsd"] < 1.5) &
    (merged["i_pae"] < 7)
]


filtered = filtered.drop_duplicates(["input_pdb"])
filtered

Unnamed: 0,mpnn,plddt,i_ptm,i_pae,rmsd,seq,model_path,binder-rmsd,input_pdb
1008,1.125475,0.914588,0.860514,6.877782,1.330983,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/a...,,output/5fmv_domain3-4_5_3H_13_4_2_5_24_6_mod/L...
32,1.244384,0.908510,0.855917,6.967074,1.316984,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
105,1.208348,0.907436,0.860606,6.892799,1.450690,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
201,1.109776,0.928341,0.856327,6.889598,1.470566,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
716,1.188016,0.915262,0.860234,6.908476,1.260212,FGSPGEPQIIFCRSEAAHQGVITWNPPQRSFHNFTLCYIKETEKDC...,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/af...,,output/5fmv_domain3-4_5_3H_13_4_2_1_5_5_mod/4D...
...,...,...,...,...,...,...,...,...,...
2156,1.098406,0.913968,0.846424,5.353226,1.383325,PSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVR...,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/a...,,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/H...
2194,1.004237,0.910334,0.752790,6.977907,1.475668,PSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVR...,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/a...,,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/H...
2229,1.291051,0.928157,0.795200,5.898710,1.178445,PSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVR...,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/a...,,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/H...
2262,1.103472,0.915764,0.835721,5.519806,1.165705,PSQVWNMTVSMTSDNSMHVKCRPPRDRNGPHERYHLEVEAGNTLVR...,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/a...,,output/5fmv_domain4_2_2H3E_13_7_9_10_4_7_mod/H...


In [52]:
# save dataframe and copy all models to folder /home/tsatler/RFdif/ClusterProteinDesign/scripts/partial_diff/output/best_cd45/all
out= f"{store_path}/all"
os.makedirs(out, exist_ok=True)

for i, row in filtered.iterrows():
    model = row["model_path"]
    shutil.copy(model, f"{out}/{model.split('/')[-1]}")

In [54]:
filtered.to_csv(f"{store_path}/all/all.csv", index=False)