In [1]:
from Bio.PDB import PDBParser, Superimposer
from Bio.PDB.Polypeptide import is_aa
import re
import glob 
import pandas as pd
three_to_one_map = {
    'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU': 'E', 'PHE': 'F',
    'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L',
    'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN': 'Q', 'ARG': 'R',
    'SER': 'S', 'THR': 'T', 'VAL': 'V', 'TRP': 'W', 'TYR': 'Y'
}
class ProteinMPNN:
    """
    Class for using ProteinMPNN to redesign binders by fixing the residues at the interface.
    """
    
    def __init__(self):
        """
        Initialize the ProteinMPNN class with the provided arguments.

        Args:
            args (dict): Dictionary containing necessary arguments and configurations.
        """
  

    def find_fix_residues(
        self,
        motifs: list[str],
        pdb_file: str,
        chain_id: str = "A"
    ) -> dict:
        """
        Find the positions of the residues to fix in the motifs.

        Inputs:
            motifs (list[str]): List of motif sequences to be fixed.
            pdb_file (str): Path to the PDB file containing the protein structure.
            chain_id (str): The chain identifier in the PDB file (default "A").

        Process:
            - Parses the PDB file and extracts the specified chain.
            - Constructs the chain's sequence from standard amino acids.
            - Searches for exact occurrences of each motif within the sequence.
            - Maps each motif occurrence to the corresponding residue IDs in the chain.

        Returns:
            dict: A dictionary mapping each motif to a list of occurrences, with each occurrence being a list of residue ID tuples.
        """
        
        # Parse the PDB structure
        
        parser = PDBParser(QUIET=True)
        structure = parser.get_structure("complex", pdb_file)


        if chain_id not in structure[0]:
            logger.error(f"Chain {chain_id} not found in the PDB structure.")
            return {}
        chain = structure[0][chain_id]

        # Extract the chain sequence and record residue IDs
        chain_seq_list = []
        residue_id_list = []
        for residue in chain:
            if is_aa(residue, standard=True):
                try:
                    res_code = three_to_one_map[residue.get_resname()]
                except Exception:
                    res_code = "X"  # Unknown residue code
                chain_seq_list.append(res_code)
                residue_id_list.append(residue.get_id())
        chain_seq = "".join(chain_seq_list)

        motif_matches = {}
        # Search for each motif in the chain sequence and map residue positions to their one-letter codes
        residue_map = {}
        for motif in motifs:
            pattern = re.escape(motif)
            found = False
            for match in re.finditer(pattern, chain_seq):
                found = True
                start, end = match.start(), match.end()
                for idx in range(start, end):
                    res_id = residue_id_list[idx]
                    aa_code = chain_seq_list[idx]
                    # Modify key to use only the residue number
                    residue_map[res_id[1]] = aa_code
            if not found:
                print(f"Motif '{motif}' not found in chain {chain_id}.")
        return residue_map
    
    def condense_residue_ranges(self, residue_dict, chain_id):
        """
        Condense a dictionary of residue positions into a string of ranges.

        Args:
            residue_dict (dict): Dictionary with residue positions as keys.
            chain_id (str): Chain ID of the binder.

        Returns:
            str: Condensed string representation of residue ranges (e.g., "A55,A57-58,A61").
        """
        # Sort the residue positions
        positions = sorted(residue_dict.keys())

        # Return an empty string if no positions are available
        if not positions:
            return ""

        ranges = []
        start = positions[0]
        prev = start

        # Iterate through positions to form ranges
        for pos in positions[1:] + [None]:
            if pos != prev + 1:
                # Append the range or single position to the list
                if start == prev:
                    ranges.append(f"{chain_id}{start}")
                else:
                    ranges.append(f"{chain_id}{start}-{prev}")
                start = pos
            prev = pos

        # Join the ranges into a single string
        return ",".join(ranges)
    
    def find_residues_to_fix(
        self,
        motifs: list[str],
        pdb_file: str,
        chain_id: str = "A"
    ) -> dict[str, str]:
        """
        Identify residues within a specified distance from the binder chain in both motif A and B PDB files.

        Args:
            complex_pdb_motif_a (str): File path to the complex PDB for motif A.
            complex_pdb_motif_b (str): File path to the complex PDB for motif B.
            binder_chain (str): Identifier for the binder chain. Defaults to "B".
            atom_distance_cutoff (float): Maximum distance in angstroms to consider a residue as an interface. Defaults to 5.0.

        Returns:
            dict[str, str]: A dictionary of combined interface residues from both motifs, condensed into a string format.
        """
        fix_residues = self.find_fix_residues(
            motifs=motifs,
            pdb_file=pdb_file,
            chain_id=chain_id
        )
       
        # Condense the residues into a string
        fix_residues = self.condense_residue_ranges(
            fix_residues, 
            chain_id
        )
        print(fix_residues)
        return fix_residues


    def get_ca_atoms(self, chain) -> list:
        """
        Returns a list of CA atoms for each residue in the chain that is a standard amino acid.
        This ensures that the RMSD calculation is based solely on the structural backbone.
        """
        return [residue["CA"] for residue in chain if "CA" in residue and is_aa(residue, standard=True)]
    
    def RMSD_two_pdbs_using_motifs(self, pdb_1: str, pdb_2: str, motifs: list) -> float:
        """
        Calculate the RMSD between two PDB files using motif-defined interface residues.

        This function:
        1. Identifies motif interface residue ranges in both PDBs.
        2. Extracts Cα atoms for those residues from chain A in each structure.
        3. Aligns the two sets of Cα atoms and computes the RMSD.

        Returns:
            float: The RMSD value between the motif regions of the two structures.
        """
        # Get motif residue ranges (e.g., "24-38,80-92") for each PDB, removing chain label
        motif_ranges_1 = self.find_residues_to_fix(motifs, pdb_1).replace("A", "")
        motif_ranges_2 = self.find_residues_to_fix(motifs, pdb_2).replace("A", "")

        # Parse structures and extract chain A
        parser = PDBParser(QUIET=True)
        chain_1 = parser.get_structure("structure_1", pdb_1)[0]["A"]
        chain_2 = parser.get_structure("structure_2", pdb_2)[0]["A"]

        # Get all Cα atoms for chain A in each structure
        ca_atoms_1 = self.get_ca_atoms(chain_1)
        ca_atoms_2 = self.get_ca_atoms(chain_2)

        # Helper to extract Cα atoms for all motif ranges
        def extract_atoms_by_ranges(ca_atoms, ranges_str):
            atoms = []
            for rng in ranges_str.split(","):
                if "-" in rng:
                    start, end = map(int, rng.split("-"))
                    # PDB residue numbering is 1-based, ca_atoms is 0-based
                    atoms.extend(ca_atoms[start-1:end])
                else:
                    idx = int(rng)
                    atoms.append(ca_atoms[idx-1])
            return atoms

        atoms_1_aligned = extract_atoms_by_ranges(ca_atoms_1, motif_ranges_1)
        atoms_2_aligned = extract_atoms_by_ranges(ca_atoms_2, motif_ranges_2)

        # Ensure both lists are the same length for alignment
        n = min(len(atoms_1_aligned), len(atoms_2_aligned))
        atoms_1_aligned = atoms_1_aligned[:n]
        atoms_2_aligned = atoms_2_aligned[:n]

        # Align and compute RMSD
        sup = Superimposer()
        sup.set_atoms(atoms_1_aligned, atoms_2_aligned)
        return round(sup.rms, 2)

