# Run the models using the defaults which were selected in the previous notebook

In [23]:
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import matplotlib.pyplot as plt
import warnings
from tqdm import tqdm 
from collections import defaultdict
from sciutil import SciUtil
import seaborn as sns

u = SciUtil()

warnings.filterwarnings('ignore')

swissprot = pd.read_csv('../manuscript/data/reviewed_sprot_08042025.tsv', sep='\t')
data_dir = '/disk1/ariane/vscode/squidly/manuscript/AEGAN_extracted_sequences/'

def compute_uncertainties(df, prob_columns, seq_col, mean_prob=0.5, mean_var=1):
    means, variances, residues, entropy_values  = [], [], [], []
    for p1, p2, p3, p4, p5, seq in tqdm(df[prob_columns + [seq_col]].values):
        mean_values = []
        variance_values = []
        entropys = []
        indicies = []
        for j in range(0, len(seq)):
            try:
                if j > len(p1): # only go to 1024 - a limitation atm
                    mean_probs = 0
                    vars = 1 # Highlight these are incorrect
                else:
                    eps = 1e-8 # For non-zeros
                    all_probs = [p1[j] + eps, p2[j] + eps, p3[j] + eps, p4[j] + eps, p5[j] + eps]
                    mean_probs = np.mean(all_probs)
                    entropy = -((mean_probs * np.log2(mean_probs)) + ((1 - mean_probs) * np.log2(1 - mean_probs)))
                    vars = np.var(all_probs) # use variance as a proxy
                    if mean_probs > mean_prob and vars < mean_var: # Use the supplied cutoffs
                        indicies.append(j)
                mean_values.append(mean_probs)
                variance_values.append(vars)
                entropys.append(entropy)
            except:
                mean_values.append(0)
                variance_values.append(1)
                entropys.append(1)
        means.append(mean_values)
        variances.append(variance_values)
        entropy_values.append(entropys)
        residues.append('|'.join([str(s) for s in indicies]))
    return means, entropy_values, variances, residues


def calculate_stats(df, id_col, true_col, pred_col, seq_col):   
    # Check the agreement
    predictions = []
    true = []
    missing = 0
    for seq_label, res_sq, res_pred, seq in df[[id_col, true_col, pred_col, seq_col]].values:
        res_sq = res_sq.split('|')
        if not res_pred or not isinstance(res_pred, str):
            res_pred = ''
        res_pred = res_pred.split('|')
        if len(res_pred) > 0:
            try:
                chosen_res_seq = [int(i) for i in res_pred]
            except:
                chosen_res_seq = []
                missing += 1
        res_sq = [int(i) for i in res_sq]
        for pos in range(0, len(seq)):
            if pos in res_sq:
                true.append(1)
            else:
                true.append(0)
            if pos in chosen_res_seq:
                predictions.append(1)
            else:
                predictions.append(0)
    precision, recall, f1, support = precision_recall_fscore_support(true, predictions)
    return precision[1], recall[1], f1[1], support[1]


def calculate_stats_uncertainty(df, id_col, true_col, pred_col, seq_col):   
    # Check the agreement:
    rows = []
    for seq_label, res_sq, res_pred, seq, mean_prob, alea, var in df[[id_col, true_col, pred_col, seq_col, 'mean_prob', 'entropy', 'variance']].values:
        if not res_sq:
            missing += 1
        else:
            res_sq = res_sq.split('|')
            res_sq = [int(i) for i in res_sq]
            for pos in range(0, len(seq)):
                if pos in res_sq:
                    rows.append([seq_label, pos, seq[pos], 1, mean_prob[pos], alea[pos], var[pos]])
                else:
                    rows.append([seq_label, pos, seq[pos], 0, mean_prob[pos], alea[pos], var[pos]])

    return pd.DataFrame(rows, columns=['Entry', 'Position', 'Residue', 'True Catalytic', 'Mean Prob', 'Entropy', 'Variance'])


def annotate_residue_from_uniprot(df):
    # Organise the active sites of these guys to be better
    active_sites = []
    active_site_residue_counts = []
    x = 0
    for act_site in df['Active site'].values:
        sites = []
        if isinstance(act_site, str):
            act_site = act_site.replace(" ", '')
            for act in act_site.split('ACT_SITE'):
                try:
                    sites.append(int(act.split(';')[0].replace(" ", '')) - 1) # Need to subtract 1 to make it fit with the fact that python is 0 encoded lol
                except:
                    x = 1
        if len(sites) != 0:
            active_sites.append('|'.join([str(s) for s in sites]))
            active_site_residue_counts.append(len(sites))
        else:
            active_sites.append('None')
            active_site_residue_counts.append(0)
    df['UniProtResidue'] = active_sites
    df['active_site_residue_counts'] = active_site_residue_counts
    return df


