# Revision

Reviewers commented that the selection of cutoffs could be improved by providing users with more information. So we opted to improve our clairty in this section.

The goal is to now create the predictions for uniref375 dataset with the variance and the mean prediction values. These will then be used as the cutoffs for the other downstream models - we can see better then whether it generalizes.




In [1]:
import pandas as pd

from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import matplotlib.pyplot as plt

def compute_uncertainties(df, prob_columns, mean_prob=0.5):
    means, aleatorics, epistemics, residues, variance, entropy_values  = [], [], [], [], [], []
    for p1, p2, p3, p4, p5 in tqdm(df[prob_columns].values):
        mean_values = []
        aleatoric_values = []
        epistemic_values = []
        variance_values = []
        entropys = []
        indicies = []
        for j in range(0, len(p1)):
            eps = 1e-8 # For non-zeros
            all_probs = [p1[j] + eps, p2[j] + eps, p3[j] + eps, p4[j] + eps, p5[j] + eps]
            mean_probs = np.mean(all_probs)
            entropy = -((mean_probs * np.log2(mean_probs)) + ((1 - mean_probs) * np.log2(1 - mean_probs)))
            epistemic = np.var(all_probs) # use variance as a proxy
            indicies.append(j)
            mean_values.append(mean_probs)
            epistemic_values.append(epistemic)
            entropys.append(entropy)
        means.append(mean_values)
        epistemics.append(epistemic_values)
        entropy_values.append(entropys)
        
        residues.append('|'.join([str(s) for s in indicies]))
    return means, entropy_values, epistemics, residues
    
def calculate_stats_uncertainty(df, id_col, true_col, pred_col, seq_col):   
    # Check the agreement:
    rows = []
    for seq_label, res_sq, res_pred, seq, mean_prob, alea, var in df[[id_col, true_col, pred_col, seq_col, 'mean_prob', 'entropy', 'variance']].values:
        if not res_sq:
            missing += 1
        else:
            res_sq = res_sq.split('|')
            res_sq = [int(i) for i in res_sq]
            for pos in range(0, len(seq)):
                if pos in res_sq:
                    rows.append([seq_label, pos, seq[pos], 1, mean_prob[pos], alea[pos], var[pos]])
                else:
                    rows.append([seq_label, pos, seq[pos], 0, mean_prob[pos], alea[pos], var[pos]])

    return pd.DataFrame(rows, columns=['Entry', 'Position', 'Residue', 'Catalytic Pred', 'Mean Prob', 'Entropy', 'Variance'])


# Run each model separately so that we can test different strategies for ensembling

Ran the following on each of the first five AEGAN models:

```
 squidly run /disk1/ariane/vscode/squidly/revision/data/uni3175/uni3175_unduplicated.fasta esm2_t36_3B_UR50D output/uni3175_aegan/model1/ cataloDB_3B --as-threshold 0.5 --lstm-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_1/3B/LSTM_3B.pth --cr-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_1/3B/CL_3B.pt --no-ensemble

squidly run /disk1/ariane/vscode/squidly/revision/data/uni3175/uni3175_unduplicated.fasta esm2_t36_3B_UR50D output/uni3175_aegan/model2/ cataloDB_3B --as-threshold 0.5 --lstm-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_2/3B/LSTM_3B.pth --cr-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_2/3B/CL_3B.pt --no-ensemble

squidly run /disk1/ariane/vscode/squidly/revision/data/uni3175/uni3175_unduplicated.fasta esm2_t36_3B_UR50D output/uni3175_aegan/model3/ cataloDB_3B --as-threshold 0.5 --lstm-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_3/3B/LSTM_3B.pth --cr-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_3/3B/CL_3B.pt --no-ensemble

squidly run /disk1/ariane/vscode/squidly/revision/data/uni3175/uni3175_unduplicated.fasta esm2_t36_3B_UR50D output/uni3175_aegan/model4/ cataloDB_3B --as-threshold 0.5 --lstm-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_4/3B/LSTM_3B.pth --cr-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_4/3B/CL_3B.pt --no-ensemble

squidly run /disk1/ariane/vscode/squidly/revision/data/uni3175/uni3175_unduplicated.fasta esm2_t36_3B_UR50D output/uni3175_aegan/model5/ cataloDB_3B --as-threshold 0.5 --lstm-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_5/3B/LSTM_3B.pth --cr-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_5/3B/CL_3B.pt --no-ensemble


```

In [3]:
import os 

for i in range(1, 6):
    # Get the model 
    cr_model = f'../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_{i}/models/temp_best_model.pt'
    lstm_model_name = os.listdir(f'../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_{i}/LSTM/models/')[0]
    lstm_model = f'../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_{i}/LSTM/models/{lstm_model_name}'
    print(f'squidly run data/uni3175/uni3175_unduplicated.fasta esm2_t36_3B_UR50D output/uni3175_aegan/model{i}/ cataloDB_3B --as-threshold 0.5 --lstm-model-as {lstm_model} --cr-model-as {cr_model} --no-ensemble')

squidly run data/uni3175/uni3175_unduplicated.fasta esm2_t36_3B_UR50D output/uni3175_aegan/model1/ cataloDB_3B --as-threshold 0.5 --lstm-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_1/LSTM/models/04-03-25_12-49_128_2_0.2_400_best_model.pth --cr-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_1/models/temp_best_model.pt --no-ensemble
squidly run data/uni3175/uni3175_unduplicated.fasta esm2_t36_3B_UR50D output/uni3175_aegan/model2/ cataloDB_3B --as-threshold 0.5 --lstm-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_2/LSTM/models/04-03-25_14-09_128_2_0.2_400_best_model.pth --cr-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_2/models/temp_best_model.pt --no-ensemble
squidly 

# See what the F1, precision and recall are for different cutoff values

In [None]:
squidly run /disk1/ariane/vscode/squidly/revision/data/uni3175/uni3175_unduplicated.fasta esm2_t36_3B_UR50D output/uni3175/model1/ cataloDB_3B --as-threshold 0.5 --lstm-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_1/3B/LSTM_3B.pth --cr-model-as ../models/FinalModels/CLEANED_reproducing_AEGAN_benchmark_squidly_scheme_3_esm2_t36_3B_UR50D_2025-03-04/Scheme3_16000_1/3B/CL_3B.pt