In [1]:
from scipy.optimize import minimize
import numpy as np
from Bio import PDB
from Bio.PDB.PDBIO import PDBIO
from Bio.PDB import Atom, Residue, Chain, Model, Structure
from Bio.PDB.vectors import Vector
from scipy.special import comb
import py3Dmol
import matplotlib.pyplot as plt
import networkx as nx

In [None]:
class PolyGLinker:

    def __init__(self):
        # Standard amino acid geometry
        self.ca_c_length = 1.52  # Å
        self.c_n_length = 1.33  # Å
        self.n_ca_length = 1.46  # Å
        self.ca_ca_distance = 3.8  # Å
        self.ca_c_n_angle = 117.2  # degrees
        self.c_n_ca_angle = 121.7  # degrees
        self.peptide_omega = 180.0  # degrees

    def _optimize_control_points(
        self, start_pos, end_pos, start_dir, end_dir, num_residues
    ):
        """Optimize control points to achieve desired path length."""
        target_length = (num_residues - 1) * self.ca_ca_distance
        distance = np.linalg.norm(end_pos - start_pos)

        def objective(x):
            # x contains scaling factors for control points
            scale1, scale2 = x
            p0 = start_pos
            p1 = start_pos + start_dir * (distance * scale1)
            p2 = end_pos - end_dir * (distance * scale2)
            p3 = end_pos

            path_length = self._calculate_path_length((p0, p1, p2, p3))
            return (path_length - target_length) ** 2

        # Optimize scaling factors
        result = minimize(objective, [0.33, 0.33], bounds=[(0.1, 0.9), (0.1, 0.9)])
        scale1, scale2 = result.x
        scale1, scale2 = (0.33, 0.33)
        # Return optimized control points
        return (
            start_pos,
            start_pos + start_dir * (distance * scale1),
            end_pos - end_dir * (distance * scale2),
            end_pos,
        )

    def _calculate_residue_direction(self, residue):
        """Calculate the direction vector of a residue's backbone."""
        ca_pos = residue["CA"].get_coord()
        if "C" in residue:
            c_pos = residue["C"].get_coord()
            direction = c_pos - ca_pos
        else:
            n_pos = residue["N"].get_coord()
            direction = ca_pos - n_pos
        return direction / np.linalg.norm(direction)

    def calculateIdealPath(self, start_res, end_res):
        start_direction = self._calculate_residue_direction(start_res)
        end_direction = self._calculate_residue_direction(end_res)

        start_ca = start_res["CA"].get_coord()
        end_ca = end_res["CA"].get_coord()

        distance = np.linalg.norm(end_ca - start_ca)
        # Generate a test path for the Bezier curve
        test_control_points = (
            start_ca,
            start_ca + start_direction * (distance / 3),
            end_ca - end_direction * (distance / 3),
            end_ca,
        )
        test_path_length = self._calculate_path_length(test_control_points)
        num_residues = int(np.ceil(test_path_length / self.ca_ca_distance))

        # Optimize control points for target path length
        control_points = self._optimize_control_points(
            start_ca, end_ca, start_direction, end_direction, num_residues
        )

        # Generate points along optimized Bezier curve with equal arc length spacing
        # First, calculate a finely sampled path
        fine_t = np.linspace(0, 1, 1000)
        fine_path = np.array(
            [self._evaluate_bezier(control_points, ti) for ti in fine_t]
        )
        self.ca_ca_distance

        path = np.zeros((num_residues, 3))
        lastDist = 0
        currentDist = 0
        for i in range(fine_path):
            currentDist += np.linalg.norm(fine_path[i] - fine_path[i - 1])
            if currentDist - lastDist > self.ca_ca_distance:
                path[i] = fine_path[i]
                lastDist = currentDist

        return path, end_direction

    def generate_linker(self, segment1, segment2):
        """Generate backbone coordinates for a linker using an optimized Bezier curve."""
        start_res = segment1.c_res
        end_res = segment2.n_res

        path, end_direction = self.calculateIdealPath(start_res, end_res)

        # Create a new chain for the linker
        linker_chain = Chain.Chain("L")  # 'L' for linker
        for i in range(1, path.shape[0]):
            residue = Residue.Residue((" ", i, " "), "GLY", "")

            glyStart = path[i - 1]
            glyEnd = path[i]
            # Define atom coordinates for glycine based on position and direction
            # Calculate direction vector from previous to current position
            direction = glyEnd - glyStart
            direction = direction / np.linalg.norm(direction)

            # Calculate perpendicular vectors to create a coordinate system
            # Use cross product with arbitrary vector (0,0,1) to get perpendicular vector
            if np.allclose(direction, [0, 0, 1]) or np.allclose(direction, [0, 0, -1]):
                perp1 = np.array([1, 0, 0])
            else:
                perp1 = np.cross(direction, [0, 0, 1])
                perp1 = perp1 / np.linalg.norm(perp1)

            # Create atom positions
            ca_pos = glyStart
            c_pos = ca_pos + direction * self.ca_c_length
            n_pos = glyEnd - direction * self.n_ca_length
            o_pos = c_pos + perp1 * 1.23  # Typical C-O bond length

            # Create atoms
            ca_atom = Atom.Atom("CA", ca_pos, 0, 1.0, " ", "CA", 0, "C")
            c_atom = Atom.Atom("C", c_pos, 0, 1.0, " ", "C", 0, "C")
            n_atom = Atom.Atom("N", n_pos, 0, 1.0, " ", "N", 0, "N")
            o_atom = Atom.Atom("O", o_pos, 0, 1.0, " ", "O", 0, "O")

            # Add atoms to residue
            residue.add(n_atom)
            residue.add(ca_atom)
            residue.add(c_atom)
            residue.add(o_atom)
            # Add residue to chain
            linker_chain.add(residue)

        return linker_chain

    def _evaluate_bezier(self, control_points, t):
        """Evaluate a cubic Bezier curve at parameter t."""
        p0, p1, p2, p3 = control_points
        return (
            (1 - t) ** 3 * p0
            + 3 * (1 - t) ** 2 * t * p1
            + 3 * (1 - t) * t**2 * p2
            + t**3 * p3
        )

    def _calculate_path_length(self, control_points, num_points=100):
        """Calculate the total length of the Bezier curve."""
        t = np.linspace(0, 1, num_points)
        points = np.array([self._evaluate_bezier(control_points, ti) for ti in t])

        # Calculate total length as sum of segments
        segments = np.diff(points, axis=0)
        lengths = np.sqrt(np.sum(segments**2, axis=1))
        return np.sum(lengths)


