In [1]:
import numpy as np
from helpers.helper import get_cath
from scipy.stats import ttest_ind
import matplotlib.pyplot as plt
cath = get_cath()


In [2]:
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
from Bio.Seq import Seq
from Bio import SeqIO

In [3]:
def get_sword2(code, chain, version, verb=False):
    file = f"../data/sword2/final_results/results/{version}/{code}/{code}_{chain}/sword.txt"
    with open(file, "r") as f:
        data = {}
        lines = f.readlines()
        option = 0
        for i, line in enumerate(lines):
            lines[i] = "".join([c for c in line if c not in ["\n",'']])
            if line != "\n":
                if not line.startswith(("PDB:", "#D", "A")):
                    res = lines[i].split("|")
                    boundaries = res[2]
                    domains = boundaries.strip().split(" ")
                    data[f"option{option}"] = {}
                    for j in range(len(domains)):
                        data[f"option{option}"][str(j+1)] = domains[j]
                    option += 1
    return data


def get_af_chain(code, chain):
    file_path = f"../data/sword2/SWORD2/misc/new_iid/af_pdbs/AF-{code}-F1-model_v3.pdb"
    chains = {record.id: record.seq for record in SeqIO.parse(file_path, 'pdb-seqres')}
    seq = chains.get(f'XXXX:{chain}')
    return seq


def get_pdb_chain(code, chain):
    pdb_file_path = f"../data/pdb/new_iid/{code}.pdb"
    pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file_path, 'pdb-seqres')}

    for key in pdb_chains.keys():
        if key[-1] == chain:
            a_chain_pdb_seq = pdb_chains[key]
            return a_chain_pdb_seq
                
        
# def evaluate(pairs, metric_of_interest):
#     """
#     Return the mean metric after the evaluation
#     """
#     margin = 20
#     p = []
#     a = []
#     for (pdb, af) in pairs:
#         metric_index = metric_indices_dict[metric_of_interest]
#         chain_len = len(get_pdb_chain(pdb[:4], pdb[-1]))
#         baseline = cath[pdb[:4]][pdb[-1]]
#         baseline_boundaries = boundaries(chain_len, baseline, ',').astype(int)
#         if sum(baseline_boundaries) > 0:
#             pdb_sword = get_sword2(pdb[:-2], pdb[-1], 'pdb')
#             af_sword = get_sword2(af[:-2], af[-1], 'af')
#             pdb_metrics = get_best_sword_option(baseline_boundaries, pdb_sword, chain_len, margin, metric_index)
#             af_metrics = get_best_sword_option(baseline_boundaries, af_sword, chain_len, margin, metric_index)
#             a.append(af_metrics[metric_index])
#             p.append(pdb_metrics[metric_index])
        
#     return (np.mean(p), np.mean(a))

In [4]:
with open('../data/sword2/final_results/valid_pairs.txt', 'r') as f:
    valid_pairs = []
    for line in f.readlines():
        line = line.split(',')
        a = line[0].strip()
        b = line[1].strip()
        valid_pairs.append((a,b))

In [5]:
len(valid_pairs)

503

In [6]:
non_disc_pairs = set()
disc_pairs = set()

for (pdb, af) in valid_pairs:
    flag = False
    baseline = cath[pdb[:4]][pdb[-1]]
    for num, domain_range in baseline.items():
        if ',' in domain_range:
            disc_pairs.add((pdb, af))
            flag = True
    
    if not flag:
        non_disc_pairs.add((pdb, af))

In [7]:
len(non_disc_pairs) + len(disc_pairs)

503

In [8]:
def assign_domains(domains_dict, seq_len, deli):
    """
        Assign domains from a domains dictionary. Delimiter is a comma (',')
    """
    baseline_domains = [0] * seq_len
    for num, domain_range in domains_dict.items():
        doms = domain_range.split(deli)
        for elt in doms:
            elt = elt.split('-')
            start = int(elt[0])
            end = int(elt[1])
            for i in range(start-1, end):
                baseline_domains[i] = max(1,int(num))
    return baseline_domains


def boundaries(domain, len_seq, discontinuity_delimiter=','):
	"""
		Defines a boundary as the beginning of a domain ONLY in multi-domain proteins
	"""
	first_start = np.inf
	bounds = np.zeros((len_seq), dtype=np.int8)
	for k, v in domain.items():
		boundary_positions = v.split(discontinuity_delimiter)
		for b in boundary_positions:
			start, end = [int(i) for i in b.split('-')]
			if start < first_start:
				first_start = start
			bounds[start-1] = 1
	bounds[first_start-1] = 0            
	return np.array(bounds, dtype=np.bool_)

