In [123]:
import subprocess
import os
import glob
import toytree
import tqdm
import pandas as pd
from subprocess import PIPE, Popen
import shlex
from Bio import SeqIO
import random
import itertools
import numpy as np

def Fident(str1,str2 , verbose = False):
    #minlen= min( (len(str1),len(str2))  )
    #str1 = str1[:minlen]
    #str2 = str2[:minlen]
    str1 = np.array(list(str1))
    str2 = np.array(list(str2))            
    return len(np.where( (str1 == str2 ) & (str1 != '-' ) & (str2 != '-')  )[0]) / len(str1)

def copyaln( aln, seq):
    seqiter = iter(seq)
    newaln = ''
    for i,char in enumerate(aln):
        if char == '-':
            newaln += '-'
        else:
            newaln+=next(seqiter)
    return newaln

def read_dbfiles3di(  AADB , threeDidb):
    #find positions 
    threeDiseq = [ l.strip().replace('\x00','') for l in open(threeDidb)]
    lookup = AADB+'.lookup'
    ids = [ l.split()[1].strip() for l in open(lookup)]
    AAs = [ l.strip().replace('\x00','') for l in open(AADB)]

    mapper3di = dict(zip(ids,threeDiseq))
    mapperAA = dict(zip(ids,AAs))
    
    return mapper3di, mapperAA

def calc_fident_crossaln(row , verbose = False):
    #amino acid representations of alns using AAand3di or just 3di
    qaln_2, taln_2 = row.qaln , row.taln
    #start and stop of aln
    
    qstart_2, qend_2, tstart_2 , tend_2 = row.qstart, row.qend , row.tstart , row.tend
    #indexing starts at 1...
    
    #3di of the query and target
    structQ, structT = row['3diq'], row['3dit']
    AAq, AAt = row['AAq'], row['AAt']

    #add gaps
    t3diAA_newgaps = copyaln(taln_2, structT[tstart_2-1:tend_2]) 
    q3diAA_newgaps = copyaln(qaln_2, structQ[qstart_2-1:qend_2])
    row = pd.Series( { '3di_qaln_mode2':q3diAA_newgaps , '3di_taln_mode2':t3diAA_newgaps })
    #return columns
    return row

def get_leafset( treenode ):
    """
    this function returns the leafset of a node
    """
    if treenode.is_leaf():
        return [treenode.name]
    else:
        return treenode.get_leaf_names()


def sub2fasta( sub, outfile , fastacol1='qaln' , fastacol2='taln' ):
    with open(outfile, 'w') as f:
        f.write('>' + sub['query'] + '\n')
        f.write(sub[fastacol1] + '\n')
        f.write('>' + sub['target'] + '\n')
        f.write(sub[fastacol2] + '\n')    
    return outfile

def retalns(allvall, leafname1,leafname2):
    sub = allvall[allvall['query'].isin( leafname1)]
    sub = sub[sub['target'].isin(leafname2)]
    sub = sub[sub['query'] != sub['target']]
    #get max prot lenght aligned
    sub['alnlen'] = sub.apply(lambda x: max(x['qend'] - x['qstart'] , x['tend'] - x['tstart']) , axis = 1)
    sub = sub[sub['alnlen'] == sub['alnlen'].max()]
    if len(sub)==0:
        print(leafname1, leafname2)
        raise Exception('no sub')
    return sub.iloc[0]
def get_fasta_leafset(fasta):
    """
    this function returns the leafset of a fasta file
    """
    aln = SeqIO.parse(fasta, 'fasta')
    leafset = []
    for s in aln:
        leafset.append(s.id)
    return leafset

