In [1]:
import pickle as pkl
import os
import prody as pro
from Bio.SeqUtils import seq1
from Bio import pairwise2
import re
import collections

In [22]:
from matplotlib.pyplot import subplots
from itertools import chain, islice
from string import ascii_uppercase
from numpy.random import choice
import matplotlib.pyplot as plt
from venn import venn
import os
import shutil
import numpy as np
import pandas as pd
from pathlib import Path
import pickle as pkl
from Bio import pairwise2

In [23]:
import numpy as np
from tmtools import tm_align
from tmtools.io import get_structure, get_residue_data
from tmtools.testing import get_pdb_path

In [24]:
def myget_structure(mpath, name):
    s2 = get_structure(get_pdb_path(mpath+name))
    chain = next(s2.get_chains())
    coords2, seq2 = get_residue_data(chain)
    return coords2, seq2

In [2]:
def hamming_distance_inverse(str1, str2):
    assert len(str1) == len(str2)
    return sum(chr1 == chr2 for chr1, chr2 in zip(str1, str2))

In [3]:
def produce_stride_output(af_models_path, stride_out_path):

    forstride = os.listdir(af_models_path)
    print(forstride)
    for f in forstride:
        if(f!='.DS_Store'):
            filename =  af_models_path+f+'/'+'ranked_0.pdb'
            if(os.path.isfile(filename)):
                pro.execSTRIDE(filename, outputname=f, outputdir=stride_out_path)
    return

In [4]:
def stride_to_dict(mpath):

    dnumclass = {}
    dnumname = {}
    with open(mpath+'.stride') as stride:
        for line in stride:
            if(line.startswith('ASG')):
                stridetext = line.split()
                resnum = int(stridetext[3])
                resname = stridetext[1]
                resclass = stridetext[6]
                dnumclass[resnum] = resclass
                dnumname[resnum] = resname
    return dnumclass, dnumname

In [5]:
def get_1code_seq_from_stride(mfile):

    d1_class, d1_name = stride_to_dict(mfile)
    s1_3 = ''.join(str(x) for x in d1_name.values())
    return seq1(s1_3),d1_class

def translate_sse(sse):

    sse2 = sse.replace('3','A')
    sse3 = sse2.replace('T', 'C')
    sse4 = sse3.replace('B', 'C')
    return sse4

def get_sse_align_seq(aligned_1, d1_class):

    arr_sse = []
    sse_class = list(d1_class.values())
    #print(sse_class)
    i = 0
    for s in aligned_1:
        if (s != '-'):
            arr_sse.append(sse_class[i][0])
            i += 1
        else:
            arr_sse.append('-')
    sse1 = ''.join(str(x) for x in arr_sse)
    #print(sse1)
    sse1_trans = translate_sse(sse1)
    return sse1_trans

In [6]:
def get_array_alignments(sse1, sse2, min_align_length):

    #print(sse1)
    arr_align = []
    sse_trans1 = []
    sse_trans2 = []
    start_origin = 0  # in relation to 1st sequence
    start_origin_2 = 0
    for i in range(0, len(sse1)):

        if ((sse1[i] != '-') and (sse2[i]) != '-'):  # alignment
            sse_trans1.append(sse1[i])
            sse_trans2.append(sse2[i])
        else:
            if(len(sse_trans1)>0):
                if(len(sse_trans1)>=min_align_length):
                    #print('length')
                    # print(len(sse_trans1))
                    # print(start_origin-len(sse_trans1))
                    arr_align.append((start_origin-len(sse_trans1), ''.join(sse_trans1), ''.join(sse_trans2), start_origin_2-len(sse_trans2)))
                sse_trans1 = []
                sse_trans2 = []
        if (sse1[i] != '-'):
            start_origin += 1
        if (sse2[i] != '-'):
            start_origin_2 += 1
    # print('arr_align')
    # print(arr_align)

    return arr_align

In [7]:
def find_longest_sse(line):

    #limited alphabet of sse: ['A','C','T','S']
    allsse = ['A', 'C', 'T', 'S']
    dict_res = {}
    for a in allsse:
        rx = r'[^{0}]'.format(a)
        arr_sse = (re.split(rx, line))
        print(arr_sse)
        sub_sse = max(arr_sse, key=len)
        dict_res[a] = len(sub_sse)

    return dict_res


