# Run the models using the defaults which were selected in the previous notebook

In [15]:
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import matplotlib.pyplot as plt

swissprot = pd.read_csv('../manuscript/data/reviewed_sprot_08042025.tsv', sep='\t')
data_dir = '/disk1/ariane/vscode/squidly/manuscript/AEGAN_extracted_sequences/'


def calculate_stats(df, id_col, true_col, pred_col, seq_col):   
    # Check the agreement
    predictions = []
    true = []
    missing = 0
    for seq_label, res_sq, res_pred, seq in df[[id_col, true_col, pred_col, seq_col]].values:
        res_sq = res_sq.split('|')
        if not res_pred or not isinstance(res_pred, str):
            res_pred = ''
        res_pred = res_pred.split('|')
        if len(res_pred) > 0:
            try:
                chosen_res_seq = [int(i) for i in res_pred]
            except:
                chosen_res_seq = []
                missing += 1
        res_sq = [int(i) for i in res_sq]
        for pos in range(0, len(seq)):
            if pos in res_sq:
                true.append(1)
            else:
                true.append(0)
            if pos in chosen_res_seq:
                predictions.append(1)
            else:
                predictions.append(0)
    precision, recall, f1, support = precision_recall_fscore_support(true, predictions)
    return precision[1], recall[1], f1[1], support[1]


def annotate_residue_from_uniprot(df):
    # Organise the active sites of these guys to be better
    active_sites = []
    active_site_residue_counts = []
    x = 0
    for act_site in df['Active site'].values:
        sites = []
        if isinstance(act_site, str):
            act_site = act_site.replace(" ", '')
            for act in act_site.split('ACT_SITE'):
                try:
                    sites.append(int(act.split(';')[0].replace(" ", '')) - 1) # Need to subtract 1 to make it fit with the fact that python is 0 encoded lol
                except:
                    x = 1
        if len(sites) != 0:
            active_sites.append('|'.join([str(s) for s in sites]))
            active_site_residue_counts.append(len(sites))
        else:
            active_sites.append('None')
            active_site_residue_counts.append(0)
    df['UniProtResidue'] = active_sites
    df['active_site_residue_counts'] = active_site_residue_counts
    return df


swissprot = annotate_residue_from_uniprot(swissprot)
swissprot = swissprot[swissprot['active_site_residue_counts'] > 0]

# Run cataloDB on squidly


In [10]:

from tqdm import tqdm 
from collections import defaultdict
from sciutil import SciUtil
import seaborn as sns

u = SciUtil()

training_ids = set(pd.read_csv('../manuscript/data/Low30_mmseq_ID_exp_subset_train.csv')['Entry'].values)
test_ids = set(pd.read_csv('../manuscript/data/Low30_mmseq_ID_exp_subset_test_foldseek.csv')['Entry'].values)

u.dp(['Number of training set:', len(training_ids), '\nNumber of test set:', len(test_ids)])

# After filtering for those in SwissProt
test_df = swissprot[swissprot['Entry'].isin(test_ids)]
training_df = swissprot[swissprot['Entry'].isin(training_ids)]
training_df['Residue'] = training_df['UniProtResidue'].values
training_df.to_csv('data/CataloDB_swissprot_training.csv', index=False)
# Save the training data as a database we're going to save this as a fasta file 
fasta_label = f'data/cataloDB_test.fasta'
with open(fasta_label, 'w+') as fout:
    for entry, seq in test_df[['Entry', 'Sequence']].values:
        if entry in test_ids:
            fout.write(f'>{entry}\n{seq}\n')
        else:
            # Just making double sure no contamination
            print(f"{entry} in training set?")