swissprot = annotate_residue_from_uniprot(swissprot)
swissprot = swissprot[swissprot['active_site_residue_counts'] > 0]
training_ids = set(pd.read_csv('../manuscript/data/AEGAN_real_training_set.txt', header=None)[0].values)
training_df = swissprot[swissprot['Entry'].isin(training_ids)]
training_df['Residue'] = training_df['UniProtResidue'].values

training_df.to_csv('data/AEGAN_swissprot_training.csv', index=False)
u.dp(['Number of AEGAN training set:', len(training_df)])

[94m--------------------------------------------------------------------------------[0m
[94m                      Number of AEGAN training set:	9888	                       [0m
[94m--------------------------------------------------------------------------------[0m


# For each of the families do the plot of the best range and then also calculate for the cutoff

In [24]:

files = ['PC',
         'NN',
         'EF_superfamily',
         'EF_fold',
         'EF_family',
         'HA_superfamily'] 

rows = []
for model in ['3B']:
    for family in files:
        # First join each of the data frames from the individual runs then compute the uncertainties 
        squidly_ensemble = pd.read_pickle(f'output/families_{model}/{family}/squidly_squidly.pkl')
        if family != 'HA_superfamily': # Already fine
            squidly_ensemble['label'] = [c[2:8] for c in squidly_ensemble['label'].values] # Need to only keep the uniprot IDs
        squidly_ensemble.set_index('label', inplace=True)
        
        test_ids = set(pd.read_csv(f'{data_dir}/{family}/{family}.txt', header=None)[0].values)
        true_df = swissprot[swissprot['Entry'].isin(test_ids)]
        true_df.set_index('Entry', inplace=True)
        true_df = true_df.join(squidly_ensemble, how='left', rsuffix='_')
        true_df['label'] = true_df.index 

        precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'Squidly_Ensemble_Residues', 'Sequence')
        u.err_p(['Squidly:', model, family, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])

        # Also print out the BLAST + squidly results
        true_df = swissprot[swissprot['Entry'].isin(test_ids)]
        true_df.set_index('Entry', inplace=True)
        squidly_blast = pd.read_csv(f'output/families_{model}/{family}/squidly_blast.csv')
        if family != 'HA_superfamily': # Already fine
            squidly_blast['label'] = [c[2:8] for c in squidly_blast['From'].values] # Need to only keep the uniprot IDs
        else:
            squidly_blast['label'] = squidly_blast['From'].values
        squidly_blast.set_index('label', inplace=True)
        true_df = true_df.join(squidly_blast, how='left', rsuffix='_')
        true_df['label'] = true_df.index
        
        precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'BLAST_residues', 'Sequence')

        u.warn_p(['BLAST', model, family, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])

        # Also print out the BLAST + squidly results
        true_df = swissprot[swissprot['Entry'].isin(test_ids)]
        true_df.set_index('Entry', inplace=True)
        squidly_blast_ensemble = pd.read_csv(f'output/families_{model}/{family}/squidly_ensemble.csv')
        if family != 'HA_superfamily': # Already fine
            squidly_blast_ensemble['label'] = [c[2:8] for c in squidly_blast_ensemble['id'].values] # Need to only keep the uniprot IDs
        else:
            squidly_blast_ensemble['label'] = squidly_blast_ensemble['id'].values
        squidly_blast_ensemble.set_index('label', inplace=True)
        true_df = true_df.join(squidly_blast_ensemble, how='left', rsuffix='_')
        true_df['label'] = true_df.index
        
        precision, recall, f1, support = calculate_stats(true_df, 'label', 'UniProtResidue', 'residues', 'Sequence')

        u.dp(['Squidly + BLAST', model, family, f'Precision: {precision}', f'Recall: {recall}', f'F1: {f1}', f'Support: {support}'])



[91m--------------------------------------------------------------------------------[0m
[91mSquidly:	3B	PC	Precision: 0.8571428571428571	Recall: 0.8484848484848485	F1: 0.8527918781725888	Support: 99	[0m
[91m--------------------------------------------------------------------------------[0m
[93m--------------------------------------------------------------------------------[0m
[93mBLAST	3B	PC	Precision: 0.9361702127659575	Recall: 0.8888888888888888	F1: 0.9119170984455959	Support: 99	[0m
[93m--------------------------------------------------------------------------------[0m
[94m--------------------------------------------------------------------------------[0m
[94mSquidly + BLAST	3B	PC	Precision: 0.9361702127659575	Recall: 0.8888888888888888	F1: 0.9119170984455959	Support: 99	[0m
[94m--------------------------------------------------------------------------------[0m
[91m--------------------------------------------------------------------------------[0m
[91mSquidly:	3

# Make the plots as in the previous version of the paper