In [9]:
def is_switch(rr, min_switch_length):

    d1 = find_longest_sse(rr[1])
    d2 = find_longest_sse(rr[2])
    print(d1)
    print(d2)
    sse1 = (max(d1, key=d1.get))
    sse2 = (max(d2, key=d2.get))
    print(sse1)
    print(sse2)
    return (sse1 != sse2)

In [10]:
def get_max_switch(rr, min_switch_length):
    #limited alphabet of sse: ['A','C','S']
    lin1 = rr[1]
    lin2 = rr[2]
    cur_length = 0
    max_length = 0
    cond_sse = False
    cur_pos = 0
    max_pos = 0
    for i in range(0, len(lin1)):
        lin2_sse = lin2[i]
        cur_sse = lin1[i]
        if cur_sse != lin2_sse:
            if(cur_length==0):
                cur_pos = i
            cur_length += 1
            #max_length = max(max_length, cur_length)
        else:
            if(cur_length>max_length):
                max_pos = cur_pos
                max_length = cur_length
            cur_length = 0
    if(max_length>=min_switch_length):
        cond_sse =True
    return max_length, max_pos, cond_sse


In [11]:
def get_all_switch(rr, min_switch_length):
    #limited alphabet of sse: ['A','C','S']
    lin1 = rr[1]
    lin2 = rr[2]
    cur_length = 0
    cur_pos = 0
    res = []
    for i in range(0, len(lin1)):
        lin2_sse = lin2[i]
        cur_sse = lin1[i]
        if cur_sse != lin2_sse:
            if(cur_length==0):
                cur_pos = i
            cur_length += 1
        else:
            if(cur_length>min_switch_length):
                res.append((cur_pos, cur_length,lin1[cur_pos:cur_pos+cur_length],lin2[cur_pos:cur_pos+cur_length]))
            cur_length = 0
    return res

In [13]:
def find_sse_switch(sse1, sse2, min_align_length, min_switch_length, aligned_1, aligned_2):
#to allow gaps in alignment change '-' to 'X' 
    res = []
    r = get_array_alignments(sse1, sse2, min_align_length)

    #print(r)
    for rr in r:
        #print(rr)
        pos_1 = (rr[0])
        pos_2 = (rr[3])

        res_cur = get_all_switch(rr, min_switch_length)
        #res.append(res_cur)
        #print(res_cur)
        for k in res_cur:
            # print(aligned_1)
            # print(aligned_2)
            pos1_edit = (k[0])+pos_1
            pos2_edit = (k[0]) + pos_2
            s1 = aligned_1[pos1_edit:pos1_edit+k[1]]
            s2 = aligned_2[pos2_edit:pos2_edit+k[1]]
            score_sw = hamming_distance_inverse(s1, s2)/k[1]
    
            res.append([pos1_edit, pos2_edit, k[1:], score_sw])

    return res

In [14]:
def get_sse_stride_alignment(m1, m2, mpath):

    # m1 = 'ORF_96383_ecoli_2981009-2981191_+'
    # m2 = 'ORF_96391_ecoli_2983547-2983744_+'
    # mpath = '/Users/hadarovi/addproject/ecoli_newprot/stride/'
    
    dict_return = {}

    s1,d1_class = get_1code_seq_from_stride(mpath+m1)
    s2,d2_class = get_1code_seq_from_stride(mpath + m2)

    alignments2 = pairwise2.align.localxs(s1, s2, -0.5, -0.0, penalize_extend_when_opening=True)
    #print(len(alignments2))
    aligned_1 = ((alignments2[0][0]))
    aligned_2 = ((alignments2[0][1]))
    #print(aligned_1)
    #print(aligned_2)
    #print('score')
    score = (alignments2[0][2])

    sse1 = get_sse_align_seq(aligned_1, d1_class)
    sse2 = get_sse_align_seq(aligned_2, d2_class)

    sse_sim = hamming_distance_inverse(sse1, sse2)

    min_align_length = 5
    min_switch_length = 5
    res = find_sse_switch(sse1, sse2, min_align_length, min_switch_length, s1, s2)
    dict_return['align'] = [aligned_1, aligned_2, sse1, sse2]
    dict_return['switch'] = res
    return dict_return
    #return score, sse_sim, res