class Segment:
    def get_named_atom(self, residue, atom_name):
        """
        Retrieves a specific atom from a residue.

        Args:
            residue (Bio.PDB.Residue.Residue): The residue to search within.
            atom_name (str): The name of the atom to find (e.g., 'N', 'CA', 'C', 'O').

        Returns:
            Bio.PDB.Atom.Atom or None: The Atom object if found, otherwise None.
        """
        if atom_name in residue:
            return residue[atom_name]
        return None

    """Represents a contiguous segment of residues. multiple segments can be made from 1 chain"""

    def __init__(self, residues, original_chain_id, segment_idx_in_chain):
        if not residues:
            raise ValueError("Segment cannot be empty")
        self.residues = list(residues)
        self.original_chain_id = original_chain_id

        first_res_id_str = (
            f"{residues[0].id[0].strip()}{residues[0].id[1]}{residues[0].id[2].strip()}"
        )
        last_res_id_str = f"{residues[-1].id[0].strip()}{residues[-1].id[1]}{residues[-1].id[2].strip()}"
        self.id = f"Seg-{original_chain_id}-{segment_idx_in_chain}({first_res_id_str}_to_{last_res_id_str})"

        self.n_res = self.residues[0]
        self.c_res = self.residues[-1]

        self.n_atom = self.get_named_atom(self.n_res, "N")
        self.c_atom = self.get_named_atom(self.c_res, "C")

        if self.n_atom is None:
            raise ValueError(
                f"Segment {self.id} missing N-terminal atom for residue {self.n_res.id}."
            )
        if self.c_atom is None:
            raise ValueError(
                f"Segment {self.id} missing C-terminal atom for residue {self.c_res.id}."
            )