#traverse tree from root to leaves recursively
def traverse_tree_merge( treenode, topleafset, allvall , alnfolder , verbose = False):
    """
    this function traverses a tree from root to leaves recursively
    it returns a dictionary with the iteratively built alignment
    """
    if verbose == True:
        print('traverse', treenode.name , treenode.is_leaf() , treenode.leafset)
    
    if treenode.is_leaf():
        topleafset.remove(treenode.name)
        #if the node is a leaf, then we need to add it to the alignment with one of the pivots in the current leafset
        sub = retalns(allvall, [treenode.name] , topleafset)  
        treenode.aln = sub2fasta(sub, alnfolder + treenode.name + '_inter.fasta')
        treenode.aln3di = sub2fasta(sub, alnfolder + treenode.name + '_inter.3di.fasta' , fastacol1='3di_qaln_mode2' , fastacol2='3di_taln_mode2')
        return treenode.aln, treenode.aln3di
    
    else:
        childalns3di = {}
        childalnsAA = {}
        bridges3di = {}
        bridgesAA = {}
        #treenode.leafset = get_leafset(treenode)
        #get the intersection of the child leafsets
        treenode.leafset = get_leafset(treenode)
        children = treenode.get_children()
        
        if len(children) == 2 and children[0].is_leaf() and children[1].is_leaf():
            #treat the case of a cherry
            print('cherry', children[0].name , children[1].name)
            treenode.aln = sub2fasta( retalns(allvall, [children[0].name] , [children[1].name]) , alnfolder + treenode.name + '_inter.fasta')
            treenode.aln3di = sub2fasta( retalns(allvall, [children[0].name] , [children[1].name]) , alnfolder + treenode.name + '_inter.3di.fasta' , fastacol1='3di_qaln_mode2' , fastacol2='3di_taln_mode2')
            return treenode.aln, treenode.aln3di
        
        else:
            #not a cherry. one or both sides is a subtree
            print('not cherry', treenode.name  )
            print( 'children', [c.name for c in children])
            for c in treenode.get_children():
                #make sub aln for each child
                if verbose == True:
                    print('traverse', c.name , c.is_leaf() , c.leafset)
                if not c.aln:
                    c.aln,c.aln3di = traverse_tree_merge(c , treenode.leafset , allvall, alnfolder , verbose = verbose)
                childalnsAA[c] = { 'fasta': c.aln , 'protset':set(get_fasta_leafset(c.aln) ) }
                childalns3di[c] = { 'fasta': c.aln3di , 'protset':set(get_fasta_leafset(c.aln3di) ) }
            
            for c1,c2 in itertools.combinations(treenode.get_children(),2):
                bridge = retalns(allvall, childalnsAA[c1]['protset'] , childalnsAA[c2]['protset'] )
                bridgesAA[(c1,c2)] = { 'fasta': sub2fasta(bridge, alnfolder + treenode.name + '_bridge.fasta') , 'protset':set([bridge.query , bridge.target]) }
                bridges3di[(c1,c2)] = { 'fasta' : sub2fasta(bridge, alnfolder + treenode.name + '_bridge.3di.fasta' , fastacol1='3di_qaln_mode2' , fastacol2='3di_taln_mode2') , 'protset':set([bridge.query, bridge.target]) }
            
            #successively merge the alignments of the children
            for i, c in enumerate(itertools.combinations(treenode.get_children(),2)):
                c1,c2 = c
                if verbose == True:
                    print('merge', c1.name , c2.name)
                if i == 0:
                    #first merge
                    try:
                        print('first merge')
                        
                        alnAA = mergealns( childalnsAA[c1]['fasta'], bridgesAA[(c1,c2)]['fasta'] ,alnfolder + treenode.name + '_inter.fasta' , verbose=verbose)
                        aln3di = mergealns( childalns3di[c1]['fasta'], bridges3di[(c1,c2)]['fasta'] ,alnfolder + treenode.name + '_inter3di.fasta', verbose=verbose)

                        print('2 merge')

                        alnAA = mergealns( childalnsAA[c1]['fasta'], alnAA , alnfolder + treenode.name + '_inter.fasta' , verbose=verbose)
                        aln3di = mergealns( childalns3di[c1]['fasta'], aln3di ,alnfolder + treenode.name + '_inter3di.fasta', verbose=verbose)

                    except:
                        print( treenode , childalnsAA , childalns3di , bridgesAA , bridges3di)
                        raise Exception('merge error 1')
                else:
                    try:
                        print('3 merge')

                        alnAA = mergealns( childalnsAA[c2]['fasta'], bridgesAA[(c1,c2)]['fasta'] , alnAA , verbose=verbose)
                        alnAA = mergealns( childalnsAA[c1]['fasta'], alnAA , alnAA , verbose=verbose)

                        print('4 merge')

                        aln3di = mergealns( childalns3di[c1]['fasta'], bridges3di[(c1,c2)]['fasta'] , aln3di  , verbose=verbose)
                        aln3di = mergealns( childalns3di[c2]['fasta'], aln3di , aln3di , verbose=verbose)

                        
                    except:
                        print( treenode )
                        print( childalnsAA , childalns3di , bridgesAA , bridges3di)
                        raise Exception('merge error 2') 
            treenode.aln = alnAA
            treenode.aln3di = aln3di
            if verbose == True:
                #check if node is root  
                if treenode.up == None:
                    print('final aln')
                    print('childalnsAA', childalnsAA)
                    print('childalns3di', childalns3di)
            return treenode.aln, treenode.aln3di

