<a href="https://colab.research.google.com/github/arronelab/SWRITHE/blob/main/SWRITHE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#@title Clone Repo and Compile { display-mode: "form" }
#@markdown Run this cell to clone the writhe calculator git repo.
%%capture
%cd /content/
%shell rm -rf src
!git clone --depth 1 https://github.com/arronelab/writheCalculator.git src
%cd src/src/
%shell sh makeFileAbs.sh
%shell sh makeFileDI.sh
%cd /content/

In [2]:
#@title Pip installs and imports
#@markdown Run this cell to install and import the requisite packages.
%%capture
!pip install biotite
!pip install Bio
!pip install PSIPREDauto
from PSIPREDauto.functions import single_submit
import urllib
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
import os, re, shutil
from tempfile import gettempdir
import biotite
import biotite.structure as struc
import biotite.structure.io.mmtf as mmtf
import biotite.sequence as seq
import biotite.database.rcsb as rcsb
from Bio.PDB import *
from sklearn.cluster import DBSCAN
from collections import defaultdict
from scipy import interpolate

In [7]:
#@title Functions
def get_chains_from_biotite(pdb_code):
    # Fetch and load structure
    file_name = rcsb.fetch(pdb_code.upper(), "mmtf", gettempdir())
    mmtf_file = mmtf.MMTFFile.read(file_name)
    array = mmtf.get_structure(mmtf_file, model=1)
    array = array[struc.filter_amino_acids(array)]
    return struc.get_chains(array)

def convert(s):
    new = ""
    for x in s:
        new+= x
    return new

def simple_ss_clean(fp):
    for i in range(len(fp)-1):
        if fp[i-1]==fp[i+1] and fp[i-1]!=fp[i]:
            fp[i]=fp[i-1]
    return convert(fp)

def get_ss_from_biotite(pdb_code,chain):
    d = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
         'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N',
         'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W',
         'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
    # Dictionary to convert 'secStructList' codes to DSSP values
    # https://github.com/rcsb/mmtf/blob/master/spec.md#secstructlist
    sec_struct_codes = {0 : "I",
                        1 : "S",
                        2 : "H",
                        3 : "E",
                        4 : "G",
                        5 : "B",
                        6 : "T",
                        7 : "C",
                       -1 : "C"}
    # Converter for the DSSP secondary structure elements
    # to the classical ones
    dssp_to_abc = {"I" : "H",
                  "S" : "-",
                  "H" : "H",
                  "E" : "S",
                  "G" : "H",
                  "B" : "S",
                  "T" : "-",
                  "C" : "-"}



    # Fetch and load structure
    file_name = rcsb.fetch(pdb_code.upper(), "mmtf", gettempdir())
    mmtf_file = mmtf.MMTFFile.read(file_name)
    array = mmtf.get_structure(mmtf_file, model=1)
    array = array[struc.filter_amino_acids(array)]
    array = array[array.chain_id==chain]
    array = array[array.atom_name=='CA']
    array = array[array.hetero==False]
    array = array[np.array([i.res_name in list(d.keys()) for i in array])]
    chain_id_per_res = array.chain_id[struc.get_residue_starts(array)]
    sse = mmtf_file["secStructList"]
    sse = sse[:len(array)]
    sse = np.array([sec_struct_codes[code] for code in sse],
                  dtype="U1")
    sse = np.array([dssp_to_abc[e] for e in sse], dtype="U1")
    return get_sses(simple_ss_clean(sse))

def get_backbone_from_biotite(pdb_code,chain):
    d = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
        'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N',
        'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W',
        'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
    # Fetch and load structure
    file_name = rcsb.fetch(pdb_code.upper(), "mmtf", gettempdir())
    mmtf_file = mmtf.MMTFFile.read(file_name)
    array = mmtf.get_structure(mmtf_file, model=1)
    array = array[struc.filter_amino_acids(array)]
    array = array[array.chain_id==chain]
    array = array[array.atom_name=='CA']
    array = array[array.hetero==False]
    array = array[np.array([i.res_name in list(d.keys()) for i in array])]
    coords = [i.coord for i in array]
    return coords

def simplify_sec_struc(SS):
    simplify_dict = {
      'H': 'H',
      'B': 'S',
      'E': 'S',
      'G': 'H',
      'I': 'H',
      'T': '-',
      'S': '-',
      '-': '-',
      'C': '-'
    }
    simp_sec_struc = []
    for i in SS:
      simp_sec_struc.append(simplify_dict[i])
    return simp_sec_struc

def split(word):
    return [char for char in word]

def get_ss_from_fasta(fasta_file):
    lines = []
    with open(fasta_file,'r') as fin:
        for line in fin:
            lines+= [line.split()]
    ss = lines[-1][0]
    ss = split(ss)
    simple_ss = simplify_sec_struc(split(ss))
    return ''.join(simple_ss)

def get_sses(ss):
    sses = []
    count = 1
    i = 0
    while i<len(ss)-1:
        if ss[i+1] == ss[i]:
            count += 1
            i += 1
        else:
            sses.append([ss[i], count])
            count = 1
            i += 1
    sses.append([ss[-1], count])
    return sses

def get_first_chain_fasta(xyz_file,fasta_file):
    ca = np.genfromtxt(xyz_file)
    lines = []
    with open(fasta_file,'r') as fin:
        for line in fin:
            lines+= [line.split()]
    return lines[1][0][:len(ca)]

def overwrite_fasta_file(xyz_file,fasta_file):
      lines=[]
      with open(fasta_file,'r') as fin:
          for line in fin:
              lines+= [line.split()]
      first_line = ' '.join(lines[0])
      second_line = get_first_chain_fasta(xyz_file,fasta_file)
      with open(fasta_file,'w') as fout:
        fout.write(first_line)
        fout.write('\n')
        fout.write(second_line)

