In [4]:
from Bio import PDB
import numpy as np
from Bio.PDB.PDBIO import PDBIO
from Bio.PDB import Atom, Residue, Chain, Model, Structure

def get_terminal_atom(residue, term_type='N'):
    """
    Get N or C terminal atom from a residue.
    Falls back to CA atom if N or C is not present.
    Falls back to the first atom if CA is also not present.

    Args:
        residue (Residue): Biopython Residue object.
        term_type (str): 'N' for N-terminus, 'C' for C-terminus.

    Returns:
        Atom: Biopython Atom object or None if residue is empty.
    """
    if term_type == 'N':
        atom_name = 'N'
    elif term_type == 'C':
        atom_name = 'C'
    else:
        raise ValueError("term_type must be 'N' or 'C'")

    if atom_name in residue:
        return residue[atom_name]
    elif 'CA' in residue:
        # print(f"Warning: Missing {atom_name} in residue {residue.id}, using CA.")
        return residue['CA']
    else:
        try:
            # Fallback to the first atom if CA is also missing
            first_atom = next(iter(residue.get_atoms()))
            # print(f"Warning: Missing {atom_name} and CA in residue {residue.id}, using first atom {first_atom.id}.")
            return first_atom
        except StopIteration:
            # Residue has no atoms
            print(f"Warning: Residue {residue.id} has no atoms.")
            return None


def connect_pdb_chains(input_pdb, output_pdb, start_chain_id='A'):
    """
    Load a PDB file with multiple chains, connect them sequentially starting
    from a specified chain by finding the closest terminal ends,
    and save as a single chain 'A'.

    Args:
        input_pdb (str): Path to input PDB file.
        output_pdb (str): Path to output PDB file.
        start_chain_id (str): ID of the chain to start merging from. Defaults to 'A'.
    """
    # Parse the PDB file
    parser = PDB.PDBParser(QUIET=True)
    try:
        structure = parser.get_structure('protein', input_pdb)
    except Exception as e:
        print(f"Error parsing PDB file {input_pdb}: {e}")
        return None
        
    model = structure[0]

    # Store chain information
    chain_data = {}
    for chain in model:
        residues = list(chain.get_residues())
        if not residues:
            continue

        n_term_res = residues[0]
        c_term_res = residues[-1]

        n_atom = get_terminal_atom(n_term_res, 'N')
        c_atom = get_terminal_atom(c_term_res, 'C')

        # Skip chain if terminal atoms couldn't be determined
        if n_atom is None or c_atom is None:
            print(f"Warning: Could not determine terminal atoms for chain {chain.id}. Skipping.")
            continue

        chain_data[chain.id] = {
            'chain_obj': chain,
            'residues': residues,
            'N_atom': n_atom,
            'C_atom': c_atom
        }

    if not chain_data:
        print("No chains with valid terminal residues found.")
        return None

    # Check if start_chain_id exists, otherwise use the first available chain
    if start_chain_id not in chain_data:
        original_start_id = start_chain_id
        start_chain_id = next(iter(chain_data.keys()))
        print(f"Warning: Start chain '{original_start_id}' not found. Using first available chain '{start_chain_id}'.")

    # Initialize the merged list with the starting chain's residues
    merged_residues = list(chain_data[start_chain_id]['residues'])
    remaining_chain_ids = set(chain_data.keys())
    remaining_chain_ids.remove(start_chain_id)

    # Keep track of the N and C terminal atoms of the growing merged chain
    current_N_term_atom = chain_data[start_chain_id]['N_atom']
    current_C_term_atom = chain_data[start_chain_id]['C_atom']
    
    print(f"Starting merge with chain: {start_chain_id}")

    # Iteratively connect the closest remaining chain
    while remaining_chain_ids:
        min_dist = float('inf')
        best_next_chain_id = None
        connection_point = None # 'N' (prepend) or 'C' (append) end of the merged chain

        current_N_coord = current_N_term_atom.get_coord()
        current_C_coord = current_C_term_atom.get_coord()

        for chain_id in remaining_chain_ids:
            cand_N_atom = chain_data[chain_id]['N_atom']
            cand_C_atom = chain_data[chain_id]['C_atom']
            cand_N_coord = cand_N_atom.get_coord()
            cand_C_coord = cand_C_atom.get_coord()

            # Check connection: Current C-term to Candidate N-term (Append candidate)
            dist_C_to_N = np.linalg.norm(current_C_coord - cand_N_coord)
            if dist_C_to_N < min_dist:
                min_dist = dist_C_to_N
                best_next_chain_id = chain_id
                connection_point = 'C' # Connect candidate N-term to current C-term

            # Check connection: Current N-term to Candidate C-term (Prepend candidate)
            dist_N_to_C = np.linalg.norm(current_N_coord - cand_C_coord)
            if dist_N_to_C < min_dist:
                min_dist = dist_N_to_C
                best_next_chain_id = chain_id
                connection_point = 'N' # Connect candidate C-term to current N-term

        if best_next_chain_id:
            next_chain_info = chain_data[best_next_chain_id]
            next_chain_residues = next_chain_info['residues']

            if connection_point == 'C':
                print(f"  Appending chain {best_next_chain_id} (C-term to N-term dist: {min_dist:.2f} A)")
                merged_residues.extend(next_chain_residues)
                # Update the C-terminus of the merged chain
                current_C_term_atom = next_chain_info['C_atom']
            elif connection_point == 'N':
                print(f"  Prepending chain {best_next_chain_id} (N-term to C-term dist: {min_dist:.2f} A)")
                merged_residues = next_chain_residues + merged_residues
                # Update the N-terminus of the merged chain
                current_N_term_atom = next_chain_info['N_atom']

            remaining_chain_ids.remove(best_next_chain_id)
        else:
            # Should only happen if remaining chains are disconnected
            print("Warning: Could not find a suitable chain to connect. Remaining chains:", remaining_chain_ids)
            break

    # Create a new structure with the merged chain
    new_structure = Structure.Structure('merged_protein')
    new_model = Model.Model(0)
    new_chain = Chain.Chain('A')

    # Add residues to the new chain, renumbering them sequentially
    current_residue_number = 1
    for original_res in merged_residues:
        # Create a deep copy to avoid modifying the original structure
        new_res = original_res.copy()
        
        # Preserve HETATM flag, renumber sequentially, remove insertion code
        hetfield = original_res.id[0]
        insertion_code = ' ' 
        new_id = (hetfield, current_residue_number, insertion_code)
        
        new_res.id = new_id
        new_res.parent = new_chain # Set parent to the new chain 'A'
        
        new_chain.add(new_res)
        current_residue_number += 1

    new_model.add(new_chain)
    new_structure.add(new_model)

    # Save the merged structure
    io = PDBIO()
    io.set_structure(new_structure)
    io.save(output_pdb)
    print(f"Successfully merged {len(chain_data)} chains into {output_pdb}")

    return new_structure

# Example usage: Replace paths with your actual file paths
# Make sure the input PDB exists and the output directory is writable.
try:
    # Create dummy PDB for testing if needed, or use your actual file
    # connect_pdb_chains("path/to/your/multi_chain.pdb", "path/to/your/connected_chain.pdb", start_chain_id='A')
    connect_pdb_chains(r"C:\Users\bashc\Desktop\working\ssBinding\working7.pdb", r"C:\Users\bashc\Desktop\working\ssBinding\connected_v2.pdb", start_chain_id='A')
except FileNotFoundError:
    print("Error: Input PDB file not found. Please check the path.")
except Exception as e:
    print(f"An error occurred during PDB connection: {e}")

Starting merge with chain: A
  Appending chain E (C-term to N-term dist: 2.49 A)
  Appending chain H (C-term to N-term dist: 1.65 A)
  Appending chain G (C-term to N-term dist: 6.86 A)
  Appending chain F (C-term to N-term dist: 4.85 A)
Successfully merged 5 chains into C:\Users\bashc\Desktop\working\ssBinding\connected_v2.pdb


In [136]:
from Bio import PDB
from Bio.PDB import *
import numpy as np
from scipy.spatial.transform import Rotation
from Bio.PDB import Residue, Atom
from IPython.display import clear_output
import py3Dmol

