In [25]:
from scipy.optimize import minimize
import numpy as np
from Bio import PDB
from Bio.PDB.PDBIO import PDBIO
from Bio.PDB import Atom, Residue, Chain, Model, Structure
from Bio.PDB.vectors import Vector
from scipy.special import comb
import py3Dmol
import matplotlib.pyplot as plt
import networkx as nx

In [32]:


class PolyGLinker:
    
    def __init__(self):
        # Standard amino acid geometry
        self.ca_c_length = 1.52    # Å
        self.c_n_length = 1.33     # Å
        self.n_ca_length = 1.46    # Å
        self.ca_ca_distance = 3.8  # Å
        self.ca_c_n_angle = 117.2  # degrees
        self.c_n_ca_angle = 121.7  # degrees
        self.peptide_omega = 180.0 # degrees

    def _optimize_control_points(self, start_pos, end_pos, start_dir, end_dir, num_residues):
        """Optimize control points to achieve desired path length."""
        target_length = (num_residues - 1) * self.ca_ca_distance
        distance = np.linalg.norm(end_pos - start_pos)
        
        def objective(x):
            # x contains scaling factors for control points
            scale1, scale2 = x
            p0 = start_pos
            p1 = start_pos + start_dir * (distance * scale1)
            p2 = end_pos - end_dir * (distance * scale2)
            p3 = end_pos
            
            path_length = self._calculate_path_length((p0, p1, p2, p3))
            return (path_length - target_length)**2
        
        # Optimize scaling factors
        result = minimize(objective, [0.33, 0.33], bounds=[(0.1, 0.9), (0.1, 0.9)])
        scale1, scale2 = result.x
        scale1, scale2 = (.33,.33)
        # Return optimized control points
        return (
            start_pos,
            start_pos + start_dir * (distance * scale1),
            end_pos - end_dir * (distance * scale2),
            end_pos
        )
        
    def _generate_backbone_atoms_with_direction(self, prev_c, prev_ca, current_ca, direction):
        """Generate backbone atoms with proper geometry."""
        peptide_plane_normal = np.cross(prev_c - prev_ca, direction)
        peptide_plane_normal = peptide_plane_normal / np.linalg.norm(peptide_plane_normal)
        
        n_direction = np.cross(direction, peptide_plane_normal)
        n_pos = prev_c + self.c_n_length * n_direction
        
        ca_pos = current_ca
        
        c_direction = np.cross(direction, peptide_plane_normal)
        c_pos = ca_pos + self.ca_c_length * c_direction
        
        return {
            'N': n_pos,
            'CA': ca_pos,
            'C': c_pos
        }

    def generate_linker(self, segment1, segment2 ):
        """Generate backbone coordinates for a linker using an optimized Bezier curve."""
        start_res =segment1.c_res
        end_res = segment2.n_res
        start_direction =segment1._calculate_residue_direction(start_res)* (.5 + np.random.rand() * 2)
        end_direction = segment2._calculate_residue_direction(end_res) * (.5 + np.random.rand() * 2)
        
        start_ca = start_res['CA'].get_coord()
        end_ca = end_res['CA'].get_coord()
        
        distance = np.linalg.norm(end_ca - start_ca)
        # Generate a test path for the Bezier curve
        test_control_points = (start_ca, start_ca + start_direction * (distance / 3), end_ca - end_direction * (distance / 3), end_ca)
        test_path_length = self._calculate_path_length(test_control_points)
        num_residues = int(np.ceil(test_path_length / self.ca_ca_distance))
        
         
        
        # Optimize control points for target path length
        control_points = self._optimize_control_points(
            start_ca, end_ca, start_direction, end_direction, num_residues
        )
        
        # Generate points along optimized Bezier curve with equal arc length spacing
        # First, calculate a finely sampled path
        fine_t = np.linspace(0, 1, 1000)
        fine_path = np.array([self._evaluate_bezier(control_points, ti) for ti in fine_t])
                
        # Calculate cumulative arc length at each point
        segments = np.diff(fine_path, axis=0)
        segment_lengths = np.sqrt(np.sum(segments**2, axis=1))
        cumulative_length = np.concatenate(([0], np.cumsum(segment_lengths)))
        total_length = cumulative_length[-1]
                
        # Generate equally spaced points based on arc length
        target_lengths = np.linspace(0, total_length, num_residues)
        path = np.zeros((num_residues, 3))
                
        # For each target length, find corresponding point on curve
        for i, target in enumerate(target_lengths):
            # Find index of the segment containing our target length
            idx = np.searchsorted(cumulative_length, target)
            if idx == 0:
                path[i] = fine_path[0]
            elif idx >= len(fine_path):
                path[i] = fine_path[-1]
            else:
                # Interpolate between points
                prev_idx = idx - 1
                segment_start = cumulative_length[prev_idx]
                segment_end = cumulative_length[idx]
                segment_fraction = (target - segment_start) / (segment_end - segment_start)
                path[i] = fine_path[prev_idx] + segment_fraction * (fine_path[idx] - fine_path[prev_idx])
                
        # Create a new chain for the linker
        linker_chain = Chain.Chain("L")  # 'L' for linker
        prev_c = start_res['C'].get_coord()
        prev_ca = start_ca
        for i in range(num_residues):
            # Calculate tangent for better backbone placement
            if i < len(path) - 1:
                tangent = path[i+1] - path[i]
            else:
                tangent = end_direction
            tangent = tangent / np.linalg.norm(tangent)
                    
            res_atoms = self._generate_backbone_atoms_with_direction(
                prev_c, prev_ca, path[i], tangent
            )
            prev_c = res_atoms['C']
            prev_ca = res_atoms['CA']
            
            residue = Residue.Residue((' ', i+1, ' ') , "GLY", "")
            
            # Add atoms to residue
            for atom_name, position in res_atoms.items():
                atom_id = atom_name
                element = atom_name[0]  # First letter of atom name
                atom = Atom.Atom(atom_id, Vector(position), 0.0, 1.0, " ", atom_id, element)
                residue.add(atom)
            
            # Add residue to chain
            linker_chain.add(residue)

        return linker_chain
         
        
    
    def _evaluate_bezier(self, control_points, t):
        """Evaluate a cubic Bezier curve at parameter t."""
        p0, p1, p2, p3 = control_points
        return (
            (1-t)**3 * p0 + 
            3*(1-t)**2 * t * p1 + 
            3*(1-t) * t**2 * p2 + 
            t**3 * p3
        )

    def _calculate_path_length(self, control_points, num_points=100):
        """Calculate the total length of the Bezier curve."""
        t = np.linspace(0, 1, num_points)
        points = np.array([self._evaluate_bezier(control_points, ti) for ti in t])
        
        # Calculate total length as sum of segments
        segments = np.diff(points, axis=0)
        lengths = np.sqrt(np.sum(segments**2, axis=1))
        return np.sum(lengths) 
    