def running_mean(x, N):
    cumsum = np.cumsum(np.insert(x, 0, 0))
    return (cumsum[N:] - cumsum[:-N]) / float(N)

def set_axes_radius(ax, origin, radius):
    ax.set_xlim3d([origin[0] - radius, origin[0] + radius])
    ax.set_ylim3d([origin[1] - radius, origin[1] + radius])
    ax.set_zlim3d([origin[2] - radius, origin[2] + radius])


def set_axes_equal(ax, zoom=1.):
    '''
        Make axes of 3D plot have equal scale so that spheres appear as spheres,
        cubes as cubes, etc..  This is one possible solution to Matplotlib's
        ax.set_aspect("equal") and ax.axis("equal") not working for 3D.
        input:
          ax:   a matplotlib axis, e.g., as output from plt.gca().

    '''
    limits = np.array([
        ax.get_xlim3d(),
        ax.get_ylim3d(),
        ax.get_zlim3d(),
    ])

    origin = np.mean(limits, axis=1)
    radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0])) / zoom
    set_axes_radius(ax, origin, radius)

def view_molecule(molecule):
    mol = np.genfromtxt(molecule)
    x_true = [i[0] for i in mol]
    y_true = [i[1] for i in mol]
    z_true = [i[2] for i in mol]
    tck, u = interpolate.splprep([x_true,y_true,z_true], s=5)
    u_fine = np.linspace(0,1,10*len(mol))
    x_fine, y_fine, z_fine = interpolate.splev(u_fine, tck)
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=x_fine,
                               y=y_fine,
                               z=z_fine,
                               marker=dict(size=1,
                                           color='black'
                                           ),
                               line=dict(width=15,
                                         color='rgba(0,0,0,0.75)')
                               )
    )
    fig.add_trace(go.Scatter3d(x=x_fine,
                               y=y_fine,
                               z=z_fine,
                               marker=dict(size=1,
                                           color=[i*100/len(x_fine) for i in range(len(x_fine))],
                                           colorscale='dense'
                                           ),
                               line=dict(width=10,
                                         color=[i*100/len(x_fine) for i in range(len(x_fine))],
                                         colorscale='dense'
                                         ),
                               )
    )
    colorbar_trace = go.Scatter3d(x=[None],
                          y=[None], z=[None],
                          mode='markers',
                          marker=dict(
                              colorscale='dense',
                              showscale=True,
                              cmin=-5,
                              cmax=5,
                              colorbar=dict(thickness=25, tickvals=[-5, 5], ticktext=['Start','End'], outlinewidth=0)
                          ),
                          hoverinfo='none'
                        )
    fig['layout']['showlegend'] = False
    fig.add_trace(colorbar_trace)
    fig.update_layout(
    scene=dict(
        xaxis_title='',
        yaxis_title='',
        zaxis_title='',
        aspectratio = dict(x=1,y=1,z=1),
        aspectmode = 'manual',
        xaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        yaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        zaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),),
    )
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
    return fig.show()

def get_all_edges(curve):
    edges=[]
    for i in range(1,len(curve)):
        edges.append([curve[i-1],curve[i]])
    return edges

def split(word):
    return [char for char in word]

def get_sses(ss):
    sses = []
    count = 1
    i = 0
    while i<len(ss)-1:
        if ss[i+1] == ss[i]:
            count += 1
            i += 1
        else:
            sses.append([ss[i], count])
            count = 1
            i += 1
    sses.append([ss[-1], count])
    return sses



def intersect_line_triangle(q1,q2,p1,p2,p3):
    def signed_tetra_volume(a,b,c,d):
        return np.sign(np.dot(np.cross(b-a,c-a),d-a)/6.0)

    s1 = signed_tetra_volume(q1,p1,p2,p3)
    s2 = signed_tetra_volume(q2,p1,p2,p3)

    if s1 != s2:
        s3 = signed_tetra_volume(q1,q2,p1,p2)
        s4 = signed_tetra_volume(q1,q2,p2,p3)
        s5 = signed_tetra_volume(q1,q2,p3,p1)
        if s3 == s4 and s4 == s5:
            n = np.cross(p2-p1,p3-p1)
            t = np.dot(p1-q1,n) / np.dot(q2-q1,n)
            return True
    return False

