In [3]:
from pyrosetta import *
from pyrosetta.rosetta.core.select.residue_selector import InterGroupInterfaceByVectorSelector, ChainSelector, OrResidueSelector, NotResidueSelector
import pandas as pd
import argparse
import os
from geometry_functions import *
init()

PyRosetta-4 2023 [Rosetta PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python311.Release 2023.27+release.e3ce6ea9faf661ae8fa769511e2a9b8596417e58 2023-07-07T12:00:46] retrieved from: http://www.pyrosetta.org
(C) Copyright Rosetta Commons Member Institutions. Created in JHU by Sergey Lyskov and PyRosetta Team.
core.init: Checking for fconfig files in pwd and ./rosetta/flags
core.init: Rosetta version: PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python311.Release r353 2023.27+release.e3ce6ea e3ce6ea9faf661ae8fa769511e2a9b8596417e58 http://www.pyrosetta.org 2023-07-07T12:00:46
core.init: command: PyRosetta -ex1 -ex2aro -database /home/linamp/miniconda3/envs/myenv_gl/envs/rosetta_env/lib/python3.11/site-packages/pyrosetta/database
basic.random.init_random_generator: 'RNG device' seed mode, using '/dev/urandom', seed=-1901301938 seed_offset=0 real_seed=-1901301938 thread_index=0
basic.random.init_random_generator: RandomGenerator:init: Normal mode, seed=-1901301938 RG

In [15]:
def group_selector (chains):
    sel = ChainSelector(chains[0])
    for ch in chains[1:]:
        sel = OrResidueSelector(sel, ChainSelector(ch))
    return sel

def run_design_positions_generator (interface, fixed_chains, CB_DISTANCE, EXTENDED_DESIGN, MAX_DISTANCE, PDB_DIR): 
    if EXTENDED_DESIGN and MAX_DISTANCE is None:
        raise ValueError("--extended_design requires MAX_DISTANCE to be set (default=12 Å)")

    grp1_str, grp2_str = interface.split('_', 1)
    GROUP1 = list(grp1_str)
    GROUP2 = list(grp2_str)
    FIXED_CHAINS = [x.strip() for x in fixed_chains.split(',')]

    cutoff = MAX_DISTANCE if EXTENDED_DESIGN else CB_DISTANCE
    mode   = "extended interface (C alpha distance)" if EXTENDED_DESIGN else "interface (C beta distance)"
    print(f"Mode = {mode}, cutoff = {cutoff:.1f} Å")
    print(f"Designing interface between {GROUP1} and {GROUP2}, keeping {FIXED_CHAINS} fixed")

    if not EXTENDED_DESIGN:
        print ("Designing residues at interface only")
        files = [f.path for f in os.scandir(PDB_DIR) if f.is_file() and f.path.endswith('.pdb')]
        for file in files:
            print (f'Processing {file}')
            pose = pose_from_pdb(file)
            start_pose = Pose()
            start_pose.assign(pose)

            sel1 = group_selector(GROUP1)
            sel2 = group_selector(GROUP2)

            interface_selector = InterGroupInterfaceByVectorSelector(sel1, sel2)
            interface_selector.cb_dist_cut(CB_DISTANCE)
            not_interface_selector = NotResidueSelector(interface_selector)
            interface_residues = interface_selector.apply(pose) # Boolean vector

            # Map the chain to the index of where it begins
            chain_map = {}
            for i in range(1, pose.num_chains()+1):
                ch = pose.pdb_info().chain(pose.chain_begin(i))
                chain_map[ch] = pose.chain_begin(i)
            print(f"Chain begin index: {chain_map}")

            # interface_residues_dict of designable residues
            interface_residues_dict = {
                **{c: [] for c in GROUP1}, # ** operator unpacks a dict into key:value
                **{c: [] for c in GROUP2}
                } # All unpacked key:values are merged

            for chain in interface_residues_dict.keys():
                if chain in FIXED_CHAINS:
                    continue
                ires = [str(i-chain_map[chain] + 1) for i in range(1, pose.size() + 1) if interface_residues[i] and pose.pdb_info().chain(i) == str(chain)]
                interface_residues_dict[chain] = ires

            PDB_NAME = os.path.basename(file).replace(".pdb", "")
            with open(f"{PDB_DIR}/{PDB_NAME}_design_residues.txt", "w") as f:
                f.write(" ".join(interface_residues_dict.keys())) # e.g. f.write("A B C D")
                f.write("\n")

                residue_strings = []
                for chain in interface_residues_dict.keys():
                    residue_strings.append(" ".join(interface_residues_dict[chain])) if interface_residues_dict[chain] else residue_strings.append("")
                f.write(", ".join(residue_strings))


    if EXTENDED_DESIGN:
        
        print("************ Make sure your chain numbering starts with 1!!!!!!! ************")

        print ("Designing residues at interface and nearby - using c alpha distance as cutoff")
        files = [f.path for f in os.scandir(PDB_DIR) if f.is_file() and f.path.endswith('.pdb')]
        for file in files:
            print (f'Processing {file}')
            PDB_NAME = os.path.basename(file).replace(".pdb", "")
            pairs_to_omit = (
                list(permutations(GROUP1, 2)) +
                list(permutations(GROUP2, 2))
            )
            _, interface_contacts = analyze_interface (file, min_distance=1, max_distance=MAX_DISTANCE, omit_chain_pairs=pairs_to_omit)
            interface_residues_dict = { c: [] for c in (GROUP1 + GROUP2) }

            for chain in interface_residues_dict:
                if chain in FIXED_CHAINS:
                    continue

                interactions = set()
                for other_chain in (GROUP1 + GROUP2):
                    if other_chain == chain:
                        continue
                    cc1 = chain + other_chain
                    cc2 = other_chain + chain
                    interactions |= interface_contacts.get(cc1, {}).get(chain, set())
                    interactions |= interface_contacts.get(cc2, {}).get(chain, set())
                interface_residues_dict[chain] = [i+1 for i in interactions]

            print (interface_residues_dict)
            with open(f"{PDB_DIR}/{PDB_NAME}_design_residues.txt", "w") as f:
                f.write(" ".join(interface_residues_dict.keys()))
                f.write("\n")

                residue_strings = []
                for c in interface_residues_dict:
                    residue_strings.append(" ".join(str(i) for i in interface_residues_dict[c]))

                f.write(", ".join(residue_strings))