class Segment:
    def get_named_atom(self,residue, atom_name):
        """
        Retrieves a specific atom from a residue.

        Args:
            residue (Bio.PDB.Residue.Residue): The residue to search within.
            atom_name (str): The name of the atom to find (e.g., 'N', 'CA', 'C', 'O').

        Returns:
            Bio.PDB.Atom.Atom or None: The Atom object if found, otherwise None.
        """
        if atom_name in residue:
            return residue[atom_name]
        return None
    
    def _calculate_residue_direction(self, residue):
        """Calculate the direction vector of a residue's backbone."""
        ca_pos = residue['CA'].get_coord()
        if 'C' in residue:
            c_pos = residue['C'].get_coord()
            direction = c_pos - ca_pos
        else:
            n_pos = residue['N'].get_coord()
            direction = ca_pos - n_pos
        return direction / np.linalg.norm(direction)

    """Represents a contiguous segment of residues. multiple segments can be made from 1 chain"""
    def __init__(self, residues, original_chain_id, segment_idx_in_chain):
        if not residues:
            raise ValueError("Segment cannot be empty")
        self.residues = list(residues) 
        self.original_chain_id = original_chain_id
        
        first_res_id_str = f"{residues[0].id[0].strip()}{residues[0].id[1]}{residues[0].id[2].strip()}"
        last_res_id_str = f"{residues[-1].id[0].strip()}{residues[-1].id[1]}{residues[-1].id[2].strip()}"
        self.id = f"Seg-{original_chain_id}-{segment_idx_in_chain}({first_res_id_str}_to_{last_res_id_str})"

        self.n_res = self.residues[0]
        self.c_res = self.residues[-1]

        self.n_atom = self. get_named_atom(self.n_res, 'N')
        self.c_atom =  self.get_named_atom(self.c_res, 'C')

        if self.n_atom is None:
            raise ValueError(f"Segment {self.id} missing N-terminal atom for residue {self.n_res.id}.")
        if self.c_atom is None:
            raise ValueError(f"Segment {self.id} missing C-terminal atom for residue {self.c_res.id}.")