def skmt(pdb_code,chain):
    mol = get_backbone_from_biotite(pdb_code,chain)
    ss = get_ss_from_biotite(pdb_code,chain)
    splitcurve = []
    index = 0
    for i in ss:
        splitcurve.append(mol[index:index+i[1]])
        index+=i[1]
    newcurve = []
    for i in range(len(splitcurve)):
        for j in range(len(splitcurve[i])):
            newcurve.append(splitcurve[i][j])
    for subsec in range(len(splitcurve)):
        if len(splitcurve[subsec])>2:
            checks = []
            for idx in range(1,len(splitcurve[subsec])-1):
                p1 = 2*splitcurve[subsec][0]-splitcurve[subsec][1]
                p2 = splitcurve[subsec][idx]
                p3 = 2*splitcurve[subsec][-1]-splitcurve[subsec][-2]
                for edge in get_all_edges(newcurve):
                    q0 = edge[0]
                    q1 = edge[1]
                    checks.append(intersect_line_triangle(q0,q1,p1,p2,p3))
            if not any(checks):
                splitcurve[subsec] = [splitcurve[subsec][0]]
                newcurve = []
                for l in range(len(splitcurve)):
                    for m in range(len(splitcurve[l])):
                        newcurve.append(splitcurve[l][m])
            else:
                idx=2
                while idx<len(splitcurve[subsec]):
                    newcurve = []
                    for i in range(len(splitcurve)):
                        for j in range(len(splitcurve[i])):
                            newcurve.append(splitcurve[i][j])
                    p1 = splitcurve[subsec][idx-2]
                    p2 = splitcurve[subsec][idx-1]
                    p3 = splitcurve[subsec][idx]
                    checks = []
                    for edge in get_all_edges(newcurve):
                        q0 = edge[0]
                        q1 = edge[1]
                        checks.append(intersect_line_triangle(q0,q1,p1,p2,p3))
                    if not any(checks):
                        splitcurve[subsec] = np.delete(splitcurve[subsec],idx-1,axis=0)
                        idx=2
                    else:
                        idx+=1
        else:
            splitcurve[subsec] = [splitcurve[subsec][0]]
            newcurve = []
            for l in range(len(splitcurve)):
                for m in range(len(splitcurve[l])):
                    newcurve.append(splitcurve[l][m])
    newcurve = []
    for i in range(len(splitcurve)):
        for j in range(len(splitcurve[i])):
          newcurve.append(splitcurve[i][j])
    if not np.array_equal(newcurve[-1],mol[-1]):
        newcurve.append(mol[-1])
    return newcurve

def write_curve_to_file(curve,outfile_name):
    with open(outfile_name,'w+') as f:
        for i in range(len(curve)-1):
            string = ' '.join(map(str,curve[i]))
            f.write(string)
            f.write('\n')
        f.write(' '.join(map(str,curve[-1])))
        f.close()

def view_molecule_subset(molecule,start,end):
    mol = np.genfromtxt(molecule)
    xs = mol[:,0][start:end]
    ys = mol[:,1][start:end]
    zs = mol[:,2][start:end]
    fig = go.Figure(data=go.Scatter3d(
        x=xs, y=ys, z=zs,opacity=0.9,
        marker=dict(
            size=1,
            color=[i*100/len(xs) for i in range(len(xs))],
            colorscale='Rainbow'
        ),
        line=dict(
            width=10,
            color=[i*100/len(xs) for i in range(len(xs))],
            colorscale='Rainbow'
        ),))
    colorbar_trace = go.Scatter3d(x=[None],
                          y=[None], z=[None],
                          mode='markers',
                          marker=dict(
                              colorscale='Rainbow',
                              showscale=True,
                              cmin=-5,
                              cmax=5,
                              colorbar=dict(thickness=25, tickvals=[-5, 5], ticktext=['Start','End'], outlinewidth=0)
                          ),
                          hoverinfo='none'
                        )
    fig['layout']['showlegend'] = False
    fig.add_trace(colorbar_trace)
    fig.update_layout(width=1250,height=1000)
    fig.update_layout(
    scene=dict(
        xaxis_title='',
        yaxis_title='',
        zaxis_title='',
        aspectratio = dict( x=1, y=1, z=1 ),
        aspectmode = 'manual',
        xaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        yaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        zaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),),
    )
    return fig.show()

def find_helical_sections(writhe_file):
    DI = np.genfromtxt(writhe_file)[:len(np.genfromtxt(writhe_file[:-12]+'.xyz'))-4]
    x1 = DI[:,0]
    x2 = DI[:,1]
    x=x2-x1+1
    y = DI[:,2]
    hels = []
    for i in range(len(x)-7):
        for j in range(i+7,len(x)):
            if abs((y[j]-y[i])/(x[j]-x[i])) > 0.1:
                hels.append([x[i],x[j],y[j]-y[i]])
    ranges = [[i[0],i[1]] for i in hels]
    sizes = []
    for i in ranges:
        sizes.append(i[1]-i[0])
    largest=[]
    for i in range(len(sizes)):
        if sizes[i] == max(sizes):
            largest.append(i)
    res=[]
    for i in largest:
      res.append(ranges[i])
    return res

def view_molecule_helical(molecule,DI,res):
    colors = px.colors.sequential.dense
    colspace = np.linspace(0,10,len(res)+2)[1:-1]
    mol = np.genfromtxt(molecule)
    x = mol[:,0]
    y = mol[:,1]
    z = mol[:,2]
    cols = ['black' for i in range(len(x))]
    for i in range(len(res)):
        stindex = np.where(DI[:,1] == res[i][0])[0][0]
        endex = np.where(DI[:,1]==res[i][1])[0][0]
        for j in range(stindex,endex):
          cols[j] = colors[int(colspace[i])]
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(
        x=x, y=y, z=z,
        marker=dict(
            size=1,
            color=cols
        ),
        line=dict(
            width=15,
            color=cols
        ),))
    fig['layout']['showlegend'] = False
    fig.update_layout(
    scene=dict(
        xaxis_title='',
        yaxis_title='',
        zaxis_title='',
        aspectratio = dict( x=1, y=1, z=1 ),
        aspectmode = 'manual',
        xaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        yaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        zaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),),
    )
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
    return fig.show()

def view_molecule_roadie(molecule,DI,res):
    colors = px.colors.sequential.dense
    colspace = np.linspace(0,10,len(res)+2)[1:-1]
    mol = np.genfromtxt(molecule)
    x = mol[:,0]
    y = mol[:,1]
    z = mol[:,2]
    cols = ['black' for i in range(len(x))]
    for i in range(len(res)):
        stindex = np.where(DI[:,1] == res[i][0]+4)[0][0]
        endex = np.where(DI[:,1]==res[i][1])[0][0]
        for j in range(stindex,endex):
          cols[j] = colors[int(colspace[i])]
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(
        x=x, y=y, z=z,
        marker=dict(
            size=1,
            color=cols
        ),
        line=dict(
            width=15,
            color=cols
        ),))
    fig['layout']['showlegend'] = False
    fig.update_layout(
    scene=dict(
        xaxis_title='',
        yaxis_title='',
        zaxis_title='',
        aspectratio = dict( x=1, y=1, z=1 ),
        aspectmode = 'manual',
        xaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        yaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        zaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),),
    )
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
    return fig.show()