def remove_redundant( alignment ):
    """
    this function removes redundant sequences from an alignment
    """
    aln = SeqIO.parse(alignment, 'fasta')
    seqs = []
    ids = []
    for s in aln:
        if s.id not in ids:
            seqs.append(s)
            ids.append(s.id)
    
    with open(alignment, 'w') as f:
        for s in seqs:
            f.write('>' + s.id + '\n')
            f.write(str(s.seq) + '\n')
    return alignment

#remove all alns except the final merged one
def cleanup( filedir ):
    """
    this function removes all alns except the final merged one
    """
    for f in glob.glob(filedir + '*inter.fasta'):
        os.remove(f)

In [124]:

def aln_mapping( s1 , s2, maxaln = 0 , coidx1= 0 , coidx2 = 0 , start1 = 0 , start2=0 ,  verbose = False):
    #build a dictionary of the positions of the characters in the string
    #convolve the strings
    maxaln, maxcount = convolve_strings(s1,s2)

    if start1 != 0 or start2 != 0:
        print('start', start1, start2)
    if verbose == True:
        print('maxaln', maxaln, maxcount)
    
    #find starting points
    if len(s1) < len(s2):
        if maxaln < 0:
            coidx1 = np.abs(maxaln)
            coidx2 = 0
        else:
            coidx1 = 0
            coidx2 = maxaln
    else:
        if maxaln < 0:
            coidx1 = 0
            coidx2 = np.abs(maxaln)
        else:
            coidx1 = maxaln
            coidx2 = 0
    
    substr1 = s1[coidx1:]
    substr2 = s2[coidx2:]

    print('substr1', substr1)
    print('substr2', substr2)
    print('coidx1', coidx1)
    print('coidx2', coidx2)
    
    #find equivalent positions in the strings
    maps1 = {}
    maps2 = {}

    oppositemap1 = {}
    oppositemap2 = {}

    for i, char in enumerate(substr1):
        if substr2[i] == char:
            maps1[i+coidx1] = i + start1
            maps2[i+coidx2] = i + start2
            
            oppositemap1[i+coidx1] = i + start2
            oppositemap2[i+coidx2] = i + start1
        else:
            #if there is mismatch convolve the remaining strings
            print('mismatch')
            sub1,sub2 , om1 , om2 = aln_mapping( substr1[i:] , substr2[i:], maxaln = 0 , start1= start1+coidx1+i , start2 = start2+coidx2+i , verbose = False)
            maps1.update( sub1 )
            maps2.update( sub2 )
            oppositemap1.update( om1 )
            oppositemap2.update( om2 )
            break
    
    return maps1, maps2 , oppositemap1, oppositemap2

def aln_mapping_full( s1 , s2, maxaln = 0 , coidx1= 0 , coidx2 = 0 , verbose = False):
    maps1, maps2 , om1 , om2 = aln_mapping( s1 , s2, maxaln = 0 , coidx1= 0 , coidx2 = 0 , verbose = False)
    #add the reverse mapping
    revmap1 = { v:k for k,v in maps1.items()}
    revmap2 = { v:k for k,v in maps2.items()}

    revmapom1 = { v:k for k,v in om1.items()}   
    revmapom2 = { v:k for k,v in om2.items()}
    
    return maps1, maps2, om1, om2 , revmap1, revmap2 , revmapom1, revmapom2

In [158]:


def convolve_strings(str1, str2):
    # Determine the lengths of the strings
    len1, len2 = len(str1), len(str2)

    if len(str1) < len(str2):
        str1, str2 = str2, str1
        len1, len2 = len2, len1
    max_alignment = 0
    max_count = 0
    # Slide str2 over str1, starting with one character overlap
    # and continue until str2 is again overlapping by just one character
    for i in range(-len2 + 1, len1):
        count = 0
        for j in range(len2):
            if 0 <= i + j < len1 and str1[i + j] == str2[j]:
                count += 1
        if count > max_count:
            max_count = count
            max_alignment = i
    return max_alignment, max_count