class HemeHelixAligner:
    def __init__(self, pdb_file):
        """Initialize with PDB file"""
        self.parser = PDB.PDBParser(QUIET=True)
        self.structure = self.parser.get_structure('protein', pdb_file)
        self.model = self.structure[0]
        
    def get_heme_atoms(self):
        """Get key heme atoms for plane calculation"""
        heme_chain = self.model['N']
        for residue in heme_chain.get_residues():
            if residue.get_resname() == "HEM":
                return {
                    'FE': residue['FE'].get_coord(),
                    'NA': residue['NA'].get_coord(),
                    'NB': residue['NB'].get_coord(),
                    'NC': residue['NC'].get_coord(),
                    'ND': residue['ND'].get_coord()
                }
        raise ValueError("Heme not found in chain N")
    
    def calculate_heme_plane(self, heme_atoms):
        """Calculate heme plane normal vector using nitrogen atoms"""
        # Use the four pyrrole nitrogens to define the plane
        points = np.array([heme_atoms[k] for k in ['NA', 'NB', 'NC', 'ND']])
        
        # Calculate center of mass of the nitrogens
        center = np.mean(points, axis=0)
        
        # Calculate normal vector using SVD
        centered_points = points - center
        _, _, vh = np.linalg.svd(centered_points)
        normal = vh[2]  # Third singular vector is normal to best-fit plane
        
        return center, normal
    
    def align_structure_to_xy(self):
        """Align heme to xy plane with iron at origin"""
        # Get heme atoms and calculate plane
        heme_atoms = self.get_heme_atoms()
        fe_pos = heme_atoms['FE']
        center, normal = self.calculate_heme_plane(heme_atoms)
        
        # First translate iron to origin
        self.translate_structure(-fe_pos)
        
        # Calculate rotation to align normal with z-axis
        z_axis = np.array([0, 0, 1])
        rotation_matrix = self.get_alignment_rotation(normal, z_axis)
        
        # Apply rotation to entire structure
        self.rotate_structure(rotation_matrix)
    
    def get_alignment_rotation(self, vector1, vector2):
        """Calculate rotation matrix to align vector1 with vector2"""
        vector1 = vector1 / np.linalg.norm(vector1)
        vector2 = vector2 / np.linalg.norm(vector2)
        
        cross_product = np.cross(vector1, vector2)
        if np.allclose(cross_product, 0):
            return np.eye(3)
            
        dot_product = np.dot(vector1, vector2)
        angle = np.arccos(np.clip(dot_product, -1.0, 1.0))
        
        rotation = Rotation.from_rotvec(cross_product / np.linalg.norm(cross_product) * angle)
        return rotation.as_matrix()
    
    def translate_structure(self, vector):
        """Translate entire structure by vector"""
        for atom in self.model.get_atoms():
            atom.set_coord(atom.get_coord() + vector)
    
    def rotate_structure(self, rotation_matrix):
        """Rotate entire structure around origin"""
        for atom in self.model.get_atoms():
            atom.set_coord(np.dot(rotation_matrix, atom.get_coord()))
    
    def position_helix(self, chain_id, above=True):
        """Position helix above or below heme plane"""
        chain = self.model[chain_id]
        his_coord = None
        
        # Find histidine NE2
        for residue in chain.get_residues():
            if residue.get_resname() == "HIS":
                his_coord = residue["NE2"].get_coord()
                break
        
        if his_coord is None:
            raise ValueError(f"No histidine found in chain {chain_id}")
        
        # Calculate translation needed to position histidine
        target_z = 2.2 if above else -2.2  # Fe-N coordination distance
        delta_z = target_z - his_coord[2]
        
        # Translate chain to correct position
        for atom in chain.get_atoms():
            coord = atom.get_coord()
            atom.set_coord(coord + np.array([-1*his_coord[0], -1*his_coord[1], delta_z]))
    
    def clone_and_replaceChain(self, source_ChainID, replace_ChainID):
        # Determine the location and main axis of helix in chain A
        chain_Replace = self.model[replace_ChainID]
        helix_axis_Replace, his_axis_Replace, center_Replace = self.get_barrel_axis(chain_Replace)
        
        
        # Clone chain source
        helix_axis_Source, his_axis_Source, center_Source = self.get_barrel_axis(self.model[source_ChainID])
        chain_Clone = self.model[source_ChainID].copy()
        
        # Translate and rotate the clone
        translation_vector = center_Replace - center_Source
        for atom in chain_Clone.get_atoms():
            atom.set_coord(atom.get_coord() + translation_vector)
         
        
        #rotation_matrix = self.get_alignment_rotation(helix_axis_Replace, his_axis_Replace)
        #self._rotate_chain_around_point(chain_Clone, rotation_matrix, center_Replace)
        
        
        # Remove the original chain A
        self.model.detach_child(replace_ChainID)
        
        # Rename the clone to chain id A
        chain_Clone.id = replace_ChainID
        self.model.add(chain_Clone)
        
        

    def align_all(self):
        """Perform complete alignment"""
        # First align heme to xy plane with Fe at origin
        self.align_structure_to_xy()
        
        chain_a = self.structure[0]['A']
        self.align_helix_and_histidine(chain_a, True)
        
        chain_b = self.structure[0]['B']
        self.align_helix_and_histidine(chain_b,False)        
        
        
        # Position helices
        self.position_helix('A', above=True)
        self.position_helix('B', above=False)

    
    def get_barrel_axis(self, chain):
        """Calculate helix axis using CA atoms"""
        ca_coords = np.array([atom.get_coord() for atom in chain.get_atoms() 
                            if atom.name == 'CA'])
        centered = ca_coords - np.mean(ca_coords, axis=0)
        _, _, vh = np.linalg.svd(centered)
        
        # Calculate the center of the helix
        ca_coords = np.array([atom.get_coord() for atom in chain.get_atoms() 
                            if atom.name == 'CA'])
        center = np.mean(ca_coords, axis=0)
        
        his = next(res for res in chain.get_residues() if res.get_resname() == "HIS")
        his_axis = his['NE2'].get_coord() - center
        
        return vh[0],his_axis, center 
    
    def align_helix_and_histidine(self,chain, above):
        """
        Align helix to x-axis and rotate histidine axis to z-plane
        
        Parameters:
            chain: BioPython Chain object containing helix and histidine
        """
        # Get CA coordinates and calculate helix axis
        ca_coords = np.array([atom.get_coord() for atom in chain.get_atoms() 
                            if atom.name == 'CA'])
        center = np.mean(ca_coords, axis=0)
        
        # Center the chain at origin
        for atom in chain.get_atoms():
            atom.set_coord(atom.get_coord() - center)
        
        # Recalculate centered CA coords for SVD
        ca_coords = np.array([atom.get_coord() for atom in chain.get_atoms() 
                            if atom.name == 'CA'])
        _, _, vh = np.linalg.svd(ca_coords)
        helix_axis = vh[0]
        
        # Ensure helix axis points in positive x direction
        if helix_axis[0] < 0:
            helix_axis = -helix_axis
        
        # Calculate rotation to align helix axis with x-axis
        x_axis = np.array([1, 0, 0])
        rotation1 =self. get_alignment_rotation(helix_axis, x_axis)
        
        # Apply first rotation to all atoms
        for atom in chain.get_atoms():
            atom.set_coord(np.dot(rotation1, atom.get_coord()))
        
        # Get histidine axis (now in rotated frame)
        his = next(res for res in chain.get_residues() if res.get_resname() == "HIS")
        his_axis = his['NE2'].get_coord()  # Center is at origin
        
        # Project histidine axis onto yz-plane (perpendicular to helix axis)
        his_yz = his_axis.copy()
        if above :
            his_yz=his_yz*-1
        his_yz[0] = 0  # Remove x component
        
        # Skip rotation if projection is too small
        if np.linalg.norm(his_yz) > 1e-6:
            his_yz = his_yz / np.linalg.norm(his_yz)
            
            # Calculate rotation around x-axis to align his_yz with z-axis
            z_axis = np.array([0, 0, 1])
            rotation2 = self. get_alignment_rotation_around_x(his_yz, z_axis)
            
            # Apply second rotation to all atoms
            for atom in chain.get_atoms():
                atom.set_coord(np.dot(rotation2, atom.get_coord()))

    def get_alignment_rotation(self,v1, v2):
        """Calculate rotation matrix to align v1 with v2"""
        v1 = v1 / np.linalg.norm(v1)
        v2 = v2 / np.linalg.norm(v2)
        
        cross = np.cross(v1, v2)
        if np.allclose(cross, 0):
            return np.eye(3)
            
        dot = np.dot(v1, v2)
        angle = np.arccos(np.clip(dot, -1.0, 1.0))
        
        rotation = Rotation.from_rotvec(cross / np.linalg.norm(cross) * angle)
        return rotation.as_matrix()

    def get_alignment_rotation_around_x(self,v1, v2):
        """Calculate rotation matrix around x-axis to align v1 with v2"""
        # Project both vectors to ensure they're in yz-plane
        v1[0] = 0
        v2[0] = 0
        v1 = v1 / np.linalg.norm(v1)
        v2 = v2 / np.linalg.norm(v2)
        
        # Calculate angle in yz-plane
        angle = np.arctan2(np.cross(v1, v2)[0], np.dot(v1, v2))
        
        # Create rotation matrix around x-axis
        cos_t = np.cos(angle)
        sin_t = np.sin(angle)
        return np.array([[1, 0, 0],
                        [0, cos_t, -sin_t],
                        [0, sin_t, cos_t]])
     
    
    def save_structure(self, output_file):
        """Save aligned structure"""
        io = PDB.PDBIO()
        io.set_structure(self.structure)
        io.save(output_file)


class HelixManipulator:
    def __init__(self, model, iron_selector="N:FE"):
        """
        Initialize the manipulator with a BioPython model.

        Args:
            model: BioPython Model object containing chains A, B (helices) and N (heme)
            iron_selector: String specifying chain:atom for the iron atom
        """
        self.model = model
        self.chain_centers = {}
        self.iron_pos = None
        self.iron_chain, self.iron_name = iron_selector.split(":")

    def get_chain_centers_and_width(self):
        """Calculate the geometric center and radius of the helix barrel."""
        all_coords = []
        max_distances = []

        # First calculate the barrel center using all CA atoms
        for chain_id in ["A", "B"]:
            chain = self.model[chain_id]
            chain_coords = []

            for residue in chain:
                for atom in residue:
                    if atom.name == "CA":
                        coord = atom.get_coord()
                        chain_coords.append(coord)
                        all_coords.append(coord)

            self.chain_centers[chain_id] = np.mean(chain_coords, axis=0)

        # Now calculate maximum radius from barrel center to any atom
        for chain_id in ["A", "B"]:
            chain = self.model[chain_id]
            chainCenter = self.chain_centers[chain_id]
            chainCenter = np.array([0, chainCenter[1], chainCenter[2]])

            for residue in chain:
                dists = []
                for atom in residue:
                    atomCoords = atom.get_coord()
                    atomCoords = np.array([0, atomCoords[1], atomCoords[2]])
                    dists.append(np.linalg.norm(atomCoords - chainCenter))
                max_distances.append(np.max(dists))

        self.helix_radius = np.mean(max_distances)

    def get_iron_position(self):
        """Get the coordinates of the iron atom in the heme."""
        for atom in self.model[self.iron_chain].get_atoms():
            if atom.name == self.iron_name:
                self.iron_pos = atom.get_coord()
                break

    def spin_heme(self, angle_degrees, angle_tilt):
        chain_heme = self.model[self.iron_chain]
        if self.iron_pos is None:
            self.get_iron_position()
            
        if angle_degrees != 0:
            self._rotate_chain_around_point(
                chain_heme, rotaxis2m(angle_degrees / 180 * np.pi, Vector([0, 0, 1])), self.iron_pos
            )
        if angle_tilt != 0:
            self._rotate_chain_around_point(
                chain_heme, rotaxis2m(angle_tilt / 180 * np.pi, Vector([1, 0, 0])), self.iron_pos
            )
            
    def align_to_z(self):
        """Rotate structure around x-axis to align helix centers with z-axis."""
        if self.iron_pos is None:
            self.get_iron_position()
        if not self.chain_centers:
            self.get_chain_centers_and_width()

        # Calculate current axis between helix centers
        axis = self.chain_centers["B"] - self.chain_centers["A"]

        # Project axis onto y-z plane
        y_component = axis[1]
        z_component = axis[2]

        # Calculate angle needed to rotate to z-axis in y-z plane
        angle = np.arctan2(y_component, z_component)

        # Create rotation matrix around x-axis
        rotator = rotaxis2m(angle, Vector([1, 0, 0]))

        # Apply rotation around iron
        self._rotate_around_point(rotator, self.iron_pos)

    def rotate_helices(self, angle_degrees):
        """
        Rotate helices in opposite directions around z-axis through iron.

        Args:
            angle_degrees: Rotation angle in degrees (positive for A clockwise)
        """
        angle_rad = np.radians(angle_degrees)

        # Create rotation matrices
        rot_matrix_A = rotaxis2m(angle_rad, Vector([0, 0, 1]))
        rot_matrix_B = rotaxis2m(-angle_rad, Vector([0, 0, 1]))

        # Rotate individual chains
        for atom in self.model["A"].get_atoms():
            coord = atom.get_coord()
            new_coord = np.dot(rot_matrix_A, coord - self.iron_pos) + self.iron_pos
            atom.set_coord(new_coord)

        for atom in self.model["B"].get_atoms():
            coord = atom.get_coord()
            new_coord = np.dot(rot_matrix_B, coord - self.iron_pos) + self.iron_pos
            atom.set_coord(new_coord)

    def _translate_chain(self, chain, translation_vector):
        """Helper method to translate all atoms in a chain."""
        for residue in chain:
            for atom in residue:
                coord = atom.get_coord()
                new_coord = coord + translation_vector
                atom.set_coord(new_coord)

    def _rotate_around_point(self, rotation_matrix, point):
        """Helper method to rotate all atoms around a point."""
        for chain in self.model:
            for residue in chain:
                for atom in residue:
                    coord = atom.get_coord()
                    new_coord = np.dot(rotation_matrix, coord - point) + point
                    atom.set_coord(new_coord)

    def _rotate_chain_around_point(self, chain, rotation_matrix, point):
        """Helper method to rotate all atoms around a point."""
        for residue in chain:
            for atom in residue:
                coord = atom.get_coord()
                new_coord = np.dot(rotation_matrix, coord - point) + point
                atom.set_coord(new_coord)

    def _spin_helix(self, chain, angle_degrees):
        """Helper method to spin a helix around its axis."""
        # First calculate the barrel center using all CA atoms
        chain_coords = []
        for residue in chain:
            for atom in residue:
                if atom.name == "CA":
                    coord = atom.get_coord()
                    chain_coords.append(coord)

        center = np.mean(chain_coords, axis=0)
        # deteremine the axis of the helix using svd
        centered = chain_coords - center
        _, _, vh = np.linalg.svd(centered)
        axis = vh[0]

        # Rotate all atoms around the helix center
        self._rotate_chain_around_point(
            chain, rotaxis2m(angle_degrees / 180 * np.pi, Vector(axis)), center
        )

    def create_packed_copies(
        self, angle_degrees, y_gapdelta, z_offset, spin, antiParallel=False, addHeme_row=False, barrel_stagger=0,heme_stagger =0 
    ):
        """
        Create copies of the helices arranged in a packed pattern.

        Args:
            angle_degrees: Rotation angle in degrees for helices
        """
        if self.helix_radius is None:
            self.get_chain_centers_and_width()

        # Calculate spacing based on helix diameter and add 10% buffer
        y_spacing = (y_gapdelta + 2 * self.helix_radius * 0.85) / np.cos(
            angle_degrees / 180 * np.pi
        )  # 2 * radius * 1.1

        # Create 6 copies
        cc = 1
        extra = 1
        newX=0
        for i in range(7):
            # Create new chains using Chain.copy()
            new_chains = {
                chain_id: self.model[chain_id].copy()
                for chain_id in ["A", "B", self.iron_chain]
            }

            # Calculate y-offset
            y_gapdelta = (i + 1) * y_spacing  # Center the pattern around y=0

            # if it is anti parallel and i is even, rotate the new chains by 180 degrees
            if antiParallel and i % 2 == 0:
                for chain in new_chains.values():
                    if chain.id != "N":
                        self._rotate_chain_around_point(
                            chain, rotaxis2m(np.pi, Vector([0, 0, 1])), self.iron_pos
                        )
            if spin != 0:
                for chain in new_chains.values():
                    if chain.id != "N":
                        self._spin_helix(chain, spin)

            # Translate the chains
            for chain in new_chains.values():
                if chain.id == "A":
                    self._translate_chain(chain, np.array([0, y_gapdelta, z_offset]))
                elif chain.id == "B":
                    self._translate_chain(chain, np.array([0, y_gapdelta, -1 * z_offset]))
                else:
                    self._translate_chain(chain, np.array([0, y_gapdelta, 0]))

            # Determine the x coordinate range of the new chains
            if addHeme_row:
                if newX==0:
                    x_coords = []
                    for chain in new_chains.values():
                        for atom in chain.get_atoms():
                            x_coords.append(atom.get_coord()[0])
                    x_range = (min(x_coords), max(x_coords))
                    newX = np.random.uniform(x_range[0]+45, x_range[1])
                newHeme = None
                for chain in new_chains.values():
                    if chain.id=='N':
                        newHeme = chain.copy()
                        newHeme_x = np.mean([atom.get_coord()[0] for atom in newHeme.get_atoms()])
                if newHeme:
                    self._translate_chain(newHeme, np.array([newX-newHeme_x, 0, 0]))
                    newHeme.id = "NN"
                    new_chains['NN']=newHeme
                    
            if barrel_stagger != 0:
                for chain in new_chains.values():
                    if chain.id=='A':
                        self._translate_chain(chain, np.array([0,y_spacing* barrel_stagger, 0]))
             
            if heme_stagger != 0 and i % 2==0:
                for chain in new_chains.values():
                    if chain.id=='N':
                        self._translate_chain(chain, np.array([ heme_stagger,0, 0]))
                        
            # Change chain IDs to the next letter and add them to the model
            for chain_id, new_chain in new_chains.items():
                new_chain.id = chr(ord("B") + cc)
                if new_chain.id == "N":
                    new_chain.id = str(extra)
                    extra += 1
                self.model.add(new_chain)
                cc += 1
        #remove chains A,B, N from the model
        self.model.detach_child("A")
        self.model.detach_child("B")
        self.model.detach_child("N")

    def compressHemes(self, num_to_add):
        heme_chains = [chain for chain in self.model if any(residue.get_resname() == "HEM" for residue in chain.get_residues())]
        
        # Calculate the total distance covered by hemes in the y direction
        heme_fe_positions = [residue['FE'].get_coord()[1] for chain in heme_chains for residue in chain.get_residues() if residue.get_resname() == "HEM"]
        minFe= min(heme_fe_positions)
        maxFe= max(heme_fe_positions)
        total_y_distance = maxFe - minFe
        shortest = total_y_distance / len(heme_chains)        
        
        # Calculate the y-axis translation needed to fit num_to_add hemes
        y_translation = total_y_distance/ (len(heme_chains) + num_to_add)
        
        source_chain = heme_chains[-2]
        for i in range(num_to_add):
            new_chain = source_chain.copy()
            new_chain.id = chr(ord("x") + i) # Assign a new unique ID
            self.model.add(new_chain)
            heme_chains.append(new_chain)
        
        # Move existing heme chains closer on the y-axis
        for i,chain in enumerate( heme_chains):
            fe_y= [residue['FE'].get_coord()[1] for residue in chain.get_residues() if residue.get_resname() == "HEM"]
            fe_y= min(fe_y)
            newY = minFe + i * y_translation
            self._translate_chain(chain, np.array([0, newY-fe_y, 0]))
        
        
        
        
        
    def check_clashes(self):
        """
        Check for clashes between all pairs of residues in the packed arrangement.

        Args:
            distance_threshold: Minimum distance (Å) between residues to be considered a clash

        Returns:
            int: Number of clashes found
        """
        clash_count = 0
        # get all the chains in the model
        chains = [chain for chain in self.model]
        # Check clashes between all pairs of models
        for i, chain1 in enumerate(chains):
            for chain2 in chains[i + 1 :]:
                clash_count += self._count_clashes_between_chains(                    chain1, chain2                )

        return clash_count

    def _count_clashes_between_chains(self, chain1, chain2):
        """Helper method to count clashes between two sets of chains."""
        clashes = 0
        residue_distance_threshold = 25.0
        #use the radius of the carbon atom as the threshold in angstroms
        atom_distance_threshold =  1.7 
        # First check residue centers (CA positions)
        for res1 in chain1:
            ca1 = res1["CA"].get_coord() if "CA" in res1 else None
            if ca1 is None:
                continue

            for res2 in chain2:
                ca2 = res2["CA"].get_coord() if "CA" in res2 else None
                if ca2 is None:
                    continue

                # If residues are close, check all atom pairs
                if np.linalg.norm(ca1 - ca2) < residue_distance_threshold:
                    # Detailed atom check for close residues
                    for atom1 in res1:
                        coord1 = atom1.get_coord()
                        for atom2 in res2:
                            coord2 = atom2.get_coord()
                            if np.linalg.norm(coord1 - coord2) < atom_distance_threshold:
                                clashes += 1
                                # Early exit for this residue pair
                                break
                        if clashes > 0:
                            break

        return clashes
    def saveModel(self, filename):
        io = PDBIO()
        io.set_structure(self.model)
        io.save(filename)


