In [26]:
import os
from pprint import pprint
import json
import random
import numpy as np
from helpers.helper import get_cath

from Bio import pairwise2
from Bio.pairwise2 import format_alignment
from Bio.Seq import Seq
from Bio import SeqIO

import requests
from requests.packages.urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter
import shutil

from scipy.stats import ttest_ind


In [None]:
def get_sword2(code, version, verb=False):
    file = f"../data/sword2/SWORD2/results/{version}/{code}/{code}_A/sword.txt"
    with open(file, "r") as f:
        data = {}
        lines = f.readlines()
        option = 0
        for i, line in enumerate(lines):
            lines[i] = "".join([c for c in line if c not in ["\n",'']])
            if line != "\n":
                if not line.startswith(("PDB:", "#D", "A")):
                    res = lines[i].split("|")
                    boundaries = res[2]
                    domains = boundaries.strip().split(" ")
                    data[f"option{option}"] = {}
                    for j in range(len(domains)):
                        data[f"option{option}"][str(j+1)] = domains[j]
                    option += 1
    verb and pprint(data)
    return data

In [None]:
cath = get_cath()

In [None]:
with open('../data/sword2/SWORD2/misc/mappings_compact.json') as json_file:
    pdb_uniprot_mappings = json.load(json_file)

In [None]:
len(pdb_uniprot_mappings)

In [None]:
# pdb_uniprot_mappings_reverse = {}


# for k, v in pdb_uniprot_mappings.items():
# 	pdb_uniprot_mappings_reverse[v] = k

In [None]:
# # write the reverse to a file for reproducibility
# with open('../data/sword2/SWORD2/misc/reverse_mappings_compact.json', "w") as json_file:
# 	json.dump(pdb_uniprot_mappings_reverse, json_file)

In [None]:
# read the reverse mapping
with open('../data/sword2/SWORD2/misc/reverse_mappings_compact.json') as json_file:
    pdb_uniprot_mappings_reverse = json.load(json_file)

In [None]:
"""
After creating a dictionary which removes the many-to-many mapping we get this many mappings:
"""
len(pdb_uniprot_mappings)
len(pdb_uniprot_mappings_reverse)

In [None]:
with open('../data/alpha/uniprot/uniprot_alpha.json') as json_file:
    alpha_uniprot = json.load(json_file)

In [None]:
# alphafold_uniprots = []
# not_alphafold_uniprots = []

# sword_usable_ids = []
# for id in alphafold_uniprots:
#     if id in pdb_uniprot_mappings.values():
#         sword_usable_ids.append(id)

# print("Usable IDs that have an AlphaFold prediction:", len(sword_usable_ids))

In [None]:
sword_usable_ids = []

res = alpha_uniprot['results']
cross_ref = res[0]['uniProtKBCrossReferences']
[x['database'] for x in cross_ref]

for elt in res:
    cross_ref = elt['uniProtKBCrossReferences']
    dbs = [db['database'] for db in cross_ref]
    id = elt['primaryAccession']
      
    if 'AlphaFoldDB' in dbs:
        if id in pdb_uniprot_mappings_reverse.keys():
            sword_usable_ids.append(id)


print("Usable IDs that have an AlphaFold prediction:", len(sword_usable_ids))

In [None]:
# pprint(res[0])

In [None]:
# len(res)
# random.choice(alphafold_uniprots)
# random.choice(not_alphafold_uniprots)

In [None]:
# write the pdb to uniprot dict because the file downloaded was 1.1gb
# compact_mappings = {}

# for elt in data['results']:
#     pdb = elt['from']
#     uniprot = elt['to']['primaryAccession']
#     compact_mappings[pdb] = uniprot

In [None]:
# with open('../data/sword2/SWORD2/mappings_compact.json', 'w') as fp:
#     json.dump(compact_mappings, fp)

In [None]:
"""
Check which IDs from the generated mappings have an AlphaFold Prediction using alphafold_uniprots
"""
# sword_usable_ids = []
# for id in alphafold_uniprots:
#     if id in pdb_uniprot_mappings.values():
#         sword_usable_ids.append(id)

print("Usable IDs that have an AlphaFold prediction:", len(sword_usable_ids))


In [None]:
len_sword_usable_ids_before = len(sword_usable_ids)

In [None]:
# for k, v in cath.items():
# 	if "a" in list(v.keys()):
# 		print(k, v)
# 		break

