In [1]:
import os
from pprint import pprint
import json
import random
import numpy as np
from helpers.helper import get_cath

from Bio import pairwise2
from Bio.pairwise2 import format_alignment
from Bio.Seq import Seq
from Bio import SeqIO

import requests
import shutil
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
from Bio.Seq import Seq
from Bio import SeqIO


from scipy.stats import ttest_ind
from pprint import pprint

cath = get_cath()
metric_indices_dict = {
    'acc' : 0,
    'pre' : 1,
    'rec' : 2,
    'f1' : 3,
    'mcc' : 4,
    'dbd' : 5
}


In [2]:
def get_sword2(code, chain, version, verb=False):
    file = f"../data/sword2/final_results/results/{version}/{code}/{code}_{chain}/sword.txt"
    with open(file, "r") as f:
        data = {}
        lines = f.readlines()
        option = 0
        for i, line in enumerate(lines):
            lines[i] = "".join([c for c in line if c not in ["\n",'']])
            if line != "\n":
                if not line.startswith(("PDB:", "#D", "A")):
                    res = lines[i].split("|")
                    boundaries = res[2]
                    domains = boundaries.strip().split(" ")
                    data[f"option{option}"] = {}
                    for j in range(len(domains)):
                        data[f"option{option}"][str(j+1)] = domains[j]
                    option += 1
    verb and pprint(data)
    return data


def sequence_sim(seq1, seq2, match_score = 1, mismatch_score = -1, gap_penalty = -2):
    alignments = pairwise2.align.globalxx(seq1, seq2)

    # Print the alignment(s)
    # for alignment in alignments:
        # print(format_alignment(*alignment))
    score = alignments[0].score
    norm_score = score / max(len(seq1), len(seq2))
    return norm_score



def dbd_score(y_pred, y_true, margin=20):
    scores = []
    for i in range(len(y_pred)):
        window = y_true[max(0, i-margin):min(len(y_true), i+margin+1)]
        indices_window = list(range(max(0, i-margin), min(len(y_true), i+margin+1)))
        if y_pred[i] == 1.0:
            if 1.0 in window:
                # if it's within the window, calculate the score
                pos = np.where(window == 1.0)[0][0]
                j = indices_window[pos]
                diff = abs(i - j)
                k = 0 if diff == 0 else 1
                score = ((margin - diff) + k) / margin
            else:
                # false positive
                score = 0
            scores.append(score)

    number_of_true_boundaries = np.sum(y_true)
    number_of_pred_boundaries = np.sum(y_pred)
    max_len = max(number_of_true_boundaries,number_of_pred_boundaries)
    if max_len == 0:
        return 1.0

    return np.sum(scores) / max_len


def observations(y_pred, y_true, margin):
    tp = 0
    tn = 0
    fp = 0
    fn = 0

    dbd = dbd_score(y_pred, y_true, margin)
    for i in range(len(y_pred)):
        window = y_true[max(0, i-margin):min(len(y_true), i+margin+1)]
        indices_window = list(range(max(0, i-margin), min(len(y_true), i+margin+1)))
        if y_pred[i] == 1.0:
            if 1.0 in window:
                pos = np.where(window == 1.0)[0][0]
                j = indices_window[pos]
                y_true[j] = 0.0
                tp += 1
            else:
                fp += 1

    for i in range(len(y_pred)):
        if y_pred[i] == 0.0:
            if  y_true[i] == 1.0:
                fn += 1
            else:
                tn += 1

    return (tp, tn, fp, fn)


def metrics(y_pred, y_true, margin=20):
    tp, tn, fp, fn = observations(y_pred, y_true, margin)

    accuracy = (tn + tp) / (tn + tp + fn + fp) if (tn + tp + fn + fp) else 0
    precision = tp / (tp + fp) if (tp + fp) else 0
    recall = tp / (tp + fn) if (tp + fn) else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) else 0

    mcc_num = (tp * tn) - (fp * fn)
    mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    mcc = mcc_num / mcc_den if mcc_den else 0

    dbd = dbd_score(y_pred, y_true, margin)

    return (accuracy, precision, recall, f1, mcc, dbd)


def boundaries(len_seq, domain, discontinuity_delimiter):
    """
        Defines a boundary as the beginning of a domain ONLY in multi-domain proteins
    """
    first_start = np.inf
    bounds = np.zeros((len_seq), dtype=np.int8)
    for k, v in domain.items():
        boundary_positions = v.split(discontinuity_delimiter)
        for b in boundary_positions:
            start, end = [int(i) for i in b.split('-')]
            if start < first_start:
                first_start = start
            bounds[start-1] = 1
    bounds[first_start-1] = 0            
    return np.array(bounds, dtype=np.bool_)


def get_af_chain(code, chain):
    file_path = f"../data/sword2/SWORD2/misc/new_iid/af_pdbs/AF-{code}-F1-model_v3.pdb"
    chains = {record.id: record.seq for record in SeqIO.parse(file_path, 'pdb-seqres')}
    seq = chains.get(f'XXXX:{chain}')
    return seq


def get_pdb_chain(code, chain):
    pdb_file_path = f"../data/pdb/new_iid/{code}.pdb"
    pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file_path, 'pdb-seqres')}

    for key in pdb_chains.keys():
        if key[-1] == chain:
            a_chain_pdb_seq = pdb_chains[key]
            return a_chain_pdb_seq
                
        
# def get_num_of_domains(code)


def get_best_sword_option(true_boundaries, sword_results_dict, seq_len, margin, metric_index):
    mccs = []
    results = []
    for option, domain in sword_results_dict.items():
        baseline_boundaries = boundaries(seq_len, domain, ';').astype(int)
        metrics_results = metrics(baseline_boundaries, true_boundaries, margin)
        mcc = metrics_results[metric_index]
        mccs.append(mcc)
        results.append(metrics_results)
    return results[np.argmax(mccs)]