In [11]:
class ChainStich:
    def __init__(self, model, excluded_chains=None,connection_threshold_Ang= 10 ):
        """
        Initialize linker for connecting helix chains.
        
        Args:
            model: BioPython Model object containing the helices
            excluded_chains: List of chain IDs to exclude (e.g., hemes)
        """
        self.model = model
        self.excluded_chains = excluded_chains or []
        self.chain_ends = {}  # Store terminal residues for each chain
        self.connections = []  # Store pairs of chain ends to connect
        
        
        
        #
        # PEPTIDE_BOND_THRESHOLD: Distance (Angstroms) between C of res(i) and N of res(i+1)
        # to be considered a broken peptide bond (i.e., a new segment starts).
        self.PEPTIDE_BOND_THRESHOLD = 2.0  # Angstroms

        # CONNECTION_THRESHOLD: Maximum distance (Angstroms) between terminal atoms
        # of segments to consider them for connection.
        self.CONNECTION_THRESHOLD = connection_threshold_Ang # Angstroms
        
        # Initialize structure
        self.SegmentChains()
        
    def displaySegments(self):
        """
        Displays each segment as a ribbon with a different color using py3Dmol.
        Creates temporary chains from segments and renders the structure.
        """
        # Create a new structure for visualization
        struct = Structure.Structure("segmented")
        model = Model.Model(0)
        struct.add(model)
        
        # Generate a list of distinct colors
        colors = plt.cm.rainbow(np.linspace(0, 1, len(self.segments)))
        colors = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})' for r, g, b, _ in colors]
        
        # Create chains from segments
        for i, segment in enumerate(self.segments):
            chain_id = chr(65 + (i % 26))  # A-Z (looping if needed)
            chain = Chain.Chain(chain_id)
            
            for residue in segment.residues:
                # Create a deep copy of the residue
                new_res = Residue.Residue(residue.id, residue.resname, residue.segid)
                for atom in residue:
                    new_atom = Atom.Atom(atom.name, atom.coord, atom.bfactor, 
                                         atom.occupancy, atom.altloc, atom.fullname, 
                                         atom.serial_number, atom.element)
                    new_res.add(new_atom)
                chain.add(new_res)
            model.add(chain)
        
        # Save structure to temporary PDB file
        io = PDBIO()
        io.set_structure(struct)
        temp_pdb = "temp_segmented.pdb"
        io.save(temp_pdb)
        
        # Visualize with py3Dmol
        view = py3Dmol.view(width=800, height=600)
        view.addModel(open(temp_pdb, 'r').read(), 'pdb')
        
        # Color chains based on segment
        for i, segment in enumerate(self.segments):
            chain_id = chr(65 + (i % 26))
            view.setStyle({'chain': chain_id}, {'cartoon': {'color': colors[i]}})
        
        view.zoomTo()
        view.setBackgroundColor('white')
        return view.show()
        
    def SegmentChains(self):
 
        all_segments = []
        
        for chain in self.model:
            residues_in_chain = [res for res in chain.get_residues() if PDB.is_aa(res, standard=True) or res.id[0] != ' ']
            if not residues_in_chain:
                continue

            current_segment_residues = []
            segment_counter_for_chain = 0
            for i, res in enumerate(residues_in_chain):
                if not current_segment_residues:
                    current_segment_residues.append(res)
                else:
                    prev_res = current_segment_residues[-1]
                    if 'C' in prev_res and 'N' in res:
                        dist = np.linalg.norm(prev_res['C'].get_coord() - res['N'].get_coord())
                        if dist > self.PEPTIDE_BOND_THRESHOLD:
                            all_segments.append(Segment(current_segment_residues, chain.id, segment_counter_for_chain))
                            segment_counter_for_chain += 1
                            current_segment_residues = [res]
                        else:
                            current_segment_residues.append(res)
                    else: 
                        all_segments.append(Segment(current_segment_residues, chain.id, segment_counter_for_chain))
                        segment_counter_for_chain += 1
                        current_segment_residues = [res]
            
            if current_segment_residues:
                all_segments.append(Segment(current_segment_residues, chain.id, segment_counter_for_chain))

        if not all_segments:
            print("No valid segments found in the PDB file.")
            return
        
        print(f"Identified {len(all_segments)} segments in total.")       
        self.segments = all_segments 
        
    
    def getConnections(self,segments_to_connect, maxDistance_Ang ):
         # Find all potential connections between segments
         
        edges = []
        # Track which segments can form connections
        for i, seg1 in enumerate(segments_to_connect):
            bestDist  = np.inf
            bestSeg2 = None
            second_bestDist = np.inf
            second_bestSeg2 = None
            for j in range(i+1, len(segments_to_connect)):
                # Calculate distance between C-term of seg1 and N-term of seg2
                seg2 = segments_to_connect[j]
                c_n_distance = np.linalg.norm(
                    seg1.c_atom.get_coord() - seg2.n_atom.get_coord()
                )
                
                # If distance is within threshold, add to potential connections
                if c_n_distance <= maxDistance_Ang and c_n_distance < bestDist:
                        bestDist = c_n_distance
                        second_bestSeg2 = bestSeg2
                        second_bestDist = bestDist
                        bestSeg2 = j
            
            if bestSeg2 is not None:
                edges.append((i, bestSeg2, {"dist":bestDist}) )
            if second_bestSeg2 is not None:
                edges.append((i, second_bestSeg2, {"dist":second_bestDist}) )
                    
        return edges
    
    
    def find_optimal_connections(self, edges, segments):
        """Find the permutation that maximizes the number of connected segments"""
        # Create a graph from the edges
        remaining_segments = set(range(len(segments)))
        longestPaths = []
        while len(remaining_segments) > 1:
            # Create a directed graph from the edges
            G = nx.DiGraph()
            G.add_edges_from(edges)
            
            # Find the largest connected component
            longestPath =nx.dag_longest_path(G, weight="dist")
            if len(longestPath) ==0:
                break
            longestPaths.append(longestPath )
            
            #remove the segments in longest path from remaining segments, and edges
            for i in longestPath:
                remaining_segments.remove(i)
                edges = [edge for edge in edges if edge[0] != i and edge[1] != i]
                  
        #add the remaining segments as singletons
        for i in remaining_segments:
            longestPaths.append([i])
        return longestPaths
       
    
    def connectChains(self, chain_segments, segments_to_connect,    chain_id):
        """Connect segments in a chain using PolyGLinker."""
        # Create a new chain for the connected segments
        new_chain = Chain.Chain(chain_id)
        resId=0
        for i in range(len(chain_segments) - 1):
            residues = segments_to_connect[chain_segments[i]].residues 
            
            for res in residues:
                newRes = res.copy()
                newRes.id = (' ',resId, ' ')  # Reset residue ID
                new_chain.add(newRes)
                resId += 1
                
            linker = PolyGLinker()
            linkResidues = linker.generate_linker(
                segments_to_connect[chain_segments[i]],
                segments_to_connect[chain_segments[i + 1]]
            )
            for res in linkResidues:
                newRes = res.copy()
                newRes.id = (' ', resId, ' ')  # Reset residue ID
                new_chain.add(newRes)
                resId += 1
        
        # Add the last segment to the new chain
        last_segment = segments_to_connect[chain_segments[-1]].residues
        for res in last_segment:
            newRes = res.copy()
            newRes.id = (' ', resId, ' ')  # Reset residue ID
            new_chain.add(newRes)
            resId += 1
        return new_chain
        
    def ConnectClosest(self, maxDistance_Ang=10):
        """
        Connect segments with the closest ends, respecting protein chain directionality.
        Uses PolyGLinker to generate connecting residues between segments.
        
        Args:
            maxDistance_Ang: Maximum distance in Angstroms to consider for connections
        
        Returns:
            A new PDB structure with connected segments
        """
        # Create a copy of the segments to avoid modifying the originals
        segments_to_connect = self.segments.copy()
        if not segments_to_connect:
            print("No segments to connect.")
            return None
        
        # Create a new structure for the connected protein
        connected_structure = Structure.Structure("connected_protein")
        connected_model = Model.Model(0)
        connected_structure.add(connected_model)
       
        potential_connections = self.getConnections(segments_to_connect, maxDistance_Ang)

        # Get the optimal connections
        chains = self.find_optimal_connections(potential_connections, segments_to_connect)

        print(f"Found {len(chains)} chains")

        # Create a linker for generating connecting residues
        

        # Organize segments by chain IDs and add them to the new structure
        chain_id_counter = 65  # Start with 'A'
        
        for chain_segments in chains:
            chain_id = chr(chain_id_counter)
            new_Chain = self.connectChains(chain_segments, segments_to_connect,    chain_id)
            connected_model.add(new_Chain)
            chain_id_counter += 1
        
        return connected_model
    
       
       
input_pdb_path=r'C:\Users\bashc\Desktop\working\ssBinding\working9.pdb'        
parser = PDB.PDBParser(QUIET=True)
try:
    structure = parser.get_structure('protein', input_pdb_path)
except FileNotFoundError:
    print(f"Error: Input PDB file not found at {input_pdb_path}")
except Exception as e:
    print(f"Error parsing PDB file {input_pdb_path}: {e}")

        
chainStich =    ChainStich(structure[0], excluded_chains=[ ], connection_threshold_Ang= 10)    

#chainStich.displaySegments()
connectedModel = chainStich.ConnectClosest(maxDistance_Ang=25)


# Save the connected structure to a new PDB file
output_pdb_path = r"C:\Users\bashc\Desktop\working\ssBinding\connected_protein.pdb"
io = PDBIO()
io.set_structure(connectedModel)
io.save(output_pdb_path)
print(f"Connected structure saved to {output_pdb_path}")

ValueError: Segment Seg-A-4(H_ADP807_to_H_ADP807) missing N-terminal atom for residue ('H_ADP', 807, ' ').