aligner = HemeHelixAligner("double_heme.pdb")
aligner.clone_and_replaceChain("B", "A")
aligner.align_all()
aligner.save_structure("test.pdb")
print("Structure aligned and saved.")

Structure aligned and saved.


In [None]:
cc=7
for i in range(1000):
    if np.random.rand() < 0.33:
        aligner = HemeHelixAligner("double_heme_end.pdb")
        edge = True
    else :
        aligner = HemeHelixAligner("double_heme.pdb")
        edge = False
    cloneProb = np.random.rand()
    if cloneProb < 0.25:
        aligner.clone_and_replaceChain("B", "A")
    elif cloneProb < 0.5:
        aligner.clone_and_replaceChain("A", "B")
        
    aligner.align_all()
    aligner.save_structure("test.pdb")
    print("Structure aligned and saved.")

    model = aligner.model
    # Initialize the manipulator with your model
    manipulator = HelixManipulator(model)

    tilt =0 
    if np.random.rand() < 0.1:
        tilt = 25*np.random.rand()
    tilt = 45
    manipulator.spin_heme(360*np.random.rand(), tilt)
    
    # Align helices with z-axis
    manipulator.align_to_z()

    # Rotate helices through a random angle
    angle = 45 * np.random.rand()
    
    yOffset = manipulator.helix_radius * 0.1 * np.random.rand()
    manipulator.rotate_helices(angle)
    spin = 0
    if np.random.rand() < 0.2:
        spin = 45 * np.random.rand()
        
    doubleUp=  np.random.rand()<.25

    heme_stagger = 0
    if np.random.rand() < 0.1:
        heme_stagger = -10 * (.5- np.random.rand())
    barrel_stagger = 0
    if np.random.rand() < 0.1:
        barrel_stagger =  np.random.rand()

    if edge or np.random.rand() > 0.5:
        manipulator.create_packed_copies(
            angle,
            (0.5 - np.random.rand()) * 2 * 0.5,
            (0.5 - np.random.rand()) * 2,
            spin,
            False, doubleUp, barrel_stagger=  barrel_stagger, heme_stagger=heme_stagger
        )
    else:
        manipulator.create_packed_copies(
            angle,
            (0.5 - np.random.rand()) * 2 * 3,
            (0.5 - np.random.rand()) * 2,
            spin,
            True, doubleUp, barrel_stagger=  barrel_stagger, heme_stagger=heme_stagger
        )
    
    clashes = manipulator.check_clashes()
    
    print(f'Found {clashes} clashes')
    if clashes <18 or True :
        manipulator.saveModel(f"D:/PythonProj/structures/perturbed_2_{cc}.pdb")
        if np.random.rand() < 0.1:
            cc+=1
            manipulator.compressHemes(1)
            manipulator.saveModel(f"D:/PythonProj/structures/perturbed_2_{cc}.pdb")
    
        clear_output(wait=True)
        # Clean the display
        view = py3Dmol.view(width=800, height=600)
        view.removeAllModels()

        # Load the model
        with open(f"D:/PythonProj/structures/perturbed_2_{cc}.pdb", 'r') as f:
            pdb_data = f.read()

        view.addModel(pdb_data, 'pdb')
        view.setStyle({'cartoon': {'color': 'spectrum'}})
        view.setStyle({'resn': 'HEM'}, {'stick': {}})
        view.zoomTo()
        view.show()
        cc+=1
     

Structure aligned and saved.


In [None]:

class HelixLinker:
    def __init__(self, model, excluded_chains=None):
        """
        Initialize linker for connecting helix chains.
        
        Args:
            model: BioPython Model object containing the helices
            excluded_chains: List of chain IDs to exclude (e.g., hemes)
        """
        self.model = model
        self.excluded_chains = excluded_chains or []
        self.chain_ends = {}  # Store terminal residues for each chain
        self.connections = []  # Store pairs of chain ends to connect
        
        # Standard amino acid geometry
        self.ca_c_length = 1.52    # Å
        self.c_n_length = 1.33     # Å
        self.n_ca_length = 1.46    # Å
        self.ca_c_n_angle = 117.2  # degrees
        self.c_n_ca_angle = 121.7  # degrees
        self.peptide_omega = 180.0 # degrees
        
        # Initialize structure
        self._identify_chain_ends()
        self._find_closest_connections()

    def _identify_chain_ends(self):
        """Identify N and C terminal residues for each chain."""
        for chain in self.model:
            if chain.id in self.excluded_chains:
                continue
                
            residues = list(chain)
            if not residues:
                continue
                
            self.chain_ends[chain.id] = {
                'N': residues[0],
                'C': residues[-1]
            }

    def _find_closest_connections(self):
        """Find closest chain ends to connect."""
        chain_ids = list(self.chain_ends.keys())
        
        for i, chain1 in enumerate(chain_ids):
            distances = []
            for chain2 in chain_ids:
                # Calculate distances between all possible end combinations
                end_carbon = self.chain_ends[chain1]['C']['C'].get_coord()
                start_n = self.chain_ends[chain2]['N']['N'].get_coord()
                
                distances .append( {   'c2':chain2, 'distance': np.linalg.norm(end_carbon - start_n)  })
                
            # Store the shortest connection
            best_connection = min(distances, key=lambda x: x['distance'])
            self.connections.append({
                'chain1': chain1,
                'chain2': best_connection['c2'],
                'distance': best_connection['distance']
            })

    def _calculate_residue_direction(self, residue):
        """Calculate the direction vector of a residue's backbone."""
        # Use CA->C vector for C-terminus and N->CA vector for N-terminus
        ca_pos = residue['CA'].get_coord()
        if 'C' in residue:
            c_pos = residue['C'].get_coord()
            direction = c_pos - ca_pos
        else:
            n_pos = residue['N'].get_coord()
            direction = ca_pos - n_pos
        return direction / np.linalg.norm(direction)

    def _evaluate_bezier(self, p0, p1, p2, p3, t):
        """
        Evaluate a cubic Bezier curve at parameter t.
        
        Args:
            p0, p1, p2, p3: Control points
            t: Parameter value between 0 and 1
            
        Returns:
            Point on the curve at parameter t
        """
        return (
            (1-t)**3 * p0 + 
            3*(1-t)**2 * t * p1 + 
            3*(1-t) * t**2 * p2 + 
            t**3 * p3
        )

    def generate_linker(self, connection, num_residues):
        """
        Generate backbone coordinates for a linker using a Bezier curve.
        
        Args:
            connection: Dict containing connection details
            num_residues: Number of residues in the linker
            
        Returns:
            List of dictionaries containing backbone atom coordinates
        """
        # Get start and end points and their directions
        baseChainID=connection['chain1']
        
        chain1 = self.chain_ends[baseChainID]
        chain2 = self.chain_ends[connection['chain2']]
        
        
        start_res = chain1['C']
        end_res = chain2['N']
        start_direction = self._calculate_residue_direction(start_res)
        end_direction = -self._calculate_residue_direction(end_res)  # Reverse for N-terminus
       
        
        start_ca = start_res['CA'].get_coord()
        end_ca = end_res['CA'].get_coord()
        
        # Calculate distance between endpoints
        distance = np.linalg.norm(end_ca - start_ca)
        
        # Set control points for Bezier curve
        # Use directions from residues to determine control points
        control_scale = distance / 3  # Scale factor for control point distance
        p0 = start_ca
        p1 = start_ca + start_direction * control_scale
        p2 = end_ca - end_direction * control_scale
        p3 = end_ca
        
        # Generate points along Bezier curve
        t = np.linspace(0, 1, num_residues)
        path = np.array([self._evaluate_bezier(p0, p1, p2, p3, ti) for ti in t])
        
        # Add small random perturbations to create more natural coil-like structure
        # but keep perturbations smaller than in previous version to maintain overall curve shape
        perturbations = np.random.normal(0, 0.5, (num_residues, 3))
        path += perturbations
        
        # Ensure proper spacing between CA atoms while maintaining curve shape
        for i in range(1, len(path)):
            direction = path[i] - path[i-1]
            current_distance = np.linalg.norm(direction)
            if current_distance != 3.8:  # Ideal CA-CA distance
                # Move point to maintain 3.8Å spacing
                path[i] = path[i-1] + 3.8 * direction/current_distance
        
        # Generate backbone atoms
        backbone = []
        for i in range(num_residues):
            if i == 0:
                prev_c = start_res['C'].get_coord()
                prev_ca = start_ca
            else:
                prev_c = backbone[i-1]['C']
                prev_ca = backbone[i-1]['CA']
            
            # Calculate tangent to curve at this point for better backbone placement
            if i < len(path) - 1:
                tangent = path[i+1] - path[i]
            else:
                tangent = end_direction
            tangent = tangent / np.linalg.norm(tangent)
            
            res_atoms = self._generate_backbone_atoms_with_direction(
                prev_c, prev_ca, path[i], tangent
            )
            backbone.append(res_atoms)
        
        return backbone
    
    def add_linker_to_chain(self, backbone_atoms, base_chain_id):
        """
        Add generated linker atoms to the specified chain.
        
        Args:
            backbone_atoms: List of dictionaries containing atom coordinates
            base_chain_id: ID of the chain to add linker to
        """
        # Get the last residue number of the base chain
        last_res_id = max(int(res.id[1]) for res in self.model[base_chain_id])
        
        # Add new residues to the chain
        for i, atoms in enumerate(backbone_atoms):
            # Create new residue (using GLY for simplicity)
            new_res_id = last_res_id + i + 1
            new_res = Residue.Residue((' ', new_res_id, ' '), 'GLY', '')
            
            # Add backbone atoms to residue
            for atom_name, coord in atoms.items():
                new_atom = Atom.Atom(
                    name=atom_name,
                    coord=coord,
                    bfactor=20.0,
                    occupancy=1.0,
                    altloc=' ',
                    fullname=atom_name,
                    serial_number=None,
                    element=atom_name[0]
                )
                new_res.add(new_atom)
            
            # Add residue to chain
            self.model[base_chain_id].add(new_res)

    def _generate_backbone_atoms_with_direction(self, prev_c, prev_ca, current_ca, direction):
        """
        Generate N, CA, C atoms with proper geometry, considering curve direction.
        
        Args:
            prev_c: Coordinates of previous residue's C atom
            prev_ca: Coordinates of previous residue's CA atom
            current_ca: Coordinates of current CA atom
            direction: Direction vector of the curve at this point
        """
        # Calculate N position using previous C and current direction
        peptide_plane_normal = np.cross(prev_c - prev_ca, direction)
        peptide_plane_normal = peptide_plane_normal / np.linalg.norm(peptide_plane_normal)
        n_direction = np.cross(direction, peptide_plane_normal)
        n_pos = prev_c + self.c_n_length * n_direction
        
        # CA position is given
        ca_pos = current_ca
        
        # Calculate C position using current direction
        c_direction = np.cross(direction, peptide_plane_normal)
        c_pos = ca_pos + self.ca_c_length * c_direction
        
        return {
            'N': n_pos,
            'CA': ca_pos,
            'C': c_pos
        }
        