In [None]:
pdb = ProteinMPNN()
origin_pdb = "/pasteur/appa/scratch/dvu/github/DeltaGAgileDesign/RFdiffusion/Spike_coronavirus/pdb_outputs/covid_spike_0.pdb"

redesigned_pdbs = glob.glob("/pasteur/appa/scratch/dvu/github/DeltaGAgileDesign/AgileDesign/results/structures/*_mpnn_*.pdb")
motifs = ["SNNLDSKVGGNYNYR", "YGFQPTNGVGYQP"]
rmsd_list = []
ids = []
from concurrent.futures import ProcessPoolExecutor


def compute_rmsd(redesigned_pdb):
    # Compute RMSD for a given redesigned_pdb using the provided origin_pdb and motifs
    rmsd_value = pdb.RMSD_two_pdbs_using_motifs(origin_pdb, redesigned_pdb, motifs)
    identifier = redesigned_pdb.split('/')[-1].split('.')[0]
    return identifier, rmsd_value

with ProcessPoolExecutor() as executor:
    # Map the compute_rmsd function to all redesigned pdbs in parallel
    results = list(executor.map(compute_rmsd, redesigned_pdbs))

# Unpack results into ids and rmsd list
ids, rmsd_list = zip(*results)

# Create a dataframe with the results and save to CSV
df_rmsd = pd.DataFrame({"id": ids, "motifs_rmsd": rmsd_list})
df_rmsd.to_csv("rmsd_motifs.csv", index=False)


In [14]:
df_rmsd = pd.read_csv("rmsd_motifs_final.csv")
df_redesigned = pd.read_csv("redesign_results_final.csv")

#join the two dataframes on the id column
df_rmsd = df_rmsd.merge(df_redesigned, on="id", how="left")
#sort the dataframe by motifs_rmsd
df_rmsd = df_rmsd.sort_values(by="motifs_rmsd", ascending=True)
#save the joined dataframe to a csv
df_rmsd.to_csv("rmsd_motifs_final_joined.csv", index=False)

#sort the dataframe by motifs_rmsd
df_rmsd = df_rmsd.sort_values(by="motifs_rmsd", ascending=True)

#save the top 10 redesigned pdbs to a csv
top_10_redesigned = df_rmsd.head(10)
top_dir = "top_10_redesigned"
import os
os.makedirs(top_dir, exist_ok=True)
import shutil
for index, row in top_10_redesigned.iterrows():
    shutil.copy("results/structures/" + row["id"] + ".pdb", os.path.join(top_dir, row["id"] + ".pdb"))








Empty DataFrame
Columns: [id, motifs_rmsd, seq, binder_ptm, binder_plddt, binder_pae]
Index: []