In [9]:
# l = [0,1,2,3,4,5,6,7,8,9,10]
# i,m = 5,3
# l[i-m:i+m+1]

In [10]:
x = np.array([True, False, True])
np.argwhere(x == 1).flatten()

array([0, 2])

In [11]:
def confusion_matrix_mcc(y_pred, y_true, margin):
    tp = 0
    tn = 0
    fp = 0
    fn = 0

    y_true_cp = np.copy(y_true)
    for i in range(len(y_pred)):
        window = y_true_cp[max(0, i-margin):min(len(y_true_cp), i+margin+1)]
        indices_window = list(range(max(0, i-margin), min(len(y_true_cp), i+margin+1)))
        if y_pred[i] == 1.0:
            if 1.0 in window:
                pos = np.where(window == 1.0)[0][0]
                j = indices_window[pos]
                y_true_cp[j] = 0.0
                tp += 1
            else:
                fp += 1

    for i in range(len(y_pred)):
        if y_pred[i] == 0.0:
            if  y_true_cp[i] == 1.0:
                fn += 1
            else:
                tn += 1

    return (tp, tn, fp, fn)

In [24]:
def ndo(y_pred, y_true):
    table = np.zeros((max(y_pred)+2,max(y_true)+2))
    
    for d_pred, d_true in zip(y_pred, y_true):
        table[d_pred, d_true] += 1

    # columns
    for i in range(1, table.shape[1]-1):
        table[-1, i] = 2 * max(table[:, i]) - sum(table[:, i])

    # rows
    for i in range(1, table.shape[0]-1):
        table[i, -1] = 2 * max(table[i, :]) - sum(table[i, :])

    table[-1, -1] = (sum(table[:, -1]) + sum(table[-1, :])) / 2


    num_of_defined_domains = 0
    for elt in y_true:
        if elt > 0:
            num_of_defined_domains += 1
    score = table[-1, -1] / num_of_defined_domains

    return score


def dbd(y_pred, y_true, margin=20):
    """We do NOT consider false positives. The rational is that DBD is a metric of
        how closes the true boundaries are from the prediction. Other metrics can be used
        to indicate if false positives were found or not."""
    scores = []
    number_of_true_boundaries = np.sum(y_true)
    number_of_pred_boundaries = np.sum(y_pred)
    denominator = max(number_of_true_boundaries,number_of_pred_boundaries)

    if denominator == 0:
        return 1.0

    y_true_cp = np.copy(y_true)
    for i in range(len(y_pred)):
        window = y_true_cp[max(0, i-margin):min(len(y_true_cp), i+margin+1)]
        indices_window = list(range(max(0, i-margin), min(len(y_true_cp), i+margin+1)))
        if y_pred[i] == 1.0:
            if 1.0 in window:
                # if it's within the window, calculate the score
                pos = np.where(window == 1.0)[0][0]
                j = indices_window[pos]
                diff = abs(i - j)
                # some ambiguity exists regarding giving a point if the boundary
                # is EXACTLY at the position of the true boundary. We assume
                # we do and divide by (margin + 1)
                score = ((margin - diff) + 1) / (margin + 1)
                y_true_cp[j] = 0
                scores.append(score)
    return np.sum(scores) / denominator

def dbd_var(y_pred, y_true, margin=20):
    """A variation of DBD above that instead of finding a 1 in the window
        we find the closest one. We expect this would increase the mean dbd"""
    scores = []
    number_of_true_boundaries = np.sum(y_true)
    number_of_pred_boundaries = np.sum(y_pred)
    denominator = max(number_of_true_boundaries,number_of_pred_boundaries)

    if denominator == 0:
        return 'n/a'

    y_true_cp = np.copy(y_true)
    for i in range(len(y_pred)):
        window = y_true_cp[max(0, i-margin):min(len(y_true_cp), i+margin+1)]
        indices_window = list(range(max(0, i-margin), min(len(y_true_cp), i+margin+1)))
        if y_pred[i] == 1.0:
            if 1.0 in window:
                positions_with_boundaries = np.argwhere(window == 1).flatten()
                js = [indices_window[pos] for pos in positions_with_boundaries]
                closest_j = js[0]
                for j in js:
                    if abs(j - i) < abs(closest_j - i):
                        closest_j = j
                diff = abs(i - closest_j)
                score = ((margin - diff) + 1) / (margin + 1)
                y_true_cp[closest_j] = 0
                scores.append(score)
    return np.sum(scores) / denominator