# Find all chains in new_model that have a HEM marker
heme_chains = []
for chain in model:
    for residue in chain:
        if residue.get_resname() == "HEM":
            heme_chains.append(chain.id)
            break

# print("Chains with HEM marker:", heme_chains)
        
# linker = HelixLinker(model, excluded_chains=heme_chains)  # Exclude heme chain

# # Generate linker atoms
# backbone_atoms = linker.generate_linker(linker.connections[0], 10)

# # Add linker to the base chain
# linker.add_linker_to_chain(backbone_atoms, linker.connections[0]['chain1'])

# # Save structure
# io = PDBIO()
# io.set_structure(model)
# io.save("linked_structure.pdb")        

In [54]:
 
from scipy.optimize import minimize

class HelixLinker:
    def __init__(self, model, excluded_chains=None):
        """
        Initialize linker for connecting helix chains.
        
        Args:
            model: BioPython Model object containing the helices
            excluded_chains: List of chain IDs to exclude (e.g., hemes)
        """
        self.model = model
        self.excluded_chains = excluded_chains or []
        self.chain_ends = {}  # Store terminal residues for each chain
        self.connections = []  # Store pairs of chain ends to connect
        
        # Standard amino acid geometry
        self.ca_c_length = 1.52    # Å
        self.c_n_length = 1.33     # Å
        self.n_ca_length = 1.46    # Å
        self.ca_ca_distance = 3.8  # Å
        self.ca_c_n_angle = 117.2  # degrees
        self.c_n_ca_angle = 121.7  # degrees
        self.peptide_omega = 180.0 # degrees
        
        # Initialize structure
        self._identify_chain_ends()
        self._find_closest_connections()

    def _identify_chain_ends(self):
        """Identify N and C terminal residues for each chain."""
        for chain in self.model:
            if chain.id in self.excluded_chains:
                continue
                
            residues = list(chain)
            if not residues:
                continue
                
            self.chain_ends[chain.id] = {
                'N': residues[0],
                'C': residues[-1]
            }

    def _find_closest_connections(self):
        """Find closest chain ends to connect."""
        chain_ids = list(self.chain_ends.keys())
        
        for i, chain1 in enumerate(chain_ids):
            distances = []
            for chain2 in chain_ids[i+1:]:  # Only look at chains we haven't considered yet
                end_carbon = self.chain_ends[chain1]['C']['C'].get_coord()
                start_n = self.chain_ends[chain2]['N']['N'].get_coord()
                
                distances.append({
                    'c2': chain2,
                    'distance': np.linalg.norm(end_carbon - start_n)
                })
            
            if distances:  # Only add if we found valid connections
                best_connection = min(distances, key=lambda x: x['distance'])
                self.connections.append({
                    'chain1': chain1,
                    'chain2': best_connection['c2'],
                    'distance': best_connection['distance']
                })

    def _calculate_residue_direction(self, residue):
        """Calculate the direction vector of a residue's backbone."""
        ca_pos = residue['CA'].get_coord()
        if 'C' in residue:
            c_pos = residue['C'].get_coord()
            direction = c_pos - ca_pos
        else:
            n_pos = residue['N'].get_coord()
            direction = ca_pos - n_pos
        return direction / np.linalg.norm(direction)

    def _evaluate_bezier(self, control_points, t):
        """Evaluate a cubic Bezier curve at parameter t."""
        p0, p1, p2, p3 = control_points
        return (
            (1-t)**3 * p0 + 
            3*(1-t)**2 * t * p1 + 
            3*(1-t) * t**2 * p2 + 
            t**3 * p3
        )

    def _calculate_path_length(self, control_points, num_points=100):
        """Calculate the total length of the Bezier curve."""
        t = np.linspace(0, 1, num_points)
        points = np.array([self._evaluate_bezier(control_points, ti) for ti in t])
        
        # Calculate total length as sum of segments
        segments = np.diff(points, axis=0)
        lengths = np.sqrt(np.sum(segments**2, axis=1))
        return np.sum(lengths)

    def _optimize_control_points(self, start_pos, end_pos, start_dir, end_dir, num_residues):
        """Optimize control points to achieve desired path length."""
        target_length = (num_residues - 1) * self.ca_ca_distance
        distance = np.linalg.norm(end_pos - start_pos)
        
        def objective(x):
            # x contains scaling factors for control points
            scale1, scale2 = x
            p0 = start_pos
            p1 = start_pos + start_dir * (distance * scale1)
            p2 = end_pos - end_dir * (distance * scale2)
            p3 = end_pos
            
            path_length = self._calculate_path_length((p0, p1, p2, p3))
            return (path_length - target_length)**2
        
        # Optimize scaling factors
        result = minimize(objective, [0.33, 0.33], bounds=[(0.1, 0.9), (0.1, 0.9)])
        scale1, scale2 = result.x
        scale1, scale2 = (.33,.33)
        # Return optimized control points
        return (
            start_pos,
            start_pos + start_dir * (distance * scale1),
            end_pos - end_dir * (distance * scale2),
            end_pos
        )

    def generate_linker(self, connection ):
        """Generate backbone coordinates for a linker using an optimized Bezier curve."""
        chain1 = self.chain_ends[connection['chain1']]
        chain2 = self.chain_ends[connection['chain2']]
        
        start_res = chain1['C']
        end_res = chain2['N']
        start_direction =self._calculate_residue_direction(start_res)* (.5 + np.random.rand() * 2)
        end_direction = self._calculate_residue_direction(end_res) * (.5 + np.random.rand() * 2)
        
        start_ca = start_res['CA'].get_coord()
        end_ca = end_res['CA'].get_coord()
        
        distance = np.linalg.norm(end_ca - start_ca)
        # Generate a test path for the Bezier curve
        test_control_points = (start_ca, start_ca + start_direction * (distance / 3), end_ca - end_direction * (distance / 3), end_ca)
        test_path_length = self._calculate_path_length(test_control_points)
        num_residues = int(np.ceil(test_path_length / self.ca_ca_distance))
        
         
        
        # Optimize control points for target path length
        control_points = self._optimize_control_points(
            start_ca, end_ca, start_direction, end_direction, num_residues
        )
        
        # Generate points along optimized Bezier curve
        t = np.linspace(0, 1, num_residues)
        path = np.array([self._evaluate_bezier(control_points, ti) for ti in t])
        
        # Generate backbone atoms
        backbone = []
        for i in range(num_residues):
            if i == 0:
                prev_c = start_res['C'].get_coord()
                prev_ca = start_ca
            else:
                prev_c = backbone[i-1]['C']
                prev_ca = backbone[i-1]['CA']
            
            # Calculate tangent for better backbone placement
            if i < len(path) - 1:
                tangent = path[i+1] - path[i]
            else:
                tangent = end_direction
            tangent = tangent / np.linalg.norm(tangent)
            
            res_atoms = self._generate_backbone_atoms_with_direction(
                prev_c, prev_ca, path[i], tangent
            )
            backbone.append(res_atoms)
        
        return backbone

    def _generate_backbone_atoms_with_direction(self, prev_c, prev_ca, current_ca, direction):
        """Generate backbone atoms with proper geometry."""
        peptide_plane_normal = np.cross(prev_c - prev_ca, direction)
        peptide_plane_normal = peptide_plane_normal / np.linalg.norm(peptide_plane_normal)
        
        n_direction = np.cross(direction, peptide_plane_normal)
        n_pos = prev_c + self.c_n_length * n_direction
        
        ca_pos = current_ca
        
        c_direction = np.cross(direction, peptide_plane_normal)
        c_pos = ca_pos + self.ca_c_length * c_direction
        
        return {
            'N': n_pos,
            'CA': ca_pos,
            'C': c_pos
        }

    def connect_chains(self):
        """Connect all identified chain pairs with linkers."""
        connection=  min(self.connections, key=lambda x: x['distance'])
        if connection['distance'] > 25:
            print('too far',connection['distance'])
            return False
        backbone_atoms = self.generate_linker(connection)
        
        # Get the chains to connect
        chain1_id = connection['chain1']
        chain2_id = connection['chain2']
        
        # Add linker residues to chain1
        last_res_id = max(int(res.id[1]) for res in self.model[chain1_id])
        
        # Add linker residues
        for i, atoms in enumerate(backbone_atoms):
            new_res_id = last_res_id + i + 1
            new_res = Residue.Residue((' ', new_res_id, ' '), 'GLY', '')
            
            for atom_name, coord in atoms.items():
                new_atom = Atom.Atom(
                    name=atom_name,
                    coord=coord,
                    bfactor=20.0,
                    occupancy=1.0,
                    altloc=' ',
                    fullname=atom_name,
                    serial_number=None,
                    element=atom_name[0]
                )
                new_res.add(new_atom)
            
            self.model[chain1_id].add(new_res)
        
        # Move residues from chain2 to chain1
        residues_to_move = list(self.model[chain2_id].get_residues())
        chain =self.model[chain1_id]
        new_res_id=max([max(int(res.id[1]) for res in chain), max(int(res.id[1]) for res in residues_to_move)])
        for res in residues_to_move:
            try:
                new_res_id +=1
                res.id = (' ', new_res_id, ' ')
                chain.add(res)
            except Exception as e:
                print(e)
                print("Error adding residue")
        
        # Remove chain2
        self.model.detach_child(chain2_id)
        print(chain1_id,chain2_id)
        return True
         

