In [1]:
from Bio import AlignIO
from Bio.Align import MultipleSeqAlignment
from Bio.Seq import Seq
from collections import Counter
import os
import pandas as pd
import numpy as np
from itertools import product
from Bio.PDB import PDBParser, NeighborSearch

In [2]:
def calculate_gapped_columns(alignment, gap_threshold):

    num_sequences = len(alignment)
    gapped_columns = set()
    
    for col in range(alignment.get_alignment_length()):
        gaps = Counter(alignment[:, col])['-']  
        gap_fraction = gaps / num_sequences
        if gap_fraction > gap_threshold:
            gapped_columns.add(col)
    
    return gapped_columns

In [3]:
def read_dca_scores(dca_path, alignment, gapped_columns):
    
    dca_scores = []  
    with open(dca_path) as f:
        for line in f:
            if line.startswith("#"):
                continue  
            i, j, score = map(float, line.split())
            i, j = int(i) - 1, int(j) - 1  

            # Exclude gapped columns
            if i not in gapped_columns and j not in gapped_columns:
                dca_scores.append((i, j, score))
    
    dca_scores.sort(key=lambda x: x[2], reverse=True)
    
    return dca_scores[:2*alignment.get_alignment_length()+1]

In [4]:
def pairs_frequencies(i,j, msa):
 
    pairs = [(seq[i], seq[j]) for seq in msa if seq[i] not in '-.' and seq[j] not in '-.']
    total = len(pairs)
    count_dict = Counter(pairs) 
    freq_dict = {key: value / total for key, value in count_dict.items()}
    
    return freq_dict 

In [5]:
def compare_freqs(i,j, stabilize, reference):

    freqs_stab = pairs_frequencies(i,j, stabilize)
    freqs_ref = pairs_frequencies(i,j, reference)
    
    AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWY'
    max_diff = -100
    max_diff_pair = None
    for aa1, aa2 in product(AMINO_ACIDS, repeat=2):
        freq_stab_pair = freqs_stab.get((aa1, aa2), 0)
        freq_ref_pair = freqs_ref.get((aa1, aa2), 0)
        diff = freq_stab_pair - freq_ref_pair
        if diff > max_diff:
            max_diff = diff
            max_diff_pair = (aa1, aa2)
            stab_freq = freq_stab_pair
            ref_freq = freq_ref_pair
    
    return max_diff_pair, max_diff, stab_freq, ref_freq

In [6]:
from Bio.PDB import PDBParser
import numpy as np

def find_contacts_single_chain(pdb_file, threshold=12.0):

    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("protein", pdb_file)
    chain = structure[0].get_list()[0]

    ca_atoms = [(residue["CA"], residue.get_id()[1]) for residue in chain if "CA" in residue]

    contacts = []
    for i in range(len(ca_atoms)):
        for j in range(i + 1, len(ca_atoms)):  
            atom_A, index_A = ca_atoms[i]
            atom_B, index_B = ca_atoms[j]

            distance = np.linalg.norm(atom_A.coord - atom_B.coord)
            if distance < threshold:
                contacts.append((index_A-1, index_B-1))
   #         if index_A == 116 and index_B == 126:
   #             print(distance)

    return contacts

