<a href="https://colab.research.google.com/github/acarbn/portfolio/blob/main/structure.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install biopython py3Dmol

Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting py3Dmol
  Downloading py3Dmol-2.4.2-py2.py3-none-any.whl.metadata (1.9 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading py3Dmol-2.4.2-py2.py3-none-any.whl (7.0 kB)
Installing collected packages: py3Dmol, biopython
Successfully installed biopython-1.85 py3Dmol-2.4.2


In [51]:
from Bio.PDB import PDBParser, Superimposer, PDBIO, PDBList
import py3Dmol
from io import StringIO
from google.colab import files
import numpy as np
import plotly.graph_objects as go

class StructureComparer:
    def __init__(self, PDBmobile, chainmobile, PDBref, chainref):
        self.PDBmobile = PDBmobile
        self.chainmobile = chainmobile  # String like "ABCD"
        self.PDBref = PDBref
        self.chainref = chainref        # String like "WXYZ"
        self.ref_atoms = []
        self.mov_atoms = []
        self.chain_len=[]
        self.mov_pdb = None
        self.ref_pdb = None
        self.ref_structure = None
        self.mov_structure = None
        self._diff_vector_computed = False
        self.diffs= []

    def parse_structures(self):
        pdbl = PDBList()
        ref_file = pdbl.retrieve_pdb_file(self.PDBref, pdir='.', file_format='pdb')
        mov_file = pdbl.retrieve_pdb_file(self.PDBmobile, pdir='.', file_format='pdb')

        parser = PDBParser(QUIET=True)
        self.ref_structure = parser.get_structure("ref", ref_file)
        self.mov_structure = parser.get_structure("mov", mov_file)
        print("Structures parsed.")

    def align_structures(self):
        # Convert chain IDs to lists
        ref_chains = list(self.chainref)
        mov_chains = list(self.chainmobile)

        if not self.ref_structure or not self.mov_structure:
            raise ValueError("Parse structures first.")

        # Validate chain counts match
        if len(ref_chains) != len(mov_chains):
            raise ValueError(f"Mismatched chain counts: Reference has {len(ref_chains)}, Mobile has {len(mov_chains)}")

        # Clear previous atom lists
        self.ref_atoms = []
        self.mov_atoms = []
        count=0
        # Process each chain pair
        for ref_id, mov_id in zip(ref_chains, mov_chains):
            try:
                ref_chain = self.ref_structure[0][ref_id]
                mov_chain = self.mov_structure[0][mov_id]
            except KeyError as e:
                raise ValueError(f"Chain {e} not found in structure") from e

            # Pair residues from corresponding positions
            for ref_res, mov_res in zip(ref_chain, mov_chain):
                if ref_res.has_id('CA') and mov_res.has_id('CA'):
                    count+=1
                    self.ref_atoms.append(ref_res['CA'])
                    self.mov_atoms.append(mov_res['CA'])
            self.chain_len.append(count)
        print("Ending positions of each chain: ",str(self.chain_len))
        if not self.ref_atoms or not self.mov_atoms:
            raise ValueError("No matching CA atoms found for alignment")

        # Perform superposition
        sup = Superimposer()
        sup.set_atoms(self.ref_atoms, self.mov_atoms)
        sup.apply(self.mov_structure.get_atoms())

        # Save aligned structures
        io = PDBIO()
        io.set_structure(self.mov_structure)
        with StringIO() as mov_aligned_str:
            io.save(mov_aligned_str)
            self.mov_pdb = mov_aligned_str.getvalue()

        io.set_structure(self.ref_structure)
        with StringIO() as ref_str:
            io.save(ref_str)
            self.ref_pdb = ref_str.getvalue()

    def visualize(self):
      view = py3Dmol.view(width=800, height=600)

      # Add reference model (blue)
      view.addModelsAsFrames(self.ref_pdb, 'pdb')
      view.setStyle({'model': 0}, {})  # Clear default styles
      for chain_id in self.chainref:
        view.setStyle({'model': 0, 'chain': chain_id}, {'cartoon': {'color': 'blue'}})

      # Add mobile model (red)
      view.addModelsAsFrames(self.mov_pdb, 'pdb')
      view.setStyle({'model': 1}, {})  # Clear default styles
      for chain_id in self.chainmobile:
        view.setStyle({'model': 1, 'chain': chain_id}, {'cartoon': {'color': 'red'}})

      view.zoomTo()
      return view


    def save_aligned_structures(self):
        with open(f"{self.PDBmobile}_aligned.pdb", 'w') as f:
            f.write(self.mov_pdb)
        with open(f"{self.PDBref}_ref.pdb", 'w') as f:
            f.write(self.ref_pdb)
        print("Files saved.")
        files.download(f"{self.PDBmobile}_aligned.pdb")
        files.download(f"{self.PDBref}_ref.pdb")
        print("Files downloaded.")

    def diff_vector(self):
        if not self.ref_atoms or not self.mov_atoms:
            raise ValueError("Align structures first")

        self.diffs = [(ref_atom.get_parent().id[1],
              float(np.linalg.norm(ref_atom.coord - mov_atom.coord)))
             for ref_atom, mov_atom in zip(self.ref_atoms, self.mov_atoms)]
        if not self._diff_vector_computed:
            print("Difference vector calculated for",len(self.diffs), "residues.")
            self._diff_vector_computed = True
        return self.diffs

    def plot_diff_vector(self,window):
        if self.diffs==[]:
          print("Difference vector not calculated. Would you like to compute it now? (y/n)")
          user_input = input().strip().lower()
          if user_input == 'y':
            # Compute the difference vector
            diff_vector = self.diff_vector()
          else:
            # Exit or break if no
            print("Exiting without calculation.")
            return  # or break depending on where this function is called
        else:
          diff_vector = self.diffs

        res_ids, distances = zip(*diff_vector)
        fig = go.Figure(data=go.Scatter(x=list(range(1,len(res_ids))), y=distances, mode='lines'))
        fig.update_layout(title='Residue Differences',
                          xaxis_title='Residue ID',
                          yaxis_title='Distance')

            # Customizing x-axis ticks
        new_tickvalslist = [] # X-axis tick positions (based on residue IDs)
        new_ticktextlist = []  # Custom labels for each tick

        # Calculate differences
        self.chain_len.insert(0, 0)

        for i in range(len(self.chain_len)):
          if i != 0:
            new_tickvals=list(range(self.chain_len[i-1]+window, self.chain_len[i], window))
            new_ticktext=list(range(window, self.chain_len[i]-self.chain_len[i-1], window))
            new_tickvalslist+=new_tickvals
            new_ticktextlist+=new_ticktext


        fig.update_layout(
          title='Residue Differences',
          xaxis_title='Residue ID',
          yaxis_title='Distance',
          xaxis=dict(
            tickvals=new_tickvalslist,   # Define the tick positions
            ticktext=new_ticktextlist    # Define the custom labels for ticks
          )
          )

        for idx in self.chain_len[1:-1]:  # skip the first, it's the start of the first chain
          fig.add_shape(
            type="line",
            x0=idx,
            y0=0,
            x1=idx,
            y1=max(distances),
            line=dict(color="red", width=1, dash="dash")
          )
        fig.show()



In [57]:
"""comparer = StructureComparer(
    PDBmobile="4FI3",
    chainmobile="ABCD",
    PDBref="1L7V",
    chainref="ABCD"
)"""

comparer = StructureComparer(
    PDBmobile="1ake",
    chainmobile="A",
    PDBref="4ake",
    chainref="A"
)

comparer.parse_structures()
comparer.align_structures()
#comparer.save_aligned_structures()
#comparer.visualize().show()
comparer.diff_vector()
comparer.plot_diff_vector(window=25)


Structure exists: './pdb4ake.ent' 
Structure exists: './pdb1ake.ent' 
Structures parsed.
Ending positions of each chain:  [214]
Difference vector calculated for 214 residues.