In [35]:
class ChainStich:
    def __init__(self, model, excluded_chains=None,connection_threshold_Ang= 10 ):
        """
        Initialize linker for connecting helix chains.
        
        Args:
            model: BioPython Model object containing the helices
            excluded_chains: List of chain IDs to exclude (e.g., hemes)
        """
        self.model = model
        self.excluded_chains = excluded_chains or []
        self.chain_ends = {}  # Store terminal residues for each chain
        self.connections = []  # Store pairs of chain ends to connect
        
        
        
        #
        # PEPTIDE_BOND_THRESHOLD: Distance (Angstroms) between C of res(i) and N of res(i+1)
        # to be considered a broken peptide bond (i.e., a new segment starts).
        self.PEPTIDE_BOND_THRESHOLD = 2.0  # Angstroms

        # CONNECTION_THRESHOLD: Maximum distance (Angstroms) between terminal atoms
        # of segments to consider them for connection.
        self.CONNECTION_THRESHOLD = connection_threshold_Ang # Angstroms
        
        # Initialize structure
        self.SegmentChains()
        
    def displaySegments(self):
        """
        Displays each segment as a ribbon with a different color using py3Dmol.
        Creates temporary chains from segments and renders the structure.
        """
        # Create a new structure for visualization
        struct = Structure.Structure("segmented")
        model = Model.Model(0)
        struct.add(model)
        
        # Generate a list of distinct colors
        colors = plt.cm.rainbow(np.linspace(0, 1, len(self.segments)))
        colors = [f'rgb({int(r*255)},{int(g*255)},{int(b*255)})' for r, g, b, _ in colors]
        
        # Create chains from segments
        for i, segment in enumerate(self.segments):
            chain_id = chr(65 + (i % 26))  # A-Z (looping if needed)
            chain = Chain.Chain(chain_id)
            
            for residue in segment.residues:
                # Create a deep copy of the residue
                new_res = Residue.Residue(residue.id, residue.resname, residue.segid)
                for atom in residue:
                    new_atom = Atom.Atom(atom.name, atom.coord, atom.bfactor, 
                                         atom.occupancy, atom.altloc, atom.fullname, 
                                         atom.serial_number, atom.element)
                    new_res.add(new_atom)
                chain.add(new_res)
            model.add(chain)
        
        # Save structure to temporary PDB file
        io = PDBIO()
        io.set_structure(struct)
        temp_pdb = "temp_segmented.pdb"
        io.save(temp_pdb)
        
        # Visualize with py3Dmol
        view = py3Dmol.view(width=800, height=600)
        view.addModel(open(temp_pdb, 'r').read(), 'pdb')
        
        # Color chains based on segment
        for i, segment in enumerate(self.segments):
            chain_id = chr(65 + (i % 26))
            view.setStyle({'chain': chain_id}, {'cartoon': {'color': colors[i]}})
        
        view.zoomTo()
        view.setBackgroundColor('white')
        return view.show()
        
    def SegmentChains(self):
 
        all_segments = []
        
        for chain in self.model:
            residues_in_chain = [res for res in chain.get_residues() if PDB.is_aa(res, standard=True) or res.id[0] != ' ']
            if not residues_in_chain:
                continue

            current_segment_residues = []
            segment_counter_for_chain = 0
            for i, res in enumerate(residues_in_chain):
                if not current_segment_residues:
                    current_segment_residues.append(res)
                else:
                    prev_res = current_segment_residues[-1]
                    if 'C' in prev_res and 'N' in res:
                        dist = np.linalg.norm(prev_res['C'].get_coord() - res['N'].get_coord())
                        if dist > self.PEPTIDE_BOND_THRESHOLD:
                            all_segments.append(Segment(current_segment_residues, chain.id, segment_counter_for_chain))
                            segment_counter_for_chain += 1
                            current_segment_residues = [res]
                        else:
                            current_segment_residues.append(res)
                    else: 
                        all_segments.append(Segment(current_segment_residues, chain.id, segment_counter_for_chain))
                        segment_counter_for_chain += 1
                        current_segment_residues = [res]
            
            if current_segment_residues:
                all_segments.append(Segment(current_segment_residues, chain.id, segment_counter_for_chain))

        if not all_segments:
            print("No valid segments found in the PDB file.")
            return
        
        print(f"Identified {len(all_segments)} segments in total.")       
        self.segments = all_segments 
        
    
    def getConnections(self,segments_to_connect, maxDistance_Ang ):
         # Find all potential connections between segments
         
        edges = []
        # Track which segments can form connections
        for i, seg1 in enumerate(segments_to_connect):
            bestDist  = np.inf
            bestSeg2 = None
            second_bestDist = np.inf
            second_bestSeg2 = None
            for j in range(i+1, len(segments_to_connect)):
                # Calculate distance between C-term of seg1 and N-term of seg2
                seg2 = segments_to_connect[j]
                c_n_distance = np.linalg.norm(
                    seg1.c_atom.get_coord() - seg2.n_atom.get_coord()
                )
                
                # If distance is within threshold, add to potential connections
                if c_n_distance <= maxDistance_Ang and c_n_distance < bestDist:
                        bestDist = c_n_distance
                        second_bestSeg2 = bestSeg2
                        second_bestDist = bestDist
                        bestSeg2 = j
            
            if bestSeg2 is not None:
                edges.append((i, bestSeg2, {"dist":bestDist}) )
            if second_bestSeg2 is not None:
                edges.append((i, second_bestSeg2, {"dist":second_bestDist}) )
                    
        return edges
    
    
    def find_optimal_connections(self, edges, segments):
        """Find the permutation that maximizes the number of connected segments"""
        # Create a graph from the edges
        remaining_segments = set(range(len(segments)))
        longestPaths = []
        while len(remaining_segments) > 1:
            # Create a directed graph from the edges
            G = nx.DiGraph()
            G.add_edges_from(edges)
            
            # Find the largest connected component
            longestPath =nx.dag_longest_path(G, weight="dist")
            if len(longestPath) ==0:
                break
            longestPaths.append(longestPath )
            
            #remove the segments in longest path from remaining segments, and edges
            for i in longestPath:
                remaining_segments.remove(i)
                edges = [edge for edge in edges if edge[0] != i and edge[1] != i]
                  
        #add the remaining segments as singletons
        for i in remaining_segments:
            longestPaths.append([i])
        return longestPaths
       
    
    def connectChains(self, chain_segments, segments_to_connect,    chain_id):
        """Connect segments in a chain using PolyGLinker."""
        # Create a new chain for the connected segments
        new_chain = Chain.Chain(chain_id)
        
        for i in range(len(chain_segments) - 1):
            residues = segments_to_connect[chain_segments[i]].residues[-1]
            for res in residues:
                newRes = res.copy()
                newRes.id = (' ', res.id[1], ' ')  # Reset residue ID
                new_chain.add(newRes)
                
            linker = PolyGLinker()
            linkResidues = linker.generate_linker(
                segments_to_connect[chain_segments[i]],
                segments_to_connect[chain_segments[i + 1]]
            )
            for res in linkResidues:
                newRes = res.copy()
                newRes.id = (' ', res.id[1], ' ')  # Reset residue ID
                new_chain.add(newRes)
        
        # Add the last segment to the new chain
        last_segment = segments_to_connect[chain_segments[-1]].residues
        for res in last_segment:
            newRes = res.copy()
            newRes.id = (' ', res.id[1], ' ')  # Reset residue ID
            new_chain.add(newRes)
        return new_chain
        
    def ConnectClosest(self, maxDistance_Ang=10):
        """
        Connect segments with the closest ends, respecting protein chain directionality.
        Uses PolyGLinker to generate connecting residues between segments.
        
        Args:
            maxDistance_Ang: Maximum distance in Angstroms to consider for connections
        
        Returns:
            A new PDB structure with connected segments
        """
        # Create a copy of the segments to avoid modifying the originals
        segments_to_connect = self.segments.copy()
        if not segments_to_connect:
            print("No segments to connect.")
            return None
        
        # Create a new structure for the connected protein
        connected_structure = Structure.Structure("connected_protein")
        connected_model = Model.Model(0)
        connected_structure.add(connected_model)
       
        potential_connections = self.getConnections(segments_to_connect, maxDistance_Ang)

        # Get the optimal connections
        chains = self.find_optimal_connections(potential_connections, segments_to_connect)

        print(f"Found {len(chains)} chains")

        # Create a linker for generating connecting residues
        

        # Organize segments by chain IDs and add them to the new structure
        chain_id_counter = 65  # Start with 'A'
        
        for chain_segments in chains:
            chain_id = chr(chain_id_counter)
            new_Chain = self.connectChains(chain_segments, segments_to_connect,    chain_id)
            connected_model.add(new_Chain)
            chain_id_counter += 1
        
        return connected_model
    
       
       
input_pdb_path=r'C:\Users\bashc\Desktop\working\ssBinding\working8.pdb'        
parser = PDB.PDBParser(QUIET=True)
try:
    structure = parser.get_structure('protein', input_pdb_path)
except FileNotFoundError:
    print(f"Error: Input PDB file not found at {input_pdb_path}")
except Exception as e:
    print(f"Error parsing PDB file {input_pdb_path}: {e}")

        
chainStich =    ChainStich(structure[0], excluded_chains=[ ], connection_threshold_Ang= 10)    

#chainStich.displaySegments()
chainStich.ConnectClosest(maxDistance_Ang=10)

Identified 18 segments in total.
Found 10 chains


