In [6]:
import pandas as pd
import numpy as np
import os

In [7]:
cols = ['index', 'seqid', 'MSA_i', 'MSA_j', 'seq_i', 'seq_j', 'score']

In [8]:
def recall_at_k(contact_matrix, df_dca, k, seq_separation_cutoff = 4.0):
    paired_probs = {}
    L = contact_matrix.shape[0]
    for i in range(L):
        for j in range(i):
            if abs(i - j) > seq_separation_cutoff:
                paired_probs[(i,j)] = contact_matrix[i,j]

    # sorted key,val pairs (decending by prob)
    # print(sorted(paired_probs.items(), key = lambda item: -item[1]))

    # top_k_pairs ESM predictions
    top_k_pairs = [k for k,v in sorted(paired_probs.items(), key = lambda item: -item[1])][:k]

    # top L dca pairs that exist in sequence
    dca_pairs = [(i,j) for i,j in zip(df_dca.seq_i, df_dca.seq_j)]

    TP = 0
    for esm_pair in top_k_pairs:
        if esm_pair in dca_pairs:
            TP += 1
    
    recall = TP / len(dca_pairs)
    return recall


def compute_average_recall(esm_matrices, df_map):

    msa_rows = df_map['index'].unique()
    total = 0

    for i in msa_rows:
        df = df_map[df_map['index'] == i]
        contact_matrix = esm_matrices[f'arr_{i}']

        L = contact_matrix.shape[0]
        # recall at L 
        recall = recall_at_k(contact_matrix, df, k = L)
        total += recall

    return total / len(msa_rows)

In [9]:
scratch_dir = os.path.join('/fs/nexus-scratch/vla')
scratch_dir

'/fs/nexus-scratch/vla'

## Cadherin

In [10]:
esm_contacthead_results = np.load(os.path.join(scratch_dir,'cadherin_contacthead.npz'))
esmfold_results = np.load(os.path.join(scratch_dir, 'cadherin_esmfold.npz'))

In [11]:
df_plmdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/cadherin/PF00028_all_plmdca_mapped.csv', header = None)
df_plmdca_by_seq.columns = cols
df_plmdca_by_seq[:1]

Unnamed: 0,index,seqid,MSA_i,MSA_j,seq_i,seq_j,score
0,0,CADH1_HUMAN/267-366,283,276,69,67,1.954369


In [12]:
# plmdca - esm+contacthead
compute_average_recall(esm_contacthead_results, df_plmdca_by_seq)

0.22034388731658433

In [13]:
# plmdca - esmfold
compute_average_recall(esmfold_results, df_plmdca_by_seq)

0.16968297561300383

In [None]:
# df_mfdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/cadherin/PF00028_plmdca_mapped.csv', header = None)
# df_mfdca_by_seq.columns = cols
# df_mfdca_by_seq[:1]

## RRM

In [17]:
esm_contacthead_results = np.load(os.path.join(scratch_dir,'rrm_contacthead.npz'))
esmfold_results = np.load(os.path.join(scratch_dir, 'rrm_esmfold.npz'))

In [18]:
df_plmdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/rrm/rrm_all_plmdca_mapped.csv', header = None)
df_plmdca_by_seq.columns = cols
df_plmdca_by_seq[:1]

Unnamed: 0,index,seqid,MSA_i,MSA_j,seq_i,seq_j,score
0,0,ELAV3_MOUSE/41-111,390,368,51,48,0.937223


In [19]:
# plmdca - esm+contacthead
compute_average_recall(esm_contacthead_results, df_plmdca_by_seq)

0.14701365413803674

In [20]:
# plmdca - esmfold
compute_average_recall(esmfold_results, df_plmdca_by_seq)

0.1220521120145991

## PF00011 

In [21]:
esm_contacthead_results = np.load(os.path.join(scratch_dir, 'pf00011_contacthead.npz'))
esmfold_results = np.load(os.path.join(scratch_dir, 'pf00011_esmfold.npz'))

In [22]:
df_plmdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/pf00011/PF00011_all_plmdca_mapped.csv', header = None)
df_plmdca_by_seq.columns = cols
df_plmdca_by_seq[:1]

Unnamed: 0,index,seqid,MSA_i,MSA_j,seq_i,seq_j,score
0,0,CRYAA_BOVIN/63-162,281,135,51,35,0.996449


In [23]:
compute_average_recall(esm_contacthead_results, df_plmdca_by_seq)

0.13304799501892345

In [24]:
compute_average_recall(esmfold_results, df_plmdca_by_seq)

0.09815407899759249

## PF00043

In [25]:
esm_contacthead_results = np.load(os.path.join(scratch_dir, 'pf00043_contacthead.npz'))
esmfold_results = np.load(os.path.join(scratch_dir, 'pf00043_esmfold.npz'))

In [26]:
df_plmdca_by_seq = pd.read_csv('/nfshomes/vla/cmsc702-protein-lm/results/pf00043/PF00043_all_plmdca_mapped.csv', header = None)
df_plmdca_by_seq.columns = cols
df_plmdca_by_seq[:1]

Unnamed: 0,index,seqid,MSA_i,MSA_j,seq_i,seq_j,score
0,0,GSTM2_RAT/104-192,280,276,83,79,1.891365


In [27]:
compute_average_recall(esm_contacthead_results, df_plmdca_by_seq)

0.17833629277591623

In [28]:
compute_average_recall(esmfold_results, df_plmdca_by_seq)

0.15238377590988322