<a href="https://colab.research.google.com/github/arronelab/writheCalculator/blob/main/writheCalculator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title Clone Repo and Compile { display-mode: "form" }
%%capture
#@markdown Please execute this cell by pressing the *Play* button on 
#@markdown the left.
%cd /content/
%shell rm -rf src
!git clone --depth 1 https://github.com/arronelab/writheCalculator.git src
%cd src/src/
%shell sh makeFileAbs.sh
%shell sh makeFileDI.sh
%cd /content/

!pip install biotite
!pip install Bio
import urllib
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import os, re, shutil
from tempfile import gettempdir
import biotite
import biotite.structure as struc
import biotite.structure.io.mmtf as mmtf
import biotite.sequence as seq
import biotite.database.rcsb as rcsb
from Bio.PDB import *
from scipy import interpolate

def get_chains_from_biotite(pdb_code):
    # Fetch and load structure
    file_name = rcsb.fetch(pdb_code.upper(), "mmtf", gettempdir())
    mmtf_file = mmtf.MMTFFile.read(file_name)
    array = mmtf.get_structure(mmtf_file, model=1)
    array = array[struc.filter_amino_acids(array)]
    return struc.get_chains(array)

def get_ss_from_biotite(pdb_code,chain):
    d = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
         'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
         'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
         'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
    # Dictionary to convert 'secStructList' codes to DSSP values
    # https://github.com/rcsb/mmtf/blob/master/spec.md#secstructlist
    sec_struct_codes = {0 : "I",
                        1 : "S",
                        2 : "H",
                        3 : "E",
                        4 : "G",
                        5 : "B",
                        6 : "T",
                        7 : "C",
                       -1 : "C"}
    # Converter for the DSSP secondary structure elements
    # to the classical ones
    dssp_to_abc = {"I" : "H",
                  "S" : "-",
                  "H" : "H",
                  "E" : "S",
                  "G" : "H",
                  "B" : "S",
                  "T" : "-",
                  "C" : "-"}



    # Fetch and load structure
    file_name = rcsb.fetch(pdb_code.upper(), "mmtf", gettempdir())
    mmtf_file = mmtf.MMTFFile.read(file_name)
    array = mmtf.get_structure(mmtf_file, model=1)
    array = array[struc.filter_amino_acids(array)]
    array = array[array.chain_id==chain]
    array = array[array.atom_name=='CA']
    array = array[array.hetero==False]
    array = array[np.array([i.res_name in list(d.keys()) for i in array])]
    chain_id_per_res = array.chain_id[struc.get_residue_starts(array)]
    sse = mmtf_file["secStructList"]
    sse = sse[:len(array)]
    sse = np.array([sec_struct_codes[code] for code in sse],
                  dtype="U1")
    sse = np.array([dssp_to_abc[e] for e in sse], dtype="U1")
    return get_sses(sse)

def get_backbone_from_biotite(pdb_code,chain):
    d = {'CYS': 'C', 'ASP': 'D', 'SER': 'S', 'GLN': 'Q', 'LYS': 'K',
        'ILE': 'I', 'PRO': 'P', 'THR': 'T', 'PHE': 'F', 'ASN': 'N', 
        'GLY': 'G', 'HIS': 'H', 'LEU': 'L', 'ARG': 'R', 'TRP': 'W', 
        'ALA': 'A', 'VAL':'V', 'GLU': 'E', 'TYR': 'Y', 'MET': 'M'}
    # Fetch and load structure
    file_name = rcsb.fetch(pdb_code.upper(), "mmtf", gettempdir())
    mmtf_file = mmtf.MMTFFile.read(file_name)
    array = mmtf.get_structure(mmtf_file, model=1)
    array = array[struc.filter_amino_acids(array)]
    array = array[array.chain_id==chain]
    array = array[array.atom_name=='CA']
    array = array[array.hetero==False]
    array = array[np.array([i.res_name in list(d.keys()) for i in array])]
    coords = [i.coord for i in array]
    return coords