IndexError: string index out of range

In [None]:

    def _find_closest_connections(self):
        """Find closest chain ends to connect."""
        chain_ids = list(self.chain_ends.keys())
        
        for i, chain1 in enumerate(chain_ids):
            distances = []
            for chain2 in chain_ids[i+1:]:  # Only look at chains we haven't considered yet
                end_carbon = self.chain_ends[chain1]['C']['C'].get_coord()
                start_n = self.chain_ends[chain2]['N']['N'].get_coord()
                
                distances.append({
                    'c2': chain2,
                    'distance': np.linalg.norm(end_carbon - start_n)
                })
            
            if distances:  # Only add if we found valid connections
                best_connection = min(distances, key=lambda x: x['distance'])
                self.connections.append({
                    'chain1': chain1,
                    'chain2': best_connection['c2'],
                    'distance': best_connection['distance']
                })

    


    

    def connect_chains(self):
        """Connect all identified chain pairs with linkers."""
        connection=  min(self.connections, key=lambda x: x['distance'])
        if connection['distance'] > 25:
            print('too far',connection['distance'])
            return False
        backbone_atoms = self.generate_linker(connection)
        
        # Get the chains to connect
        chain1_id = connection['chain1']
        chain2_id = connection['chain2']
        
        # Add linker residues to chain1
        last_res_id = max(int(res.id[1]) for res in self.model[chain1_id])
        
        # Add linker residues
        for i, atoms in enumerate(backbone_atoms):
            new_res_id = last_res_id + i + 1
            new_res = Residue.Residue((' ', new_res_id, ' '), 'GLY', '')
            
            for atom_name, coord in atoms.items():
                new_atom = Atom.Atom(
                    name=atom_name,
                    coord=coord,
                    bfactor=20.0,
                    occupancy=1.0,
                    altloc=' ',
                    fullname=atom_name,
                    serial_number=None,
                    element=atom_name[0]
                )
                new_res.add(new_atom)
            
            self.model[chain1_id].add(new_res)
        
        # Move residues from chain2 to chain1
        residues_to_move = list(self.model[chain2_id].get_residues())
        chain =self.model[chain1_id]
        new_res_id=max([max(int(res.id[1]) for res in chain), max(int(res.id[1]) for res in residues_to_move)])
        for res in residues_to_move:
            try:
                new_res_id +=1
                res.id = (' ', new_res_id, ' ')
                chain.add(res)
            except Exception as e:
                print(e)
                print("Error adding residue")
        
        # Remove chain2
        self.model.detach_child(chain2_id)
        print(chain1_id,chain2_id)
        return True
         

aligner = HemeHelixAligner("double_heme.pdb")
aligner.align_all()
aligner.save_structure('test.pdb')
print("Structure aligned and saved.")

model =aligner.model
# Initialize the manipulator with your model
manipulator = HelixManipulator(model)

# Align helices with z-axis
manipulator.align_to_z()

# Rotate helices (e.g., 30 degrees)
angle = 10
manipulator.rotate_helices(angle)
manipulator.create_packed_copies(angle)


io = PDBIO()
io.set_structure(model)
io.save("target-complex.pdb")

heme_chains = []
for chain in model:
    for residue in chain:
        if residue.get_resname() == "HEM":
            heme_chains.append(chain.id)
            break
        
heme_chains        

Structure aligned and saved.


['N', 'E', 'H', 'K', 'Z', 'Q']

In [None]:


def bernstein_polynomial(t, n, i):
    """Bernstein polynomial for Bezier curve calculation."""
    return comb(n, i) * (t ** i) * ((1 - t) ** (n - i))

def bezier_curve(control_points, num_points=100):
    """
    Generate points along a Bezier curve defined by control points.
    
    Args:
        control_points: List of 3D coordinates (numpy arrays)
        num_points: Number of points to generate along the curve
    
    Returns:
        List of coordinates along the curve
    """
    n = len(control_points) - 1
    curve_points = []
    
    for t in np.linspace(0, 1, num_points):
        point = np.zeros(3)
        for i, cp in enumerate(control_points):
            point += bernstein_polynomial(t, n, i) * cp
        curve_points.append(point)
    
    return curve_points

