In [7]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
from scipy.stats import pearsonr, spearmanr

In [12]:
def read_dist_ids(pdb):
    
    path = pdb + '/reps_MSATransformer/'
    matrix_dir = os.listdir(path)
    
    matrices = []
    ids_merg = []
    for i,m in enumerate(matrix_dir):
        matrix = np.load(path + m + '/dist.npy') 
        ids = []
        for j,name in enumerate(os.listdir(path + m + '/reps/')):
            if j == 0:
                ids.append(np.load(path + m + '/ids/'+ name))
            else:
                ids.append(np.load(path + m + '/ids/'+ name)[1:])
        matrices.append(matrix)
        ids_merg.append([id_m for id_m in ids])

    return matrices, ids_merg

In [19]:
def order_matrices(matrices, ids_merg):
    
    matrix_ref = matrices[0]
    ids_ref = [val for idx in ids_merg[0] for val in idx]
    
    ordered = []
    triu_ids = np.triu_indices(matrix_ref.shape[0], k=1) 
    ordered.append(matrix_ref[triu_ids].flatten())
    
    for i in range(1,len(matrices)):
        
        ids_no_order = [val for idx in ids_merg[i] for val in idx]
        indices = [ids_no_order.index(item) for item in ids_ref]
        sel_matrix = matrices[i]
        ordered_matrix = sel_matrix[indices,:][:,indices]
        
        triu_ids = np.triu_indices(ordered_matrix.shape[0], k=1) 
        ordered.append(ordered_matrix[triu_ids].flatten())

    return ordered

In [29]:
def pearsonr_pairs(ordered):
    
    rows = len(ordered)-30
    corrs = []
    pvals = []
    for i in range(rows-1):
        for j in range(i+1, rows):
            
            vec1 = ordered[i]
            vec2 = ordered[j]
           # corr, pval = pearsonr(vec1, vec2)
            corr, pval = spearmanr(vec1, vec2)
            corrs.append(corr)
            pvals.append(pval)
    return corrs, pvals

In [31]:
pdbs = ['Mad2', 'KB', 'RfaH']
for pdb in pdbs:
    matrices, ids_merg = read_dist_ids(pdb)
    ordered = order_matrices(matrices, ids_merg)
    corrs, pvals = pearsonr_pairs(ordered)
    print(pdb, np.mean(corrs), np.std(corrs), np.mean(pvals), np.std(pvals))

Mad2 0.8045739337939242 0.013849470544411931 0.0 0.0
KB 0.8603276936705659 0.004843570315544147 0.0 0.0
