# Check performance on the EasIFA dataset

In [6]:
import pandas as pd
import os
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import matplotlib.pyplot as plt
import warnings
from tqdm import tqdm 
from collections import defaultdict
from sciutil import SciUtil
import seaborn as sns

u = SciUtil()

warnings.filterwarnings('ignore')

swissprot = pd.read_csv('data/reviewed_sprot_08042025.tsv', sep='\t')
data_dir = 'data/AEGAN_extracted_sequences/'


def calculate_stats(df, id_col, true_col, pred_col, seq_col):   
    # Check the agreement
    predictions = []
    true = []
    missing = 0
    for seq_label, res_sq, res_pred, seq in df[[id_col, true_col, pred_col, seq_col]].values:
        res_sq = res_sq.split('|')
        if not res_pred or not isinstance(res_pred, str):
            res_pred = ''
        res_pred = res_pred.split('|')
        if len(res_pred) > 0:
            try:
                chosen_res_seq = [int(i) for i in res_pred]
            except:
                chosen_res_seq = []
                missing += 1
        res_sq = [int(i) for i in res_sq]
        for pos in range(0, len(seq)):
            if pos in res_sq:
                true.append(1)
            else:
                true.append(0)
            if pos in chosen_res_seq:
                predictions.append(1)
            else:
                predictions.append(0)
    precision, recall, f1, support = precision_recall_fscore_support(true, predictions)
    return precision[1], recall[1], f1[1], support[1]

def annotate_residue_from_uniprot(df):
    # Organise the active sites of these guys to be better
    active_sites = []
    active_site_residue_counts = []
    x = 0
    for act_site in df['Active site'].values:
        sites = []
        if isinstance(act_site, str):
            act_site = act_site.replace(" ", '')
            for act in act_site.split('ACT_SITE'):
                try:
                    sites.append(int(act.split(';')[0].replace(" ", '')) - 1) # Need to subtract 1 to make it fit with the fact that python is 0 encoded lol
                except:
                    x = 1
        if len(sites) != 0:
            active_sites.append('|'.join([str(s) for s in sites]))
            active_site_residue_counts.append(len(sites))
        else:
            active_sites.append('None')
            active_site_residue_counts.append(0)
    df['UniProtResidue'] = active_sites
    df['active_site_residue_counts'] = active_site_residue_counts
    return df


swissprot = annotate_residue_from_uniprot(swissprot)
swissprot = swissprot[swissprot['active_site_residue_counts'] > 0]
training_ids = set(pd.read_csv('data/AEGAN_real_training_set.txt', header=None)[0].values)
training_df = swissprot[swissprot['Entry'].isin(training_ids)]
training_df['Residue'] = training_df['UniProtResidue'].values

training_df.to_csv('data/AEGAN_swissprot_training.csv', index=False)
u.dp(['Number of AEGAN training set:', len(training_df)])

df = pd.read_csv('data/EasIFA_benchmark_catalytic_only.csv', index_col=0)
fasta_label = 'data/EasIFA_benchmark_catalytic_only.fasta'
with open(fasta_label, 'w+') as fout:
    for label, seq in df[['Entry', 'Sequence']].values:
        fout.write(f'>{label}\n{seq}\n')

model_dir = 'output/uni3175_aegan/models/'

#os.system(f'squidly run {fasta_label} esm2_t36_3B_UR50D output/EasIFA_3B/ --model-folder {model_dir} --database data/AEGAN_swissprot_training.csv --blast-threshold 30 --no-filter-blast')
#os.system(f'squidly run {fasta_label} esm2_t48_15B_UR50D output/EasIFA_15B/ --model-folder {model_dir} --database data/AEGAN_swissprot_training.csv --blast-threshold 30 --no-filter-blast')