In [15]:
def get_bfactor_and_stride(mpath_to_stride,mpath_to_pdb):
        d = {}
        dn = {}
        d1, d2 = stride_to_dict(mpath_to_stride)
        s = set()
        with open(mpath_to_pdb) as file:
            for line in file:
                if (line.startswith('ATOM')):

                    resname = line[17:20]
                    resrel = int(line[22:26])
                    resnum = resrel
                    bfactor = float(line[60:66].strip())
                    dn[resnum] = resname
                    sclass = d1[resrel]

                    if (resrel not in s):
                        if (resnum in d.keys()):
                            d[resnum].append((sclass, bfactor))

                        else:
                            d[resnum] = []
                            d[resnum].append((sclass, bfactor))
                        s.add(resrel)

        return d, dn

In [16]:
def get_strideclass_dict():
    aadict = {'A': 1, 'G': 1, 'V': 1, 'M': 2, 'S': 2, 'T': 2, 'Y': 2, 'C': 7, 'D': 6, 'E': 6, 'R': 5, 'K': 5, 'I': 3,
              'L': 3, 'F': 3, 'P': 3, 'N': 4, 'Q': 4, 'H': 4, 'W': 4}
    amino = sorted(list((aadict.keys())))
    thrarr = ['d', 'l', 'm', 'h']
    strideclass = sorted(['Coil', 'Turn', 'Strand', 'AlphaHelix', '310Helix', 'Bridge', 'PiHelix'])
    keysd = list()
    for k in list(amino):
        for stride in strideclass:
            for thr in thrarr:
                keysd.append(str(k) + '-' + str(stride) + '-' + str(thr))

    d = dict.fromkeys(keysd, 0)
    return d

In [17]:
def save_to_pkl_sse_bfactor(mpath_to_stride,mpath_to_pdb, mpath_to_pkl):

    dict_to_save = {}
    d, dn = get_bfactor_and_stride(mpath_to_stride, mpath_to_pdb)
    dict_to_save['seqlen'] = len(d)
    strideclass = sorted(['Coil', 'Turn', 'Strand', 'AlphaHelix', '310Helix', 'Bridge', 'PiHelix'])
    for k in strideclass:
        dict_to_save[k] = []
    #print(d)
    for k, v in d.items():
        #print(k)
        str_class = v[0][0]
        #print(v[0][0])
        if(str_class in dict_to_save.keys()):
            dict_to_save[str_class].append(v[0][1])
        else:
            print('error')
        #print(v[1])
    with open(mpath_to_pkl + '_sse_stride_afmodel.pickle', 'wb') as handle:
        #print(dict_to_save)
        pkl.dump(dict_to_save, handle)
    return dict_to_save

In [18]:
def save_to_pkl_sse_bfactor_pipe(mpath,mpath_stride,fprsse):

    for s in fprsse:
        if(s.startswith('ORF')):
            mpath_to_pkl = mpath + s + '/'
            filename = mpath + s + '/' + 'ranked_0.pdb'
            if (os.path.isfile(filename)):
                r = save_to_pkl_sse_bfactor(mpath_stride + s, mpath + s + '/' + 'ranked_0.pdb', mpath_to_pkl)
    return

In [21]:
def get_plddt_from_AF(mpath_bio):
    
    res_plddt2 = {}
    atom_len = {}
    fprsse = os.listdir(mpath_bio)
    fprsse2 = [k for k in fprsse if k.endswith('.pdb')]

    for f in fprsse2:
        print(f)
        test = pro.parsePDB(mpath_bio+f)
        b = test.getBetas()
        b50 = b[b<50]
        atom_len[f] = len(b)
        res = len(b50)/len(b)
        res_plddt2[f] = (res)
    return res_plddt2, atom_len    