In [7]:
def main(stabilize, reference, dca_stab_path, dca_ref_path, pdb_stab_file, pdb_ref_file, diff_thresh, sing_thresh):

    query = Seq(str(reference[0].seq))

    gap_threshold = 0.6
    gapped_ref = calculate_gapped_columns(reference, gap_threshold)
    gapped_stab = calculate_gapped_columns(stabilize, gap_threshold)

    dca_ref = read_dca_scores(dca_ref_path, reference, gapped_ref)
    dca_stab = read_dca_scores(dca_stab_path, stabilize, gapped_stab)

    results = []
    for i,j, score in dca_stab:
        if (i,j) not in {(k, l) for k, l, _ in dca_ref}:
            if abs(i - j) > 4:
                max_diff_pair, max_diff, stab_freq, ref_freq = compare_freqs(i,j,stabilize, reference) 
                if max_diff_pair[0] != query[i] and max_diff_pair[1] != query[j]:
                    stab_indices = find_contacts_single_chain(pdb_stab_file)
                    ref_indices = find_contacts_single_chain(pdb_ref_file)
               
                    if (i,j) in stab_indices and (i,j) not in ref_indices:
                 #       print(max_diff)
                  #      if max_diff > diff_thresh and ref_freq < 0.1 and stab_freq > sing_thresh:
                        results.append([i+1, j+1, max_diff_pair, query[i], query[j], max_diff, stab_freq])
    
    df = pd.DataFrame(results, columns= ['i','j','mutation','query0','query1','diff','freq'])
    return df       

In [8]:
def read_files(alternative, prot):
    
    if alternative == True:
        stabilize = AlignIO.read('concat_'+ prot + '/max_alternative', "fasta")
        dca_stab_path =  prot + '_alternative/' + os.listdir(prot + '_alternative')[0]
        pdb_stab_file = 'concat_'+ prot + '/pdb_alternative'

        reference = AlignIO.read('concat_'+ prot + '/full.fasta', "fasta")
        dca_ref_path = prot + '_full/' + os.listdir(prot + '_full')[0]
        pdb_ref_file = 'concat_'+ prot + '/pdb_full'
        diff_thresh = 0.2
        sing_thresh = 0.2
    
    else:
    #    stabilize = AlignIO.read('concat_'+ prot + '/full.fasta', "fasta")
    #    dca_stab_path = prot + '_full/' + os.listdir(prot + '_full')[0]
    #    pdb_stab_file = 'concat_'+ prot + '/pdb_full'
        
        reference =  AlignIO.read('concat_'+ prot + '/max_alternative', "fasta")
        dca_ref_path = prot + '_alternative/' + os.listdir(prot + '_alternative')[0]
        pdb_ref_file = 'concat_'+ prot + '/pdb_alternative'
        
        
        stabilize = AlignIO.read('concat_'+ prot + '/full.fasta', "fasta")
        dca_stab_path = prot + '_full/' + os.listdir(prot + '_full')[0]
        pdb_stab_file = 'concat_'+ prot + '/pdb_full'
        diff_thresh = 0.1
        sing_thresh = 0.1


    return stabilize, reference, dca_stab_path, dca_ref_path, pdb_stab_file, pdb_ref_file, diff_thresh, sing_thresh

In [10]:
prot = 'KB91'
alternative = False

stabilize, reference, dca_stab_path, dca_ref_path, pdb_stab_file, pdb_ref_file, diff_thresh, sing_thresh = read_files(alternative, prot)
data = main(stabilize, reference, dca_stab_path, dca_ref_path, pdb_stab_file, pdb_ref_file, diff_thresh, sing_thresh)
sort = data.sort_values(by="diff", ascending=False)

In [11]:
data

Unnamed: 0,i,j,mutation,query0,query1,diff,freq
0,27,82,"(L, L)",I,V,0.06257,0.065904
1,7,64,"(R, L)",K,V,0.094965,0.098298
2,23,75,"(N, T)",T,D,0.052723,0.056045


In [122]:
prot = 'KB91'
alternative = False

stabilize, reference, dca_stab_path, dca_ref_path, pdb_stab_file, pdb_ref_file, diff_thresh, sing_thresh = read_files(alternative, prot)
data = main(stabilize, reference, dca_stab_path, dca_ref_path, pdb_stab_file, pdb_ref_file, diff_thresh, sing_thresh)
sort = data.sort_values(by="freq", ascending=False)

In [123]:
data

Unnamed: 0,i,j,mutation,query0,query1,diff,freq
0,27,82,"(L, L)",I,V,0.06257,0.065904
1,7,64,"(R, L)",K,V,0.094965,0.098298
2,23,75,"(N, T)",T,D,0.052723,0.056045