def calculate_writhe(curveFile):
    #open that file into numpy array
    coords = np.loadtxt(curveFile)
    #make a list of subsets of this curve
    size=int(np.size(coords)/3)
    # we need at least 5 points for a reasonable writhe make
    workingDirectory=os.getcwd()
    os.mkdir(workingDirectory+"/tmp")
    f3 = open(curveFile[:-4]+'_writhes.dat', 'w+')
    for i in range(0,size-4):
        for j in range(i+5,size+1):
            subsets=coords[i:j]
            # write it to file
            subfile=("tmp/subcurve{:d}.dat".format(i-4))
            np.savetxt(subfile,subsets[:])
            fullloc3 ="/content/src/src/DIwr "+subfile
            f3.write(str(i+1)+" " +str(j)+" " + os.popen(fullloc3).read()+"\n")
            os.remove(subfile)
    shutil.rmtree(workingDirectory+"/tmp")
    f3.close()

def calculate_abs_writhe(curveFile):
    #open that file into numpy array
    coords = np.loadtxt(curveFile)
    #make a list of subsets of this curve
    size=int(np.size(coords)/3)
    # we need at least 5 points for a reasonable writhe make
    workingDirectory=os.getcwd()
    os.mkdir(workingDirectory+"/tmp")
    f3 = open(curveFile[:-4]+'_abswrithes.dat', 'w+')
    for i in range(0,size-4):
        for j in range(i+5,size+1):
            subsets=coords[i:j]
            # write it to file
            subfile=("tmp/subcurve{:d}.dat".format(i-4))
            np.savetxt(subfile,subsets[:])
            fullloc3 ="/content/src/src/AbsDI "+subfile
            f3.write(str(i+1)+" " +str(j)+" " + os.popen(fullloc3).read()+"\n")
            os.remove(subfile)
    shutil.rmtree(workingDirectory+"/tmp")
    f3.close()

def pdb_to_fasta(pdb_file_loc,chain):
    aa3to1={
   'ALA':'A', 'VAL':'V', 'PHE':'F', 'PRO':'P', 'MET':'M',
   'ILE':'I', 'LEU':'L', 'ASP':'D', 'GLU':'E', 'LYS':'K',
   'ARG':'R', 'SER':'S', 'THR':'T', 'TYR':'Y', 'HIS':'H',
   'CYS':'C', 'ASN':'N', 'GLN':'Q', 'TRP':'W', 'GLY':'G',
   'MSE':'M', 'HID': 'H', 'HIP': 'H'
    }
    ca_pattern=re.compile("^ATOM\s{2,6}\d{1,5}\s{2}CA\s[\sA]([A-Z]{3})\s([\s\w])|^HETATM\s{0,4}\d{1,5}\s{2}CA\s[\sA](MSE)\s([\s\w])")
    filename=os.path.basename(pdb_file_loc).split('.')[0]
    chain_dict=dict()
    chain_list=[]
    fp=open(pdb_file_loc,'r')
    for line in fp.read().splitlines():
        if line.startswith("ENDMDL"):
            break
        match_list=ca_pattern.findall(line)
        if match_list:
            resn=match_list[0][0]+match_list[0][2]
            chain=match_list[0][1]+match_list[0][3]
            if chain in chain_dict:
                chain_dict[chain]+=aa3to1[resn]
            else:
                chain_dict[chain]=aa3to1[resn]
                chain_list.append(chain)
    fp.close()
    with open(pdb_file_loc[:-4]+'.fasta','w+') as fout:
      fout.write(chain_dict[chain])

def get_ss_fp_psipred(fasta_file_loc):
  dssp_to_simp = {"I" : "H",
                 "S" : "-",
                 "H" : "H",
                 "E" : "S",
                 "G" : "H",
                 "B" : "S",
                 "T" : "-",
                 "C" : "-"
                 }
  lines = []
  with open(fasta_file_loc+' output/'+os.path.splitext(os.path.basename(fasta_file_loc))[0]+'.ss','r') as fin:
    for line in fin:
      lines.append(line.split())
  ss = [dssp_to_simp[i[2]] for i in lines]
  return get_sses(simple_ss_clean(ss))