def create_glycine_residue(res_id, position, prev_c_pos=None, next_n_pos=None):
    """
    Create a glycine residue with coordinates based on the given position.
    Adjusts N and C atoms to form peptide bonds of PEPTIDE_BOND_LENGTH.
    Places O atom using standard geometry relative to N, CA, C of the glycine.

    Args:
        res_id: Residue ID tuple (hetflag, resnum, icode)
        position: 3D coordinates for CA atom (numpy array)
        prev_c_pos: Position of C atom from previous residue for N-terminal connection
        next_n_pos: Position of N atom from next residue for C-terminal connection

    Returns:
        Bio.PDB.Residue with proper glycine geometry
    """
    res = Residue.Residue(res_id, "GLY", " ")
    ca_coord = np.array(position)

    # Ideal Glycine coordinates relative to CA at origin (N-CA-C plane in XY)
    n_ideal_offset_from_ca = np.array([-0.589, -1.354, 0.0]) 
    c_ideal_offset_from_ca = np.array([1.523, 0.0, 0.0])

    n_coord = ca_coord + n_ideal_offset_from_ca
    c_coord = ca_coord + c_ideal_offset_from_ca

    if prev_c_pos is not None:
        prev_c_coord = np.array(prev_c_pos)
        vec_prevC_to_glyCA = ca_coord - prev_c_coord
        dist_prevC_to_glyCA = np.linalg.norm(vec_prevC_to_glyCA)
        if dist_prevC_to_glyCA > 1e-3: 
            dir_prevC_to_glyCA = vec_prevC_to_glyCA / dist_prevC_to_glyCA
            n_coord = prev_c_coord + dir_prevC_to_glyCA * PEPTIDE_BOND_LENGTH

    if next_n_pos is not None:
        next_n_coord = np.array(next_n_pos)
        vec_glyCA_to_nextN = next_n_coord - ca_coord
        dist_glyCA_to_nextN = np.linalg.norm(vec_glyCA_to_nextN)
        if dist_glyCA_to_nextN > 1e-3: 
            dir_glyCA_to_nextN = vec_glyCA_to_nextN / dist_glyCA_to_nextN
            c_coord = next_n_coord - dir_glyCA_to_nextN * PEPTIDE_BOND_LENGTH

    vec_ca_c = c_coord - ca_coord
    vec_ca_n = n_coord - ca_coord 

    norm_vec_ca_c = np.linalg.norm(vec_ca_c)
    norm_vec_ca_n = np.linalg.norm(vec_ca_n)

    if norm_vec_ca_c < 1e-3 or norm_vec_ca_n < 1e-3:
        o_default_offset_from_c = np.array([0.624, 1.078, 0.0]) 
        o_coord = c_coord + o_default_offset_from_c 
    else:
        u_ca_c = vec_ca_c / norm_vec_ca_c 
        u_ca_n = vec_ca_n / norm_vec_ca_n 

        plane_normal = np.cross(u_ca_n, u_ca_c)
        norm_plane_normal = np.linalg.norm(plane_normal)

        if norm_plane_normal < 1e-3: 
            if abs(u_ca_c[0]) > 0.9: arbitrary_perp_ref = np.array([0.0, 1.0, 0.0])
            else: arbitrary_perp_ref = np.array([1.0, 0.0, 0.0])
            plane_normal = np.cross(u_ca_c, arbitrary_perp_ref)
            if np.linalg.norm(plane_normal) < 1e-3: 
                 plane_normal = np.cross(u_ca_c, np.array([0.0,0.0,1.0]))
            plane_normal /= np.linalg.norm(plane_normal)

        m_vec = np.cross(plane_normal, u_ca_c)
        m_vec = m_vec / np.linalg.norm(m_vec)
        
        if np.dot(m_vec, u_ca_n - u_ca_c * np.dot(u_ca_n, u_ca_c)) < 0:
            m_vec = -m_vec
            
        C_O_BOND_LENGTH = 1.23 
        ANGLE_CA_C_O = np.radians(120.8) 

        o_coord_comp_along_cca = C_O_BOND_LENGTH * np.cos(ANGLE_CA_C_O)
        o_coord_comp_along_mvec = C_O_BOND_LENGTH * np.sin(ANGLE_CA_C_O)
        
        o_coord = c_coord + o_coord_comp_along_cca * (-u_ca_c) + o_coord_comp_along_mvec * m_vec

    n_atom = Atom.Atom('N', n_coord, 0.0, 1.0, ' ', ' N  ', 0, 'N')
    ca_atom = Atom.Atom('CA', ca_coord, 0.0, 1.0, ' ', ' CA ', 0, 'C')
    c_atom = Atom.Atom('C', c_coord, 0.0, 1.0, ' ', ' C  ', 0, 'C')
    o_atom = Atom.Atom('O', o_coord, 0.0, 1.0, ' ', ' O  ', 0, 'O')
    
    res.add(n_atom)
    res.add(ca_atom)
    res.add(c_atom)
    res.add(o_atom)
    
    return res