[94m--------------------------------------------------------------------------------[0m
[94m             Number of training set:	5355	
Number of test set:	239	             [0m
[94m--------------------------------------------------------------------------------[0m


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  training_df['Residue'] = training_df['UniProtResidue'].values


In [11]:
training_df

Unnamed: 0,Entry,Length,Sequence,Active site,PDB,UniProtResidue,active_site_residue_counts,Residue
0,A0A009IHW8,269,MSLEQKKGADIISKILQIQNSIGKTTSPSTLKTKLSEISRKEQENA...,"ACT_SITE 208; /evidence=""ECO:0000255|PROSITE-P...",7UWG;7UXU;8G83;,207,1,207
1,A0A023I7E1,796,MRFQVIVAAATITMITSYIPGVASQSTSDGDDLFVPVSNFDPKSIF...,"ACT_SITE 500; /evidence=""ECO:0000255|PROSITE-P...",4K35;4K3A;5XBZ;5XC2;,499|577|581,3,499|577|581
4,A0A024SC78,248,MRSLAILTTLLAGHAFAYPKPAPQSVNRRDWPSINEFLSELAKVMP...,"ACT_SITE 164; /note=""Nucleophile""; /evidence=""...",4PSC;4PSD;4PSE;,163|215|228,3,163|215|228
8,A0A059TC02,333,MRSVSGQVVCVTGAGGFIASWLVKILLEKGYTVRGTVRNPDDPKNG...,"ACT_SITE 161; /note=""Proton donor""; /evidence=...",4R1S;4R1T;,160,1,160
39,A0A075TRK9,628,MRLTSGIFHAAIAVAAVGAVLPEGPSSSKTHRNEYARRMLGSSFGI...,"ACT_SITE 564; /note=""Proton acceptor""; /eviden...",8BXL;,563,1,563
...,...,...,...,...,...,...,...,...
534049,Q9RWH9,247,MTDQPDLFGLAPDAPRPIIPANLPEDWQEALLPEFSAPYFHELTDF...,"ACT_SITE 83; /note=""Proton acceptor""; /evidenc...",2BOO;3UFM;4UQM;,82,1,82
535768,Q9X0Z9,475,MIDLDFRKLTIEECLKLSEEEREKLPQLSLETIKRLDPHVKAFISV...,"ACT_SITE 66; /note=""Charge relay system""; /evi...",2GI3;3AL0;,65|140|164,3,65|140|164
535838,Q9X2E1,207,MEELLKELERIREEAKPLVEQRFEEFKRLGEEGTEEDLFCELSFCV...,"ACT_SITE 129; /evidence=""ECO:0000255|HAMAP-Rul...",3N0U;,128|146,2,128|146
536113,Q9XW42,647,MFLEKINQKTGEREWVVAEEDYDMAQELARSRFGDMILDFDRNDKF...,"ACT_SITE 140; /evidence=""ECO:0000250""; ACT_SIT...",3WST;3X0D;,139|148,2,139|148


In [12]:
import os
# esm2_t48_15B_UR50D
os.system(f'squidly run {fasta_label} esm2_t36_3B_UR50D output/cataloDB_3B/  --database data/CataloDB_swissprot_training.csv --blast-threshold 30 --no-filter-blast')

[94m--------------------------------------------------------------------------------[0m
[94m                             Starting squidly... 	                              [0m
[94m--------------------------------------------------------------------------------[0m


100%|█████████████████████████████████████| 239/239 [00:00<00:00, 92169.79it/s]
diamond v2.1.12.166 (C) Max Planck Society for the Advancement of Science, Benjamin Buchfink, University of Tuebingen
Documentation, support and updates available at http://www.diamondsearch.org
Please cite: http://dx.doi.org/10.1038/s41592-021-01101-x Nature Methods (2021)

#CPU threads: 384
Scoring parameters: (Matrix=BLOSUM62 Lambda=0.267 K=0.041 Penalties=11/1)
Database input file: output/cataloDB_3B/squidly_database.fasta
Opening the database file...  [0.001s]
Loading sequences...  [0.017s]
Masking sequences...  [0.028s]
Writing sequences...  [0.001s]
Hashing sequences...  [0s]
Loading sequences...  [0s]
Writing trailer...  [0s]
Closing the input file...  [0s]
Closing the database file...  [0.003s]

Database sequences  5355
  Database letters  2195051
     Database hash  c9ace6d022667b97e1802b18a96a2ca4
        Total time  0.052000s


[93m--------------------------------------------------------------------------------[0m
[93m   Running BLAST on the following DB: 	data/CataloDB_swissprot_training.csv	    [0m
[93m--------------------------------------------------------------------------------[0m
['diamond', 'blastp', '--ultra-sensitive', '-d', '/tmp/tmpmwcqwkbw/zQfMXr7eW4_db', '-q', '/tmp/tmpmwcqwkbw/zQfMXr7eW4_query.fasta', '-o', '/tmp/tmpmwcqwkbw/zQfMXr7eW4_matches.tsv']
[94m--------------------------------------------------------------------------------[0m
[94mRunning command	diamond blastp --ultra-sensitive -d /tmp/tmpmwcqwkbw/zQfMXr7eW4_db -q /tmp/tmpmwcqwkbw/zQfMXr7eW4_query.fasta -o /tmp/tmpmwcqwkbw/zQfMXr7eW4_matches.tsv	[0m
[94m--------------------------------------------------------------------------------[0m
[93m--------------------------------------------------------------------------------[0m
[93m                                    Output:	                                    [0m
[93m-----

diamond v2.1.12.166 (C) Max Planck Society for the Advancement of Science, Benjamin Buchfink, University of Tuebingen
Documentation, support and updates available at http://www.diamondsearch.org
Please cite: http://dx.doi.org/10.1038/s41592-021-01101-x Nature Methods (2021)

#CPU threads: 384
Scoring parameters: (Matrix=BLOSUM62 Lambda=0.267 K=0.041 Penalties=11/1)
Temporary directory: /tmp/tmpmwcqwkbw
#Target sequences to report alignments for: 25
Opening the database...  [0.003s]
Database: /tmp/tmpmwcqwkbw/zQfMXr7eW4_db (type: Diamond database, sequences: 5355, letters: 2195051)
Block size = 400000000
Opening the input file...  [0s]
Opening the output file...  [0s]
Loading query sequences...  [0s]
Masking queries...  [0.034s]
Algorithm: Double-indexed
Building query histograms...  [0.097s]
Seeking in database...  [0s]
Loading reference sequences...  [0.001s]
Masking reference...  [0.028s]
Initializing temporary storage...  [0.001s]
Building reference histograms...  [0.11s]
Allocating


[91m--------------------------------------------------------------------------------[0m
[94m--------------------------------------------------------------------------------[0m
[94m              Time for command to run (min): 	0.18574542751690995	              [0m
[94m--------------------------------------------------------------------------------[0m
         0       1     2    3    4   ...   7    8    9             10     11
0    A5JTM5  A4YI89  28.9  228  149  ...  225    3  221  7.370000e-27  102.0
1    A5JTM5  Q5SKU3  34.3  207  129  ...  211    3  205  1.340000e-23   94.0
2    A5JTM5  P52046  27.7  224  156  ...  232   11  229  3.010000e-23   93.2
3    A5JTM5  Q5LLW6  29.8  215  135  ...  213   18  227  8.640000e-21   86.7
4    A5JTM5  O53561  30.6  196  120  ...  204   21  207  2.850000e-18   79.7
..      ...     ...   ...  ...  ...  ...  ...  ...  ...           ...    ...
428  Q56694  P94358  19.6  424  297  ...  413   27  417  1.460000e-06   48.1
429  Q56694  Q56R04  24

100%|███████████████████████████████████████| 239/239 [00:01<00:00, 147.09it/s]


0

In [None]:
os.system(f'squidly run {fasta_label} esm2_t48_15B_UR50D output/cataloDB_15B/  --database data/CataloDB_swissprot_training.csv --blast-threshold 30 --no-filter-blast')

[94m--------------------------------------------------------------------------------[0m
[94m                             Starting squidly... 	                              [0m
[94m--------------------------------------------------------------------------------[0m


100%|████████████████████████████████████| 239/239 [00:00<00:00, 170280.05it/s]
diamond v2.1.12.166 (C) Max Planck Society for the Advancement of Science, Benjamin Buchfink, University of Tuebingen
Documentation, support and updates available at http://www.diamondsearch.org
Please cite: http://dx.doi.org/10.1038/s41592-021-01101-x Nature Methods (2021)

#CPU threads: 384
Scoring parameters: (Matrix=BLOSUM62 Lambda=0.267 K=0.041 Penalties=11/1)
Database input file: output/cataloDB_15B/squidly_database.fasta
Opening the database file...  [0.001s]
Loading sequences...  [0.017s]
Masking sequences...  [0.039s]
Writing sequences...  [0.001s]
Hashing sequences...  [0.001s]
Loading sequences...  [0s]
Writing trailer...  [0s]
Closing the input file...  [0s]
Closing the database file...  [0.003s]

Database sequences  5355
  Database letters  2195051
     Database hash  c9ace6d022667b97e1802b18a96a2ca4
        Total time  0.066000s


# Read in each from the folder and print the output

In [17]:
for model in ['3B']:
    # First join each of the data frames from the individual runs then compute the uncertainties 
    squidly_ensemble = pd.read_pickle(f'output/cataloDB_{model}/squidly_squidly.pkl')
    squidly_ensemble.set_index('label', inplace=True)
    
    true_df = swissprot[swissprot['Entry'].isin(test_ids)]
    true_df.set_index('Entry', inplace=True)
    true_df = true_df.join(squidly_ensemble, how='left', rsuffix='_')
    true_df['label'] = true_df.index 

    precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'Squidly_Ensemble_Residues', 'Sequence')
    u.err_p(['Squidly:', model, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])

    # Also print out the BLAST + squidly results
    true_df = swissprot[swissprot['Entry'].isin(test_ids)]
    true_df.set_index('Entry', inplace=True)
    squidly_blast = pd.read_csv(f'output/cataloDB_{model}/squidly_blast.csv')
    squidly_blast['label'] = squidly_blast['From'].values
    squidly_blast.set_index('label', inplace=True)
    true_df = true_df.join(squidly_blast, how='left', rsuffix='_')
    true_df['label'] = true_df.index
    
    precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'BLAST_residues', 'Sequence')

    u.warn_p(['BLAST', model, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])

    # Also print out the BLAST + squidly results
    true_df = swissprot[swissprot['Entry'].isin(test_ids)]
    true_df.set_index('Entry', inplace=True)
    squidly_blast_ensemble = pd.read_csv(f'output/cataloDB_{model}/squidly_ensemble.csv')
    squidly_blast_ensemble['label'] = squidly_blast_ensemble['id'].values
    squidly_blast_ensemble.set_index('label', inplace=True)
    true_df = true_df.join(squidly_blast_ensemble, how='left', rsuffix='_')
    true_df['label'] = true_df.index
    
    precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'residues', 'Sequence')

    u.dp(['Squidly + BLAST', model, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])


[91m--------------------------------------------------------------------------------[0m
[91mSquidly:	3B	Precision: 0.8184615384615385	Recall: 0.6186046511627907	F1: 0.7046357615894039	Support: 430	[0m
[91m--------------------------------------------------------------------------------[0m
[93m--------------------------------------------------------------------------------[0m
[93mBLAST	3B	Precision: 0.6948051948051948	Recall: 0.24883720930232558	F1: 0.3664383561643836	Support: 430	[0m
[93m--------------------------------------------------------------------------------[0m
[94m--------------------------------------------------------------------------------[0m
[94mSquidly + BLAST	3B	Precision: 0.7597597597597597	Recall: 0.5883720930232558	F1: 0.6631716906946264	Support: 430	[0m
[94m--------------------------------------------------------------------------------[0m


# For each of the families do the plot of the best range and then also calculate for the cutoff

In [13]:
import warnings
warnings.filterwarnings('ignore')

In [None]:


files = ['PC',
         'NN',
         'EF_superfamily',
         'EF_fold',
         'EF_family',
         'HA_superfamily'] 

for model in ['15B', '3B']:
    # First join each of the data frames from the individual runs then compute the uncertainties 
    squidly_ensemble = pd.read_pickle(f'output/families_{b}/{family}/{family}_squidly_0.pkl')
    for model_i in range(1, 5):
        squidly_df = pd.read_pickle(f'output/families_{b}/{family}/{family}_squidly_{model_i}.pkl')
        squidly_ensemble = squidly_ensemble.join(squidly_df, how='outer', rsuffix=f'_{model_i}')
        
    squidly_ensemble.set_index('label', inplace=True)
    
    if family != 'HA_superfamily':
        squidly_ensemble['Entry'] = [e.split('|')[1] for e in squidly_ensemble.index.values]
    else:
        squidly_ensemble['Entry'] = squidly_ensemble.index.values
    squidly_ensemble.set_index('Entry', inplace=True)
    
    test_ids = set(pd.read_csv(f'{data_dir}/{family}/{family}.txt', header=None)[0].values)
    true_df = swissprot[swissprot['Entry'].isin(test_ids)]
    true_df.set_index('Entry', inplace=True)
    true_df = true_df.join(squidly_ensemble, how='left')

    true_df['label'] = true_df.index

    means, entropy_values, epistemics, residues = compute_uncertainties(true_df, ['all_AS_probs', 'all_AS_probs_1', 'all_AS_probs_2', 'all_AS_probs_3', 'all_AS_probs_4'], 'Sequence', 0.6, 0.225)
    true_df['mean_prob'] = means
    true_df['entropy'] = entropy_values
    true_df['variance'] = epistemics
    true_df['residues'] = residues    

    precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'residues', 'Sequence')

    u.warn_p([model, family, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])
    unc_df = calculate_stats_uncertainty(true_df, 'label', 'UniProtResidue', 'residues', 'Sequence')
    cols = defaultdict(list)
    for mean_prob in [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]:
        for vari in [0.05, 0.0725, 0.1, 0.125, 0.15, 0.175, 0.2, 0.225, 0.25]:
            # Make a pred
            preds_prob = 1.0 * unc_df['Mean Prob'].values > mean_prob
            preds_var = 1.0 * unc_df['Variance'].values < vari
            cols[f'm{mean_prob}_v{vari}'] = 1.0*preds_prob*preds_var
    
    for c in cols:
        unc_df[c] = cols[c]
        
    # Now calculate the different predictions for each of them
    rows = []
    for c in cols:
        precision, recall, f1, support = precision_recall_fscore_support(list(unc_df['True Catalytic'].values), list(unc_df[c].values))
        # Maybe save the 0's too but for now look only at precision and recall of getting one correct
        rows.append([c, c.split('_')[0][1:], c.split('_')[1][1:], precision[1], recall[1], f1[1], support[1]])
    
    pred_df = pd.DataFrame(rows, columns=['label', 'mean_pred', 'variance', 'precision', 'recall', 'f1', 'support'])
    
    plt.rcParams['figure.figsize'] = (5,4)
    
    df_plot = pd.DataFrame(pred_df, columns=["mean_pred", "variance", "f1"])
    pivot = df_plot.pivot("mean_pred", "variance", "f1")
    sns.heatmap(pivot, annot=True, cmap="viridis")
    plt.ylabel("Prediction cutoff")
    plt.xlabel("Variance cutoff")
    plt.title(f'F1 {family} {model}')
    plt.show()

    df_plot = pd.DataFrame(pred_df, columns=["mean_pred", "variance", "precision"])
    pivot = df_plot.pivot("mean_pred", "variance", "precision")
    sns.heatmap(pivot, annot=True, cmap="viridis")
    plt.ylabel("Prediction cutoff")
    plt.xlabel("Variance cutoff")
    plt.title(f'Precision {family} {model}')
    plt.show()

    df_plot = pd.DataFrame(pred_df, columns=["mean_pred", "variance", "recall"])
    pivot = df_plot.pivot("mean_pred", "variance", "recall")
    sns.heatmap(pivot, annot=True, cmap="viridis")
    plt.ylabel("Prediction cutoff")
    plt.xlabel("Variance cutoff")
    plt.title(f'Recall {family} {model}')
    plt.show()
    

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 55/55 [00:00<00:00, 158.28it/s]


[93m--------------------------------------------------------------------------------[0m
[93m15B	PC	Precision: 0.8631578947368421	Recall: 0.8282828282828283	F1: 0.8453608247422681	Support: 99	[0m
[93m--------------------------------------------------------------------------------[0m