In [33]:
def save_structural_info(mpath,mpath_bio,dict_isoforms):
    
    #mpath = '/Users/hadarovi/addproject/new_transc/clean_data/large_data/'
    #mpath_bio = mpath+'all_structures_rank0/'
    genes = dict_isoforms.keys()
    for g in genes:
        #print(g)
        pkl_path = mpath+'info/'+g+'/' 
        d_res = {}
        Path(pkl_path).mkdir(parents=True, exist_ok=True)
        infog = dict_isoforms[g]  # for each gene g dict contains all isoforms names (AF structure names)

        for i in range(0, len(infog)):
            mf1 = Path(mpath_bio+infog[i]+'_ranked_0.pdb')
            if mf1.is_file():
                for j in range(i+1, len(infog)):
       
                    name1 = infog[i]+'_ranked_0'
                    name2 = infog[j]+'_ranked_0'

                    mf2 = Path(mpath_bio+infog[j]+'_ranked_0.pdb')
                    if mf2.is_file():
                        if(name1!=name2):
                            coords1, seq1 = myget_structure(mpath_bio, name1)
                            coords2, seq2 = myget_structure(mpath_bio, name2)
                            res2 = tm_align(coords1, coords2, seq1, seq2)
                            res_dict = {}
                            res_dict['t'] = res2.t
                            res_dict['u'] = res2.u
                            res_dict['tm_norm_chain1'] = res2.tm_norm_chain1
                            res_dict['tm_norm_chain2'] = res2.tm_norm_chain2
                            res_dict['seq1'] = seq1
                            res_dict['seq2'] = seq2
                            alignments = pairwise2.align.globalxx(seq1, seq2)
                            res_dict['nw_1'] = alignments[0].score/(len(seq1))
                            res_dict['nw_2'] = alignments[0].score/(len(seq2))
                           
                            d_res[(name1,name2)]=[(res2.tm_norm_chain1,res2.tm_norm_chain2),(res_dict['nw_1'],res_dict['nw_2'])]
                            with open(pkl_path+name1+'#'+name2+'.pickle', 'wb') as handle:
                                pkl.dump(res_dict, handle)
                               

        with open(pkl_path+g+'.pickle', 'wb') as handle2:
            pkl.dump(d_res, handle2)

In [27]:
def structure_filter_isoforms(mpath_bio):
    
    res_plddt2, atom_len = get_plddt_from_AF(mpath_bio)
    arr_sav = []
    max_pldt = []
    for g in genes:
        #print(g)
        pkl_path = mpath+'info/'+g+'/'
        mf2 = Path(pkl_path+g+'.pickle')
        #print(mf2)
        if mf2.is_file():
            with open(pkl_path+g+'.pickle', 'rb') as handle:
                b = pkl.load(handle)
                #print(b)
                for k, v in b.items():
                    p1 = k[0]+'.pdb'
                    #print(p1)
                    p2 = k[1]+'.pdb'
                    if(p1 in res_plddt2.keys()):
                        pl1 = res_plddt2[p1]

                    if(p2 in res_plddt2.keys()):
                        pl2 = res_plddt2[p2]
          
                    plddt = max(pl1,pl2)
                    if((plddt<=0.4)):
                            arr_sav.append((v,g,k,plddt))
    return arr_sav                     


In [31]:
def aggregate_info_switches(mpath,mpstride,mpath_bio):

#     with open(mpath + 'max_v0_l_1.0_max_v1_l_0.1_plddt_lg04_dict_all.pickle', 'rb') as handle:
#         b = pkl.load(handle)
    b = structure_filter_isoforms(mpath_bio)   
    ress = []
    dictres = {}
    dictresnum = {}
    dictres['AC'] = set()
    dictres['CS'] = set()
    dictres['AS'] = set()
    dictresnum['AC'] = 0
    dictresnum['CS'] = 0
    dictresnum['AS'] = 0
 
    dict_pkl = {}
    for k in b:

        m1 = k[2][0]+'.pdb'
        m2 = k[2][1]+'.pdb'
        s = get_sse_stride_alignment(m1, m2, mpstride)
        s['scores'] = k
        gene = (k[1])

        if(len(s['switch'])>0):
    
            arrs = s['switch']
         
            for sw in arrs:
               
                sw1 = sw[2][1]
                sw2 = sw[2][2]
               
                a1 = (collections.Counter(sw1).most_common(1)[0][0])
                a2 = (collections.Counter(sw2).most_common(1)[0][0])
                code = (sorted(a1+a2))
                scode = ''.join(code)
                print('code')
                print(scode)
                dictres[scode].add(gene)
                dictresnum[scode]+=1
               
                lensw = sw[2][0]

            ress.append(s)
            dict_pkl[k[2]] = s
    return ress, dict_pkl
    


In [32]:
def process_SSE():
    
    mpdb = '/Users/hadarovi/ddcode/prediction/proteomes_calc/human/sse/beta_ex/pdb/'
    mstride = '/Users/hadarovi/ddcode/prediction/proteomes_calc/human/sse/beta_ex/stride/'
    produce_stride_output(mpdb, mstride)
    
    save_structural_info(mpath,mpath_bio,dict_isoforms) #dict_isoforms contains names of isoforms for each gene
    aggregate_info_switches(mpath, mpstride,mpath_bio)