def connect_broken_chains(input_pdb_path, output_pdb_path):
    """
    Loads a PDB file, identifies all segments, determines the optimal way to connect
    them into a primary long chain, and connects remaining segments into secondary chains.
    Fills gaps with glycine residues. Outputs all resulting chains.
    """
    parser = PDB.PDBParser(QUIET=True)
    try:
        structure = parser.get_structure('protein', input_pdb_path)
    except FileNotFoundError:
        print(f"Error: Input PDB file not found at {input_pdb_path}")
        return
    except Exception as e:
        print(f"Error parsing PDB file {input_pdb_path}: {e}")
        return

    

    # --- 3. Find the best single chain by trying all possible start segments ---
    best_connection_run = {
        "initial_segment_index": -1,
        "connected_segments_ordered": [], 
        "num_original_segments": 0
    }

    for i, initial_seg in enumerate(all_segments):
        current_merged_chain_segments = [initial_seg]
        remaining_for_this_run = [s for idx, s in enumerate(all_segments) if idx != i]
        
        chain_n_end_atom = initial_seg.n_atom 
        chain_c_end_atom = initial_seg.c_atom

        while remaining_for_this_run:
            best_append_candidate = None 
            min_dist_append = float('inf')
            best_prepend_candidate = None
            min_dist_prepend = float('inf')

            chain_c_coord = chain_c_end_atom.get_coord()
            chain_n_coord = chain_n_end_atom.get_coord()

            for j, seg_candidate in enumerate(remaining_for_this_run):
                dist_c_to_n = np.linalg.norm(chain_c_coord - seg_candidate.n_atom.get_coord())
                if dist_c_to_n < min_dist_append:
                    min_dist_append = dist_c_to_n
                    best_append_candidate = (j, seg_candidate, dist_c_to_n)

                dist_n_to_c = np.linalg.norm(chain_n_coord - seg_candidate.c_atom.get_coord())
                if dist_n_to_c < min_dist_prepend:
                    min_dist_prepend = dist_n_to_c
                    best_prepend_candidate = (j, seg_candidate, dist_n_to_c)
            
            added_in_iteration = False
            choice_made = None
            
            can_append = best_append_candidate and best_append_candidate[2] < CONNECTION_THRESHOLD
            can_prepend = best_prepend_candidate and best_prepend_candidate[2] < CONNECTION_THRESHOLD

            if can_append and can_prepend:
                choice_made = "append" if best_append_candidate[2] <= best_prepend_candidate[2] else "prepend"
            elif can_append:
                choice_made = "append"
            elif can_prepend:
                choice_made = "prepend"

            if choice_made == "append":
                idx_rem, seg_to_add, _ = best_append_candidate
                current_merged_chain_segments.append(seg_to_add)
                chain_c_end_atom = seg_to_add.c_atom 
                remaining_for_this_run.pop(idx_rem)
                added_in_iteration = True
            elif choice_made == "prepend":
                idx_rem, seg_to_add, _ = best_prepend_candidate
                current_merged_chain_segments.insert(0, seg_to_add)
                chain_n_end_atom = seg_to_add.n_atom 
                remaining_for_this_run.pop(idx_rem)
                added_in_iteration = True
            
            if not added_in_iteration:
                break 

        if len(current_merged_chain_segments) > best_connection_run["num_original_segments"]:
            best_connection_run["initial_segment_index"] = i
            best_connection_run["connected_segments_ordered"] = list(current_merged_chain_segments)
            best_connection_run["num_original_segments"] = len(current_merged_chain_segments)

    # --- 4. Build the primary chain ---
    final_chains_residues = [] 
    used_segments_for_main_chain = set()
    
    max_res_num_in_pdb = 0
    for seg_glob in all_segments:
        for res_glob in seg_glob.residues:
            if res_glob.id[1] > max_res_num_in_pdb:
                max_res_num_in_pdb = res_glob.id[1]
    temp_glycine_res_num_counter = max_res_num_in_pdb

    if not best_connection_run["connected_segments_ordered"]:
        print("Warning: Could not form any primary connected chain.")
    else:
        print(f"Best primary chain will connect {best_connection_run['num_original_segments']} segments.")
        primary_chain_merged_residues = []
        segments_for_main_chain = best_connection_run["connected_segments_ordered"]
        
        primary_chain_merged_residues.extend(segments_for_main_chain[0].residues)
        # The C-atom of the *last added residue* to primary_chain_merged_residues
        current_c_atom_obj_main = get_terminal_atom(primary_chain_merged_residues[-1], 'C')
        if not current_c_atom_obj_main: # Should not happen if segment is valid
             raise ValueError(f"Segment {segments_for_main_chain[0].id} has no C-terminal C atom in its last residue.")


        for i in range(1, len(segments_for_main_chain)):
            prev_segment_in_main = segments_for_main_chain[i-1] # For context/CA
            current_segment_to_add_main = segments_for_main_chain[i]

            prev_c_for_glycine = current_c_atom_obj_main.get_coord()
            next_n_for_glycine = current_segment_to_add_main.n_atom.get_coord()
            dist_c_n = np.linalg.norm(prev_c_for_glycine - next_n_for_glycine)

            added_glycines_main = []
            if dist_c_n > PEPTIDE_BOND_THRESHOLD:
                p0 = prev_c_for_glycine
                p3 = next_n_for_glycine
                ca_prev_res = primary_chain_merged_residues[-1] # Last residue of growing chain
                ca_next_res = current_segment_to_add_main.residues[0] # First residue of next segment

                if 'CA' in ca_prev_res and 'CA' in ca_next_res:
                    ca_prev = ca_prev_res['CA'].get_coord()
                    ca_next = ca_next_res['CA'].get_coord()
                    v0 = p0 - ca_prev 
                    v3 = p3 - ca_next 
                    p1 = p0 + v0 * 0.5 
                    p2 = p3 + v3 * 0.5 
                else: 
                    mid_point = (p0 + p3) / 2.0
                    p1 = p0 + (mid_point - p0) * 0.5
                    p2 = p3 + (mid_point - p3) * 0.5
                control_points = [p0, p1, p2, p3]

                curve_length_approx = np.linalg.norm(p3 - p0) 
                num_glycines = max(0, int(round(curve_length_approx / CA_CA_DISTANCE)) -1)

                if num_glycines > 0:
                    print(f"  Main chain: adding {num_glycines} glycines between ...{prev_segment_in_main.id} and {current_segment_to_add_main.id} (C-N dist: {dist_c_n:.2f} A)")
                    ca_p0 = ca_prev_res['CA'].get_coord() if 'CA' in ca_prev_res else p0
                    ca_p3 = ca_next_res['CA'].get_coord() if 'CA' in ca_next_res else p3
                    
                    ca_v0 = ca_p0 - p0 
                    ca_v3 = ca_p3 - p3

                    ca_control_points = [ca_p0, ca_p0 + ca_v0 * 0.5, ca_p3 + ca_v3 * 0.5, ca_p3]
                    if num_glycines == 1 and np.linalg.norm(ca_p0 - ca_p3) < CA_CA_DISTANCE * 1.5 :
                         gly_ca_positions = [(ca_p0 + ca_p3)/2.0]
                    else:
                        gly_ca_curve = bezier_curve(ca_control_points, num_points=num_glycines + 2)
                        gly_ca_positions = gly_ca_curve[1:-1]
                    
                    current_connection_c_coord_main = prev_c_for_glycine
                    for k_gly, gly_ca_pos in enumerate(gly_ca_positions):
                        temp_glycine_res_num_counter += 1
                        prev_c_for_this_gly = current_connection_c_coord_main
                        next_n_for_this_gly = next_n_for_glycine if k_gly == num_glycines - 1 else None
                        gly_res = create_glycine_residue(
                            (' ', temp_glycine_res_num_counter, ' '), gly_ca_pos,
                            prev_c_pos=prev_c_for_this_gly, next_n_pos=next_n_for_this_gly
                        )
                        added_glycines_main.append(gly_res)
                        current_connection_c_coord_main = gly_res['C'].get_coord()
            
            primary_chain_merged_residues.extend(added_glycines_main)
            primary_chain_merged_residues.extend(current_segment_to_add_main.residues)
            current_c_atom_obj_main = get_terminal_atom(primary_chain_merged_residues[-1], 'C')
            if not current_c_atom_obj_main:
                 raise ValueError(f"Segment {current_segment_to_add_main.id} resulted in no C-terminal C atom after adding to main chain.")

        final_chains_residues.append(primary_chain_merged_residues)
        used_segments_for_main_chain = set(segments_for_main_chain)

    # --- 5. Process unused segments ---
    all_segments_set = set(all_segments)
    unused_segments_list = list(all_segments_set - used_segments_for_main_chain)
    
    unused_by_original_chain = {}
    for seg in unused_segments_list:
        unused_by_original_chain.setdefault(seg.original_chain_id, []).append(seg)

    for chain_id_orig, seg_list_orig in unused_by_original_chain.items():
        seg_list_orig.sort(key=lambda s: (s.residues[0].id[1], s.residues[0].id[2].strip())) # Sort by resnum then icode

    processed_unused_segments = set()

    for original_chain_id_of_unused, segments_from_one_orig_chain_initial in unused_by_original_chain.items():
        
        segments_to_process_in_group = [s for s in segments_from_one_orig_chain_initial if s not in processed_unused_segments]
        
        while segments_to_process_in_group:
            if not segments_to_process_in_group: break

            if len(segments_to_process_in_group) == 1:
                seg_single = segments_to_process_in_group[0]
                final_chains_residues.append(list(seg_single.residues))
                processed_unused_segments.add(seg_single)
                print(f"  Unused segment {seg_single.id} from original chain {original_chain_id_of_unused} forms a separate chain.")
                segments_to_process_in_group.pop(0)
                continue

            best_internal_connection_run = { "connected_segments_ordered": [], "num_original_segments": 0 }

            for i_start_internal, initial_seg_internal in enumerate(segments_to_process_in_group):
                current_merged_internal = [initial_seg_internal]
                remaining_internal = [s for idx, s in enumerate(segments_to_process_in_group) if idx != i_start_internal]
                
                chain_n_end_atom_internal = initial_seg_internal.n_atom
                chain_c_end_atom_internal = initial_seg_internal.c_atom

                while remaining_internal:
                    best_append_internal = None; min_dist_app_int = float('inf')
                    best_prepend_internal = None; min_dist_pre_int = float('inf')
                    chain_c_coord_int = chain_c_end_atom_internal.get_coord()
                    chain_n_coord_int = chain_n_end_atom_internal.get_coord()

                    for j_int, seg_cand_int in enumerate(remaining_internal):
                        dist_c_n_int = np.linalg.norm(chain_c_coord_int - seg_cand_int.n_atom.get_coord())
                        if dist_c_n_int < min_dist_app_int:
                            min_dist_app_int = dist_c_n_int
                            best_append_internal = (j_int, seg_cand_int, dist_c_n_int)
                        dist_n_c_int = np.linalg.norm(chain_n_coord_int - seg_cand_int.c_atom.get_coord())
                        if dist_n_c_int < min_dist_pre_int:
                            min_dist_pre_int = dist_n_c_int
                            best_prepend_internal = (j_int, seg_cand_int, dist_n_c_int)
                    
                    added_in_iter_internal = False; choice_made_internal = None
                    can_app_int = best_append_internal and best_append_internal[2] < CONNECTION_THRESHOLD
                    can_pre_int = best_prepend_internal and best_prepend_internal[2] < CONNECTION_THRESHOLD

                    if can_app_int and can_pre_int: choice_made_internal = "append" if best_append_internal[2] <= best_prepend_internal[2] else "prepend"
                    elif can_app_int: choice_made_internal = "append"
                    elif can_pre_int: choice_made_internal = "prepend"

                    if choice_made_internal == "append":
                        idx_rem_int, seg_to_add_int, _ = best_append_internal
                        current_merged_internal.append(seg_to_add_int)
                        chain_c_end_atom_internal = seg_to_add_int.c_atom
                        remaining_internal.pop(idx_rem_int); added_in_iter_internal = True
                    elif choice_made_internal == "prepend":
                        idx_rem_int, seg_to_add_int, _ = best_prepend_internal
                        current_merged_internal.insert(0, seg_to_add_int)
                        chain_n_end_atom_internal = seg_to_add_int.n_atom
                        remaining_internal.pop(idx_rem_int); added_in_iter_internal = True
                    if not added_in_iter_internal: break
                
                if len(current_merged_internal) > best_internal_connection_run["num_original_segments"]:
                    best_internal_connection_run["connected_segments_ordered"] = list(current_merged_internal)
                    best_internal_connection_run["num_original_segments"] = len(current_merged_internal)

            if best_internal_connection_run["connected_segments_ordered"]:
                print(f"  Connecting {best_internal_connection_run['num_original_segments']} unused segments from original chain {original_chain_id_of_unused} into a new chain.")
                secondary_chain_residues = []
                segments_for_secondary = best_internal_connection_run["connected_segments_ordered"]
                
                secondary_chain_residues.extend(segments_for_secondary[0].residues)
                current_c_atom_obj_secondary = get_terminal_atom(secondary_chain_residues[-1], 'C')
                if not current_c_atom_obj_secondary:
                    raise ValueError(f"Segment {segments_for_secondary[0].id} has no C-terminal C atom for secondary chain.")

                processed_segments_in_this_group_run = set(segments_for_secondary)

                for i_sec in range(1, len(segments_for_secondary)):
                    prev_seg_sec = segments_for_secondary[i_sec-1]
                    curr_seg_sec = segments_for_secondary[i_sec]
                    prev_c_sec = current_c_atom_obj_secondary.get_coord()
                    next_n_sec = curr_seg_sec.n_atom.get_coord()
                    dist_c_n_sec = np.linalg.norm(prev_c_sec - next_n_sec)
                    added_glycines_sec = []
                    if dist_c_n_sec > PEPTIDE_BOND_THRESHOLD:
                        p0_sec, p3_sec = prev_c_sec, next_n_sec
                        ca_prev_res_sec, ca_next_res_sec = secondary_chain_residues[-1], curr_seg_sec.residues[0]
                        if 'CA' in ca_prev_res_sec and 'CA' in ca_next_res_sec:
                            ca_prev_sec, ca_next_sec = ca_prev_res_sec['CA'].get_coord(), ca_next_res_sec['CA'].get_coord()
                            v0_sec,v3_sec = p0_sec - ca_prev_sec, p3_sec - ca_next_sec
                            p1_sec,p2_sec = p0_sec + v0_sec*0.5, p3_sec + v3_sec*0.5
                        else: 
                            mid_point_sec = (p0_sec+p3_sec)/2.0
                            p1_sec,p2_sec = p0_sec+(mid_point_sec-p0_sec)*0.5, p3_sec+(mid_point_sec-p3_sec)*0.5
                        ctrl_pts_sec = [p0_sec,p1_sec,p2_sec,p3_sec]
                        curve_len_approx_sec = np.linalg.norm(p3_sec-p0_sec)
                        num_gly_sec = max(0,int(round(curve_len_approx_sec/CA_CA_DISTANCE))-1)
                        if num_gly_sec > 0:
                            print(f"    Secondary chain: adding {num_gly_sec} glycines between ...{prev_seg_sec.id} and {curr_seg_sec.id} (C-N dist: {dist_c_n_sec:.2f} A)")
                            ca_p0_sec = ca_prev_res_sec['CA'].get_coord() if 'CA' in ca_prev_res_sec else p0_sec
                            ca_p3_sec = ca_next_res_sec['CA'].get_coord() if 'CA' in ca_next_res_sec else p3_sec
                            ca_v0_sec,ca_v3_sec = ca_p0_sec-p0_sec, ca_p3_sec-p3_sec
                            ca_ctrl_pts_sec = [ca_p0_sec,ca_p0_sec+ca_v0_sec*0.5,ca_p3_sec+ca_v3_sec*0.5,ca_p3_sec]
                            gly_ca_pos_sec_list = [(ca_p0_sec+ca_p3_sec)/2.0] if num_gly_sec==1 and np.linalg.norm(ca_p0_sec-ca_p3_sec)<CA_CA_DISTANCE*1.5 else bezier_curve(ca_ctrl_pts_sec,num_points=num_gly_sec+2)[1:-1]
                            curr_conn_c_coord_sec = prev_c_sec
                            for k_gly_sec, gly_ca_pos_val_sec in enumerate(gly_ca_pos_sec_list):
                                temp_glycine_res_num_counter+=1
                                prev_c_this_gly_sec,next_n_this_gly_sec = curr_conn_c_coord_sec, next_n_sec if k_gly_sec==num_gly_sec-1 else None
                                gly_res_sec = create_glycine_residue((' ',temp_glycine_res_num_counter,' '),gly_ca_pos_val_sec,prev_c_pos=prev_c_this_gly_sec,next_n_pos=next_n_this_gly_sec)
                                added_glycines_sec.append(gly_res_sec)
                                curr_conn_c_coord_sec = gly_res_sec['C'].get_coord()
                    secondary_chain_residues.extend(added_glycines_sec)
                    secondary_chain_residues.extend(curr_seg_sec.residues)
                    current_c_atom_obj_secondary = get_terminal_atom(secondary_chain_residues[-1], 'C')
                    if not current_c_atom_obj_secondary:
                        raise ValueError(f"Segment {curr_seg_sec.id} resulted in no C-terminal C atom for secondary chain.")
                final_chains_residues.append(secondary_chain_residues)
                processed_unused_segments.update(processed_segments_in_this_group_run)
            elif segments_to_process_in_group : # No internal connections, but current segment exists
                seg_single = segments_to_process_in_group[0] # Process the first one as a single chain
                final_chains_residues.append(list(seg_single.residues))
                processed_unused_segments.add(seg_single)
                print(f"  Unused segment {seg_single.id} from original chain {original_chain_id_of_unused} forms a separate chain (no internal connection found for this pass).")
            
            # Update segments_to_process_in_group for the next iteration of the while loop
            segments_to_process_in_group = [s for s in segments_from_one_orig_chain_initial if s not in processed_unused_segments]


    # --- 6. Create new PDB structure with all chains ---
    new_structure = Structure.Structure('connected_protein_multi')
    new_model = Model.Model(0)
    
    chain_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
    if len(final_chains_residues) > len(chain_letters):
        print(f"Warning: More than {len(chain_letters)} chains generated ({len(final_chains_residues)}). Some will be skipped in output.")
        final_chains_residues = final_chains_residues[:len(chain_letters)]

    for chain_idx, residue_list_for_chain in enumerate(final_chains_residues):
        if not residue_list_for_chain: continue

        new_chain_id = chain_letters[chain_idx % len(chain_letters)]
        new_chain = Chain.Chain(new_chain_id)
        
        current_residue_seq_num = 1
        for original_res in residue_list_for_chain:
            new_res = Residue.Residue(original_res.id, original_res.resname, original_res.segid)
            for atom in original_res.get_atoms():
                new_atom = Atom.Atom(atom.name, atom.coord, atom.bfactor, atom.occupancy,
                                     atom.altloc, atom.fullname, 0, 
                                     element=atom.element, pqr_charge=atom.pqr_charge, radius=atom.radius)
                new_res.add(new_atom)
            hetfield = original_res.id[0]
            new_id = (hetfield, current_residue_seq_num, ' ')
            new_res.id = new_id
            new_chain.add(new_res)
            current_residue_seq_num += 1
        
        new_model.add(new_chain)
        print(f"Outputting Chain {new_chain_id} with {len(new_chain)} residues.")

    new_structure.add(new_model)
    io = PDBIO()
    io.set_structure(new_structure)
    io.save(output_pdb_path)
    print(f"Successfully processed segments. Output saved to {output_pdb_path}")



