In [None]:
import numpy as np
from sklearn.metrics import f1_score, matthews_corrcoef, precision_recall_curve, auc

#### Download your AF3 predictions and convert them to PDB/parse them in MMCIF form, and extract contact maps from them
#### The evaluation pipeline expects pairs of GT and AF3 contact maps which have been indexed using indices calculated below


In [5]:
''' 

The longest_common_substring_indices function aligns sequences of the structures and the FASTA sequences for the most accurate comparison. (See https://www.biostars.org/p/9588718/#9588721 for more details)

'''

def longest_common_substring_indices(s1, s2):
    m, n = len(s1), len(s2)
    
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    length = 0 
    end_idx_s1 = 0  

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if s1[i - 1] == s2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
                if dp[i][j] > length:
                    length = dp[i][j]
                    end_idx_s1 = i - 1

    start_idx_s1 = end_idx_s1 - length + 1

    common_substring_s2 = s1[start_idx_s1:end_idx_s1 + 1]
    start_idx_s2 = s2.find(common_substring_s2)
    
    return (start_idx_s1, end_idx_s1), (start_idx_s2, start_idx_s2 + length - 1)




In [None]:
def calculate_net_metrics(pdbs, matrix_pairs):
    all_preds = []
    all_truths = []
    all_pdbs = []
    
    for pdb, _, _, _ in pdbs:
        all_pdbs.append(pdb)
        for M1, M2, pdb_id in matrix_pairs:
            if pdb == pdb_id:
                L_overlap = min(M1.shape[0], M2.shape[0])
                N_overlap = min(M1.shape[1], M2.shape[1])

                submatrix1 = M1[:L_overlap, :N_overlap].flatten()
                submatrix2 = M2[:L_overlap, :N_overlap].flatten()

                all_preds.extend(submatrix1)
                all_truths.extend(submatrix2)
                break

    all_preds = np.array(all_preds)
    all_truths = np.array(all_truths)

    net_f1 = f1_score(all_truths, all_preds, average='binary')
    net_mcc = matthews_corrcoef(all_truths, all_preds)
    precision, recall, _ = precision_recall_curve(all_truths, all_preds)
    net_pr_auc = auc(recall, precision)

    return net_f1, net_mcc, net_pr_auc, all_pdbs



#### RANDOM BASELINE

In [25]:
### Random Baseline
import numpy as np
from sklearn.metrics import f1_score, matthews_corrcoef, precision_recall_curve, auc

import sys
sys.path.append('..')
from utils.data import *

random_baseline = load_and_process_data(mode='dna', lower_threshold=10, na_upper_threshold=100, protein_upper_threshold=1000, dataset_dir='../data')

all_contacts = []
for elem in random_baseline:
    all_contacts.append(elem["complex_contact_map"])

sum_contacts = 0
total_elems = 0

for map in all_contacts:
    sum_contacts += np.sum(map)
    total_elems += map.shape[0]*map.shape[1]
    

train_data, eval_data = sequence_similarity_split(random_baseline, split_path='../mmseq2/train_test_clusters.pkl', mode='dna')


def run_random_baseline(Cmaps):
    total_positives = 0
    total_elems = 0

    #  Calculate p_success
    for cmap in Cmaps:
        total_positives += np.sum(cmap)
        total_elems += cmap.shape[0] * cmap.shape[1]

    p_success = total_positives / total_elems

    # Create a Bernoulli distribution with p_success
    bernoulli = np.random.binomial(1, p_success, size=total_elems)


    f1_scores = []
    mcc_scores = []
    pr_aucs = []

    for cmap in Cmaps:
        random_cmap = bernoulli[:cmap.size].reshape(cmap.shape)
        bernoulli = bernoulli[cmap.size:]  

        original_flat = cmap.flatten()
        random_flat = random_cmap.flatten()

        f1 = f1_score(original_flat, random_flat)
        f1_scores.append(f1)
        mcc = matthews_corrcoef(original_flat, random_flat)
        mcc_scores.append(mcc)
        precision, recall, _ = precision_recall_curve(original_flat, random_flat)
        pr_auc = auc(recall, precision)
        pr_aucs.append(pr_auc)


    return {
        'F1': np.mean(f1_scores),
        'MCC': np.mean(mcc_scores),
        'PR-AUC': np.mean(pr_aucs)
    }