def simplify_sec_struc(SS):
    simplify_dict = {
      'H': 'H',
      'B': 'S',
      'E': 'S',
      'G': 'H',
      'I': 'H',
      'T': '-',
      'S': '-',
      '-': '-',
      'C': '-'
    }
    simp_sec_struc = []
    for i in SS:
      simp_sec_struc.append(simplify_dict[i])
    return simp_sec_struc

def split(word):
    return [char for char in word]

def get_ss_from_fasta(fasta_file):
    lines = []
    with open(fasta_file,'r') as fin:
        for line in fin:
            lines+= [line.split()]
    ss = lines[-1][0]
    ss = split(ss)
    simple_ss = simplify_sec_struc(split(ss))
    return ''.join(simple_ss)

def get_sses(ss):
    sses = []
    count = 1
    i = 0
    while i<len(ss)-1:
        if ss[i+1] == ss[i]:
            count += 1
            i += 1
        else:
            sses.append([ss[i], count])
            count = 1
            i += 1
    sses.append([ss[-1], count])
    return sses

def get_first_chain_fasta(xyz_file,fasta_file):
    ca = np.genfromtxt(xyz_file)
    lines = []
    with open(fasta_file,'r') as fin:
        for line in fin:
            lines+= [line.split()]
    return lines[1][0][:len(ca)]

def overwrite_fasta_file(xyz_file,fasta_file):
      lines=[]
      with open(fasta_file,'r') as fin:
          for line in fin:
              lines+= [line.split()]    
      first_line = ' '.join(lines[0])
      second_line = get_first_chain_fasta(xyz_file,fasta_file)
      with open(fasta_file,'w') as fout:
        fout.write(first_line)
        fout.write('\n')
        fout.write(second_line)

def running_mean(x, N):
    cumsum = np.cumsum(np.insert(x, 0, 0)) 
    return (cumsum[N:] - cumsum[:-N]) / float(N)

def set_axes_radius(ax, origin, radius):
    ax.set_xlim3d([origin[0] - radius, origin[0] + radius])
    ax.set_ylim3d([origin[1] - radius, origin[1] + radius])
    ax.set_zlim3d([origin[2] - radius, origin[2] + radius])


def set_axes_equal(ax, zoom=1.):
    '''
        Make axes of 3D plot have equal scale so that spheres appear as spheres,
        cubes as cubes, etc..  This is one possible solution to Matplotlib's
        ax.set_aspect("equal") and ax.axis("equal") not working for 3D.
        input:
          ax:   a matplotlib axis, e.g., as output from plt.gca().

    '''
    limits = np.array([
        ax.get_xlim3d(),
        ax.get_ylim3d(),
        ax.get_zlim3d(),
    ])

    origin = np.mean(limits, axis=1)
    radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0])) / zoom
    set_axes_radius(ax, origin, radius)

def view_molecule_helical(molecule,DI,res):
    mol = np.genfromtxt(molecule)
    x = mol[:,0]
    y = mol[:,1]
    z = mol[:,2]
    cols = ['blue' for i in range(len(x))]
    for i in range(len(res)):
        stindex = np.where(DI[:,0] == res[i][0])[0][0]
        endex = np.where(DI[:,0]==res[i][1])[0][0]
        for j in range(stindex,endex):
          cols[j] = 'green'
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(
        x=x, y=y, z=z,
        marker=dict(
            size=1,
            color=cols
        ),
        line=dict(
            width=15,
            color=cols
        ),))
    fig['layout']['showlegend'] = False
    fig.update_layout(
    scene=dict(
        xaxis_title='',
        yaxis_title='',
        zaxis_title='',
        aspectratio = dict( x=1, y=1, z=1 ),
        aspectmode = 'manual',
        xaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        yaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        zaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),),
    )
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
    return fig.show()