connect_broken_chains(r"C:\Users\bashc\Desktop\working\ssBinding\working8.pdb",r"C:\Users\bashc\Desktop\working\ssBinding\connected.pdb"  )

Identified 18 segments in total.
Best primary chain will connect 17 segments.
  Main chain: adding 5 glycines between ...Seg-W-5(205_to_216) and Seg-Z-4(146_to_162) (C-N dist: 23.57 A)
  Main chain: adding 4 glycines between ...Seg-Z-4(146_to_162) and Seg-Z-2(45_to_67) (C-N dist: 18.71 A)
  Main chain: adding 1 glycines between ...Seg-Z-2(45_to_67) and Seg-Z-0(5_to_5) (C-N dist: 7.71 A)
  Main chain: adding 2 glycines between ...Seg-Z-1(6_to_37) and Seg-C-0(20_to_44) (C-N dist: 10.30 A)
  Main chain: adding 2 glycines between ...Seg-C-0(20_to_44) and Seg-A-0(16_to_44) (C-N dist: 11.21 A)
  Main chain: adding 1 glycines between ...Seg-A-0(16_to_44) and Seg-W-6(229_to_256) (C-N dist: 9.45 A)
  Main chain: adding 1 glycines between ...Seg-W-6(229_to_256) and Seg-W-7(259_to_259) (C-N dist: 8.08 A)
  Main chain: adding 2 glycines between ...Seg-W-8(260_to_276) and Seg-W-4(159_to_188) (C-N dist: 12.12 A)
  Main chain: adding 2 glycines between ...Seg-W-3(131_to_155) and Seg-W-1(60_to_75) (C-