In [None]:
print(len(pdb_uniprot_mappings_reverse))

In [None]:
"""Some PDB's (from RCSB) do not have an A chain"""
counter = 0
for id in sword_usable_ids:
	pdb = pdb_uniprot_mappings_reverse[id]
	chains = list(cath[pdb].keys())
	if "A" not in chains:
		counter += 1
		sword_usable_ids.remove(id)
		del pdb_uniprot_mappings_reverse[id]

In [None]:
print(len_sword_usable_ids_before)
print(counter)
print(len(sword_usable_ids))
print(len(pdb_uniprot_mappings_reverse))

In [None]:
print(counter)
assert counter == len_sword_usable_ids_before - len(sword_usable_ids)

In [None]:
# write the above to a file
with open("../data/sword2/SWORD2/misc/alphafold_dataset_overlap.txt", "w") as f:
    comma_sep_ids = [x + '\n' for x in sword_usable_ids[:-1]]
    comma_sep_ids.append(sword_usable_ids[-1])
    f.writelines(comma_sep_ids)

In [None]:
def download_af_model(id):
    """
    Download the Alphafold2 model corresponding to the Uniprot Id given by user
    https://alphafold.ebi.ac.uk/

    Returns:
        - File path (string): Path of the downloaded PDB file
        or
        False if wrong id
        "DOWNLOAD ERROR" if could not download
    """
    name = f"AF-{id}-F1-model_v3"
    url = f"https://alphafold.ebi.ac.uk/files/{name}.pdb"
    try:
        response = requests_retry_session().get(url)
    except Exception as x:
        return (False, x)
    with open(f"../data/sword2/SWORD2/misc/af_pdbs/{name}.pdb", "w") as f:
        f.write(response.text)
    return (True, f"../data/sword2/SWORD2/misc/af_pdbs/{name}.pdb")

def requests_retry_session(retries=3,
                           backoff_factor=0.3,
                           status_forcelist=(500, 502, 504),
                           session=None):
    session = session or requests.Session()
    retry = Retry(
        total=retries,
        read=retries,
        connect=retries,
        backoff_factor=backoff_factor,
        status_forcelist=status_forcelist,
    )
    adapter = HTTPAdapter(max_retries=retry)
    session.mount('http://', adapter)
    session.mount('https://', adapter)
    return session

In [None]:
# download all the usable ones

# for id in sword_usable_ids:
# 	download_af_model(id)

In [None]:
# create a very small file to test the pipeline that will run on the remote machine
# with open("../data/sword2/SWORD2/misc/TEST_alphafold_dataset_overlap.txt", "w") as f:
#     comma_sep_ids = [x + '\n' for x in sword_usable_ids[:3]]
#     comma_sep_ids.append(sword_usable_ids[3])
#     f.writelines(comma_sep_ids)

In [None]:
def sequence_sim(seq1, seq2, match_score = 1, mismatch_score = -1, gap_penalty = -2):
	alignments = pairwise2.align.globalxx(seq1, seq2)

	# Print the alignment(s)
	# for alignment in alignments:
		# print(format_alignment(*alignment))
	score = alignments[0].score
	norm_score = score / max(len(seq1), len(seq2))
	return norm_score

In [None]:
counter = 0
perf_score = []
missmatch = []
for id in sword_usable_ids:
	af_pdb_file_path = f"../data/sword2/SWORD2/misc/af_pdbs/AF-{id}-F1-model_v3.pdb"
	af_chains = {record.id: record.seq for record in SeqIO.parse(af_pdb_file_path, 'pdb-seqres')}
	a_chain_uniprot_seq = af_chains['XXXX:A']

	pdb = pdb_uniprot_mappings_reverse[id]
	pdb_file_path = f"../data/pdb/bulk/balanced/backup/data/{pdb}.pdb"
	pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file_path, 'pdb-seqres')}
	
	for key in pdb_chains.keys():
		if key[-1] == 'A':
			a_chain_pdb_seq = pdb_chains[key]
			break

	if len(a_chain_pdb_seq) == len(a_chain_uniprot_seq):
		counter += 1
		
		# print(a_chain_uniprot_seq)
		# print(a_chain_pdb_seq)
		score = sequence_sim(a_chain_pdb_seq, a_chain_uniprot_seq)
		if int(score) == 1:
			perf_score.append(id)