def view_molecule(molecule):
    mol = np.genfromtxt(molecule)
    x_sample = mol[:,0]
    y_sample = mol[:,1]
    z_sample = mol[:,2]
    tck, u = interpolate.splprep([x_sample,y_sample,z_sample], k=5, s=len(mol)/2)
    x_knots, y_knots, z_knots = interpolate.splev(tck[0], tck)
    u_fine = np.linspace(0,1,len(x_sample))
    x_fine, y_fine, z_fine = interpolate.splev(u, tck)
    fig = go.Figure()
    fig.add_trace(go.Scatter3d(x=x_fine, y=y_fine, z=z_fine,
                               marker=dict(size=1,color='black'),line=dict(width=17.5,color='black')
                               ))
    fig.add_trace(go.Scatter3d(
        x=x_fine, y=y_fine, z=z_fine,
        marker=dict(
            size=1,
            color=[i*100/len(x_fine) for i in range(len(x_fine))],
            colorscale='viridis'
        ),
        line=dict(
            width=15,
            color=[i*100/len(x_fine) for i in range(len(x_fine))],
            colorscale='viridis'
        ),))

    colorbar_trace = go.Scatter3d(x=[None],
                          y=[None], z=[None],
                          mode='markers',
                          marker=dict(
                              colorscale='viridis', 
                              showscale=True,
                              cmin=-5,
                              cmax=5,
                              colorbar=dict(thickness=25, tickvals=[-5, 5], ticktext=['Start','End'], outlinewidth=0)
                          ),
                          hoverinfo='none'
                        )
    fig['layout']['showlegend'] = False
    fig.add_trace(colorbar_trace)
    fig.update_layout(
    scene=dict(
        xaxis_title='',
        yaxis_title='',
        zaxis_title='',
        aspectratio = dict( x=1, y=1, z=1 ),
        aspectmode = 'manual',
        xaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        yaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        zaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),),
    )
    fig.update_layout(margin=dict(l=0, r=0, b=0, t=0))
    return fig.show()

def get_all_edges(curve):
    edges=[]
    for i in range(1,len(curve)):
        edges.append([curve[i-1],curve[i]])
    return edges

def split(word):
    return [char for char in word]

def get_sses(ss):
    sses = []
    count = 1
    i = 0
    while i<len(ss)-1:
        if ss[i+1] == ss[i]:
            count += 1
            i += 1
        else:
            sses.append([ss[i], count])
            count = 1
            i += 1
    sses.append([ss[-1], count])
    return sses

def intersect_line_triangle(q1,q2,p1,p2,p3):
    def signed_tetra_volume(a,b,c,d):
        return np.sign(np.dot(np.cross(b-a,c-a),d-a)/6.0)

    s1 = signed_tetra_volume(q1,p1,p2,p3)
    s2 = signed_tetra_volume(q2,p1,p2,p3)

    if s1 != s2:
        s3 = signed_tetra_volume(q1,q2,p1,p2)
        s4 = signed_tetra_volume(q1,q2,p2,p3)
        s5 = signed_tetra_volume(q1,q2,p3,p1)
        if s3 == s4 and s4 == s5:
            n = np.cross(p2-p1,p3-p1)
            t = np.dot(p1-q1,n) / np.dot(q2-q1,n)
            return True
    return False

def kmt_ss(pdb_code,chain):
    mol = get_backbone_from_biotite(pdb_code,chain)
    ss = get_ss_from_biotite(pdb_code,chain)
    splitcurve = []
    index = 0
    for i in ss:
        splitcurve.append(mol[index:index+i[1]])
        index+=i[1]
    newcurve = []
    for i in range(len(splitcurve)):
        for j in range(len(splitcurve[i])):
            newcurve.append(splitcurve[i][j])
    for subsec in range(len(splitcurve)):
        if len(splitcurve[subsec])>2:
            checks = []
            for idx in range(1,len(splitcurve[subsec])-1):
                p1 = splitcurve[subsec][0]
                p2 = splitcurve[subsec][idx]
                p3 = splitcurve[subsec][-1]
                for edge in get_all_edges(newcurve):
                    q0 = edge[0]
                    q1 = edge[1]
                    checks.append(intersect_line_triangle(q0,q1,p1,p2,p3))
            if not any(checks):
                splitcurve[subsec] = [splitcurve[subsec][0]]
                newcurve = []
                for l in range(len(splitcurve)):
                    for m in range(len(splitcurve[l])):
                        newcurve.append(splitcurve[l][m])
            else:
                idx=2
                while idx<len(splitcurve[subsec]):
                    newcurve = []
                    for i in range(len(splitcurve)):
                        for j in range(len(splitcurve[i])):
                            newcurve.append(splitcurve[i][j])
                    p1 = splitcurve[subsec][idx-2]
                    p2 = splitcurve[subsec][idx-1]
                    p3 = splitcurve[subsec][idx]
                    checks = []
                    for edge in get_all_edges(newcurve):
                        q0 = edge[0]
                        q1 = edge[1]
                        checks.append(intersect_line_triangle(q0,q1,p1,p2,p3))
                    if not any(checks):
                        splitcurve[subsec] = np.delete(splitcurve[subsec],idx-1,axis=0)
                        idx=2
                    else:
                        idx+=1
        else:
            splitcurve[subsec] = [splitcurve[subsec][0]]
            newcurve = []
            for l in range(len(splitcurve)):
                for m in range(len(splitcurve[l])):
                    newcurve.append(splitcurve[l][m])
    newcurve = []
    for i in range(len(splitcurve)):
        for j in range(len(splitcurve[i])):
          newcurve.append(splitcurve[i][j])
    if not np.array_equal(newcurve[-1],mol[-1]):
        newcurve.append(mol[-1])
    return newcurve