In [16]:
MAX_DISTANCE = 8.5
CB_DISTANCE = 7
EXTENDED_DESIGN = True
interface = "A_BC"
PDB_DIR = "input_pdbs"
fixed_chains = "C"

run_design_positions_generator (interface, fixed_chains, CB_DISTANCE, EXTENDED_DESIGN, MAX_DISTANCE, PDB_DIR)

Mode = extended interface (C alpha distance), cutoff = 8.5 Å
Designing interface between ['A'] and ['B', 'C'], keeping ['C'] fixed
************ Make sure your chain numbering starts with 1!!!!!!! ************
Designing residues at interface and nearby - using c alpha distance as cutoff
Processing input_pdbs/8.pdb
Analyzing PDB file 8.pdb

{'A': [1, 3, 6, 7, 9, 10, 13, 14, 16, 17, 18, 20, 21, 22], 'B': [34, 35, 46, 23, 26, 27, 28, 30, 31], 'C': []}




#### Check selected positions for design

In [17]:
from Bio.PDB import PDBParser, PDBIO
import os

design_positions_file = "input_pdbs/8_design_residues.txt"
basename = os.path.basename(design_positions_file).replace("_design_residues.txt", "")
pdb_file    = os.path.join(os.path.dirname(design_positions_file), f"{basename}.pdb")
output_file = os.path.join(os.path.dirname(design_positions_file),
                           f"{basename}_selection_as_bfactor.pdb")

with open(design_positions_file) as f:
    chains = f.readline().split()
    lists  = [seg.strip() for seg in f.readline().split(',')]
design_map = {}
for ch, seg in zip(chains, lists):
    if seg:
        design_map[ch] = [int(x) for x in seg.split()]
    else:
        design_map[ch] = []

parser = PDBParser(QUIET=True)
structure = parser.get_structure(basename, pdb_file)

for model in structure:
    for chain in model:
        chID = chain.id
        residues = list(chain)  
        first_residue_in_chain = residues[0].id[1]
        for residue in residues:
            resnum_pdb = residue.id[1]
            resnum = resnum_pdb - first_residue_in_chain + 1
            is_design = (resnum in design_map.get(chID, []))
            bval = 100.0 if is_design else 0.0
            for atom in residue:
                atom.set_bfactor(bval)

io = PDBIO()
io.set_structure(structure)
io.save(output_file)

# To color it, in pymol use: spectrum b, lightblue red, minimum=0, maximum=100