aligner = HemeHelixAligner("double_heme.pdb")
aligner.align_all()
aligner.save_structure('test.pdb')
print("Structure aligned and saved.")

model =aligner.model
# Initialize the manipulator with your model
manipulator = HelixManipulator(model)

# Align helices with z-axis
manipulator.align_to_z()

# Rotate helices (e.g., 30 degrees)
angle = 10
manipulator.rotate_helices(angle)
manipulator.create_packed_copies(angle)


io = PDBIO()
io.set_structure(model)
io.save("target-complex.pdb")

heme_chains = []
for chain in model:
    for residue in chain:
        if residue.get_resname() == "HEM":
            heme_chains.append(chain.id)
            break
        
heme_chains        

Structure aligned and saved.


['N', 'E', 'H', 'K', 'Z', 'Q']

In [None]:

# Find all chains in model that have a HEM marker
heme_chains = []
for chain in model:
    for residue in chain:
        if residue.get_resname() == "HEM":
            heme_chains.append(chain.id)
            break

# Create linker and connect chains
linker = HelixLinker(model, excluded_chains=heme_chains)
linker.connect_chains( )

In [55]:
parser = PDB.PDBParser(QUIET=True)
structure = parser.get_structure('protein', 'folded.pdb')
model = structure[0]

# Save structure
io = PDBIO()
io.set_structure(model)
io.save("connected_structure.pdb")

 
shortest=0
while shortest <30:
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure('protein', 'connected_structure.pdb')
    model = structure[0]
    # Find all chains in model that have a HEM marker
    heme_chains = []
    for chain in model:
        for residue in chain:
            if residue.get_resname() == "HEM":
                heme_chains.append(chain.id)
                break

    linker = HelixLinker(model, excluded_chains=heme_chains)
    shortest = min(linker.connections, key=lambda x: x['distance'])['distance']
    print('short',shortest)
    minFound = linker.connect_chains( )
    print('minFound',minFound)
    if minFound==False:
        break

    # Save structure
    io = PDBIO()
    io.set_structure(model)
    io.save("connected_structure.pdb")

short 9.8761215
G J
minFound True
short 10.232591
A D
minFound True
short 10.558442
E H
minFound True
short 10.602797
I L
minFound True
short 12.28846
B C
minFound True
short 18.03264
E G
minFound True
short 18.522594
F I
minFound True
short 19.149797
A K
minFound True
short 11.411663
A F
minFound True
short 16.879282
A E
minFound True
short 22.073984
A B
minFound True


ValueError: min() iterable argument is empty

In [None]:


class ModelManipulator:
    def __init__(self, model):
        self.model = model

    
    def calculate_center(self, chain_id):
        """Calculates the geometric center of a chain."""
        chain = self.model[chain_id]
        # Get CA coordinates 
        ca_coords = np.array([atom.get_coord() for atom in chain.get_atoms() 
                            if atom.name == 'CA'])
        center = np.mean(ca_coords, axis=0)
        
        return center

    def rotate_around_axis(self, axis_point, axis_vector, angle):
        """Rotates the model around an arbitrary axis."""
        # Normalize the axis vector
        axis_vector = axis_vector / np.linalg.norm(axis_vector)
        cos_theta = np.cos(angle)
        sin_theta = np.sin(angle)
        u = axis_vector

        # Rotation matrix using Rodrigues' rotation formula
        rotation_matrix = np.array([
            [cos_theta + u[0]**2 * (1 - cos_theta),
             u[0]*u[1]*(1 - cos_theta) - u[2]*sin_theta,
             u[0]*u[2]*(1 - cos_theta) + u[1]*sin_theta],
            [u[1]*u[0]*(1 - cos_theta) + u[2]*sin_theta,
             cos_theta + u[1]**2 * (1 - cos_theta),
             u[1]*u[2]*(1 - cos_theta) - u[0]*sin_theta],
            [u[2]*u[0]*(1 - cos_theta) - u[1]*sin_theta,
             u[2]*u[1]*(1 - cos_theta) + u[0]*sin_theta,
             cos_theta + u[2]**2 * (1 - cos_theta)]
        ])

        for chain in self.model:
            for atom in chain.get_atoms():
                atom.coord = axis_point + np.dot(rotation_matrix, atom.coord - axis_point)

    def align_chains_z(self, iron_coord, center_a, center_b):
        """Aligns the centers of chains A and B to the z-axis above and below the iron."""
        center_vector_a = center_a - iron_coord
        center_vector_b = center_b - iron_coord

        z_vector_a = np.array([0, 0, np.linalg.norm(center_vector_a)])
        z_vector_b = np.array([0, 0, -np.linalg.norm(center_vector_b)])

        self.rotate_chain_to_vector("A", center_vector_a, z_vector_a, iron_coord)
        self.rotate_chain_to_vector("B", center_vector_b, z_vector_b, iron_coord)

    def rotate_chain_to_vector(self, chain_id, from_vector, to_vector, origin):
        """Rotates a chain so that from_vector aligns with to_vector."""
        from_vector = from_vector / np.linalg.norm(from_vector)
        to_vector = to_vector / np.linalg.norm(to_vector)

        # Calculate rotation axis and angle
        rotation_axis = np.cross(from_vector, to_vector)
        if np.linalg.norm(rotation_axis) < 1e-6:
            return  # No rotation needed

        rotation_axis = rotation_axis / np.linalg.norm(rotation_axis)
        angle = np.arccos(np.clip(np.dot(from_vector, to_vector), -1.0, 1.0))
        self.rotate_around_axis(origin, rotation_axis, angle)

    def rotate_chains_ab(self, iron_coord, angle):
        """Rotates chain A clockwise and chain B counterclockwise."""
        axis_vector = np.array([0, 0, 1])
        self.rotate_chain("A", iron_coord, axis_vector, angle)
        self.rotate_chain("B", iron_coord, axis_vector, -angle)

    def rotate_chain(self, chain_id, origin, axis_vector, angle):
        """Rotates a specific chain around a given axis."""
        chain = self.model[chain_id]
        for atom in chain.get_atoms():
            atom.coord = origin + np.dot(self.get_rotation_matrix(axis_vector, angle), atom.coord - origin)

    def duplicate_model(self, repeats, y_distance):
        """Creates copies of chains A, B, and N along the y-axis."""
        new_chains = []
        for i in range(1, repeats + 1):
            offset = np.array([0, i * y_distance, 0])
            for chain_id in ["A", "B", "N"]:
                chain = self.model[chain_id]
                for residue in chain:
                    for atom in residue:
                        atom.coord += offset
                new_chains.append(chain)
        return new_chains

    def save_model(self, output_filename):
        """Saves the modified structure."""
        io = PDBIO()
        io.set_structure(self.model)
        io.save(output_filename)

In [None]:

class HelixManipulator:
    def __init__(self, model, iron_selector="N:FE"):
        """
        Initialize the manipulator with a BioPython model.

        Args:
            model: BioPython Model object containing chains A, B (helices) and N (heme)
            iron_selector: String specifying chain:atom for the iron atom
        """
        self.model = model
        self.chain_centers = {}
        self.iron_pos = None
        self.iron_chain, self.iron_name = iron_selector.split(":")
        self.helix_radius = None
        self.copies = []
        self.barrel_center = None

    def get_chain_centers_and_width(self):
        """Calculate the geometric center and radius of the helix barrel."""
        all_coords = []
        max_distances = []
        
        # First calculate the barrel center using all CA atoms
        for chain_id in ["A", "B"]:
            chain = self.model[chain_id]
            chain_coords = []
            
            for residue in chain:
                for atom in residue:
                    if atom.name == "CA":
                        coord = atom.get_coord()
                        chain_coords.append(coord)
                        all_coords.append(coord)
                        
            self.chain_centers[chain_id] = np.mean(chain_coords, axis=0)
            
        # Calculate barrel center as midpoint between helix centers
        self.barrel_center = np.mean([self.chain_centers["A"], self.chain_centers["B"]], axis=0)
        
        # Now calculate maximum radius from barrel center to any atom
        for chain_id in ["A", "B"]:
            chain = self.model[chain_id]
            for residue in chain:
                for atom in residue:
                    dist = np.linalg.norm(atom.get_coord() - self.barrel_center)
                    max_distances.append(dist)
        
        self.helix_radius = np.max(max_distances)

    def create_packed_copies(self, angle_degrees):
        """
        Create copies of the helices arranged in a packed pattern.
        
        Args:
            angle_degrees: Rotation angle in degrees for helices
        """
        if self.helix_radius is None:
            self.get_chain_centers_and_width()
            
        # Calculate spacing based on helix diameter and add 10% buffer
        y_spacing = self.helix_radius * 4.2  # 2 * radius * 1.1
        
        # Create 6 copies
        for i in range(6):
            # Create new chains using Chain.copy()
            new_chains = {
                chain_id: self.model[chain_id].copy() 
                for chain_id in ["A", "B", self.iron_chain]
            }
            
            # Calculate y-offset
            y_offset = (i - 2.5) * y_spacing  # Center the pattern around y=0
            
            # Translate the chains
            for chain in new_chains.values():
                self._translate_chain(chain, np.array([0, y_offset, 0]))
            
            # Rotate the helices
            self._rotate_helices_in_chains(new_chains, angle_degrees)
            
            self.copies.append(new_chains)

    def check_clashes(self, distance_threshold=2.0):
        """
        Check for clashes between all pairs of residues in the packed arrangement.
        
        Args:
            distance_threshold: Minimum distance (Å) between residues to be considered a clash
            
        Returns:
            int: Number of clashes found
        """
        clash_count = 0
        
        # Check clashes between all pairs of models
        for i, chains1 in enumerate(self.copies):
            for chains2 in self.copies[i+1:]:
                clash_count += self._count_clashes_between_chains(
                    chains1, chains2, distance_threshold)
                
        return clash_count

    def _count_clashes_between_chains(self, chains1, chains2, threshold):
        """Helper method to count clashes between two sets of chains."""
        clashes = 0
        
        # Compare residues between helix chains (A and B only)
        for chain_id1 in ["A", "B"]:
            for chain_id2 in ["A", "B"]:
                chain1 = chains1[chain_id1]
                chain2 = chains2[chain_id2]
                
                # First check residue centers (CA positions)
                for res1 in chain1:
                    ca1 = res1["CA"].get_coord() if "CA" in res1 else None
                    if ca1 is None:
                        continue
                        
                    for res2 in chain2:
                        ca2 = res2["CA"].get_coord() if "CA" in res2 else None
                        if ca2 is None:
                            continue
                            
                        # If residues are close, check all atom pairs
                        if np.linalg.norm(ca1 - ca2) < threshold * 2:
                            # Detailed atom check for close residues
                            for atom1 in res1:
                                coord1 = atom1.get_coord()
                                for atom2 in res2:
                                    coord2 = atom2.get_coord()
                                    if np.linalg.norm(coord1 - coord2) < threshold:
                                        clashes += 1
                                        # Early exit for this residue pair
                                        break
                                if clashes > 0:
                                    break
                    
        return clashes

    def _translate_chain(self, chain, translation_vector):
        """Helper method to translate all atoms in a chain."""
        for residue in chain:
            for atom in residue:
                coord = atom.get_coord()
                new_coord = coord + translation_vector
                atom.set_coord(new_coord)

    def _rotate_helices_in_chains(self, chains, angle_degrees):
        """Helper method to rotate helices in copied chains."""
        angle_rad = np.radians(angle_degrees)
        
        # Create rotation matrices
        rot_matrix_A = rotaxis2m(angle_rad, Vector([0, 0, 1]))
        rot_matrix_B = rotaxis2m(-angle_rad, Vector([0, 0, 1]))
        
        # Get iron position
        iron_pos = None
        for atom in chains[self.iron_chain].get_atoms():
            if atom.name == self.iron_name:
                iron_pos = atom.get_coord()
                break
                
        if iron_pos is None:
            raise ValueError("Iron atom not found in copied chains")
            
        # Rotate individual chains
        for atom in chains["A"].get_atoms():
            coord = atom.get_coord()
            new_coord = np.dot(rot_matrix_A, coord - iron_pos) + iron_pos
            atom.set_coord(new_coord)
            
        for atom in chains["B"].get_atoms():
            coord = atom.get_coord()
            new_coord = np.dot(rot_matrix_B, coord - iron_pos) + iron_pos
            atom.set_coord(new_coord)

    

