# Imports and Setup

In [6]:
import os
import os.path
from pprint import pprint
import json
import random
import numpy as np
from helpers.helper import get_cath

from Bio import pairwise2
from Bio.pairwise2 import format_alignment
from Bio.Seq import Seq
from Bio import SeqIO

import requests
from requests.packages.urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter
import shutil
from Bio import pairwise2
from Bio.pairwise2 import format_alignment
from Bio.Seq import Seq
from Bio import SeqIO


from scipy.stats import ttest_ind
from pprint import pprint

cath = get_cath()


In [3]:
with open('../data/sword2/SWORD2/misc/new_iid/pdb_to_uniprot_map.json') as json_file:
    pdb_to_uni_map = json.load(json_file)

In [4]:
pdb_to_af_map = {}
for elt in pdb_to_uni_map['results']:
    pdb = elt['from']
    databases = [db['database'] for db in elt['to']['uniProtKBCrossReferences']]
    for i, db in enumerate(databases):
        if db == 'AlphaFoldDB':
            af_id = (elt['to']['uniProtKBCrossReferences'][i]['id'])
            if pdb_to_af_map.get(pdb):
                pdb_to_af_map[pdb].append(af_id)
            else:
                pdb_to_af_map[pdb] = [af_id]
    

In [5]:
total = 0
for k, v in pdb_to_af_map.items():
	total += len(v)
print("Total:", total)

Total: 11862


In [15]:
def download_af_model(id):
    """
    Download the Alphafold2 model corresponding to the Uniprot Id given by user
    https://alphafold.ebi.ac.uk/

    Returns:
        - File path (string): Path of the downloaded PDB file
        or
        False if wrong id
        "DOWNLOAD ERROR" if could not download
    """
    name = f"AF-{id}-F1-model_v3"
    url = f"https://alphafold.ebi.ac.uk/files/{name}.pdb"
    try:
        response = requests_retry_session().get(url)
    except Exception as x:
        return (False, x)
    with open(f"../data/sword2/SWORD2/misc/new_iid/af_pdbs/{name}.pdb", "w") as f:
        f.write(response.text)
    return (True, f"../data/sword2/SWORD2/misc/new_iid/af_pdbs/{name}.pdb")

def requests_retry_session(retries=3,
                           backoff_factor=0.3,
                           status_forcelist=(500, 502, 504),
                           session=None):
    session = session or requests.Session()
    retry = Retry(
        total=retries,
        read=retries,
        connect=retries,
        backoff_factor=backoff_factor,
        status_forcelist=status_forcelist,
    )
    adapter = HTTPAdapter(max_retries=retry)
    session.mount('http://', adapter)
    session.mount('https://', adapter)
    return session

In [78]:
# these are after i added extra 1-domain proteins
counter = 0
for _, codes in pdb_to_af_map.items():
	for code in codes:
			if not os.path.isfile(f"../data/sword2/SWORD2/misc/new_iid/af_pdbs/AF-{code}-F1-model_v3.pdb"):
				(bool, msg) = download_af_model(code)
				if not bool:
					print(msg)
				counter += 1
				if counter % 250 == 0:
					print(f"[{counter}/{1695}]")
print(f"[{counter}/{1695}]")

[0/1695]


In [79]:
# download AF PDB files
counter = 0
for _, codes in pdb_to_af_map.items():
	for code in codes:
		if not os.path.isfile(f"../data/sword2/SWORD2/misc/new_iid/af_pdbs/AF-{code}-F1-model_v3.pdb"):
			(bool, msg) = download_af_model(code)
			if not bool:
				print(msg)
			counter += 1
			if counter % 250 == 0:
				print(f"[{counter}/{total}]")
print(f"[{counter}/{total}]")

[0/11862]


In [21]:
def get_af_chain(code, chain):
    file_path = f"../data/sword2/SWORD2/misc/new_iid/af_pdbs/AF-{code}-F1-model_v3.pdb"
    chains = {record.id: record.seq for record in SeqIO.parse(file_path, 'pdb-seqres')}
    seq = chains.get(f'XXXX:{chain}')
    return seq


def get_pdb_chain(code, chain):
    pdb_file_path = f"../data/pdb/new_iid/{code}.pdb"
    pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file_path, 'pdb-seqres')}

    for key in pdb_chains.keys():
        if key[-1] == chain:
            a_chain_pdb_seq = pdb_chains[key]
            return a_chain_pdb_seq
        