In [3]:
with open('../data/sword2/SWORD2/misc/new_iid/pdb_to_uniprot_map.json') as json_file:
    pdb_to_uni_map = json.load(json_file)
    
pdb_to_af_map = {}
for elt in pdb_to_uni_map['results']:
    pdb = elt['from']
    databases = [db['database'] for db in elt['to']['uniProtKBCrossReferences']]
    for i, db in enumerate(databases):
        if db == 'AlphaFoldDB':
            af_id = (elt['to']['uniProtKBCrossReferences'][i]['id'])
            if pdb_to_af_map.get(pdb):
                pdb_to_af_map[pdb].append(af_id)
            else:
                pdb_to_af_map[pdb] = [af_id]
    

In [4]:
with open('../data/cath/iid/chains_to_seq_iid.json') as json_file:
    chain_to_seq_iid = json.load(json_file)

valid_pairs = []
for chain, seq in chain_to_seq_iid.items():
    if pdb_to_af_map.get(chain[:4]):
        pdb_filename = f"../data/pdb/new_iid/{chain[:4]}.pdb"
        chain_id = chain[-1]
        with open(pdb_filename, "r") as pdb_file:
            pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file, 'pdb-seqres')}
            for key in pdb_chains.keys():
                if key[-1] == chain[-1]:
                    pdb_seq = pdb_chains[key]
        af_seqs = []
        for uniprot in pdb_to_af_map.get(chain[:4]):
            af_seq = get_af_chain(uniprot, chain[-1])
            if af_seq:
                af_seqs.append((uniprot, af_seq))
        for (uniprot, af_seq) in af_seqs:
            if len(af_seq) == len(pdb_seq):
                sim = sequence_sim(af_seq, pdb_seq)
                if sim == 1.0:
                    valid_pairs.append((chain, uniprot + f':{chain[-1]}'))

In [5]:
valid_pairs[0]

('1a2k:A', 'P61972:A')

In [6]:
len(valid_pairs)

508

In [7]:
for (pdb, af) in valid_pairs:
    pdb_seq = get_pdb_chain(pdb[:-2], pdb[-1])
    af_seq = get_af_chain(af[:-2], af[-1])
    assert pdb_seq == af_seq
print("Pairs' order match")
    

Pairs' order match


In [8]:
path = '../data/sword2/final_results/results/'
for (pdb, af) in valid_pairs:
    try:
        pdb_sword = get_sword2(pdb[:-2], pdb[-1], 'pdb')
    except FileNotFoundError:
        print(f"PDB File {pdb} does not exist")
    
    try:
        af_sword = get_sword2(af[:-2], af[-1], 'af')
    except FileNotFoundError:
        print(f"PDB File {af} does not exist")



PDB File 2h47:A does not exist
PDB File P13479:A does not exist
PDB File 2w02:A does not exist
PDB File B5THI3:A does not exist
PDB File 4wl9:A does not exist


In [11]:
for i, (pdb, af) in enumerate(valid_pairs):
    if pdb == "2h47:A":
        print("2h47:A and", af)

    if pdb == "2w02:A":
        print("2w02:A  and", af)

    if pdb == "4wl9:A":
        print("4wl9:A  and", af)
        
    if af == "P13479:A":
        print("P13479:A and", pdb)

    if af == "B5THI3:A":
        print("B5THI3:A and", pdb)

2h47:A and P84888:A
P13479:A and 2vlp:A
2w02:A  and Q93AT8:A
B5THI3:A and 4rej:A
4wl9:A  and P16113:A


In [12]:
for i, (pdb, af) in enumerate(valid_pairs):
    if pdb in ["2h47:A", "2w02:A", "4wl9:A"]:
        del valid_pairs[i]
        
    if af in ["P13479:A", "B5THI3:A"]:
        del valid_pairs[i]
        
for i, (pdb, af) in enumerate(valid_pairs):
    assert pdb != "2h47:A" and af != "P84888:A"
    assert af != "P13479:A" and pdb != "2vlp:A"
    assert af != "2w02:A" and pdb != "Q93AT8:A"
    assert af != "B5THI3:A" and pdb != "4rej:A"
    assert af != "4wl9:A" and pdb != "P16113:A"

print("Go ahead and evaluate!")


Go ahead and evaluate!


In [15]:
len(valid_pairs)

503

In [13]:
def evaluate(metric_of_interest):
    """
    Return the mean metric after the evaluation
    """
    margin = 20
    p = []
    a = []
    for (pdb, af) in valid_pairs:
        metric_index = metric_indices_dict[metric_of_interest]
        chain_len = len(get_pdb_chain(pdb[:4], pdb[-1]))
        baseline = cath[pdb[:4]][pdb[-1]]
        baseline_boundaries = boundaries(chain_len, baseline, ',').astype(int)
        if sum(baseline_boundaries) > 0:
            pdb_sword = get_sword2(pdb[:-2], pdb[-1], 'pdb')
            af_sword = get_sword2(af[:-2], af[-1], 'af')
            pdb_metrics = get_best_sword_option(baseline_boundaries, pdb_sword, chain_len, margin, metric_index)
            af_metrics = get_best_sword_option(baseline_boundaries, af_sword, chain_len, margin, metric_index)
            a.append(af_metrics[metric_index])
            p.append(pdb_metrics[metric_index])
        
    return (np.mean(p), np.mean(a))
        

In [16]:
evaluate('mcc')

(0.5794446057895699, 0.19707455073244642)

In [95]:
evaluate('acc')

(0.996145081302511, 0.9965586647401254)