In [None]:
from Bio import PDB
import numpy as np
from scipy.spatial.transform import Rotation
import random

class HistidineRotamerBuilder:
    def __init__(self, pdb_file):
        self.parser = PDB.PDBParser(QUIET=True)
        self.structure = self.parser.get_structure('protein', pdb_file)
        self.model = self.structure[0]
        
    def get_histidine_chi_angles(self):
        """Common rotamer chi angles for histidine (in degrees)"""
        # Based on common histidine rotamer library
        # Chi1 (N-CA-CB-CG), Chi2 (CA-CB-CG-ND1)
        return [
            (-180, -180), (-180, -90), (-180, 0), (-180, 90), 
            (-90, -180), (-90, -90), (-90, 0), (-90, 90),
            (0, -180), (0, -90), (0, 0), (0, 90),
            (90, -180), (90, -90), (90, 0), (90, 90)
        ]
    
    def get_histidine_atoms(self, residue):
        """Get ordered list of atoms for chi angle calculation"""
        chi1_atoms = ['N', 'CA', 'CB', 'CG']
        chi2_atoms = ['CA', 'CB', 'CG', 'ND1']
        
        atoms_chi1 = [residue[atom] for atom in chi1_atoms]
        atoms_chi2 = [residue[atom] for atom in chi2_atoms]
        
        return atoms_chi1, atoms_chi2
    
    def calculate_dihedral(self, p1, p2, p3, p4):
        """Calculate dihedral angle between 4 points"""
        v1 = p2.get_coord() - p1.get_coord()
        v2 = p3.get_coord() - p2.get_coord()
        v3 = p4.get_coord() - p3.get_coord()
        
        n1 = np.cross(v1, v2)
        n2 = np.cross(v2, v3)
        
        angle = np.arctan2(np.dot(np.cross(n1, n2), v2/np.linalg.norm(v2)), 
                          np.dot(n1, n2))
        return np.degrees(angle)
    
    def rotate_about_bond(self, movable_atoms, axis_start, axis_end, angle):
        """Rotate atoms around a bond axis by given angle"""
        axis = axis_end.get_coord() - axis_start.get_coord()
        axis = axis / np.linalg.norm(axis)
        
        rotation = Rotation.from_rotvec(axis * np.radians(angle))
        center = axis_start.get_coord()
        
        for atom in movable_atoms:
            coord = atom.get_coord()
            centered = coord - center
            rotated = rotation.apply(centered)
            atom.set_coord(rotated + center)
    
    def apply_histidine_rotamer(self, residue, chi1, chi2):
        """Apply specific chi angles to histidine"""
        atoms_chi1, atoms_chi2 = self.get_histidine_atoms(residue)
        
        # Get current chi angles
        current_chi1 = self.calculate_dihedral(*atoms_chi1)
        current_chi2 = self.calculate_dihedral(*atoms_chi2)
        
        # Calculate needed rotations
        delta_chi1 = chi1 - current_chi1
        delta_chi2 = chi2 - current_chi2
        
        # Rotate around chi1
        movable_atoms = [atom for atom in residue.get_atoms() 
                        if atom.name not in ['N', 'CA', 'HA', 'H']]
        self.rotate_about_bond(movable_atoms, residue['CA'], residue['CB'], delta_chi1)
        
        # Rotate around chi2
        movable_atoms = [atom for atom in residue.get_atoms() 
                        if atom.name not in ['N', 'CA', 'CB', 'HB2', 'HB3', 'HA', 'H']]
        self.rotate_about_bond(movable_atoms, residue['CB'], residue['CG'], delta_chi2)
    
    
    
    def get_histidine_axis(self, residue):
        """Get vector from CB to NE2 of histidine"""
        return residue['NE2'].get_coord() - residue['CB'].get_coord()
    
    
    
    def align_histidine_to_iron(self, chain, above=True):
        """Align histidine to point towards iron"""
        his = next(res for res in chain.get_residues() if res.get_resname() == "HIS")
        his_axis = self.get_histidine_axis(his)
        
        # Target vector (up or down)
        target = np.array([0, 0, 1]) if above else np.array([0, 0, -1])
        
        # Project histidine axis onto yz plane
        his_yz = his_axis.copy()
        his_yz[0] = 0
        his_yz = his_yz / np.linalg.norm(his_yz)
        
        rotation = self.get_alignment_rotation(his_yz, target)
        for atom in chain.get_atoms():
            coord = atom.get_coord()
            atom.set_coord(np.dot(rotation, coord))
    
    def get_alignment_rotation(self, v1, v2):
        """Get rotation matrix to align v1 with v2"""
        v1 = v1 / np.linalg.norm(v1)
        v2 = v2 / np.linalg.norm(v2)
        
        cross = np.cross(v1, v2)
        if np.allclose(cross, 0):
            return np.eye(3)
            
        dot = np.dot(v1, v2)
        angle = np.arccos(np.clip(dot, -1.0, 1.0))
        
        rotation = Rotation.from_rotvec(cross / np.linalg.norm(cross) * angle)
        return rotation.as_matrix()
    
    def generate_conformer(self):
        """Generate a single conformer with random histidine rotamers"""
        new_struct = self.structure.copy()
        rotamers = self.get_histidine_chi_angles()
        
        # Process chain A
        chain_a = new_struct[0]['A']
        his_a = next(res for res in chain_a.get_residues() if res.get_resname() == "HIS")
        chi1, chi2 = random.choice(rotamers)
        self.apply_histidine_rotamer(his_a, chi1, chi2)
        self.align_chain_to_x(chain_a)
        self.align_histidine_to_iron(chain_a, above=True)
        
        # Process chain B
        chain_b = new_struct[0]['B']
        his_b = next(res for res in chain_b.get_residues() if res.get_resname() == "HIS")
        chi1, chi2 = random.choice(rotamers)
        self.apply_histidine_rotamer(his_b, chi1, chi2)
        self.align_chain_to_x(chain_b)
        self.align_histidine_to_iron(chain_b, above=False)
        
        return new_struct

def generate_rotamer_conformers(input_pdb, output_prefix, num_conformers=10):
    """Generate multiple conformers with different histidine rotamers"""
    builder = HistidineRotamerBuilder(input_pdb)
    io = PDB.PDBIO()
    
    for i in range(num_conformers):
        structure = builder.generate_conformer()
        io.set_structure(structure)
        io.save(f"{output_prefix}_{i+1}.pdb")
        print(f"Generated conformer {i+1}")
            
generate_rotamer_conformers("test.pdb", "conformer")            

Generated conformer 1
Generated conformer 2
Generated conformer 3
Generated conformer 4
Generated conformer 5
Generated conformer 6
Generated conformer 7
Generated conformer 8
Generated conformer 9
Generated conformer 10


In [2]:
from PIL import Image
import os

picfolder = r'C:\Users\bashc\Desktop\art'
# Get all PNG files in the folder
png_files = [f for f in os.listdir(picfolder) if f.endswith('.png')]

# Resize images to 128x128
resized_images = []
for file in png_files:
    img = Image.open(os.path.join(picfolder, file))
    img = img.resize((128, 128))
    resized_images.append(img)

# Create a blank sprite sheet
sprite_sheet = Image.new('RGBA', (2048, 2048))

# Paste images into the sprite sheet
x_offset = 0
y_offset = 0
for img in resized_images:
    sprite_sheet.paste(img, (x_offset, y_offset))
    x_offset += 128
    if x_offset >= 2048:
        x_offset = 0
        y_offset += 128

# Save the sprite sheet
sprite_sheet.save(os.path.join(picfolder, 'sprite_sheet.png'))

In [10]:
import numpy as np
from Bio import PDB
from Bio.PDB.PDBIO import PDBIO
from Bio.PDB import Atom, Residue, Chain, Model, Structure
from scipy.special import comb
#
# PEPTIDE_BOND_THRESHOLD: Distance (Angstroms) between C of res(i) and N of res(i+1)
# to be considered a broken peptide bond (i.e., a new segment starts).
PEPTIDE_BOND_THRESHOLD = 2.0  # Angstroms

# CONNECTION_THRESHOLD: Maximum distance (Angstroms) between terminal atoms
# of segments to consider them for connection.
CONNECTION_THRESHOLD = 25.0 # Angstroms

# PEPTIDE_BOND_LENGTH: Ideal peptide bond length (C-N) for connecting residues
PEPTIDE_BOND_LENGTH = 1.33  # Angstroms

# Average distance between consecutive alpha carbons in a protein (~3.8Å)
CA_CA_DISTANCE = 3.8  # Angstroms

def bernstein_polynomial(t, n, i):
    """Bernstein polynomial for Bezier curve calculation."""
    return comb(n, i) * (t ** i) * ((1 - t) ** (n - i))

def bezier_curve(control_points, num_points=100):
    """
    Generate points along a Bezier curve defined by control points.
    
    Args:
        control_points: List of 3D coordinates (numpy arrays)
        num_points: Number of points to generate along the curve
    
    Returns:
        List of coordinates along the curve
    """
    n = len(control_points) - 1
    curve_points = []
    
    for t in np.linspace(0, 1, num_points):
        point = np.zeros(3)
        for i, cp in enumerate(control_points):
            point += bernstein_polynomial(t, n, i) * cp
        curve_points.append(point)
    
    return curve_points

def create_glycine_residue(res_id, position, prev_c_pos=None, next_n_pos=None):
    """
    Create a glycine residue with coordinates based on the given position.
    Adjusts N and C atoms to form peptide bonds of PEPTIDE_BOND_LENGTH.
    Places O atom using standard geometry relative to N, CA, C of the glycine.

    Args:
        res_id: Residue ID tuple (hetflag, resnum, icode)
        position: 3D coordinates for CA atom (numpy array)
        prev_c_pos: Position of C atom from previous residue for N-terminal connection
        next_n_pos: Position of N atom from next residue for C-terminal connection

    Returns:
        Bio.PDB.Residue with proper glycine geometry
    """
    res = Residue.Residue(res_id, "GLY", " ")
    ca_coord = np.array(position)

    # Ideal Glycine coordinates relative to CA at origin (N-CA-C plane in XY)
    n_ideal_offset_from_ca = np.array([-0.589, -1.354, 0.0]) 
    c_ideal_offset_from_ca = np.array([1.523, 0.0, 0.0])

    n_coord = ca_coord + n_ideal_offset_from_ca
    c_coord = ca_coord + c_ideal_offset_from_ca

    if prev_c_pos is not None:
        prev_c_coord = np.array(prev_c_pos)
        vec_prevC_to_glyCA = ca_coord - prev_c_coord
        dist_prevC_to_glyCA = np.linalg.norm(vec_prevC_to_glyCA)
        if dist_prevC_to_glyCA > 1e-3: 
            dir_prevC_to_glyCA = vec_prevC_to_glyCA / dist_prevC_to_glyCA
            n_coord = prev_c_coord + dir_prevC_to_glyCA * PEPTIDE_BOND_LENGTH

    if next_n_pos is not None:
        next_n_coord = np.array(next_n_pos)
        vec_glyCA_to_nextN = next_n_coord - ca_coord
        dist_glyCA_to_nextN = np.linalg.norm(vec_glyCA_to_nextN)
        if dist_glyCA_to_nextN > 1e-3: 
            dir_glyCA_to_nextN = vec_glyCA_to_nextN / dist_glyCA_to_nextN
            c_coord = next_n_coord - dir_glyCA_to_nextN * PEPTIDE_BOND_LENGTH

    vec_ca_c = c_coord - ca_coord
    vec_ca_n = n_coord - ca_coord 

    norm_vec_ca_c = np.linalg.norm(vec_ca_c)
    norm_vec_ca_n = np.linalg.norm(vec_ca_n)

    if norm_vec_ca_c < 1e-3 or norm_vec_ca_n < 1e-3:
        o_default_offset_from_c = np.array([0.624, 1.078, 0.0]) 
        o_coord = c_coord + o_default_offset_from_c 
    else:
        u_ca_c = vec_ca_c / norm_vec_ca_c 
        u_ca_n = vec_ca_n / norm_vec_ca_n 

        plane_normal = np.cross(u_ca_n, u_ca_c)
        norm_plane_normal = np.linalg.norm(plane_normal)

        if norm_plane_normal < 1e-3: 
            if abs(u_ca_c[0]) > 0.9: arbitrary_perp_ref = np.array([0.0, 1.0, 0.0])
            else: arbitrary_perp_ref = np.array([1.0, 0.0, 0.0])
            plane_normal = np.cross(u_ca_c, arbitrary_perp_ref)
            if np.linalg.norm(plane_normal) < 1e-3: 
                 plane_normal = np.cross(u_ca_c, np.array([0.0,0.0,1.0]))
            plane_normal /= np.linalg.norm(plane_normal)

        m_vec = np.cross(plane_normal, u_ca_c)
        m_vec = m_vec / np.linalg.norm(m_vec)
        
        if np.dot(m_vec, u_ca_n - u_ca_c * np.dot(u_ca_n, u_ca_c)) < 0:
            m_vec = -m_vec
            
        C_O_BOND_LENGTH = 1.23 
        ANGLE_CA_C_O = np.radians(120.8) 

        o_coord_comp_along_cca = C_O_BOND_LENGTH * np.cos(ANGLE_CA_C_O)
        o_coord_comp_along_mvec = C_O_BOND_LENGTH * np.sin(ANGLE_CA_C_O)
        
        o_coord = c_coord + o_coord_comp_along_cca * (-u_ca_c) + o_coord_comp_along_mvec * m_vec

    n_atom = Atom.Atom('N', n_coord, 0.0, 1.0, ' ', ' N  ', 0, 'N')
    ca_atom = Atom.Atom('CA', ca_coord, 0.0, 1.0, ' ', ' CA ', 0, 'C')
    c_atom = Atom.Atom('C', c_coord, 0.0, 1.0, ' ', ' C  ', 0, 'C')
    o_atom = Atom.Atom('O', o_coord, 0.0, 1.0, ' ', ' O  ', 0, 'O')
    
    res.add(n_atom)
    res.add(ca_atom)
    res.add(c_atom)
    res.add(o_atom)
    
    return res