def alnchop(s1,s2,rawaln1,rawaln2,aln1,aln2,maxaln = 0):

    #align the two sequences
    coidx1 = 0
    coidx2 = 0

    if len(s1) < len(s2):
        if maxaln < 0:
            coidx1 = np.abs(maxaln)
            coidx2 = 0
        else:
            coidx1 = 0
            coidx2 = maxaln
    else:
        if maxaln < 0:
            coidx1 = 0
            coidx2 = np.abs(maxaln)
        else:
            coidx1 = maxaln
            coidx2 = 0
    discardcount = 0 
    rawaln1 = iter(rawaln1)
    while coidx1 > 0:
        rchar1 = next(rawaln1)
        discardcount += 1
        if rchar1 != '-':
            coidx1 -= 1
    aln1 = aln1[discardcount:]


    discardcount = 0 
    rawaln2 = iter(rawaln2)
    while coidx1 > 0:
        rchar2 = next(rawaln2)
        discardcount += 1
        if rchar2 != '-':
            coidx2 -= 1
    aln2 = aln2[discardcount:]


    rawaln1 = ''.join([ s for s in iter(rawaln1)])
    rawaln2 = ''.join([ s for s in iter(rawaln2)])

    return aln1, aln2, rawaln1, rawaln2


def mergealns( aln1f, aln2f, outfile , verbose = False):
    if set(get_fasta_leafset(aln1f)) == set(get_fasta_leafset(aln2f)):
        print('identical')
        return aln1f

    #find sequences in common between the two alignments
    aln1 = SeqIO.parse(aln1f, 'fasta')
    aln2 = SeqIO.parse(aln2f, 'fasta')
    ids1 = {s.id:str(s.seq) for s in aln1}
    ids2 = {s.id:str(s.seq) for s in aln2}
    aln1 = SeqIO.parse(aln1f, 'fasta')
    aln2 = SeqIO.parse(aln2f, 'fasta')
    idlist = [ s.id for s in aln1] + [ s.id for s in aln2]
    commonids = set(ids1.keys()).intersection(set(ids2.keys()))
    try:
        assert len(commonids) > 0
    except:
        print('no common ids')
        print('ids1', ids1)
        print('ids2', ids2)
        raise Exception('no common ids')
    #transform both alignments into numpy matrices
    aln1 = SeqIO.parse(aln1f, 'fasta')
    aln2 = SeqIO.parse(aln2f, 'fasta')
    aln1 = np.array([ list(str(s.seq)) for s in aln1])
    aln2 = np.array([ list(str(s.seq)) for s in aln2])
    nrows1 = aln1.shape[0]
    nrows2 = aln2.shape[0]

    #generate a list of column arrays
    aln1 = [ aln1[:,i] for i in range(aln1.shape[1])]
    aln2 = [ aln2[:,i] for i in range(aln2.shape[1])]
    #find the best common sequence
    maxconv = 0
    maxaln = 0

    print(ids1)
    print(ids2)
    
    for commonid in commonids:
        s1t = ids1[commonid]
        s2t = ids2[commonid]
        s1t = s1t.replace('-','')
        s2t = s2t.replace('-','')
        #if the common subsequence is not found start by removing the first character of the common sequence
        #convolution of the two sequences
        aln, count = convolve_strings(s1t,s2t)
        if count > maxconv:
            maxconv = count
            maxaln = aln
            ID = commonid
            s1 = s1t
            s2 = s2t

    rawaln1 = ids1[ID]
    rawaln2 = ids2[ID]
    print('pivot' , ID)
    print('s1',s1)
    print('s2',s2)

    print('rawaln1', rawaln1)
    print('rawaln2', rawaln2)
    
    print('maxaln', maxaln)
    #use the sequence convolution to align the two alignment arrays
    if len(s1) < len(s2):
        if maxaln < 0:
            coidx1 = np.abs(maxaln)
            coidx2 = 0
        else:
            coidx1 = 0
            coidx2 = maxaln
    else:
        if maxaln < 0:
            coidx1 = 0
            coidx2 = np.abs(maxaln)
        else:
            coidx1 = maxaln
            coidx2 = 0
    print('coidx1', coidx1)
    print('coidx2', coidx2)


    #remove the leading gaps
    for i in range(len(rawaln1)):
        if rawaln1 != '-':
            break
    aln1 = aln1[i:]
    rawaln1 = rawaln1[i:]
    for i in range(len(rawaln1)):
        if rawaln1 != '-':
            break
    aln2 = aln2[i:]
    rawaln2 = rawaln2[i:]

    #remove the trailing gaps
    for i in range(len(rawaln1)):
        if rawaln1[-i] != '-':
            break
    aln1 = aln1[:len(rawaln1)-i]
    rawaln1 = rawaln1[:len(rawaln1)-i]
    for i in range(len(rawaln2)):
        if rawaln2[-i] != '-':
            break
    aln2 = aln2[:len(rawaln2)-i]
    rawaln2 = rawaln2[:len(rawaln2)-i]
    

    #construct alignment with common sequence

    rawaln1 = iter(rawaln1)
    rawaln2 = iter(rawaln2)    
    char1 = next(rawaln1)
    char2 = next(rawaln2)
    i = 1
    j = 1 
    pchar1 = char1
    pchar2 = char2
    newaln1 = []
    newaln2 = []
    convolved = False
    while True:
        try:
            if pchar1 == '-' and char1 != '-':
                print('end insertion1')

            if pchar2 == '-' and char2 != '-':
                print('end insertion2')

            if char1 == '-' and char2 != '-':
                print('insertion1')
                newaln2.append(['-']*nrows2)
                newaln1.append(aln1[i])
                pchar1 = char1
                char1 = next(rawaln1)
                i +=1

            elif char2 == '-' and char1 != '-':
                print('insertion2')
                newaln2.append(aln2[j])
                newaln1.append(['-']*nrows1)
                pchar2 = char2
                char2 = next(rawaln2)
                j +=1

            elif char1 == char2 and char1 != '-' and char2 != '-':
                char1 = next(rawaln1)
                char2 = next(rawaln2)
                newaln2.append(aln2[j])
                newaln1.append(aln1[i])
                pchar1 = char1
                pchar2 = char2
                j+= 1
                i+= 1

            elif char1 != '-' and char2 != '-' and char1 != char2:
                convolved = True
                print('mismatch')
                #mismatch reconvolve remaining strings
                rawaln1 = ''.join([ s for s in iter(rawaln1)])
                rawaln2 = ''.join([ s for s in iter(rawaln2)])
                s1 = rawaln1.replace('-','')
                s2 = rawaln2.replace('-','')

                aln, count = convolve_strings(s1,s2)
                if len(s1) < len(s2):
                    if maxaln < 0:
                        coidx1 = np.abs(maxaln)
                        coidx2 = 0
                    else:
                        coidx1 = 0
                        coidx2 = maxaln
                else:
                    if maxaln < 0:
                        coidx1 = 0
                        coidx2 = np.abs(maxaln)
                    else:
                        coidx1 = maxaln
                        coidx2 = 0
                
                rawaln1 = iter(rawaln1)
                discardcount1 = 0
                count1 = 0
                while count1 < coidx1:
                    discardcount1 += 1
                    char1 = next(rawaln1)
                    if char1 != '-':
                        count1 += 1
                aln1 = aln1[discardcount1:]
                rawaln2 = iter(rawaln2) 
                discardcount2 = 0
                count2 = 0
                while count2 < coidx2:
                    discardcount2 += 1
                    char2 = next(rawaln2)
                    if char2 != '-':
                        count2 += 1
                aln2 = aln2[discardcount2:]                
                i = 0 
                j = 0
                print('char1', char1)
                print('char2', char2)
                
            else:
                print('end')
                break
        
        except StopIteration:
            break
    
    newaln1 = np.vstack(newaln1).T
    newaln2 = np.vstack(newaln2).T
    newaln = np.concatenate((newaln1, newaln2), axis = 0)
    #write out the new alignment
    with open(outfile, 'w') as f:
        for i in range(newaln.shape[0]):
            #print('>' + idlist[i] + '\n' + ''.join(list(newaln[i,:])) + '\n')
            f.write('>' + idlist[i] + '\n')
            f.write(''.join(list(newaln[i,:])) + '\n')
    remove_redundant( outfile )
    with open(outfile) as out:
        print(out.read())
        
    return outfile  