def write_curve_to_file(curve,outfile_name):
    with open(outfile_name,'w+') as f:
        for i in range(len(curve)-1):
            string = ' '.join(map(str,curve[i]))
            f.write(string)
            f.write('\n')
        f.write(' '.join(map(str,curve[-1])))
        f.close()

def view_molecule_subset(molecule,start,end):
    mol = np.genfromtxt(molecule)
    xs = mol[:,0][start:end]
    ys = mol[:,1][start:end]
    zs = mol[:,2][start:end]
    fig = go.Figure(data=go.Scatter3d(
        x=xs, y=ys, z=zs,opacity=0.9,
        marker=dict(
            size=1,
            color=[i*100/len(xs) for i in range(len(xs))],
            colorscale='Rainbow'
        ),
        line=dict(
            width=10,
            color=[i*100/len(xs) for i in range(len(xs))],
            colorscale='Rainbow'
        ),))
    colorbar_trace = go.Scatter3d(x=[None],
                          y=[None], z=[None],
                          mode='markers',
                          marker=dict(
                              colorscale='Rainbow', 
                              showscale=True,
                              cmin=-5,
                              cmax=5,
                              colorbar=dict(thickness=25, tickvals=[-5, 5], ticktext=['Start','End'], outlinewidth=0)
                          ),
                          hoverinfo='none'
                        )
    fig['layout']['showlegend'] = False
    fig.add_trace(colorbar_trace)
    fig.update_layout(width=1250,height=1000)
    fig.update_layout(
    scene=dict(
        xaxis_title='',
        yaxis_title='',
        zaxis_title='',
        aspectratio = dict( x=1, y=1, z=1 ),
        aspectmode = 'manual',
        xaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        yaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),
        zaxis = dict(
            gridcolor="white",
            showbackground=False,
            zerolinecolor="white",
            nticks=0,
            showticklabels=False),),
    )
    return fig.show()

def find_helical_sections(writhe_file):
    DI = np.genfromtxt(writhe_file)
    x = DI[:,0]
    y = DI[:,1]
    hels = []
    for i in range(len(DI)-7):
        for j in range(i+7,len(DI)):
            if abs((DI[j][1]-DI[i][1])/(DI[j][0]-DI[i][0])) >= 0.1:
                hels.append([DI[i][0],DI[j][0],DI[j][1]-DI[i][1]])
    starts = [i[0] for i in hels]
    ends = [i[1] for i in hels]

    longest_subsections = []
    current_start = starts[0]
    current_end = ends[0]
    max_length = current_end - current_start

    for i in range(1, len(starts)):
        start = starts[i]
        end = ends[i]
        if start > current_end:
            if current_end - current_start == max_length:
                longest_subsections.append([current_start, current_end])
            elif current_end - current_start > max_length:
                max_length = current_end - current_start
                longest_subsections = [[current_start, current_end]]
            current_start = start
            current_end = end
            max_length = current_end - current_start
        else:
            current_end = max(current_end, end)
            max_length = current_end - current_start

    if current_end - current_start == max_length:
        longest_subsections.append([current_start, current_end])
    elif current_end - current_start > max_length:
        longest_subsections = [[current_start, current_end]]

    result = []
    for subsection in longest_subsections:
        current_start = subsection[0]
        current_end = subsection[1]
        max_length = current_end - current_start
        
        for i in range(1, len(starts)):
            start = starts[i]
            end = ends[i]
            if start > current_end:
                result.append([current_start, current_end])
                break
            else:
                current_end = max(current_end, end)
                max_length = current_end - current_start
        else:
            result.append([current_start, current_end])
    return result