def get_terminal_atom(residue, atom_name):
    """
    Retrieves a specific atom from a residue.

    Args:
        residue (Bio.PDB.Residue.Residue): The residue to search within.
        atom_name (str): The name of the atom to find (e.g., 'N', 'CA', 'C', 'O').

    Returns:
        Bio.PDB.Atom.Atom or None: The Atom object if found, otherwise None.
    """
    if atom_name in residue:
        return residue[atom_name]
    return None

class Segment:
    """Represents a contiguous segment of residues."""
    def __init__(self, residues, original_chain_id, segment_idx_in_chain):
        if not residues:
            raise ValueError("Segment cannot be empty")
        self.residues = list(residues) 
        self.original_chain_id = original_chain_id
        
        first_res_id_str = f"{residues[0].id[0].strip()}{residues[0].id[1]}{residues[0].id[2].strip()}"
        last_res_id_str = f"{residues[-1].id[0].strip()}{residues[-1].id[1]}{residues[-1].id[2].strip()}"
        self.id = f"Seg-{original_chain_id}-{segment_idx_in_chain}({first_res_id_str}_to_{last_res_id_str})"

        self.n_res = self.residues[0]
        self.c_res = self.residues[-1]

        self.n_atom = get_terminal_atom(self.n_res, 'N')
        self.c_atom = get_terminal_atom(self.c_res, 'C')

        if self.n_atom is None:
            raise ValueError(f"Segment {self.id} missing N-terminal atom for residue {self.n_res.id}.")
        if self.c_atom is None:
            raise ValueError(f"Segment {self.id} missing C-terminal atom for residue {self.c_res.id}.")