In [None]:
len(perf_score)

In [None]:
for id in perf_score:
	# pdb_path = f'../data/sword2/SWORD2/results/pdb/{pdb}.pdb'

	af_pdb_file_path = f"../data/sword2/SWORD2/misc/af_pdbs/AF-{id}-F1-model_v3.pdb"
	af_chains = {record.id: record.seq for record in SeqIO.parse(af_pdb_file_path, 'pdb-seqres')}
	a_chain_uniprot_seq = af_chains['XXXX:A']

	pdb = pdb_uniprot_mappings_reverse[id]
	pdb_file_path = f"../data/pdb/bulk/balanced/backup/data/{pdb}.pdb"
	pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file_path, 'pdb-seqres')}

	for key in pdb_chains.keys():
		if key[-1] == 'A':
			a_chain_pdb_seq = pdb_chains[key]
			break

	if len(a_chain_pdb_seq) != len(a_chain_uniprot_seq):
		print(a_chain_pdb_seq)
		print(a_chain_uniprot_seq)
		print(abs(len(a_chain_pdb_seq) - len(a_chain_uniprot_seq)))
		raise ValueError("Different sequence lengths is not expected")

In [None]:
# parse estimation
minutes = []
with open("../data/sword2/SWORD2/misc/total_time_estimate.txt") as f:
	lines = f.readlines()
	for i in range(0, len(lines), 3):
		_, mins, id = lines[i], lines[i+1], lines[i+2]
		mins = mins.split('    ')
		mins = int(mins[1].strip())
		id = id.split('    ')
		id = id[1].strip()
		if id in perf_score:
			# print(mins, id)
			minutes.append((mins, id))

In [None]:
len(minutes)

In [None]:
random.seed(2023)

# only_mins = []
# for (min, id) in minutes:
# 	only_mins.append(min)

filtered = []
for mins, id in minutes:
	if mins <= 10:
		filtered.append((mins, id))

# random_sample = random.sample(filtered, 500)

# random_sample_only_mins = []
# for (min, id) in random_sample:
# 	random_sample_only_mins.append(min)

# random_sample_total_minutes = np.sum(random_sample_only_mins)
# random_sample_total_minutes

In [None]:
len(filtered)

In [None]:
# pdb_uniprot_mappings_reverse['P43235']

In [None]:
counter = 0
for i, (_, id) in enumerate(filtered):
	pdb = pdb_uniprot_mappings_reverse[id]
	chains = list(cath[pdb].keys())
	if "A" not in chains:
		del filtered[i]
		counter += 1
		print(id)
counter

In [None]:
print(len(filtered))

In [None]:
# check how many are multi-domain
single = 0
multi = 0
multis = []
for (_, id) in filtered:
	pdb = pdb_uniprot_mappings_reverse[id]
	num = len(cath[pdb]['A'])
	if num > 1:
		multi += 1
		multis.append(id)
	else:
		single += 1

filtered = multis[:]

print(f"Single: {single}  --  Multi: {multi}")

In [None]:
# write the IDs and PDBs to a text file for bulk

with open("../data/sword2/SWORD2/misc/filtered_uniprots.txt", "w") as ids_f:
	with open("../data/sword2/SWORD2/misc/filtered_pdbs.txt", "w") as pdbs_f:
		for id in filtered:
			pdb = pdb_uniprot_mappings_reverse[id]
			ids_f.write(id + '\n')
			pdbs_f.write(pdb + '\n')

In [None]:
# test set
with open("../data/sword2/SWORD2/misc/TEST_filtered_uniprots.txt", "w") as ids_f:
	with open("../data/sword2/SWORD2/misc/TEST_filtered_pdbs.txt", "w") as pdbs_f:
		for id in filtered[:3]:
			pdb = pdb_uniprot_mappings_reverse[id]
			ids_f.write(id + '\n')
			pdbs_f.write(pdb + '\n')