def mcc(y_pred, y_true, margin=20):
    tp, tn, fp, fn = confusion_matrix_mcc(y_pred, y_true, margin)
    mcc_num = (tp * tn) - (fp * fn)
    mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    mcc = mcc_num / mcc_den if mcc_den else 0
    return mcc

In [32]:
def best_option_for_ndo(baseline_domain, sword_results_dict, seq_len):
    ndos = []
    for option, domain in sword_results_dict.items():
        assigned_domains = assign_domains(domain, seq_len, ';')
        ndo_score = ndo(assigned_domains, baseline_domain)
        if ndo_score > 1.5:
            print(domain)
        ndos.append(ndo_score)
    return max(ndos)


def best_option_for_dbd(baseline_boundaries, sword_results_dict, seq_len, margin=8):
    dbds = []
    for option, domain in sword_results_dict.items():
        bounds = boundaries(domain, seq_len, ';')
        dbd_score = dbd_var(bounds, baseline_boundaries, margin)
        if dbd_score != 'n/a':
            dbds.append(dbd_score)

    if len(dbds) == 0:
        return 'n/a'
    return max(dbds)

def best_option_for_mcc(baseline_boundaries, sword_results_dict, seq_len, margin=20):
    mccs = []
    for option, domain in sword_results_dict.items():
        bounds = boundaries(domain, seq_len, ';')
        mcc_score = mcc(bounds, baseline_boundaries, margin)
        mccs.append(mcc_score)
    return max(mccs)

# Boundary prediction

## NDO

In [14]:
pdb_ndos = []
af_ndos = []
for (pdb, af) in non_disc_pairs:
    chain_len = len(get_pdb_chain(pdb[:4], pdb[-1]))
    baseline = cath[pdb[:4]][pdb[-1]]
    baseline_domains = assign_domains(baseline, chain_len, ',')
    pdb_sword = get_sword2(pdb[:-2], pdb[-1], 'pdb')
    af_sword = get_sword2(af[:-2], af[-1], 'af')
    pdb_ndo = best_option_for_ndo(baseline_domains,pdb_sword, chain_len)
    af_ndo = best_option_for_ndo(baseline_domains,af_sword, chain_len)
    pdb_ndos.append(pdb_ndo)
    af_ndos.append(af_ndo)

print("Mean NDO for AlphaFold/SWORD:", np.mean(af_ndos))
print("Mean NDO for PDB/SWORD:", np.mean(pdb_ndos))

{'1': '1-131', '2': '132-244', '3': '245-330', '4': '331-413', '5': '414-583'}
{'1': '1-131', '2': '132-244', '3': '245-330', '4': '331-413', '5': '414-511', '6': '512-583'}
{'1': '1-131', '2': '132-244', '3': '245-330', '4': '331-511', '5': '512-583'}
{'1': '1-131', '2': '132-244', '3': '245-330', '4': '331-583'}
{'1': '1-131', '2': '132-244', '3': '245-413', '4': '414-583'}
{'1': '1-131', '2': '132-244', '3': '245-583'}
{'1': '1-352', '2': '353-650', '3': '651-747', '4': '748-1188'}
{'1': '1-35;353-650', '2': '36-352', '3': '651-747', '4': '748-1188'}
Mean NDO for AlphaFold/SWORD: 0.9171428995310758
Mean NDO for PDB/SWORD: 0.8268182539319743


## DBD / DBD Var

In [34]:
pdb_dbds = []
af_dbds = []
for (pdb, af) in valid_pairs:
    chain_len = len(get_pdb_chain(pdb[:4], pdb[-1]))
    baseline = cath[pdb[:4]][pdb[-1]]
    baseline_boundaries = boundaries(baseline, chain_len, ',')
    pdb_sword = get_sword2(pdb[:-2], pdb[-1], 'pdb')
    af_sword = get_sword2(af[:-2], af[-1], 'af')
    pdb_dbd = best_option_for_dbd(baseline_boundaries,pdb_sword, chain_len)
    af_dbd = best_option_for_dbd(baseline_boundaries,af_sword, chain_len)

    if pdb_dbd != 'n/a' and af_dbd != 'n/a':
        pdb_dbds.append(pdb_dbd)
        af_dbds.append(af_dbd)