In [159]:

import toytree
import os
import pandas as pd
import glob


alndf = pd.read_table('../testdata/allvall_1.csv', header = None)

infolder = '../testdata/allvall_1.csv'.split('/')[:-1]
infolder = ''.join( [i + '/' for i in infolder])
mapper3di, mapperAA = read_dbfiles3di( "../testdata/outdb" , "../testdata/outdb_ss")
#add the 3di alignment to the dataframe
columns = 'query,target,fident,alnlen,mismatch,gapopen,qstart,qend,tstart,tend,evalue,bits,lddt,qaln,taln,cigar,lntmscore'.split(',')
alndf.columns = columns

alndf['3diq']= alndf['query'].map(mapper3di)
alndf['3dit']= alndf['target'].map(mapper3di)
alndf['AAq']= alndf['query'].map(mapperAA)
alndf['AAt']= alndf['target'].map(mapperAA)

#output a fasta with the 3di sequences
res = alndf.apply(calc_fident_crossaln , axis = 1)
alndf = pd.concat([alndf,res] , axis = 1)

with open('../testdata/3diseqs.fasta' , 'w') as out:
    for seq in alndf['query'].unique():
        out.write('>'+seq.replace('.pdb', '' )+'\n')
        out.write(mapper3di[seq]+'\n')