In [None]:
with open("../data/sword2/SWORD2/misc/filtered_uniprots.txt", "r") as f:
	for line in f.readlines():
		id = line.strip()
		pdb = pdb_uniprot_mappings_reverse[id]
		pdb_path = f'../data/sword2/SWORD2/results/pdb/{pdb}.pdb'

		af_pdb_file_path = f"../data/sword2/SWORD2/misc/af_pdbs/AF-{id}-F1-model_v3.pdb"
		af_chains = {record.id: record.seq for record in SeqIO.parse(af_pdb_file_path, 'pdb-seqres')}
		a_chain_uniprot_seq = af_chains['XXXX:A']

		pdb_file_path = f"../data/pdb/bulk/balanced/backup/data/{pdb}.pdb"
		pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file_path, 'pdb-seqres')}

		for key in pdb_chains.keys():
			if key[-1] == 'A':
				a_chain_pdb_seq = pdb_chains[key]
				break

		if len(a_chain_pdb_seq) != len(a_chain_uniprot_seq):
			print(a_chain_pdb_seq)
			print(a_chain_uniprot_seq)
			print(abs(len(a_chain_pdb_seq) - len(a_chain_uniprot_seq)))
			raise ValueError("Different sequence lengths is not expected")

In [None]:
# copy the pdb files in a seperate directory

with open("../data/sword2/SWORD2/misc/filtered_uniprots.txt", "r") as f:
	for line in f.readlines():
		id = line.strip()
		pdb = pdb_uniprot_mappings_reverse[id]
		pdb_path = f'../data/pdb/bulk/balanced/backup/data/{pdb}.pdb'

		shutil.copy(pdb_path, '../data/sword2/SWORD2/misc/pdb_files/')

In [None]:
def boundaries2(len_seq, domain, discontinuity_delimiter):
	"""
		Defines a boundary as the beginning of a domain ONLY in multi-domain proteins
	"""
	first_start = np.inf
	bounds = np.zeros((len_seq), dtype=np.int8)
	for k, v in domain.items():
		boundary_positions = v.split(discontinuity_delimiter)
		for b in boundary_positions:
			start, end = [int(i) for i in b.split('-')]
			if start < first_start:
				first_start = start
			bounds[start-1] = 1
	bounds[first_start-1] = 0            
	return np.array(bounds, dtype=np.bool_)


def dbd_score(y_pred, y_true, margin=20):
    scores = []
    for i in range(len(y_pred)):
        window = y_true[max(0, i-margin):min(len(y_true), i+margin+1)]
        indices_window = list(range(max(0, i-margin), min(len(y_true), i+margin+1)))
        if y_pred[i] == 1.0:
            if 1.0 in window:
                # if it's within the window, calculate the score
                pos = np.where(window == 1.0)[0][0]
                j = indices_window[pos]
                diff = abs(i - j)
                k = 0 if diff == 0 else 1
                score = ((margin - diff) + k) / margin
            else:
                # false positive
                score = 0
            scores.append(score)

    number_of_true_boundaries = np.sum(y_true)
    number_of_pred_boundaries = np.sum(y_pred)
    max_len = max(number_of_true_boundaries,number_of_pred_boundaries)
    if max_len == 0:
        return 1.0

    return np.sum(scores) / max_len


def metrics(y_pred, y_true, margin=20):
    true_positive = 0
    true_negative = 0
    false_negative = 0
    false_positive = 0

    for i in range(len(y_pred)):
        window = y_true[max(0, i-margin):min(len(y_true), i+margin+1)]
        if y_pred[i] == 1.0:
            if 1.0 in window:
                true_positive += 1
            else:
                false_positive += 1


        elif y_pred[i] == 0.0:
            if  y_true[i] == 1.0:
                false_negative += 1
            else:
                true_negative += 1

    try:
        accuracy = (true_negative + true_positive) / (true_negative + true_positive + false_negative + false_positive)
    except ZeroDivisionError:
        accuracy = 0

    try:
        precision = true_positive / (true_positive + false_positive)
    except ZeroDivisionError:
        precision = 0

    try:
        recall = true_positive / (true_positive + false_negative)
    except ZeroDivisionError:
        recall = 0

    try:
        f1 = (2 * precision * recall) / (precision + recall)
    except ZeroDivisionError:
        f1 = 0

    try:
        mcc = ((true_positive * true_negative) - (false_positive * false_negative)) / ((true_positive + false_positive) * (true_positive + false_negative) * (true_negative + false_positive) * (true_negative + false_negative))**0.5
    except ZeroDivisionError:
        mcc = 0

    dbd = dbd_score(y_pred, y_true, margin)

    return (accuracy, precision, recall, f1, mcc, dbd)

In [None]:
9

In [48]:
path = '../data/sword2/SWORD2/results/af'

uniprots = os.listdir(path)
# print(uniprots)
# print()