def sequence_sim(seq1, seq2, match_score = 1, mismatch_score = -1, gap_penalty = -2):
    alignments = pairwise2.align.globalxx(seq1, seq2)

    # Print the alignment(s)
    # for alignment in alignments:
        # print(format_alignment(*alignment))
    score = alignments[0].score
    norm_score = score / max(len(seq1), len(seq2))
    return norm_score

In [22]:
with open('../data/cath/iid/chains_to_seq_iid.json') as json_file:
    chain_to_seq_iid = json.load(json_file)

valid_pairs = []
for chain, seq in chain_to_seq_iid.items():
    if pdb_to_af_map.get(chain[:4]):
        pdb_filename = f"../data/pdb/new_iid/{chain[:4]}.pdb"
        chain_id = chain[-1]
        with open(pdb_filename, "r") as pdb_file:
            pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file, 'pdb-seqres')}
            for key in pdb_chains.keys():
                if key[-1] == chain[-1]:
                    pdb_seq = pdb_chains[key]
        af_seqs = []
        for uniprot in pdb_to_af_map.get(chain[:4]):
            af_seq = get_af_chain(uniprot, chain[-1])
            if af_seq:
                af_seqs.append((uniprot, af_seq))
        for (uniprot, af_seq) in af_seqs:
            if len(af_seq) == len(pdb_seq):
                sim = sequence_sim(af_seq, pdb_seq)
                if sim == 1.0:
                    valid_pairs.append((chain, uniprot + f':{chain[-1]}'))



In [80]:
for (a, b) in valid_pairs:
    if b[-1] != 'A' or a[-1] != 'A':
        print(a,b)
        
# These ARE ALL A chains

In [81]:
len(valid_pairs)

508

In [45]:
valid_pdbs = [x[0] for x in valid_pairs]
valid_afs = [x[1] for x in valid_pairs]

from collections import Counter
for k,v in Counter(valid_pdbs).items():
    if v > 1:
        print(k, v)

1p4x:A 2
1s7o:A 2


In [95]:
# read the new from 1-domains and check set difference
new_afs = 0
new_pdbs = 0
print("AFs")
with open('../data/sword2/SWORD2/misc/new_iid/sword_af.txt', "r") as af_f:
    uniprot_ids_from_file = set([x.strip() for x in af_f.readlines()])
    new_valid_uniprots = set([x[:-2] for x in valid_afs])
    print("New ones:", len(new_valid_uniprots))
    print("Old ones:", len(uniprot_ids_from_file))
    diff = new_valid_uniprots.difference(uniprot_ids_from_file)
    with open('../data/sword2/SWORD2/misc/new_iid/diff_1domain_sword_af.txt', "w") as daf:
        for elt in diff:
            daf.write(elt + '\n')
    
print('\n' + "PDBs")
with open('../data/sword2/SWORD2/misc/new_iid/sword_pdb.txt', "r") as pdb_f:
    pdb_codes_from_file = set([x.strip() for x in pdb_f.readlines()])
    new_valid_pdbs = set([x[:-2] for x in valid_pdbs])
    print("New ones:", len(new_valid_pdbs))
    print("Old ones:", len(pdb_codes_from_file))
    diff = new_valid_pdbs.difference(pdb_codes_from_file)
    with open('../data/sword2/SWORD2/misc/new_iid/diff_1domain_sword_pdb.txt', "w") as dpdb:
        for elt in diff:
            dpdb.write(elt + '\n')

AFs
New ones: 508
Old ones: 455

PDBs
New ones: 506
Old ones: 453


In [60]:
# write the pairs to their respective files
with open('../data/sword2/SWORD2/misc/new_iid/sword_af.txt', "w") as af_f:
    with open('../data/sword2/SWORD2/misc/new_iid/sword_pdb.txt', "w") as pdb_f:
        for af in set(valid_afs):
            af_f.write(af[:-2] + '\n')
        for pdb in set(valid_pdbs):
            pdb_f.write(pdb[:-2] + '\n')


In [48]:
# write a true_random_sample_keys.txt
with open('../data/cath/iid/true_random_sample_keys.txt', 'w') as f:
    for chain in chain_to_seq_iid.keys():
        f.write(chain + '\n')