def skmt_upload(pdb_file_loc,chain):
    mol = np.genfromtxt(pdb_file_loc[:-4]+'.xyz')
    ss = get_ss_fp_psipred(pdb_file_loc[:-4]+'.fasta')
    splitcurve = []
    index = 0
    for i in ss:
        splitcurve.append(mol[index:index+i[1]])
        index+=i[1]
    newcurve = []
    for i in range(len(splitcurve)):
        for j in range(len(splitcurve[i])):
            newcurve.append(splitcurve[i][j])
    for subsec in range(len(splitcurve)):
        if len(splitcurve[subsec])>2:
            checks = []
            for idx in range(1,len(splitcurve[subsec])-1):
                p1 = splitcurve[subsec][0]
                p2 = splitcurve[subsec][idx]
                p3 = splitcurve[subsec][-1]
                for edge in get_all_edges(newcurve):
                    q0 = edge[0]
                    q1 = edge[1]
                    checks.append(intersect_line_triangle(q0,q1,p1,p2,p3))
            if not any(checks):
                splitcurve[subsec] = [splitcurve[subsec][0]]
                newcurve = []
                for l in range(len(splitcurve)):
                    for m in range(len(splitcurve[l])):
                        newcurve.append(splitcurve[l][m])
            else:
                idx=2
                while idx<len(splitcurve[subsec]):
                    newcurve = []
                    for i in range(len(splitcurve)):
                        for j in range(len(splitcurve[i])):
                            newcurve.append(splitcurve[i][j])
                    p1 = splitcurve[subsec][idx-2]
                    p2 = splitcurve[subsec][idx-1]
                    p3 = splitcurve[subsec][idx]
                    checks = []
                    for edge in get_all_edges(newcurve):
                        q0 = edge[0]
                        q1 = edge[1]
                        checks.append(intersect_line_triangle(q0,q1,p1,p2,p3))
                    if not any(checks):
                        splitcurve[subsec] = np.delete(splitcurve[subsec],idx-1,axis=0)
                        idx=2
                    else:
                        idx+=1
        else:
            splitcurve[subsec] = [splitcurve[subsec][0]]
            newcurve = []
            for l in range(len(splitcurve)):
                for m in range(len(splitcurve[l])):
                    newcurve.append(splitcurve[l][m])
    newcurve = []
    for i in range(len(splitcurve)):
        for j in range(len(splitcurve[i])):
          newcurve.append(splitcurve[i][j])
    if not np.array_equal(newcurve[-1],mol[-1]):
        newcurve.append(mol[-1])
    return newcurve

def sim(wr1,wr2,i,j,k,l):
    assert j-i == l-k
    assert j-i>=4
    sum_elements = []
    for idx in range(len(wr1)):
      if wr1[idx][0]==i:
        idx1=idx
        break
    for idx in range(len(wr2)):
      if wr2[idx][0]==k:
        idx2=idx
        break
    for m in range(0,j-i-3):
      sum_elements.append(abs(wr1[idx1+m][2]-wr2[idx2+m][2])/(abs(wr1[idx1+m][2])+abs(wr2[idx2+m][2])))
    return 1-sum(sum_elements)/(j-i)


def find_sim_sections(wrfile1,wrfile2):
    wr1 = np.genfromtxt(wrfile1)
    wr2 = np.genfromtxt(wrfile2)
    sim_sections=[]
    lenwr1 = int(wr1[-1][1])
    lenwr2 = int(wr2[-1][1])
    sim_scores = []
    for i in range(1,lenwr1-7):
      for j in range(i+8,lenwr1+1):
        for k in range(1,lenwr2-7):
          for l in range(k+9,lenwr2+1):
            if j-i == l-k:
              s = sim(wr1,wr2,i,j,k,l)
              if s>0.85:
                sim_sections.append([i,j,k,l,s])
    return sim_sections

def find_longest_sim_sections(wrfile1,wrfile2):
    sim_sections = find_sim_sections(wrfile1,wrfile2)
    sim_lengths=[]
    for i in sim_sections:
      sim_lengths.append(i[1]-i[0])
    max_sim_length = max(sim_lengths)
    longest_sim_sections=[]
    for i in sim_sections:
      if i[1]-i[0] == max_sim_length:
        longest_sim_sections.append(i)
    max_sim_score = max([i[-1] for i in longest_sim_sections])
    idx = np.where([i[-1] for i in longest_sim_sections]==max_sim_score)[0][0]
    return longest_sim_sections[idx]

def plot_longest_sim_sections(wrfile1,wrfile2):
    sim = find_longest_sim_sections(wrfile1,wrfile2)
    wr1 = np.genfromtxt(wrfile1)
    wr2 = np.genfromtxt(wrfile2)
    x1 = wr1[:int(wr1[-1][1]-4)][:,1]
    y1 = wr1[:int(wr1[-1][1]-4)][:,2]
    x2 = wr2[:int(wr2[-1][1]-4)][:,1]
    y2 = wr2[:int(wr2[-1][1]-4)][:,2]
    fig=go.Figure()
    fig.add_trace(go.Scatter(x=x1,
                             y=y1,
                             name=wrfile1.split('_')[1],
                             legendgroup=wrfile1.split('_')[1],
                             mode='lines',
                             marker=dict(color='firebrick',size=10),
                             line=dict(width=5)))
    if sim[0]<=x1[0]:
      fig.add_trace(go.Scatter(x=x1[:np.where(x1==sim[1])[0][0]],
                               y=y1[:np.where(x1==sim[1])[0][0]],
                               name=wrfile1.split('_')[1],
                               legendgroup=wrfile1.split('_')[1],
                               showlegend=False,
                               mode='lines',
                               marker=dict(color='black',size=10),
                               line=dict(width=5)))
    else:
      fig.add_trace(go.Scatter(x=x1[np.where(x1==sim[0])[0][0]:np.where(x1==sim[1])[0][0]],
                               y=y1[np.where(x1==sim[0])[0][0]:np.where(x1==sim[1])[0][0]],
                               name=wrfile1.split('_')[1],
                               legendgroup=wrfile1.split('_')[1],
                               showlegend=False,
                               mode='lines',
                               marker=dict(color='black',size=10),
                               line=dict(width=5)))
    fig.add_trace(go.Scatter(x=x2,
                             y=y2,
                             name=wrfile2.split('_')[1],
                             legendgroup=wrfile2.split('_')[1],
                             mode='lines',
                             marker=dict(color='royalblue',size=10),
                             line=dict(width=5)))
    if sim[2]<=x2[0]:
      fig.add_trace(go.Scatter(x=x2[:np.where(x2==sim[3])[0][0]],
                               y=y2[:np.where(x2==sim[3])[0][0]],
                               name=wrfile2.split('_')[1],
                               legendgroup=wrfile2.split('_')[1],
                               showlegend=False,
                               mode='lines',
                               marker=dict(color='black',size=10),
                               line=dict(width=5)))
    else:
      fig.add_trace(go.Scatter(x=x2[np.where(x2==sim[2])[0][0]:np.where(x2==sim[3])[0][0]],
                               y=y2[np.where(x2==sim[2])[0][0]:np.where(x2==sim[3])[0][0]],
                               name=wrfile2.split('_')[1],
                               legendgroup=wrfile2.split('_')[1],
                               showlegend=False,
                               mode='lines',
                               marker=dict(color='black',size=10),
                               line=dict(width=5)))
    fig.update_layout(autosize=False,
                      width=1000,
                      height=0.75*1000)
    fig.update_layout(font_family="Tenorite",
                      font_color="black",
                      title_font_family="Tenorite",
                      title_font_color="black",
                      legend_title_font_color="black",
                      xaxis_title="Subsection Length",
                      yaxis_title="Writhe",
                      font=dict(size=16)
    )
    fig.update_layout(template='simple_white')
    fig.show()