print("Mean DBD for AlphaFold/SWORD:", np.mean(af_dbds))
print("Mean DBD for PDB/SWORD:", np.mean(pdb_dbds))

Mean DBD for AlphaFold/SWORD: 0.3173435471054519
Mean DBD for PDB/SWORD: 0.1688398195440449


In [25]:
pdb_dbds = []
af_dbds = []
for (pdb, af) in valid_pairs:
    chain_len = len(get_pdb_chain(pdb[:4], pdb[-1]))
    baseline = cath[pdb[:4]][pdb[-1]]
    baseline_boundaries = boundaries(baseline, chain_len, ',')
    pdb_sword = get_sword2(pdb[:-2], pdb[-1], 'pdb')
    af_sword = get_sword2(af[:-2], af[-1], 'af')
    pdb_dbd = best_option_for_dbd(baseline_boundaries,pdb_sword, chain_len)
    af_dbd = best_option_for_dbd(baseline_boundaries,af_sword, chain_len)
    pdb_dbds.append(pdb_dbd)
    af_dbds.append(af_dbd)

print("Mean DBD for AlphaFold/SWORD:", np.mean(af_dbds))
print("Mean DBD for PDB/SWORD:", np.mean(pdb_dbds))

TypeError: '>' not supported between instances of 'numpy.ndarray' and 'str'

## MCC

In [17]:
# 
pdb_mccs = []
af_mccs = []
for (pdb, af) in disc_pairs:
    chain_len = len(get_pdb_chain(pdb[:4], pdb[-1]))
    baseline = cath[pdb[:4]][pdb[-1]]
    baseline_boundaries = boundaries(baseline, chain_len, ',')
    pdb_sword = get_sword2(pdb[:-2], pdb[-1], 'pdb')
    af_sword = get_sword2(af[:-2], af[-1], 'af')
    pdb_mcc = best_option_for_mcc(baseline_boundaries,pdb_sword, chain_len)
    af_mcc = best_option_for_mcc(baseline_boundaries,af_sword, chain_len)
    pdb_mccs.append(pdb_mcc)
    af_mccs.append(af_mcc)

print("Mean MCC for AlphaFold/SWORD:", np.mean(af_mccs))
print("Mean MCC for PDB/SWORD:", np.mean(pdb_mccs))

Mean MCC for AlphaFold/SWORD: 0.8388908544978673
Mean MCC for PDB/SWORD: 0.6642438681646838


In [18]:
def get_type_of_domain(domain):
    if len(domain) > 1:
        return 'multi'
    
    if len(domain) == 1:
            return 'single'
        
    raise ValueError


def options_best_type(sword_results_dict, truth):
    dbds = []
    for option, domain in sword_results_dict.items():
        domain_type = get_type_of_domain(domain)
        if domain_type == truth:
             return domain_type
    if truth == 'single':
         return 'multi'
    
    if truth == 'multi':
         return 'single'
    
    raise ValueError

def domain_number_metrics(confusion_matrix):
    ts, tm, fs, fm = confusion_matrix
    # single
    pre_single = ts / (ts + fs) if (ts + fs) != 0 else 0
    rec_single = ts / (ts + fm) if (ts + fm) != 0 else 0
    
    # multi
    pre_multi = tm / (tm + fm) if (tm + fm) != 0 else 0
    rec_multi = tm / (tm + fs) if (tm + fs) != 0 else 0

    acc = (tm + ts) / (tm + ts + fm + fs)

    # mcc
    mcc_num = (tm * ts) - (fm * fs)
    mcc_den = ((tm + fm) * (tm + fs) * (fm + ts) * (ts + fs))**0.5
    mcc = mcc_num / mcc_den if mcc_den != 0 else 0

    return (pre_single, rec_single, pre_multi, rec_multi, acc, mcc)

# Domain number prediction

assumption: confusion matrix is generated from both multi and single domains

In [19]:
# CONFUSION MATRIX FOR NUMBER OF DOMAINS
af_ts, af_tm, af_fs, af_fm = 0, 0, 0, 0
pdb_ts, pdb_tm, pdb_fs, pdb_fm = 0, 0, 0, 0