In [49]:
# analyse time estimate
minutes = []
with open("../data/sword2/SWORD2/misc/new_iid/estimate_time.txt") as f:
	lines = f.readlines()
	for i in range(0, len(lines), 3):
		_, mins, id = lines[i], lines[i+1], lines[i+2]
		mins = mins.split('    ')
		mins = int(mins[1].strip())
		id = id.split('    ')
		id = id[1].strip()
		minutes.append((mins, id))

In [50]:
just_minutes = [x[0] for x in minutes]
np.mean(just_minutes)
len(just_minutes)

454

In [51]:
valid_pairs[:40]

[('1a2k:A', 'P61972:A'),
 ('1a91:A', 'P68699:A'),
 ('1al3:A', 'P45600:A'),
 ('1b2l:A', 'P10807:A'),
 ('1bs2:A', 'Q05506:A'),
 ('1c56:A', 'P59936:A'),
 ('1c6o:A', 'P57736:A'),
 ('1cby:A', 'Q04470:A'),
 ('1cc7:A', 'P38636:A'),
 ('1dce:A', 'Q08602:A'),
 ('1dgj:A', 'Q9REC4:A'),
 ('1dl5:A', 'Q56308:A'),
 ('1dqu:A', 'P28298:A'),
 ('1drt:A', 'Q05581:A'),
 ('1e19:A', 'P95474:A'),
 ('1e2u:A', 'P31101:A'),
 ('1e44:A', 'P02984:A'),
 ('1ei5:A', 'Q9ZBA9:A'),
 ('1eyt:A', 'P80176:A'),
 ('1fhu:A', 'P29208:A'),
 ('1fnn:A', 'Q8ZYK1:A'),
 ('1ft9:A', 'P72322:A'),
 ('1fy2:A', 'P36936:A'),
 ('1gm5:A', 'Q9WY48:A'),
 ('1gsh:A', 'P04425:A'),
 ('1gu9:A', 'P9WQB5:A'),
 ('1h2h:A', 'Q9X1X6:A'),
 ('1h37:A', 'P33247:A'),
 ('1h3f:A', 'P83453:A'),
 ('1he3:A', 'P30043:A'),
 ('1hmu:A', 'Q59288:A'),
 ('1i5p:A', 'P0A377:A'),
 ('1in0:A', 'P44096:A'),
 ('1iq8:A', 'O58843:A'),
 ('1j1h:A', 'Q8RR57:A'),
 ('1j5n:A', 'P11632:A'),
 ('1jdw:A', 'P50440:A'),
 ('1jey:A', 'P12956:A'),
 ('1jmv:A', 'P44880:A'),
 ('1jp4:A', 'Q9Z1N4:A')]

# Functions

## SWORD parsing

In [52]:
def get_sword2(code, version, verb=False):
    file = f"../data/sword2/SWORD2/results/{version}/{code}/{code}_A/sword.txt"
    with open(file, "r") as f:
        data = {}
        lines = f.readlines()
        option = 0
        for i, line in enumerate(lines):
            lines[i] = "".join([c for c in line if c not in ["\n",'']])
            if line != "\n":
                if not line.startswith(("PDB:", "#D", "A")):
                    res = lines[i].split("|")
                    boundaries = res[2]
                    domains = boundaries.strip().split(" ")
                    data[f"option{option}"] = {}
                    for j in range(len(domains)):
                        data[f"option{option}"][str(j+1)] = domains[j]
                    option += 1
    verb and pprint(data)
    return data

## Metrics

In [53]:
def sequence_sim(seq1, seq2, match_score = 1, mismatch_score = -1, gap_penalty = -2):
	alignments = pairwise2.align.globalxx(seq1, seq2)

	# Print the alignment(s)
	# for alignment in alignments:
		# print(format_alignment(*alignment))
	score = alignments[0].score
	norm_score = score / max(len(seq1), len(seq2))
	return norm_score