def measure_roadieness(wr_list_in):
    wr_list_whole = [item for item in wr_list_in if item[1] > item[0]]
    length = len(wr_list_whole)
    result = []
    for i in range(2, length):
        wr_list = wr_list_whole[:i]
        if len(wr_list) > 1:
            wr_list = [item[2] for item in wr_list]
            max_val = max(map(abs, wr_list))
            end_diff = wr_list[-1] - wr_list[0]
            result.append([wr_list_whole[i][0], wr_list_whole[i][1], max_val, end_diff])
        else:
            result.append([0, 0, 0.0, 0.0])
    return result

def retrieve_clusters(clustered_list):
  no_noise = [i for i in clustered_list if i[1]!=-1]
  res = defaultdict(list)
  for v, k in no_noise: res[k].append(v)
  returns = []
  for key in res.keys():
    returns.append(res[key][-1])
  max_length = max([i[1]-i[0] for i in returns])
  return [i for i in returns if i[1]-i[0]==max_length]

def find_roadie_sections(fp):
    rtests = measure_roadieness(np.genfromtxt(fp))
    pos = [i for i, item in enumerate(rtests) if item[2] > 0.95 and abs(item[3]) < 0.05]
    if len(pos) > 0:
        potential_secs = [rtests[i] for i in pos]
        uppers = list(set([item[1] for item in potential_secs]))
        bound_pairs = []
        for upper in uppers:
            max_lower = max([item[0] for item in potential_secs if item[1] == upper])
            bound_pairs.append([max_lower, upper])
        if len(bound_pairs) > 1:
            clus = [cluster for cluster in bound_pairs if cluster[1] - cluster[0] > 4]
            clustering = DBSCAN(eps=3, min_samples=1).fit_predict(clus)
            clustered = zip(clus,clustering)
            return retrieve_clusters(list(clustered))
        else:
            return []
    else:
        return []

In [None]:
%%capture
#@title Download PDB File and FASTA
#@markdown Leave chain_id blank if unsure, this will take the first chain in the PDB file.
pdb_code = "" #@param {type:"string"}
chain_id = "" #@param {type:"string"}
if chain_id == "":
    chain_id = get_chains_from_biotite(pdb_code)[0]
if not os.path.exists('/content/'+pdb_code.lower()):
    os.mkdir('/content/'+pdb_code.lower())
urllib.request.urlretrieve('http://files.rcsb.org/download/'+pdb_code.lower()+'.pdb', '/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.pdb')

In [None]:
#@title Extract backbone
pdb_code = "" #@param {type:"string"}
chain_id = "" #@param {type:"string"}
write_curve_to_file(get_backbone_from_biotite(pdb_code,chain_id),'/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.xyz')

In [None]:
#@title Smooth Backbone using the SKMT method
pdb_code = "" #@param {type:"string"}
chain_id = "" #@param {type:"string"}
write_curve_to_file(skmt(pdb_code.lower(),chain_id),'/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz')

In [None]:
#@title Calculate Writhe Profile of Smoothed Backbone (Recommended)
pdb_code = "" #@param {type:"string"}
curveFile='/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'.xyz'
calculate_writhe(curveFile)

In [None]:
#@title Calculate acn Profile of Smoothed Backbone (Recommended)
pdb_code = "" #@param {type:"string"}
curveFile='/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'.xyz'
calculate_abs_writhe(curveFile)

In [None]:
#@title Plot Writhe Profile
pdb_code = "" #@param {type:"string"}
colors = px.colors.sequential.dense
highlight_helical_subsections = False #@param {type:"boolean"}
highlight_roadie_subsections = False #@param {type:"boolean"}
DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')[:len(np.genfromtxt('/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz'))-4]
if highlight_helical_subsections:
    try:
        res1 = find_helical_sections('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')
        colspace = np.linspace(0,10,len(res1)+2)[1:-1]
    except:
        print('No helical subsections')
        highlight_helical_subsections=False
if highlight_roadie_subsections:
    try:
        res2 = find_roadie_sections('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')
        colspace = np.linspace(0,10,len(res2)+2)[1:-1]
    except:
        print('No roadie subsections')
        highlight_helical_subsections=False
x1 = DI[:,0]
x2 = DI[:,1]
x = x2-x1+1
y = DI[:,2]
fig=go.Figure()
fig.add_trace(go.Scatter(x=x,y=y,mode='lines',name=pdb_code.upper(),
                         marker=dict(color='black',size=10),
                         line=dict(width=5)))