In [None]:
%%capture
#@title Download PDB File and FASTA
#@markdown Leave chain_id blank if unsure, this will take the first chain in the PDB file.
pdb_code = "6eni" #@param {type:"string"}
chain_id = "" #@param {type:"string"}
if chain_id == "":
    chain_id = get_chains_from_biotite(pdb_code)[0]
if not os.path.exists('/content/'+pdb_code.lower()):
    os.mkdir('/content/'+pdb_code.lower())
urllib.request.urlretrieve('http://files.rcsb.org/download/'+pdb_code.lower()+'.pdb', '/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.pdb')
#urllib.request.urlretrieve('https://www.rcsb.org/fasta/entry/'+pdb_code.upper()+'/download', '/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.fasta')

In [None]:
#@title Extract backbone
write_curve_to_file(get_backbone_from_biotite(pdb_code,chain_id),'/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.xyz')

In [None]:
#@title Smooth Backbone (Secondary Structures + Essential Crossings)
write_curve_to_file(kmt_ss(pdb_code.lower(),chain_id),'/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz')

In [None]:
#@title Smooth Backbone (Every n<sup>th</sup> point) - Recommend 3
n = 3 #@param {type:"integer"}
def smooth_backbone_n(name,n):
    sses = get_sses(get_ss_from_fasta('/content/'+name+'/ss_'+name+'.fasta'))
    coords = np.genfromtxt('/content/'+name+'/'+name+'.xyz')
    splitmol=[]
    idx=0
    for i in sses:
        splitmol.append(coords[idx:idx+i[1]])
        idx+=i[1]
    for i in range(len(splitmol)):
        splitmol[i]=splitmol[i][::n]
    downmol = [element for sublist in splitmol for element in sublist]
    with open('/content/'+name+'/smooth_'+str(n)+'_'+name+'.xyz', "w+") as f:
        for i in range(len(downmol)-1):
            string = ' '.join(map(str,downmol[i]))
            f.write(string)
            f.write('\n')
        string = ' '.join(map(str,downmol[-1]))
        f.write(string)

smooth_backbone_n(pdb_code.lower(),n)


In [None]:
#@title Calculate Writhe Profile of Smoothed Backbone (Recommended)
smoothed_by_n = False #@param {type:"boolean"}
#read in the initial file
if smoothed_by_n:
    curveFile='/content/'+pdb_code.lower()+'/'+'smooth_'+str(n)+'_'+pdb_code.lower()+'.xyz'
else:
    curveFile='/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'.xyz'
#open that file into numpy array
coords = np.loadtxt(curveFile)
#make a list of subsets of this curve
size=int(np.size(coords)/3)
# we need at least 5 points for a reasonable writhe make
workingDirectory=os.getcwd()
os.mkdir(workingDirectory+"/tmp")
if smoothed_by_n:
    f3 = open('/content/'+pdb_code.lower()+'/'+'smooth_'+str(n)+'_'+pdb_code.lower()+'_writhes.dat', 'w+')
else:
    f3 = open('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat', 'w+')
for i in range(5,size+1):
    subsets=coords[:i]
    # write it to file
    subfile=("tmp/subcurve{:d}.dat".format(i-4))
    np.savetxt(subfile,subsets[:])
    fullloc3 ="/content/src/src/DIwr "+subfile
    f3.write(str(i)+" " + os.popen(fullloc3).read()+"\n")
    os.remove(subfile)