alndf['query'] = alndf['query'].map(lambda x :x.replace('.pdb', ''))
alndf['target'] = alndf['target'].map(lambda x :x.replace('.pdb', ''))



In [160]:
#prepare tree attributes
tre = toytree.tree('../testdata/foldtree_struct_tree.PP.nwk.rooted.final'  )
for i,n in enumerate(tre.treenode.traverse()):
    n.aln = None
    n.aln3di = None
    n.leafset = None
    if len(n.name) == 0:
        n.name = 'internal_'+str(i)

alnfolder = infolder+'alnscratch/'
if not os.path.exists(alnfolder):
    os.mkdir(infolder+'alnscratch/')

finalaln, finalaln3di = traverse_tree_merge( tre.treenode.get_tree_root(), get_leafset(tre.treenode.get_tree_root()) , alndf , infolder+'alnscratch/' , verbose = True ) 



print('finalaln',finalaln)
#print the final alignments

with open(finalaln) as out:
    print(out.read())

print('finalaln3di',finalaln3di)
with open(finalaln3di) as out:
    print(out.read())

print('nsequences' , len(tre.get_tip_labels()))


#print number of sequences in the final alignment   
with open(finalaln) as f:
    print( f.read().count('>'))


with open(finalaln3di) as f:
    print( f.read().count('>'))


traverse 59 False None
not cherry 59
children ['A0A3B3BIE0', '58', '57']
traverse A0A3B3BIE0 True None
traverse A0A3B3BIE0 True None
traverse 58 False None
traverse 58 False None
not cherry 58
children ['56', '55']
traverse 56 False None
traverse 56 False None
cherry A0A3Q2ZTT6 A0A1A8U1R0
traverse 55 False None
traverse 55 False None
not cherry 55
children ['A0A3Q2GCF1', '52']
traverse A0A3Q2GCF1 True None
traverse A0A3Q2GCF1 True None
traverse 52 False None
traverse 52 False None
not cherry 52
children ['A0A3P9PIM8', '47']
traverse A0A3P9PIM8 True None
traverse A0A3P9PIM8 True None
traverse 47 False None
traverse 47 False None
cherry A0A3B5Q007 A0A087X4E6
merge A0A3P9PIM8 47
first merge
identical
identical
2 merge
identical
identical
merge A0A3Q2GCF1 52
first merge
identical
identical
2 merge
identical
identical
merge 56 55
first merge
{'A0A3Q2ZTT6': 'MSLSIGDKIEDFKVLTLLGKGSFACVYRAKSVKTGVEVAIKMIDKKAMHKAGMVQRVANEVEIHCRLKHPSILELYNYFEDSNYVYLVLEMCHNGEMSRYLKERKVPFSEDEARHFMHQIIKGMLYLHTHGILHR