if highlight_helical_subsections:
    for i in range(len(res1)):
        stindex = np.where(x == res1[i][0])[0][0]
        endex = np.where(x==res1[i][1])[0][0]
        fig.add_trace(go.Scatter(x=x[stindex:endex],y=y[stindex:endex],
                                  mode='lines',
                                  name='Helical Subsection '+str(i+1),
                                  showlegend=True,
                                  marker=dict(color=colors[int(colspace[i])],
                                              size=10),
                                  line=dict(width=7.5)
                                  )
        )
if highlight_roadie_subsections:
    for i in range(len(res2)):
        stindex = np.where(x == res2[i][0]+4)[0][0]
        endex = np.where(x==res2[i][1])[0][0]
        fig.add_trace(go.Scatter(x=x[stindex:endex],y=y[stindex:endex],
                                  mode='lines',
                                  name='Roadie Subsection '+str(i+1),
                                  showlegend=True,
                                  marker=dict(color=colors[int(colspace[i])],
                                              size=10),
                                  line=dict(width=7.5)
                                  )
        )
fig.update_layout(
    autosize=False,
    width=1000,
    height=0.75*1000)
fig.update_layout(
    font_family="Tenorite",
    font_color="black",
    title_font_family="Tenorite",
    title_font_color="black",
    legend_title_font_color="black",
    xaxis_title="Subsection Length",
    yaxis_title="Writhe",
    font=dict(size=16)
)
fig.update_layout(template='simple_white')
fig.show()

In [None]:
#@title Plot acn Profile
pdb_code = "" #@param {type:"string"}
colors = px.colors.sequential.dense
DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_abswrithes.dat')
x = DI[:,0]
y = DI[:,1]
fig=go.Figure()
fig.add_trace(go.Scatter(x=x,y=y,mode='lines',name=pdb_code.upper(),
                         marker=dict(color='black',size=10),
                         line=dict(width=5)))
fig.update_layout(
    autosize=False,
    width=1000,
    height=0.75*1000)
fig.update_layout(
    font_family="Tenorite",
    font_color="black",
    title_font_family="Tenorite",
    title_font_color="black",
    legend_title_font_color="black",
    xaxis_title="Subsection Length",
    yaxis_title="acn",
    font=dict(size=16)
)
fig.update_layout(template='simple_white')
fig.show()

In [None]:
#@title View Smoothed Molecule
pdb_code = "" #@param {type:"string"}
highlight_helical_subsections = False #@param {type:"boolean"}
highlight_roadie_subsections = False #@param {type:"boolean"}
if highlight_helical_subsections:
    DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')[:len(np.genfromtxt('/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz'))-4]
    res = find_helical_sections('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')
    view_molecule_helical('/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz',DI,res)
if highlight_roadie_subsections:
    DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')[:len(np.genfromtxt('/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz'))-4]
    res = find_roadie_sections('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')
    view_molecule_roadie('/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz',DI,res)
else:
    view_molecule('/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz')

In [None]:
#@title View Molecule
pdb_code = "" #@param {type:"string"}
view_molecule('/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.xyz')

In [None]:
#@title Compare writhe profiles using Smax
#@markdown Make sure you've ran the above cells to process each PDB file and compute its writhe profile before running this one.
pdb_code1 = "" #@param {type:"string"}
pdb_code2 = "" #@param {type:"string"}
plot_longest_sim_sections('/content/'+pdb_code1.lower()+'/smooth_'+pdb_code1.lower()+'_writhes.dat',
                          '/content/'+pdb_code2.lower()+'/smooth_'+pdb_code2.lower()+'_writhes.dat')

In [None]:
#@title Calculate Writhe Profile of Full Backbone (Noisy and Slow)
pdb_code = "" #@param {type:"string"}
curveFile='/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.xyz'
calculate_writhe(curveFile)


In [None]:
#@title Plot Full Writhe Profile
#@title Plot Writhe Profile
pdb_code = "" #@param {type:"string"}
colors = px.colors.sequential.dense
DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'_writhes.dat')[:len(np.genfromtxt('/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.xyz'))-4]
x1 = DI[:,0]
x2 = DI[:,1]
x = x2-x1
y = DI[:,2]
fig=go.Figure()
fig.add_trace(go.Scatter(x=x,y=y,mode='lines',name=pdb_code.upper(),
                         marker=dict(color='black',size=10),
                         line=dict(width=5)))
fig.update_layout(
    autosize=False,
    width=1000,
    height=0.75*1000)
fig.update_layout(
    font_family="Tenorite",
    font_color="black",
    title_font_family="Tenorite",
    title_font_color="black",
    legend_title_font_color="black",
    xaxis_title="Subsection Length",
    yaxis_title="Writhe",
    font=dict(size=16)
)
fig.update_layout(template='simple_white')
fig.show()

In [None]:
#@title Process your own PDB file
#@markdown First, using the file explorer on the left, upload your own PDB file. Then enter its location below and press the play button.
#@markdown This will retrieve the backbone curve and secondary structure prediction using PSIPRED<sup>[1]</sup>.
pdb_file_loc = "" #@param {type:"string"}
chain_id = "" #@param {type:"string"}
fr = open(pdb_file_loc, 'r')
fw = open('/content/'+os.path.basename(pdb_file_loc)[:-4]+'CA.dat','w+')
for record in fr:
    if(re.search(r'^ATOM\s+\d+\s+CA\s+', record)):
        fw.write(record)