def dbd_score(y_pred, y_true, margin=20):
    scores = []
    for i in range(len(y_pred)):
        window = y_true[max(0, i-margin):min(len(y_true), i+margin+1)]
        indices_window = list(range(max(0, i-margin), min(len(y_true), i+margin+1)))
        if y_pred[i] == 1.0:
            if 1.0 in window:
                # if it's within the window, calculate the score
                pos = np.where(window == 1.0)[0][0]
                j = indices_window[pos]
                diff = abs(i - j)
                k = 0 if diff == 0 else 1
                score = ((margin - diff) + k) / margin
            else:
                # false positive
                score = 0
            scores.append(score)

    number_of_true_boundaries = np.sum(y_true)
    number_of_pred_boundaries = np.sum(y_pred)
    max_len = max(number_of_true_boundaries,number_of_pred_boundaries)
    if max_len == 0:
        return 1.0

    return np.sum(scores) / max_len


def observations(y_pred, y_true, margin):
    tp = 0
    tn = 0
    fp = 0
    fn = 0

    dbd = dbd_score(y_pred, y_true, margin)
    for i in range(len(y_pred)):
        window = y_true[max(0, i-margin):min(len(y_true), i+margin+1)]
        indices_window = list(range(max(0, i-margin), min(len(y_true), i+margin+1)))
        if y_pred[i] == 1.0:
            if 1.0 in window:
                pos = np.where(window == 1.0)[0][0]
                j = indices_window[pos]
                y_true[j] = 0.0
                tp += 1
            else:
                fp += 1


        elif y_pred[i] == 0.0:
            if  y_true[i] == 1.0:
                fn += 1
            else:
                tn += 1

    return (tp, tn, fp, fn)


def observations__(y_pred, y_true, margin):
    tp = 0
    tn = 0
    fp = 0
    fn = 0

    dbd = dbd_score(y_pred, y_true, margin)
    for i in range(len(y_pred)):
        window = y_true[max(0, i-margin):min(len(y_true), i+margin+1)]
        indices_window = list(range(max(0, i-margin), min(len(y_true), i+margin+1)))
        if y_pred[i] == 1.0:
            if 1.0 in window:
                pos = np.where(window == 1.0)[0][0]
                j = indices_window[pos]
                y_true[j] = 0.0
                tp += 1
            else:
                fp += 1

    for i in range(len(y_pred)):
        if y_pred[i] == 0.0:
            if  y_true[i] == 1.0:
                fn += 1
            else:
                tn += 1

    return (tp, tn, fp, fn)


def metrics(y_pred, y_true, margin=20):
    tp, tn, fp, fn = observations__(y_pred, y_true, margin)

    accuracy = (tn + tp) / (tn + tp + fn + fp) if (tn + tp + fn + fp) else 0
    precision = tp / (tp + fp) if (tp + fp) else 0
    recall = tp / (tp + fn) if (tp + fn) else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) else 0

    mcc_num = (tp * tn) - (fp * fn)
    mcc_den = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
    mcc = mcc_num / mcc_den if mcc_den else 0

    dbd = dbd_score(y_pred, y_true, margin)

    return (accuracy, precision, recall, f1, mcc, dbd)

## Other

In [54]:
def requests_retry_session(retries=3,
                           backoff_factor=0.3,
                           status_forcelist=(500, 502, 504),
                           session=None):
    session = session or requests.Session()
    retry = Retry(
        total=retries,
        read=retries,
        connect=retries,
        backoff_factor=backoff_factor,
        status_forcelist=status_forcelist,
    )
    adapter = HTTPAdapter(max_retries=retry)
    session.mount('http://', adapter)
    session.mount('https://', adapter)
    return session

def download_af_model(id):
    """
    Download the Alphafold2 model corresponding to the Uniprot Id given by user
    https://alphafold.ebi.ac.uk/

    Returns:
        - File path (string): Path of the downloaded PDB file
        or
        False if wrong id
        "DOWNLOAD ERROR" if could not download
    """
    RESULTS_DIR = "../data/sword2/SWORD2/misc/new_iid/af_pdbs/"
    name = f"AF-{id}-F1-model_v3"
    url = f"https://alphafold.ebi.ac.uk/files/{name}.pdb"
    try:
        response = requests_retry_session().get(url)
    except Exception as x:
        return (False, x)
    with open(f"{RESULTS_DIR}/{name}.pdb", "w") as f:
        f.write(response.text)
    return (True, f"{RESULTS_DIR}/{name}.pdb")

## Boundaries