def connect_broken_chains(input_pdb_path, output_pdb_path):
    """
    Loads a PDB file, identifies all segments, determines the optimal way to connect
    them into a primary long chain, and connects remaining segments into secondary chains.
    Fills gaps with glycine residues. Outputs all resulting chains.
    """
    parser = PDB.PDBParser(QUIET=True)
    try:
        structure = parser.get_structure('protein', input_pdb_path)
    except FileNotFoundError:
        print(f"Error: Input PDB file not found at {input_pdb_path}")
        return
    except Exception as e:
        print(f"Error parsing PDB file {input_pdb_path}: {e}")
        return

    model = structure[0]
    all_segments = []
    
    for chain in model:
        residues_in_chain = [res for res in chain.get_residues() if PDB.is_aa(res, standard=True) or res.id[0] != ' ']
        if not residues_in_chain:
            continue

        current_segment_residues = []
        segment_counter_for_chain = 0
        for i, res in enumerate(residues_in_chain):
            if not current_segment_residues:
                current_segment_residues.append(res)
            else:
                prev_res = current_segment_residues[-1]
                if 'C' in prev_res and 'N' in res:
                    dist = np.linalg.norm(prev_res['C'].get_coord() - res['N'].get_coord())
                    if dist > PEPTIDE_BOND_THRESHOLD:
                        all_segments.append(Segment(current_segment_residues, chain.id, segment_counter_for_chain))
                        segment_counter_for_chain += 1
                        current_segment_residues = [res]
                    else:
                        current_segment_residues.append(res)
                else: 
                    all_segments.append(Segment(current_segment_residues, chain.id, segment_counter_for_chain))
                    segment_counter_for_chain += 1
                    current_segment_residues = [res]
        
        if current_segment_residues:
            all_segments.append(Segment(current_segment_residues, chain.id, segment_counter_for_chain))

    if not all_segments:
        print("No valid segments found in the PDB file.")
        return
    
    print(f"Identified {len(all_segments)} segments in total.")

    # --- 3. Find the best single chain by trying all possible start segments ---
    best_connection_run = {
        "initial_segment_index": -1,
        "connected_segments_ordered": [], 
        "num_original_segments": 0
    }

    for i, initial_seg in enumerate(all_segments):
        current_merged_chain_segments = [initial_seg]
        remaining_for_this_run = [s for idx, s in enumerate(all_segments) if idx != i]
        
        chain_n_end_atom = initial_seg.n_atom 
        chain_c_end_atom = initial_seg.c_atom

        while remaining_for_this_run:
            best_append_candidate = None 
            min_dist_append = float('inf')
            best_prepend_candidate = None
            min_dist_prepend = float('inf')

            chain_c_coord = chain_c_end_atom.get_coord()
            chain_n_coord = chain_n_end_atom.get_coord()

            for j, seg_candidate in enumerate(remaining_for_this_run):
                dist_c_to_n = np.linalg.norm(chain_c_coord - seg_candidate.n_atom.get_coord())
                if dist_c_to_n < min_dist_append:
                    min_dist_append = dist_c_to_n
                    best_append_candidate = (j, seg_candidate, dist_c_to_n)

                dist_n_to_c = np.linalg.norm(chain_n_coord - seg_candidate.c_atom.get_coord())
                if dist_n_to_c < min_dist_prepend:
                    min_dist_prepend = dist_n_to_c
                    best_prepend_candidate = (j, seg_candidate, dist_n_to_c)
            
            added_in_iteration = False
            choice_made = None
            
            can_append = best_append_candidate and best_append_candidate[2] < CONNECTION_THRESHOLD
            can_prepend = best_prepend_candidate and best_prepend_candidate[2] < CONNECTION_THRESHOLD

            if can_append and can_prepend:
                choice_made = "append" if best_append_candidate[2] <= best_prepend_candidate[2] else "prepend"
            elif can_append:
                choice_made = "append"
            elif can_prepend:
                choice_made = "prepend"

            if choice_made == "append":
                idx_rem, seg_to_add, _ = best_append_candidate
                current_merged_chain_segments.append(seg_to_add)
                chain_c_end_atom = seg_to_add.c_atom 
                remaining_for_this_run.pop(idx_rem)
                added_in_iteration = True
            elif choice_made == "prepend":
                idx_rem, seg_to_add, _ = best_prepend_candidate
                current_merged_chain_segments.insert(0, seg_to_add)
                chain_n_end_atom = seg_to_add.n_atom 
                remaining_for_this_run.pop(idx_rem)
                added_in_iteration = True
            
            if not added_in_iteration:
                break 

        if len(current_merged_chain_segments) > best_connection_run["num_original_segments"]:
            best_connection_run["initial_segment_index"] = i
            best_connection_run["connected_segments_ordered"] = list(current_merged_chain_segments)
            best_connection_run["num_original_segments"] = len(current_merged_chain_segments)

    # --- 4. Build the primary chain ---
    final_chains_residues = [] 
    used_segments_for_main_chain = set()
    
    max_res_num_in_pdb = 0
    for seg_glob in all_segments:
        for res_glob in seg_glob.residues:
            if res_glob.id[1] > max_res_num_in_pdb:
                max_res_num_in_pdb = res_glob.id[1]
    temp_glycine_res_num_counter = max_res_num_in_pdb

    if not best_connection_run["connected_segments_ordered"]:
        print("Warning: Could not form any primary connected chain.")
    else:
        print(f"Best primary chain will connect {best_connection_run['num_original_segments']} segments.")
        primary_chain_merged_residues = []
        segments_for_main_chain = best_connection_run["connected_segments_ordered"]
        
        primary_chain_merged_residues.extend(segments_for_main_chain[0].residues)
        # The C-atom of the *last added residue* to primary_chain_merged_residues
        current_c_atom_obj_main = get_terminal_atom(primary_chain_merged_residues[-1], 'C')
        if not current_c_atom_obj_main: # Should not happen if segment is valid
             raise ValueError(f"Segment {segments_for_main_chain[0].id} has no C-terminal C atom in its last residue.")


        for i in range(1, len(segments_for_main_chain)):
            prev_segment_in_main = segments_for_main_chain[i-1] # For context/CA
            current_segment_to_add_main = segments_for_main_chain[i]

            prev_c_for_glycine = current_c_atom_obj_main.get_coord()
            next_n_for_glycine = current_segment_to_add_main.n_atom.get_coord()
            dist_c_n = np.linalg.norm(prev_c_for_glycine - next_n_for_glycine)

            added_glycines_main = []
            if dist_c_n > PEPTIDE_BOND_THRESHOLD:
                p0 = prev_c_for_glycine
                p3 = next_n_for_glycine
                ca_prev_res = primary_chain_merged_residues[-1] # Last residue of growing chain
                ca_next_res = current_segment_to_add_main.residues[0] # First residue of next segment

                if 'CA' in ca_prev_res and 'CA' in ca_next_res:
                    ca_prev = ca_prev_res['CA'].get_coord()
                    ca_next = ca_next_res['CA'].get_coord()
                    v0 = p0 - ca_prev 
                    v3 = p3 - ca_next 
                    p1 = p0 + v0 * 0.5 
                    p2 = p3 + v3 * 0.5 
                else: 
                    mid_point = (p0 + p3) / 2.0
                    p1 = p0 + (mid_point - p0) * 0.5
                    p2 = p3 + (mid_point - p3) * 0.5
                control_points = [p0, p1, p2, p3]

                curve_length_approx = np.linalg.norm(p3 - p0) 
                num_glycines = max(0, int(round(curve_length_approx / CA_CA_DISTANCE)) -1)

                if num_glycines > 0:
                    print(f"  Main chain: adding {num_glycines} glycines between ...{prev_segment_in_main.id} and {current_segment_to_add_main.id} (C-N dist: {dist_c_n:.2f} A)")
                    ca_p0 = ca_prev_res['CA'].get_coord() if 'CA' in ca_prev_res else p0
                    ca_p3 = ca_next_res['CA'].get_coord() if 'CA' in ca_next_res else p3
                    
                    ca_v0 = ca_p0 - p0 
                    ca_v3 = ca_p3 - p3

                    ca_control_points = [ca_p0, ca_p0 + ca_v0 * 0.5, ca_p3 + ca_v3 * 0.5, ca_p3]
                    if num_glycines == 1 and np.linalg.norm(ca_p0 - ca_p3) < CA_CA_DISTANCE * 1.5 :
                         gly_ca_positions = [(ca_p0 + ca_p3)/2.0]
                    else:
                        gly_ca_curve = bezier_curve(ca_control_points, num_points=num_glycines + 2)
                        gly_ca_positions = gly_ca_curve[1:-1]
                    
                    current_connection_c_coord_main = prev_c_for_glycine
                    for k_gly, gly_ca_pos in enumerate(gly_ca_positions):
                        temp_glycine_res_num_counter += 1
                        prev_c_for_this_gly = current_connection_c_coord_main
                        next_n_for_this_gly = next_n_for_glycine if k_gly == num_glycines - 1 else None
                        gly_res = create_glycine_residue(
                            (' ', temp_glycine_res_num_counter, ' '), gly_ca_pos,
                            prev_c_pos=prev_c_for_this_gly, next_n_pos=next_n_for_this_gly
                        )
                        added_glycines_main.append(gly_res)
                        current_connection_c_coord_main = gly_res['C'].get_coord()
            
            primary_chain_merged_residues.extend(added_glycines_main)
            primary_chain_merged_residues.extend(current_segment_to_add_main.residues)
            current_c_atom_obj_main = get_terminal_atom(primary_chain_merged_residues[-1], 'C')
            if not current_c_atom_obj_main:
                 raise ValueError(f"Segment {current_segment_to_add_main.id} resulted in no C-terminal C atom after adding to main chain.")

        final_chains_residues.append(primary_chain_merged_residues)
        used_segments_for_main_chain = set(segments_for_main_chain)

    # --- 5. Process unused segments ---
    all_segments_set = set(all_segments)
    unused_segments_list = list(all_segments_set - used_segments_for_main_chain)
    
    unused_by_original_chain = {}
    for seg in unused_segments_list:
        unused_by_original_chain.setdefault(seg.original_chain_id, []).append(seg)

    for chain_id_orig, seg_list_orig in unused_by_original_chain.items():
        seg_list_orig.sort(key=lambda s: (s.residues[0].id[1], s.residues[0].id[2].strip())) # Sort by resnum then icode

    processed_unused_segments = set()

    for original_chain_id_of_unused, segments_from_one_orig_chain_initial in unused_by_original_chain.items():
        
        segments_to_process_in_group = [s for s in segments_from_one_orig_chain_initial if s not in processed_unused_segments]
        
        while segments_to_process_in_group:
            if not segments_to_process_in_group: break

            if len(segments_to_process_in_group) == 1:
                seg_single = segments_to_process_in_group[0]
                final_chains_residues.append(list(seg_single.residues))
                processed_unused_segments.add(seg_single)
                print(f"  Unused segment {seg_single.id} from original chain {original_chain_id_of_unused} forms a separate chain.")
                segments_to_process_in_group.pop(0)
                continue

            best_internal_connection_run = { "connected_segments_ordered": [], "num_original_segments": 0 }

            for i_start_internal, initial_seg_internal in enumerate(segments_to_process_in_group):
                current_merged_internal = [initial_seg_internal]
                remaining_internal = [s for idx, s in enumerate(segments_to_process_in_group) if idx != i_start_internal]
                
                chain_n_end_atom_internal = initial_seg_internal.n_atom
                chain_c_end_atom_internal = initial_seg_internal.c_atom

                while remaining_internal:
                    best_append_internal = None; min_dist_app_int = float('inf')
                    best_prepend_internal = None; min_dist_pre_int = float('inf')
                    chain_c_coord_int = chain_c_end_atom_internal.get_coord()
                    chain_n_coord_int = chain_n_end_atom_internal.get_coord()

                    for j_int, seg_cand_int in enumerate(remaining_internal):
                        dist_c_n_int = np.linalg.norm(chain_c_coord_int - seg_cand_int.n_atom.get_coord())
                        if dist_c_n_int < min_dist_app_int:
                            min_dist_app_int = dist_c_n_int
                            best_append_internal = (j_int, seg_cand_int, dist_c_n_int)
                        dist_n_c_int = np.linalg.norm(chain_n_coord_int - seg_cand_int.c_atom.get_coord())
                        if dist_n_c_int < min_dist_pre_int:
                            min_dist_pre_int = dist_n_c_int
                            best_prepend_internal = (j_int, seg_cand_int, dist_n_c_int)
                    
                    added_in_iter_internal = False; choice_made_internal = None
                    can_app_int = best_append_internal and best_append_internal[2] < CONNECTION_THRESHOLD
                    can_pre_int = best_prepend_internal and best_prepend_internal[2] < CONNECTION_THRESHOLD

                    if can_app_int and can_pre_int: choice_made_internal = "append" if best_append_internal[2] <= best_prepend_internal[2] else "prepend"
                    elif can_app_int: choice_made_internal = "append"
                    elif can_pre_int: choice_made_internal = "prepend"

                    if choice_made_internal == "append":
                        idx_rem_int, seg_to_add_int, _ = best_append_internal
                        current_merged_internal.append(seg_to_add_int)
                        chain_c_end_atom_internal = seg_to_add_int.c_atom
                        remaining_internal.pop(idx_rem_int); added_in_iter_internal = True
                    elif choice_made_internal == "prepend":
                        idx_rem_int, seg_to_add_int, _ = best_prepend_internal
                        current_merged_internal.insert(0, seg_to_add_int)
                        chain_n_end_atom_internal = seg_to_add_int.n_atom
                        remaining_internal.pop(idx_rem_int); added_in_iter_internal = True
                    if not added_in_iter_internal: break
                
                if len(current_merged_internal) > best_internal_connection_run["num_original_segments"]:
                    best_internal_connection_run["connected_segments_ordered"] = list(current_merged_internal)
                    best_internal_connection_run["num_original_segments"] = len(current_merged_internal)

            if best_internal_connection_run["connected_segments_ordered"]:
                print(f"  Connecting {best_internal_connection_run['num_original_segments']} unused segments from original chain {original_chain_id_of_unused} into a new chain.")
                secondary_chain_residues = []
                segments_for_secondary = best_internal_connection_run["connected_segments_ordered"]
                
                secondary_chain_residues.extend(segments_for_secondary[0].residues)
                current_c_atom_obj_secondary = get_terminal_atom(secondary_chain_residues[-1], 'C')
                if not current_c_atom_obj_secondary:
                    raise ValueError(f"Segment {segments_for_secondary[0].id} has no C-terminal C atom for secondary chain.")

                processed_segments_in_this_group_run = set(segments_for_secondary)

                for i_sec in range(1, len(segments_for_secondary)):
                    prev_seg_sec = segments_for_secondary[i_sec-1]
                    curr_seg_sec = segments_for_secondary[i_sec]
                    prev_c_sec = current_c_atom_obj_secondary.get_coord()
                    next_n_sec = curr_seg_sec.n_atom.get_coord()
                    dist_c_n_sec = np.linalg.norm(prev_c_sec - next_n_sec)
                    added_glycines_sec = []
                    if dist_c_n_sec > PEPTIDE_BOND_THRESHOLD:
                        p0_sec, p3_sec = prev_c_sec, next_n_sec
                        ca_prev_res_sec, ca_next_res_sec = secondary_chain_residues[-1], curr_seg_sec.residues[0]
                        if 'CA' in ca_prev_res_sec and 'CA' in ca_next_res_sec:
                            ca_prev_sec, ca_next_sec = ca_prev_res_sec['CA'].get_coord(), ca_next_res_sec['CA'].get_coord()
                            v0_sec,v3_sec = p0_sec - ca_prev_sec, p3_sec - ca_next_sec
                            p1_sec,p2_sec = p0_sec + v0_sec*0.5, p3_sec + v3_sec*0.5
                        else: 
                            mid_point_sec = (p0_sec+p3_sec)/2.0
                            p1_sec,p2_sec = p0_sec+(mid_point_sec-p0_sec)*0.5, p3_sec+(mid_point_sec-p3_sec)*0.5
                        ctrl_pts_sec = [p0_sec,p1_sec,p2_sec,p3_sec]
                        curve_len_approx_sec = np.linalg.norm(p3_sec-p0_sec)
                        num_gly_sec = max(0,int(round(curve_len_approx_sec/CA_CA_DISTANCE))-1)
                        if num_gly_sec > 0:
                            print(f"    Secondary chain: adding {num_gly_sec} glycines between ...{prev_seg_sec.id} and {curr_seg_sec.id} (C-N dist: {dist_c_n_sec:.2f} A)")
                            ca_p0_sec = ca_prev_res_sec['CA'].get_coord() if 'CA' in ca_prev_res_sec else p0_sec
                            ca_p3_sec = ca_next_res_sec['CA'].get_coord() if 'CA' in ca_next_res_sec else p3_sec
                            ca_v0_sec,ca_v3_sec = ca_p0_sec-p0_sec, ca_p3_sec-p3_sec
                            ca_ctrl_pts_sec = [ca_p0_sec,ca_p0_sec+ca_v0_sec*0.5,ca_p3_sec+ca_v3_sec*0.5,ca_p3_sec]
                            gly_ca_pos_sec_list = [(ca_p0_sec+ca_p3_sec)/2.0] if num_gly_sec==1 and np.linalg.norm(ca_p0_sec-ca_p3_sec)<CA_CA_DISTANCE*1.5 else bezier_curve(ca_ctrl_pts_sec,num_points=num_gly_sec+2)[1:-1]
                            curr_conn_c_coord_sec = prev_c_sec
                            for k_gly_sec, gly_ca_pos_val_sec in enumerate(gly_ca_pos_sec_list):
                                temp_glycine_res_num_counter+=1
                                prev_c_this_gly_sec,next_n_this_gly_sec = curr_conn_c_coord_sec, next_n_sec if k_gly_sec==num_gly_sec-1 else None
                                gly_res_sec = create_glycine_residue((' ',temp_glycine_res_num_counter,' '),gly_ca_pos_val_sec,prev_c_pos=prev_c_this_gly_sec,next_n_pos=next_n_this_gly_sec)
                                added_glycines_sec.append(gly_res_sec)
                                curr_conn_c_coord_sec = gly_res_sec['C'].get_coord()
                    secondary_chain_residues.extend(added_glycines_sec)
                    secondary_chain_residues.extend(curr_seg_sec.residues)
                    current_c_atom_obj_secondary = get_terminal_atom(secondary_chain_residues[-1], 'C')
                    if not current_c_atom_obj_secondary:
                        raise ValueError(f"Segment {curr_seg_sec.id} resulted in no C-terminal C atom for secondary chain.")
                final_chains_residues.append(secondary_chain_residues)
                processed_unused_segments.update(processed_segments_in_this_group_run)
            elif segments_to_process_in_group : # No internal connections, but current segment exists
                seg_single = segments_to_process_in_group[0] # Process the first one as a single chain
                final_chains_residues.append(list(seg_single.residues))
                processed_unused_segments.add(seg_single)
                print(f"  Unused segment {seg_single.id} from original chain {original_chain_id_of_unused} forms a separate chain (no internal connection found for this pass).")
            
            # Update segments_to_process_in_group for the next iteration of the while loop
            segments_to_process_in_group = [s for s in segments_from_one_orig_chain_initial if s not in processed_unused_segments]


    # --- 6. Create new PDB structure with all chains ---
    new_structure = Structure.Structure('connected_protein_multi')
    new_model = Model.Model(0)
    
    chain_letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
    if len(final_chains_residues) > len(chain_letters):
        print(f"Warning: More than {len(chain_letters)} chains generated ({len(final_chains_residues)}). Some will be skipped in output.")
        final_chains_residues = final_chains_residues[:len(chain_letters)]

    for chain_idx, residue_list_for_chain in enumerate(final_chains_residues):
        if not residue_list_for_chain: continue

        new_chain_id = chain_letters[chain_idx % len(chain_letters)]
        new_chain = Chain.Chain(new_chain_id)
        
        current_residue_seq_num = 1
        for original_res in residue_list_for_chain:
            new_res = Residue.Residue(original_res.id, original_res.resname, original_res.segid)
            for atom in original_res.get_atoms():
                new_atom = Atom.Atom(atom.name, atom.coord, atom.bfactor, atom.occupancy,
                                     atom.altloc, atom.fullname, 0, 
                                     element=atom.element, pqr_charge=atom.pqr_charge, radius=atom.radius)
                new_res.add(new_atom)
            hetfield = original_res.id[0]
            new_id = (hetfield, current_residue_seq_num, ' ')
            new_res.id = new_id
            new_chain.add(new_res)
            current_residue_seq_num += 1
        
        new_model.add(new_chain)
        print(f"Outputting Chain {new_chain_id} with {len(new_chain)} residues.")

    new_structure.add(new_model)
    io = PDBIO()
    io.set_structure(new_structure)
    io.save(output_pdb_path)
    print(f"Successfully processed segments. Output saved to {output_pdb_path}")



connect_broken_chains(r"C:\Users\bashc\Desktop\working\ssBinding\working9.pdb",r"C:\Users\bashc\Desktop\working\ssBinding\connected.pdb"  )

Identified 7 segments in total.
Best primary chain will connect 6 segments.
  Unused segment Seg-B-0(1_to_47) from original chain B forms a separate chain.
Outputting Chain A with 492 residues.
Outputting Chain B with 47 residues.
Successfully processed segments. Output saved to C:\Users\bashc\Desktop\working\ssBinding\connected.pdb