shutil.rmtree(workingDirectory+"/tmp")
f3.close()

In [None]:
#@title Plot Writhe Profile 
smoothed_by_n = False #@param {type:"boolean"}
highlight_helical_subsections = True #@param {type:"boolean"}
if smoothed_by_n:
    DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+'smooth_'+str(n)+'_'+pdb_code.lower()+'_writhes.dat')
    if highlight_helical_subsections:
        res = find_helical_sections('/content/'+pdb_code.lower()+'/'+'smooth_'+str(n)+'_'+pdb_code.lower()+'_writhes.dat')
else:
    DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')
    if highlight_helical_subsections:
        try:
            res = find_helical_sections('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')
        except:
            print('No helical subsections')
            highlight_helical_subsections=False

x = DI[:,0]
y = DI[:,1]
fig=go.Figure()
fig.add_trace(go.Scatter(x=x,y=y,mode='lines',name=pdb_code.upper(),
                         marker=dict(color='blue',size=10),
                         line=dict(width=5)))
if highlight_helical_subsections:
    for i in range(len(res)):
        stindex = np.where(x == res[i][0])[0][0]
        endex = np.where(x==res[i][1])[0][0]
        leg = (i==0)
        fig.add_trace(go.Scatter(x=x[stindex:endex],y=y[stindex:endex],
                                  mode='lines',
                                  name='Helical Subsections',
                                  legendgroup="helical",
                                  showlegend=leg,
                                  marker=dict(color='green',size=10),
                                  line=dict(width=5)
                                  )
        )
fig.update_layout(
    autosize=False,
    width=1000,
    height=0.75*1000)
fig.update_layout(
    font_family="Tenorite",
    font_color="black",
    title_font_family="Tenorite",
    title_font_color="black",
    legend_title_font_color="black",
    xaxis_title="Subsection Length",
    yaxis_title="Writhe",
    font=dict(size=16)
)
fig.show()

In [None]:
#@title View Smoothed Molecule
smoothed_by_n = False #@param {type:"boolean"}
highlight_helical_subsections = True #@param {type:"boolean"}
if smoothed_by_n:
    if highlight_helical_subsections:
        DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+'smooth_'+str(n)+'_'+pdb_code.lower()+'_writhes.dat')
        res = find_helical_sections('/content/'+pdb_code.lower()+'/'+'smooth_'+str(n)+'_'+pdb_code.lower()+'_writhes.dat')
        view_molecule_helical('/content/'+pdb_code.lower()+'/smooth_'+str(n)+'_'+pdb_code.lower()+'.xyz',DI,res)
    else:
        view_molecule('/content/'+pdb_code.lower()+'/smooth_'+str(n)+'_'+pdb_code.lower()+'.xyz')
else:
    if highlight_helical_subsections:
        DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')
        res = find_helical_sections('/content/'+pdb_code.lower()+'/'+'smooth_'+pdb_code.lower()+'_writhes.dat')
        view_molecule_helical('/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz',DI,res)
    else:
        view_molecule('/content/'+pdb_code.lower()+'/smooth_'+pdb_code.lower()+'.xyz')

In [None]:
#@title View Molecule
view_molecule('/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.xyz')

