In [48]:
import pandas as pd
import numpy as np
import os

In [3]:
cols = ['index', 'seqid', 'MSA_i', 'MSA_j', 'seq_i', 'seq_j', 'score']

In [23]:
def recall_at_k(contact_matrix, df_dca, k, seq_separation_cutoff = 4.0):
    paired_probs = {}
    L = contact_matrix.shape[0]
    for i in range(L):
        for j in range(i):
            if abs(i - j) > seq_separation_cutoff:
                paired_probs[(i,j)] = contact_matrix[i,j]

    # sorted key,val pairs (decending by prob)
    # print(sorted(paired_probs.items(), key = lambda item: -item[1]))

    # top_k_pairs ESM predictions
    top_k_pairs = [k for k,v in sorted(paired_probs.items(), key = lambda item: -item[1])][:k]

    # top L dca pairs that exist in sequence
    dca_pairs = [(i,j) for i,j in zip(df_dca.seq_i, df_dca.seq_j)]

    TP = 0
    for esm_pair in top_k_pairs:
        if esm_pair in dca_pairs:
            TP += 1
    
    recall = TP / len(dca_pairs)
    return recall


def compute_average_recall(esm_matrices, df_map):

    msa_rows = df_map['index'].unique()
    total = 0

    for i in msa_rows:
        df = df_map[df_map['index'] == i]
        contact_matrix = esm_matrices[f'arr_{i}']

        L = contact_matrix.shape[0]
        # recall at L 
        recall = recall_at_k(contact_matrix, df, k = L)
        total += recall

    return total / len(msa_rows)

In [49]:
scratch_dir = os.path.join('/fs/nexus-scratch/vla')
scratch_dir

'/fs/nexus-scratch/vla'

## Cadherin

In [50]:
esm_contacthead_results = np.load(os.path.join(scratch_dir,'cadherin_contacthead.npz'))
esmfold_results = np.load(os.path.join(scratch_dir, 'cadherin_esmfold.npz'))

In [53]:
df_plmdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/cadherin/PF00028_plmdca_mapped.csv', header = None)
df_plmdca_by_seq.columns = cols
df_plmdca_by_seq[:1]

Unnamed: 0,index,seqid,MSA_i,MSA_j,seq_i,seq_j,score
0,0,CADH1_HUMAN/267-366,232,198,59,57,0.686212


In [54]:
# plmdca - esm+contacthead
compute_average_recall(esm_contacthead_results, df_plmdca_by_seq)

0.0098

In [55]:
# plmdca - esmfold
compute_average_recall(esmfold_results, df_plmdca_by_seq)

0.01

In [None]:
df_mfdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/cadherin/PF00028_plmdca_mapped.csv', header = None)
df_mfdca_by_seq.columns = cols
df_mfdca_by_seq[:1]

## RRM

In [56]:
esm_contacthead_results = np.load(os.path.join(scratch_dir,'rrm_contacthead.npz'))
esmfold_results = np.load(os.path.join(scratch_dir, 'rrm_esmfold.npz'))

In [29]:
# df_plmdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/rrm/plmdca_rrm_output.csv', header = None)
# df_plmdca_by_seq
# # df_plmdca_by_seq.columns = cols

In [None]:
# plmdca - esm+contacthead
esm_contacthead_results = np.load('/fs/nexus-scratch/vla/rrm_contacthead.npz')
compute_average_recall(esm_contacthead_results)


## PF00011 

In [57]:
esm_contacthead_results = np.load(os.path.join(scratch_dir, 'pf00011_contacthead.npz'))
esmfold_results = np.load(os.path.join(scratch_dir, 'pf00011_esmfold.npz'))

In [33]:
df_plmdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/pf00011/pf00011_plmdca_mapped.csv', header = None)
df_plmdca_by_seq.columns = cols

In [34]:
df_plmdca_by_seq[:1]

Unnamed: 0,index,seqid,MSA_i,MSA_j,seq_i,seq_j,score
0,0,CRYAA_BOVIN/63-162,402,3,81,2,0.636316


In [None]:
compute_average_recall(esm_contacthead_results, df_plmdca_by_seq)

0.00895650471343101

In [35]:
compute_average_recall(esmfold_results, df_plmdca_by_seq)

0.00633361484979726

## PF00043

In [58]:
esm_contacthead_results = np.load(os.path.join(scratch_dir, 'pf00043_contacthead.npz'))
esmfold_results = np.load(os.path.join(scratch_dir, 'pf00043_esmfold.npz'))

In [45]:
df_plmdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/pf00043/pf00043_plmdca_mapped.csv', header = None)
df_plmdca_by_seq.columns = cols
df_plmdca_by_seq[:1]

Unnamed: 0,index,seqid,MSA_i,MSA_j,seq_i,seq_j,score
0,0,GSTM2_RAT/104-192,282,278,85,81,0.734966


In [46]:
compute_average_recall(esm_contacthead_results, df_plmdca_by_seq)

0.08199790457157424

In [47]:
compute_average_recall(esmfold_results, df_plmdca_by_seq)

0.086425188846375