pdb_mcc = []
af_mcc = []

for i, id in enumerate(uniprots):

    if id != 'A0NLY7':
        continue

    pdb = pdb_uniprot_mappings_reverse[id]
    pdb_path = f'../data/sword2/SWORD2/results/pdb/{pdb}.pdb'

    af_pdb_file_path = f"../data/sword2/SWORD2/misc/af_pdbs/AF-{id}-F1-model_v3.pdb"
    af_chains = {record.id: record.seq for record in SeqIO.parse(af_pdb_file_path, 'pdb-seqres')}
    a_chain_uniprot_seq = af_chains['XXXX:A']

    pdb_file_path = f"../data/pdb/bulk/balanced/backup/data/{pdb}.pdb"
    pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file_path, 'pdb-seqres')}

    for key in pdb_chains.keys():
        if key[-1] == 'A':
            a_chain_pdb_seq = pdb_chains[key]
            break

    if len(a_chain_pdb_seq) != len(a_chain_uniprot_seq):
        print(a_chain_pdb_seq)
        print(a_chain_uniprot_seq)
        print(abs(len(a_chain_pdb_seq) - len(a_chain_uniprot_seq)))
        raise ValueError("Different sequence lengths is not expected")
    else:
        baseline = cath[pdb]['A']
        print("True domain boundaries from CATH:", baseline)
        # print("Length:", len(a_chain_pdb_seq))
        # print(id)
        af_sword_results = get_sword2(id, 'af', verb=False)
        # print(pdb)
        try:
            pdb_sword_results = get_sword2(pdb, 'pdb', verb=False)
        except FileNotFoundError:
            print("File not found")
            print(id, pdb)
            continue


        margin = 8
        baseline_boundaries = boundaries2(len(a_chain_pdb_seq), baseline, ',').astype(int)
        pdb_mccs = []
        af_mccs = []
        pdb_dbds = []
        af_dbds = []

        print(id, pdb)

        print("pdb")
        for option, domain in pdb_sword_results.items():
            pdb_sword_boundaries = boundaries2(len(a_chain_pdb_seq), domain, ';').astype(int)
            pdb_sword_metrics = metrics(pdb_sword_boundaries, baseline_boundaries, margin)
            pdb_sword_mcc = pdb_sword_metrics[-2]
            pdb_sword_dbd = pdb_sword_metrics[-1]
            pdb_mccs.append(pdb_sword_mcc)
            pdb_dbds.append(pdb_sword_dbd)
            print(option)
            print(domain)
            print(pdb_sword_mcc)
            print()

        print("af")
        for option, domain in af_sword_results.items():
            af_sword_boundaries = boundaries2(len(a_chain_pdb_seq), domain, ';').astype(int)
            af_sword_metrics = metrics(af_sword_boundaries, baseline_boundaries, margin)
            af_sword_mcc = af_sword_metrics[-2]
            af_sword_dbd = af_sword_metrics[-1]
            af_mccs.append(af_sword_mcc)
            af_dbds.append(af_sword_dbd)
            print(option)
            print(domain)
            print(af_sword_boundaries)
            print(af_sword_mcc)
            print()

        best_pdb_i = np.argmax(pdb_mccs)
        best_af_i = np.argmax(af_mccs)

        pdb_mcc.append(pdb_mccs[best_pdb_i])
        af_mcc.append(af_mccs[best_af_i])

    if (i + 1) % 50 == 0:
        print(f"[{i}/{len(uniprots)}]")


In [36]:
t_stat, p_val = ttest_ind(pdb_mcc[:100], af_mcc[:100])

print(t_stat, float(p_val))

-3.39084160550465 0.0008410867748462028


In [37]:
print(np.mean(pdb_mcc[:100]))
print(np.mean(af_mcc[:100]))

0.43234073554287933
0.5675947560551442


In [None]:
import numpy as np
from scipy.stats import ttest_ind


# calculate the mean MCC values
mean_mcc_a = np.mean(pdb_mccs)
mean_mcc_b = np.mean(af_mccs)

# perform a two-sample t-test
t_stat, p_value = ttest_ind(pdb_mccs, af_mccs)

print("Mean MCC for model A: ", mean_mcc_a)
print("Mean MCC for model B: ", mean_mcc_b)
print("t-statistic: ", t_stat)
print("p-value: ", p_value)