for (pdb, af) in disc_pairs:
    chain_len = len(get_pdb_chain(pdb[:4], pdb[-1]))
    baseline = cath[pdb[:4]][pdb[-1]]
    baseline_type = get_type_of_domain(baseline)
    print(baseline_type)
    pdb_sword = get_sword2(pdb[:-2], pdb[-1], 'pdb')
    af_sword = get_sword2(af[:-2], af[-1], 'af')
    pdb_type = options_best_type(pdb_sword, baseline_type) # this is a list of options
    af_type = options_best_type(af_sword, baseline_type)
    
    if baseline_type == 'single' and pdb_type == 'single':
        pdb_ts += 1
        
    elif baseline_type == 'multi' and pdb_type == 'multi':
        pdb_tm += 1
        
    elif baseline_type == 'multi' and pdb_type == 'single':
        pdb_fs += 1
        
    elif baseline_type == 'single' and pdb_type == 'multi':
        pdb_fm += 1
        

    if baseline_type == 'single' and af_type == 'single':
        af_ts += 1
        
    elif baseline_type == 'multi' and af_type == 'multi':
        af_tm += 1
        
    elif baseline_type == 'multi' and af_type == 'single':
        af_fs += 1
        
    elif baseline_type == 'single' and af_type == 'multi':
        af_fm += 1

multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi
multi


In [20]:
af_pre_single, af_rec_single, af_pre_multi, af_rec_multi, af_acc, af_mcc = domain_number_metrics((af_ts, af_tm, af_fs, af_fm))
print("Pre(single):", af_pre_single)
print("Rec(single):", af_rec_single)
print("Pre(multi):", af_pre_multi)
print("Rec(multi):", af_rec_multi)
print("Acc:", af_acc)
print("MCC:", af_mcc)

Pre(single): 0
Rec(single): 0
Pre(multi): 1.0
Rec(multi): 1.0
Acc: 1.0
MCC: 0


In [21]:
pdb_pre_single, pdb_rec_single, pdb_pre_multi, pdb_rec_multi, pdb_acc, pdb_mcc = domain_number_metrics((pdb_ts, pdb_tm, pdb_fs, pdb_fm))
print("Pre(single):", pdb_pre_single)
print("Rec(single):", pdb_rec_single)
print("Pre(multi):", pdb_pre_multi)
print("Rec(multi):", pdb_rec_multi)
print("Acc:", pdb_acc)
print("MCC:", pdb_mcc)

Pre(single): 0
Rec(single): 0
Pre(multi): 1.0
Rec(multi): 1.0
Acc: 1.0
MCC: 0


In [23]:
disc_pairs

{('1al3:A', 'P45600:A'),
 ('1dce:A', 'Q08602:A'),
 ('1dgj:A', 'Q9REC4:A'),
 ('1dqu:A', 'P28298:A'),
 ('1fhu:A', 'P29208:A'),
 ('1fnn:A', 'Q8ZYK1:A'),
 ('1gm5:A', 'Q9WY48:A'),
 ('1gsh:A', 'P04425:A'),
 ('1h2h:A', 'Q9X1X6:A'),
 ('1h37:A', 'P33247:A'),
 ('1i5p:A', 'P0A377:A'),
 ('1in0:A', 'P44096:A'),
 ('1jey:A', 'P12956:A'),
 ('1k6i:A', 'Q5AU62:A'),
 ('1kea:A', 'P29588:A'),
 ('1keu:A', 'P26391:A'),
 ('1l5j:A', 'P36683:A'),
 ('1lfp:A', 'O67517:A'),
 ('1lns:A', 'P22346:A'),
 ('1m0u:A', 'P41043:A'),
 ('1nyq:A', 'Q8NW68:A'),
 ('1peq:A', 'Q08698:A'),
 ('1pj5:A', 'Q9AGP8:A'),
 ('1q7g:A', 'P31116:A'),
 ('1q9j:A', 'P9WIN5:A'),
 ('1rqg:A', 'Q9V011:A'),
 ('1s7j:A', 'Q839P3:A'),
 ('1t6j:A', 'P11544:A'),
 ('1tdj:A', 'P04968:A'),
 ('1tmo:A', 'O87948:A'),
 ('1u2s:A', 'P74325:A'),
 ('1u7l:A', 'P31412:A'),
 ('1w0o:A', 'P0C6E9:A'),
 ('1w6j:A', 'P48449:A'),
 ('1wqg:A', 'P9WGY1:A'),
 ('1wxq:A', 'O58261:A'),
 ('1wyv:A', 'Q5SKW8:A'),
 ('1wz2:A', 'O58698:A'),
 ('1x3l:A', 'O58231:A'),
 ('1xqp:A', 'Q8ZVK6:A'),