In [55]:
def boundaries2(len_seq, domain, discontinuity_delimiter):
	"""
		Defines a boundary as the beginning of a domain ONLY in multi-domain proteins
	"""
	first_start = np.inf
	bounds = np.zeros((len_seq), dtype=np.int8)
	for k, v in domain.items():
		boundary_positions = v.split(discontinuity_delimiter)
		for b in boundary_positions:
			start, end = [int(i) for i in b.split('-')]
			if start < first_start:
				first_start = start
			bounds[start-1] = 1
	bounds[first_start-1] = 0            
	return np.array(bounds, dtype=np.bool_)

In [56]:
def get_af_chain(code, chain):
	file_path = f"../data/sword2/SWORD2/misc/new_iid/af_pdbs/AF-{code}-F1-model_v3.pdb"
	chains = {record.id: record.seq for record in SeqIO.parse(file_path, 'pdb-seqres')}
	seq = chains.get(f'XXXX:{chain}')
	return seq


def get_pdb_chain(code, chain):
	pdb_file_path = f"../data/pdb/new_iid/{code}.pdb"
	pdb_chains = {record.id: record.seq for record in SeqIO.parse(pdb_file_path, 'pdb-seqres')}

	for key in pdb_chains.keys():
		if key[-1] == 'A':
			a_chain_pdb_seq = pdb_chains[key]
			return a_chain_pdb_seq

In [57]:
ypr = np.array([0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0], dtype=np.float16)
ytr = np.array([0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0], dtype=np.float16)

mcc = metrics(ypr, ytr)[-2]
mcc

0.6882472016116853

# Analysis

In [58]:
path = '../data/sword2/final_results/results/af'

uniprots = os.listdir(path)
# print(uniprots)
# print()

pdb_mcc = []
af_mcc = []

pdb_f1 = []
af_f1 = []

pdb_no_bounds = []
af_no_bounds = []

for i, id in enumerate(uniprots):

    pdb = pdb_uniprot_mappings_reverse[id]
    a_chain_uniprot_seq = get_af_chain(id)
    a_chain_pdb_seq = get_pdb_chain(pdb)
    chain_len = None
    if len(a_chain_pdb_seq) != len(a_chain_uniprot_seq):
        raise ValueError("Different sequence lengths is not expected")
    else:
        chain_len = len(a_chain_pdb_seq)
    
    baseline = cath[pdb]['A']
    af_sword_results = get_sword2(id, 'af', verb=False)
    try:
        pdb_sword_results = get_sword2(pdb, 'pdb', verb=False)
    except FileNotFoundError:
        print("File not found", id, pdb)
        continue

    margin = 20
    baseline_boundaries = boundaries2(len(a_chain_pdb_seq), baseline, ',').astype(int)
    pdb_mccs = []
    af_mccs = []
    pdb_dbds = []
    af_dbds = []
    pdb_f1s = []
    af_f1s = []
    pdb_bounds = []
    af_bounds = []

    for option, domain in pdb_sword_results.items():
        pdb_sword_boundaries = boundaries2(chain_len, domain, ';').astype(int)
        pdb_sword_metrics = metrics(pdb_sword_boundaries, baseline_boundaries, margin)
        pdb_sword_mcc = pdb_sword_metrics[-2]
        pdb_sword_dbd = pdb_sword_metrics[-1]
        pdb_sword_f1 = pdb_sword_metrics[-3]
        pdb_mccs.append(pdb_sword_mcc)
        pdb_dbds.append(pdb_sword_dbd)
        pdb_f1s.append(pdb_sword_f1)
        pdb_bounds.append(np.sum(pdb_sword_boundaries))

    for option, domain in af_sword_results.items():
        af_sword_boundaries = boundaries2(chain_len, domain, ';').astype(int)
        af_sword_metrics = metrics(af_sword_boundaries, baseline_boundaries, margin)
        af_sword_mcc = af_sword_metrics[-2]
        af_sword_dbd = af_sword_metrics[-1]
        af_sword_f1 = af_sword_metrics[-3]
        af_mccs.append(af_sword_mcc)
        af_dbds.append(af_sword_dbd)
        af_f1s.append(af_sword_f1)
        af_bounds.append(np.sum(af_sword_boundaries))

    af_f1.append(max(af_f1s))
    pdb_f1.append(max(pdb_f1s))

    best_pdb_i = np.argmax(pdb_mccs)
    best_af_i = np.argmax(af_mccs)
    pdb_no_bounds.append(pdb_bounds[best_pdb_i])

    pdb_mcc.append(pdb_mccs[best_pdb_i])
    af_mcc.append(af_mccs[best_af_i])
    af_no_bounds.append(af_bounds[best_af_i])

    if (i + 1) % 50 == 0:
        print(f"[{i + 1}/{len(uniprots)}]")