fr.close()
fw.close()
parser = PDBParser(PERMISSIVE=False,QUIET=False)
structure = parser.get_structure(os.path.basename(pdb_file_loc)[:-4],'/content/'+os.path.basename(pdb_file_loc)[:-4]+'CA.dat')
XYZ=[]
for model in structure:
    chains = model.get_chains()
    ids = []
    for chain in chains:
        ids.append(chain.get_id())
    chain = model[ids[0]]
    for residue in chain:
        for atom in residue:
            XYZ.append(atom.get_coord())
for i in range(len(XYZ)):
    XYZ[i] = XYZ[i].tolist()
    for j in range(3):
        XYZ[i][j] = float(XYZ[i][j])
write_curve_to_file(XYZ,'/content/'+os.path.basename(pdb_file_loc)[:-4]+'.xyz')
os.remove('/content/'+os.path.basename(pdb_file_loc)[:-4]+'CA.dat')
pdb_to_fasta(pdb_file_loc,chain_id)
single_submit(pdb_file_loc[:-4]+'.fasta', "foo@bar.com", pdb_file_loc)
write_curve_to_file(skmt_upload(pdb_file_loc,chain_id),pdb_file_loc[:-4]+'_smooth.xyz')
#shutil.rmtree(pdb_file_loc[:-4]+'.fasta output')

In [None]:
#@title Calculate Writhe Profile of your Smoothed Backbone
pdb_file_loc = "" #@param {type:"string"}
curveFile=pdb_file_loc[:-4]+'_smooth.xyz'
calculate_writhe(curveFile)

In [None]:
#@title Plot Writhe Profile
colors = px.colors.sequential.dense
pdb_file_loc = "" #@param {type:"string"}
highlight_helical_subsections = False #@param {type:"boolean"}
highlight_roadie_subsections = False #@param {type:"boolean"}
DI = np.genfromtxt(pdb_file_loc[:-4]+'_smooth_writhes.dat')[:len(np.genfromtxt(pdb_file_loc[:-4]+'_smooth.xyz'))-4]
if highlight_helical_subsections:
    try:
        res1 = find_helical_sections(pdb_file_loc[:-4]+'_smooth_writhes.dat')
        colspace = np.linspace(0,10,len(res1)+2)[1:-1]
    except:
        print('No helical subsections')
        highlight_helical_subsections=False
if highlight_roadie_subsections:
    try:
        res2 = find_roadie_sections(pdb_file_loc[:-4]+'_smooth_writhes.dat')
        colspace = np.linspace(0,10,len(res2)+2)[1:-1]
    except:
        print('No roadie subsections')
        highlight_roadie_subsections=False
x1 = DI[:,0]
x2 = DI[:,1]
x = x2-x1+1
y = DI[:,2]
fig=go.Figure()
fig.add_trace(go.Scatter(x=x,y=y,mode='lines',name=os.path.basename(pdb_file_loc)[:-4],
                         marker=dict(color='black',size=10),
                         line=dict(width=5)))
if highlight_helical_subsections:
    for i in range(len(res1)):
        stindex = np.where(x == res1[i][0])[0][0]
        endex = np.where(x==res1[i][1])[0][0]
        fig.add_trace(go.Scatter(x=x[stindex:endex],y=y[stindex:endex],
                                  mode='lines',
                                  name='Helical Subsection '+str(i+1),
                                  showlegend=True,
                                  marker=dict(color=colors[int(colspace[i])],
                                              size=10),
                                  line=dict(width=7.5)
                                  )
        )
if highlight_roadie_subsections:
    for i in range(len(res2)):
        stindex = np.where(x == res2[i][0]+4)[0][0]
        endex = np.where(x==res2[i][1])[0][0]
        fig.add_trace(go.Scatter(x=x[stindex:endex],y=y[stindex:endex],
                                  mode='lines',
                                  name='Roadie Subsection '+str(i+1),
                                  showlegend=True,
                                  marker=dict(color=colors[int(colspace[i])],
                                              size=10),
                                  line=dict(width=7.5)
                                  )
        )
fig.update_layout(
    autosize=False,
    width=1000,
    height=0.75*1000)
fig.update_layout(
    font_family="Tenorite",
    font_color="black",
    title_font_family="Tenorite",
    title_font_color="black",
    legend_title_font_color="black",
    xaxis_title="Subsection Length",
    yaxis_title="Writhe",
    font=dict(size=16)
)
fig.update_layout(template='simple_white')
fig.show()

In [None]:
#@title View Smoothed Molecule
pdb_file_loc = "" #@param {type:"string"}
highlight_helical_subsections = False #@param {type:"boolean"}
highlight_roadie_subsections = False #@param {type:"boolean"}
if highlight_helical_subsections:
  DI = np.genfromtxt(pdb_file_loc[:-4]+'_smooth_writhes.dat')[:len(np.genfromtxt(pdb_file_loc[:-4]+'_smooth.xyz'))-4]
  res = find_helical_sections(pdb_file_loc[:-4]+'_smooth_writhes.dat')
  view_molecule_helical('/content/'+os.path.basename(pdb_file_loc)[:-4]+'_smooth.xyz',DI,res)
if highlight_roadie_subsections:
  DI = np.genfromtxt(pdb_file_loc[:-4]+'_smooth_writhes.dat')[:len(np.genfromtxt(pdb_file_loc[:-4]+'_smooth.xyz'))-4]
  res = find_roadie_sections(pdb_file_loc[:-4]+'_smooth_writhes.dat')
  view_molecule_helical('/content/'+os.path.basename(pdb_file_loc)[:-4]+'_smooth.xyz',DI,res)
else:
  view_molecule('/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'.xyz')

<sup>[1]</sup>**The PSIPRED Server** Buchan DWA, Jones DT (2019). The PSIPRED Protein Analysis Workbench: 20 years on. Nucleic Acids Research. https://doi.org/10.1093/nar/gkz297