In [None]:
#@title Calculate Writhe Profile of Full Backbone (Noisy and Slow)
#read in the initial file
curveFile='/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'.xyz'
#open that file into numpy array
coords = np.loadtxt(curveFile)
#make a list of subsets of this curve
size=int(np.size(coords)/3)
# we need at least 5 points for a reasonable writhe make
workingDirectory=os.getcwd()
os.mkdir(workingDirectory+"/tmp")
f3 = open('/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'_writhes.dat', 'w+')
for i in range(4,size):
    subsets=coords[:i]
    # write it to file
    subfile=("tmp/subcurve{:d}.dat".format(i-4))
    np.savetxt(subfile,subsets[:])
    fullloc3 ="/content/src/src/DIwr "+subfile
    f3.write(str(i)+" " + os.popen(fullloc3).read()+"\n")
    os.remove(subfile)
shutil.rmtree(workingDirectory+"/tmp")
f3.close()


In [None]:
#@title Plot Full Writhe Profile
DI = np.genfromtxt('/content/'+pdb_code.lower()+'/'+pdb_code.lower()+'_writhes.dat')
x = DI[:,0]
y = DI[:,1]
fig=go.Figure()
fig.add_trace(go.Scatter(x=x,y=y,mode='lines',name=pdb_code.upper(),
                         marker=dict(color='blue',size=10),
                         line=dict(width=5)))
fig.update_layout(
    autosize=False,
    width=1500,
    height=0.5*1500)
fig.update_layout(
    font_family="Tenorite",
    font_color="black",
    title_font_family="Tenorite",
    title_font_color="black",
    legend_title_font_color="black",
    xaxis_title="No. Residues",
    yaxis_title="Writhe",
    font=dict(size=24)
)
fig.show()

In [None]:
#@title Upload your own PDB and SS Fingerprint file.
#@markdown Note: The fingerprint file should be in the format ---HHH---SSS--- etc. where H is an alpha helix, S a beta strand, and - a linker.
pdb_file_loc = "/content/human_SMARCAL1.pdb" #@param {type:"string"}
ssfp_file_loc = "/content/fp.dat" #@param {type:"string"}
fr = open(pdb_file_loc, 'r')
fw = open('/content/'+os.path.basename(pdb_file_loc)[:-4]+'CA.dat','w+')
for record in fr:
    if(re.search(r'^ATOM\s+\d+\s+CA\s+', record)):
        fw.write(record)
fr.close()
fw.close()
parser = PDBParser(PERMISSIVE=False,QUIET=False)
structure = parser.get_structure(os.path.basename(pdb_file_loc)[:-4],'/content/'+os.path.basename(pdb_file_loc)[:-4]+'CA.dat')
XYZ=[]
for model in structure:
    chains = model.get_chains()
    ids = []
    for chain in chains:
        ids.append(chain.get_id())
    chain = model[ids[0]]
    for residue in chain:
        for atom in residue:
            XYZ.append(atom.get_coord())
for i in range(len(XYZ)):
    XYZ[i] = XYZ[i].tolist()
    for j in range(3):
        XYZ[i][j] = float(XYZ[i][j])
write_curve_to_file(XYZ,'/content/'+os.path.basename(pdb_file_loc)[:-4]+'.xyz')
os.remove('/content/'+os.path.basename(pdb_file_loc)[:-4]+'CA.dat')

In [None]:
#@title Smooth your file
def read_in_ss(flname):
    ss = []
    with open(flname) as fin:
        for line in fin:
            ss.append(line)
    return [*ss[0]]

def kmt_ss_diy(pdb_file_loc,ssfp_file_loc):
    mol = np.genfromtxt('/content/'+os.path.basename(pdb_file_loc)[:-4]+'.xyz')
    ss = get_sses(read_in_ss(ssfp_file_loc))
    assert len(mol) == sum([i[1] for i in ss])
    splitcurve = []
    index = 0
    for i in ss:
        splitcurve.append(mol[index:index+i[1]])
        index+=i[1]
    newcurve = []
    for i in range(len(splitcurve)):
        for j in range(len(splitcurve[i])):
            newcurve.append(splitcurve[i][j])
    for subsec in range(len(splitcurve)):
        if len(splitcurve[subsec])>2:
            checks = []
            for idx in range(1,len(splitcurve[subsec])-1):
                p1 = splitcurve[subsec][0]
                p2 = splitcurve[subsec][idx]
                p3 = splitcurve[subsec][-1]
                for edge in get_all_edges(newcurve):
                    q0 = edge[0]
                    q1 = edge[1]
                    checks.append(intersect_line_triangle(q0,q1,p1,p2,p3))
            if not any(checks):
                splitcurve[subsec] = [splitcurve[subsec][0]]
                newcurve = []
                for l in range(len(splitcurve)):
                    for m in range(len(splitcurve[l])):
                        newcurve.append(splitcurve[l][m])
            else:
                idx=2
                while idx<len(splitcurve[subsec]):
                    newcurve = []
                    for i in range(len(splitcurve)):
                        for j in range(len(splitcurve[i])):
                            newcurve.append(splitcurve[i][j])
                    p1 = splitcurve[subsec][idx-2]
                    p2 = splitcurve[subsec][idx-1]
                    p3 = splitcurve[subsec][idx]
                    checks = []
                    for edge in get_all_edges(newcurve):
                        q0 = edge[0]
                        q1 = edge[1]
                        checks.append(intersect_line_triangle(q0,q1,p1,p2,p3))
                    if not any(checks):
                        splitcurve[subsec] = np.delete(splitcurve[subsec],idx-1,axis=0)
                        idx=2
                    else:
                        idx+=1
        else:
            splitcurve[subsec] = [splitcurve[subsec][0]]
            newcurve = []
            for l in range(len(splitcurve)):
                for m in range(len(splitcurve[l])):
                    newcurve.append(splitcurve[l][m])
    newcurve = []
    for i in range(len(splitcurve)):
        for j in range(len(splitcurve[i])):
          newcurve.append(splitcurve[i][j])
    if not np.array_equal(newcurve[-1],mol[-1]):
        newcurve.append(mol[-1])
    return newcurve

write_curve_to_file(kmt_ss_diy(pdb_file_loc,ssfp_file_loc),'/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'.xyz')

In [None]:
#@title Calculate Writhe Profile of your Smoothed Backbone
curveFile='/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'.xyz'
#open that file into numpy array
coords = np.loadtxt(curveFile)
#make a list of subsets of this curve
size=int(np.size(coords)/3)
# we need at least 5 points for a reasonable writhe make
workingDirectory=os.getcwd()
os.mkdir(workingDirectory+"/tmp")
f3 = open('/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'_writhes.dat', 'w+')
for i in range(5,size+1):
    subsets=coords[:i]
    # write it to file
    subfile=("tmp/subcurve{:d}.dat".format(i-4))
    np.savetxt(subfile,subsets[:])
    fullloc3 ="/content/src/src/DIwr "+subfile
    f3.write(str(i)+" " + os.popen(fullloc3).read()+"\n")
    os.remove(subfile)
shutil.rmtree(workingDirectory+"/tmp")
f3.close()

In [None]:
#@title Plot Writhe Profile 
highlight_helical_subsections = True #@param {type:"boolean"}
DI = np.genfromtxt('/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'_writhes.dat')
if highlight_helical_subsections:
    res = find_helical_sections('/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'_writhes.dat')

x = DI[:,0]
y = DI[:,1]
fig=go.Figure()
fig.add_trace(go.Scatter(x=x,y=y,mode='lines',name=pdb_code.upper(),
                         marker=dict(color='blue',size=10),
                         line=dict(width=5)))
if highlight_helical_subsections:
    for i in range(len(res)):
        stindex = np.where(x == res[i][0])[0][0]
        endex = np.where(x==res[i][1])[0][0]
        leg = (i==0)
        fig.add_trace(go.Scatter(x=x[stindex:endex],y=y[stindex:endex],
                                  mode='lines',
                                  name='Helical Subsections',
                                  legendgroup="helical",
                                  showlegend=leg,
                                  marker=dict(color='green',size=10),
                                  line=dict(width=5)
                                  )
        )
fig.update_layout(
    autosize=False,
    width=1000,
    height=0.75*1000)
fig.update_layout(
    font_family="Tenorite",
    font_color="black",
    title_font_family="Tenorite",
    title_font_color="black",
    legend_title_font_color="black",
    xaxis_title="Subsection Length",
    yaxis_title="Writhe",
    font=dict(size=16)
)
fig.show()

In [None]:
#@title View Smoothed Molecule
highlight_helical_subsections = True #@param {type:"boolean"}
if highlight_helical_subsections:
  DI = np.genfromtxt('/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'_writhes.dat')
  res = find_helical_sections('/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'_writhes.dat')
  view_molecule_helical('/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'.xyz',DI,res)
else:
  view_molecule('/content/smooth_'+os.path.basename(pdb_file_loc)[:-4]+'.xyz')