print(f"[{i + 1}/{len(uniprots)}]")


NameError: name 'pdb_uniprot_mappings_reverse' is not defined

In [None]:
print("af", np.mean(af_no_bounds))
print("pdb", np.mean(pdb_no_bounds))

af 2.2448979591836733
pdb 2.941043083900227


In [None]:
t_stat, p_val = ttest_ind(pdb_mcc, af_mcc)

print("Mean PDB MCC:", np.mean(pdb_mcc))
print("Mean AF MCC:", np.mean(af_mcc))

print(t_stat, float(p_val))

Mean PDB MCC: 0.6210332441611623
Mean AF MCC: 0.13217732495763904
27.3760386666944 7.451803918104624e-120


In [None]:
t_stat, p_val = ttest_ind(pdb_mcc, af_mcc)

print("Mean PDB MCC:", np.mean(pdb_mcc))
print("Mean AF MCC:", np.mean(af_mcc))

print(t_stat, float(p_val))

Mean PDB MCC: 0.6740580770787745
Mean AF MCC: 0.1436396457669393
28.077329120492674 2.2876669819673796e-124


In [None]:
t_stat, p_val = ttest_ind(af_no_bounds, pdb_no_bounds)

print("Mean PDB MCC:", np.mean(pdb_no_bounds))
print("Mean AF MCC:", np.mean(af_no_bounds))

print(t_stat, float(p_val))

Mean PDB MCC: 2.9115646258503403
Mean AF MCC: 2.240362811791383
-5.973230598758096 3.373535054941592e-09


Mean PDB MCC: 0.7312201785001735

Mean AF MCC: 0.04587217486107879

56.20185031552613 2.0480532531702836e-293

In [None]:
# import tensorflow as tf
# from keras import backend as K

# def sensitivity(y_true, y_pred):
#     true_label = K.argmax(y_true, axis=-1)
#     pred_label = K.argmax(y_pred, axis=-1)
#     INTERESTING_CLASS_ID = 2
#     sample_mask = K.cast(K.not_equal(true_label, INTERESTING_CLASS_ID), 'int32')

#     TP_tmp1 = K.cast(K.equal(true_label, 0), 'int32') * sample_mask
#     TP_tmp2 = K.cast(K.equal(pred_label, 0), 'int32') * sample_mask    
#     TP = K.sum(TP_tmp1 * TP_tmp2)

#     FN_tmp1 = K.cast(K.equal(true_label, 0), 'int32') * sample_mask
#     FN_tmp2 = K.cast(K.not_equal(pred_label, 0), 'int32') * sample_mask    
#     FN = K.sum(FN_tmp1 * FN_tmp2)

#     epsilon = 0.000000001
#     return K.cast(TP, 'float') / (K.cast(TP, 'float') + K.cast(FN, 'float') + epsilon)


# def precision(y_true, y_pred):
#     true_label = K.argmax(y_true, axis=-1)
#     pred_label = K.argmax(y_pred, axis=-1)
#     INTERESTING_CLASS_ID = 2
#     sample_mask = K.cast(K.not_equal(true_label, INTERESTING_CLASS_ID), 'int32')

#     TP_tmp1 = K.cast(K.equal(true_label, 0), 'int32') * sample_mask
#     TP_tmp2 = K.cast(K.equal(pred_label, 0), 'int32') * sample_mask
#     TP = K.sum(TP_tmp1 * TP_tmp2)

#     FP_tmp1 = K.cast(K.not_equal(true_label, 0), 'int32') * sample_mask
#     FP_tmp2 = K.cast(K.equal(pred_label, 0), 'int32') * sample_mask
#     FP = K.sum(FP_tmp1 * FP_tmp2)

#     epsilon = 0.000000001
#     return K.cast(TP, 'float') / (K.cast(TP, 'float') + K.cast(FP, 'float') + epsilon)

# def f1_score(y_true, y_pred):
#     pre = precision(y_true, y_pred)
#     sen = sensitivity(y_true, y_pred)
#     epsilon = 0.000000001
#     f1 = 2 * pre * sen / (pre + sen + epsilon)
#     return f1