[94m--------------------------------------------------------------------------------[0m
[94m                      Number of AEGAN training set:	9888	                       [0m
[94m--------------------------------------------------------------------------------[0m


In [7]:
from ast import literal_eval

df['UniProtResidue'] = ['|'.join([str(x) for x in literal_eval(str(a))]) for a in df['true_CR_labels_list'].values]
df

Unnamed: 0,reaction,ec,pdb_id,Entry,Sequence,site_labels,site_types,pdb_files,dataset_flag,aa_sequence_calculated,...,squid_8,squid_9,true_CR_labels,true_CR_labels_list,easifa_train_test_pident,AEGAN_mmseqs_hit_pident,EasIFA_CR_posis,label,ensemble_CR_posis,UniProtResidue
4,NC(=O)C1=CN([C@@H]2O[C@H](COP(=O)([O-])OP(=O)(...,1.-.-.-,3E9Q;3F1K;3F1L;,P31808,MHYQPKQDLLNDRIILVTGASDGIGREAAMTYARYGATVILLGRNE...,"[[16, 40], [152], [165]]","[0, 0, 1]",/home/xiaoruiwang/data/ubuntu_work_beta/single...,test,MHYQPKQDLLNDRIILVTGASDGIGREAAMTYARYGATVILLGRNE...,...,164,150|151|164,164,[164],0.318,0.239,[164],P31808,"[150, 164]",164
7,CCCC(=O)CC(=O)SCCNC(=O)CCNC(=O)[C@H](O)C(C)(C)...,1.1.1.100,2B4Q;,Q9RPT1,MHPYFSLAGRIALVTGGSRGIGQMIAQGLLEAGARVFICARDAEAC...,"[[14, 38], [148], [162]]","[0, 0, 1]",/home/xiaoruiwang/data/ubuntu_work_beta/single...,test,MHPYFSLAGRIALVTGGSRGIGQMIAQGLLEAGARVFICARDAEAC...,...,161|165,161,161,[161],0.291,0.284,[161],Q9RPT1,"[146, 161]",161
8,CC(C)(CO)C(=O)C(=O)[O-].NC(=O)C1=CN([C@@H]2O[C...,1.1.1.169,3EGO;,O34661,MKIGIIGGGSVGLLCAYYLSLYHDVTVVTRRQEQAAAIQSEGIRLY...,"[[7, 12], [98], [98], [124], [183], [187], [19...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 1]",/home/xiaoruiwang/data/ubuntu_work_beta/single...,test,MKIGIIGGGSVGLLCAYYLSLYHDVTVVTRRQEQAAAIQSEGIRLY...,...,178,178,178,[178],0.337,0.218,[178],O34661,[178],178
10,CC(O)CSCCS(=O)(=O)O.NC(=O)c1ccc[n+]([C@@H]2O[C...,1.1.1.269,4GH5;4ITU;,A7IQH5,MSNRLKNEVIAITGGGAGIGLAIASAALREGAKVALIDLDQGLAER...,"[[19], [38], [64, 65], [91], [143], [156], [16...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2]",/home/xiaoruiwang/data/ubuntu_work_beta/single...,test,MSNRLKNEVIAITGGGAGIGLAIASAALREGAKVALIDLDQGLAER...,...,141|155|159,155|159,155,[155],0.679,0.355,[155],A7IQH5,[155],155
11,CC(O)C(=O)[O-].NC(=O)c1ccc[n+]([C@@H]2O[C@H](C...,1.1.1.27,7NAY;,Q9EVR0,MNNRRKIVVIGASNVGSAVANKIADFQLATEVVLIDLNEDKAWGEA...,"[[13, 41], [93], [125], [125], [156], [234], [...","[0, 0, 0, 0, 0, 0, 1]",/home/xiaoruiwang/data/ubuntu_work_beta/single...,test,MNNRRKIVVIGASNVGSAVANKIADFQLATEVVLIDLNEDKAWGEA...,...,179,179,179,[179],0.402,0.380,[179],Q9EVR0,[179],179
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
214,CCCCCCC/C=C/CC(=O)SCCNC(=O)CCNC(=O)[C@H](O)C(C...,5.3.3.8,1HNO;1HNU;1K39;1PJH;4ZDB;4ZDC;4ZDD;4ZDE;4ZDF;,Q05871,MSQEIRQNEKISYRIEGPFFIIHLMNPDNLNALEGEDYIYLGELLE...,"[[68, 72], [126], [158]]","[0, 0, 1]",/home/xiaoruiwang/data/ubuntu_work_beta/single...,test,MSQEIRQNEKISYRIEGPFFIIHLMNPDNLNALEGEDYIYLGELLE...,...,,157,157,[157],0.202,0.000,[],Q05871,[],157
215,CC1=N[C@@H](C(=O)[O-])S/C1=C\COP(=O)([O-])[O-]...,5.3.99.10,1YAD;3QH2;,P25053,MELHAITDDSKPVEELARIIITIQNEVDFIHIRERSKSAADILKLL...,"[[102], [156], [176, 177], [122]]","[0, 0, 0, 1]",/home/xiaoruiwang/data/ubuntu_work_beta/single...,test,MELHAITDDSKPVEELARIIITIQNEVDFIHIRERSKSAADILKLL...,...,,101,121,[121],0.302,0.000,[61],P25053,[],121
216,O=P(O)(O)OC1OC(CO)C(O)C(O)C1O|MKLQGVIFDLDGVITD...,5.4.2.6,4G9B;,P77366,MKLQGVIFDLDGVITDTAHLHFQAWQQIAAEIGISIDAQFNESLKG...,"[[9, 11], [9], [9], [11], [11], [25], [44, 49]...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...",/home/xiaoruiwang/data/ubuntu_work_beta/single...,test,MKLQGVIFDLDGVITDTAHLHFQAWQQIAAEIGISIDAQFNESLKG...,...,,,10,[10],0.281,0.204,[8],P77366,[8],10
217,O=P(O)(O)OC1OC(CO)C(O)C(O)C1O|MKAVIFDLDGVITDTA...,5.4.2.6,3NAS;,O06995,MKAVIFDLDGVITDTAEYHFLAWKHIAEQIDIPFDRDMNERLKGIS...,"[[7, 9], [7], [7], [9], [9], [23], [42, 47], [...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...",/home/xiaoruiwang/data/ubuntu_work_beta/single...,test,MKAVIFDLDGVITDTAEYHFLAWKHIAEQIDIPFDRDMNERLKGIS...,...,,,8,[8],0.307,0.307,"[6, 8]",O06995,[6],8


In [11]:
for model in ['3B', '15B']:
    # First join each of the data frames from the individual runs then compute the uncertainties 
    squidly_ensemble = pd.read_pickle(f'output/EasIFA_{model}/squidly_squidly.pkl')
    squidly_ensemble.set_index('label', inplace=True)
    
    true_df = df.copy()
    true_df.set_index('Entry', inplace=True)
    true_df = true_df.join(squidly_ensemble, how='left', rsuffix='_')
    true_df['label'] = true_df.index 

    precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'Squidly_Ensemble_Residues', 'Sequence')
    u.err_p(['Squidly:', model, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])

    # Also print out the BLAST + squidly results
    true_df = df.copy()
    true_df.set_index('Entry', inplace=True)
    squidly_blast = pd.read_csv(f'output/cataloDB_{model}/squidly_blast.csv')
    squidly_blast['label'] = squidly_blast['From'].values
    squidly_blast.set_index('label', inplace=True)
    true_df = true_df.join(squidly_blast, how='left', rsuffix='_')
    true_df['label'] = true_df.index
    
    precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'BLAST_residues', 'Sequence')

    u.warn_p(['BLAST', model, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])

    # Also print out the BLAST + squidly results
    true_df = df.copy()
    true_df.set_index('Entry', inplace=True)
    squidly_blast_ensemble = pd.read_csv(f'output/cataloDB_{model}/squidly_ensemble.csv')
    squidly_blast_ensemble['label'] = squidly_blast_ensemble['id'].values
    squidly_blast_ensemble.set_index('label', inplace=True)
    true_df = true_df.join(squidly_blast_ensemble, how='left', rsuffix='_')
    true_df['label'] = true_df.index
    
    precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'residues', 'Sequence')

    u.dp(['Squidly + BLAST', model, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])


[91m--------------------------------------------------------------------------------[0m
[91mSquidly:	3B	Precision: 0.7387387387387387	Recall: 0.780952380952381	F1: 0.7592592592592593	Support: 105	[0m
[91m--------------------------------------------------------------------------------[0m
[93m--------------------------------------------------------------------------------[0m
[93mBLAST	3B	Precision: 0.25	Recall: 0.009523809523809525	F1: 0.01834862385321101	Support: 105	[0m
[93m--------------------------------------------------------------------------------[0m
[94m--------------------------------------------------------------------------------[0m
[94mSquidly + BLAST	3B	Precision: 0.6666666666666666	Recall: 0.05714285714285714	F1: 0.10526315789473684	Support: 105	[0m
[94m--------------------------------------------------------------------------------[0m
[91m--------------------------------------------------------------------------------[0m
[91mSquidly:	15B	Precision: 0.