In [19]:
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    auc, roc_curve,
    matthews_corrcoef
)
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import numpy as np

def get_metrics(y_test, y_test_predictions, y_probs):
    accuracy = accuracy_score(y_test, y_test_predictions)
    precision = precision_score(y_test, y_test_predictions)
    recall = recall_score(y_test, y_test_predictions)
    f1score = f1_score(y_test, y_test_predictions)
    #auc = roc_auc_score(y_test, y_test_predictions)
    mcc = matthews_corrcoef(y_test, y_test_predictions)
    conf_matrix = confusion_matrix(y_test, y_test_predictions)
    #auc_val = roc_auc_score(y_test, y_probs)
    fpr, tpr, thresholds = roc_curve(y_test, y_probs, pos_label = 1)
    auc_val = auc(fpr, tpr)  

    return {"accuracy":accuracy, "precision":precision, "recall":recall, "f1score":f1score, "auc":auc_val, "mcc":mcc}




# Join all predictions

Hacemos una comparación de los modelos ESM2: t6, t12, t30 y t33; tape y protbert (entrenados con 3 epochs)

In [20]:
# junta todos los archivos de predicciones en uno solo

import pandas as pd
import numpy as np

def softmax(logits):
    return (np.exp(logits) / np.exp(logits).sum() )

data = pd.read_csv("dataset/hlab/hlab_test2.csv")

types_train = ['', '_acc_steps', '_freeze', '_freeze_acc_steps']
types_models = ['esm2_t6', 'esm2_t12', 'esm2_t30', 'esm2_t33', 'tape', 'protbert']

for model in types_models:
    for type_t in types_train:
        print("loading", "predictions/" + model + "_rnn"  + type_t + ".csv")
        tmp_data = pd.read_csv("predictions/" + model + "_rnn"  + type_t + ".csv", index_col=0)  
        #print(tmp_data.head(3))
        data[model+type_t+'_prob'] = tmp_data.apply(lambda row: ( softmax([row[0], row[1]])[1] ), axis=1)
        data[model+type_t+'_pred'] = tmp_data["prediction"]  

print(data.head(5))  

loading predictions/esm2_t6_rnn.csv
loading predictions/esm2_t6_rnn_acc_steps.csv
loading predictions/esm2_t6_rnn_freeze.csv
loading predictions/esm2_t6_rnn_freeze_acc_steps.csv
loading predictions/esm2_t12_rnn.csv
loading predictions/esm2_t12_rnn_acc_steps.csv
loading predictions/esm2_t12_rnn_freeze.csv
loading predictions/esm2_t12_rnn_freeze_acc_steps.csv
loading predictions/esm2_t30_rnn.csv
loading predictions/esm2_t30_rnn_acc_steps.csv
loading predictions/esm2_t30_rnn_freeze.csv
loading predictions/esm2_t30_rnn_freeze_acc_steps.csv
loading predictions/esm2_t33_rnn.csv
loading predictions/esm2_t33_rnn_acc_steps.csv
loading predictions/esm2_t33_rnn_freeze.csv
loading predictions/esm2_t33_rnn_freeze_acc_steps.csv
loading predictions/tape_rnn.csv
loading predictions/tape_rnn_acc_steps.csv
loading predictions/tape_rnn_freeze.csv
loading predictions/tape_rnn_freeze_acc_steps.csv
loading predictions/protbert_rnn.csv
loading predictions/protbert_rnn_acc_steps.csv
loading predictions/protbe

In [21]:
data = data.sort_values('id')
data.to_csv("predictions/all_preds.csv", index=False)

# Comparison of AUC for all combinations

In [23]:
# plot barplot for each model

from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def get_auc_models(data, models):
    aucs_models = {}    

    for model in models:
        metrics = get_metrics(data['Label'], data[model+'_pred'], data[model+'_prob']) 
        aucs_models[model] = metrics["auc"]
    return aucs_models


def plot_metrics_by_type(auc_models, file_name):    

    metrics = ("Normal", "Freeze", "GAS", "Freeze-GAS")
    results = {
        'ESM2(t6)': (auc_models["esm2_t6"], auc_models["esm2_t6_freeze"], auc_models["esm2_t6_acc_steps"], auc_models["esm2_t6_freeze_acc_steps"]),
        'ESM2(t12)': (auc_models["esm2_t12"], auc_models["esm2_t12_freeze"], auc_models["esm2_t12_acc_steps"], auc_models["esm2_t12_freeze_acc_steps"]),
        'ESM2(t30)': (auc_models["esm2_t30"], auc_models["esm2_t30_freeze"], auc_models["esm2_t30_acc_steps"], auc_models["esm2_t30_freeze_acc_steps"]),
        'ESM2(t33)': (auc_models["esm2_t33"], auc_models["esm2_t33_freeze"], auc_models["esm2_t33_acc_steps"], auc_models["esm2_t33_freeze_acc_steps"]),
        'TAPE': (auc_models["tape"], auc_models["tape_freeze"], auc_models["tape_acc_steps"], auc_models["tape_freeze_acc_steps"]),
        'ProtBert': (auc_models["protbert"], auc_models["protbert_freeze"], auc_models["protbert_acc_steps"], auc_models["protbert_freeze_acc_steps"]),
        
    }

    colors = {'ESM2(t6)':'#0C4483', 'ESM2(t12)':'#0A6AAE', 'ESM2(t30)':'#2C8DBE', 'ESM2(t33)':'#50B6D5', 'TAPE':'#7ECFC8', 'ProtBert':'#AADEB5'}
    #colors = {'TAPE-gas':'#0C4483', 'ESM2(t6)-fz':'#0A6AAE', 
    #          'Anthem':'#2C8DBE', 'NetMHCpan4.1':'#50B6D5', 'Acme':'#7ECFC8', 'MixMHCpred':'#AADEB5','MHCflurry':'#CDECC4'}

    x = np.arange(len(metrics))  # the label locations
    width = 0.15  # the width of the bars
    multiplier = 0

    fig, ax = plt.subplots(layout='constrained')

    for attribute, measurement in results.items():
        offset = width * multiplier
        rects = ax.bar(x + offset - 0.2, measurement, width, label=attribute, color=colors[attribute])
        #ax.bar_label(rects, padding=3) # agrega el valor arriba de cada barra
        multiplier += 1

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel('AUC')
    #ax.set_title('guin attributes by species')
    ax.set_xticks(x + width, metrics)
    ax.legend(loc='upper left', ncols=4)
    ax.set_ylim(0.4, 1.1)

    plt.savefig("plots/" + file_name, dpi=300, bbox_inches='tight')
    #plt.show()
    plt.clf()

def plot_metrics_by_model(auc_models, file_name):    

    metrics = ("Normal", "Freeze", "GAS", "Freeze-GAS")
    metrics = ("ESM2(t6)", "ESM2(t12)", "ESM2(t30)", "ESM2(t33)", 'TAPE', 'ProtBert')
    results = {
        'Normal': (auc_models["esm2_t6"], auc_models["esm2_t12"], auc_models["esm2_t30"], auc_models["esm2_t33"], auc_models["tape"], auc_models["protbert"]),
        'Freeze': (auc_models["esm2_t6_freeze"], auc_models["esm2_t12_freeze"], auc_models["esm2_t30_freeze"], auc_models["esm2_t33_freeze"], auc_models["tape_freeze"], auc_models["protbert_freeze"]),
        'GAS': (auc_models["esm2_t6_acc_steps"], auc_models["esm2_t12_acc_steps"], auc_models["esm2_t30_acc_steps"], auc_models["esm2_t33_acc_steps"], auc_models["tape_acc_steps"], auc_models["protbert_acc_steps"]),
        'Freeze-GAS': (auc_models["esm2_t6_freeze_acc_steps"], auc_models["esm2_t12_freeze_acc_steps"], auc_models["esm2_t30_freeze_acc_steps"], auc_models["esm2_t33_freeze_acc_steps"], auc_models["tape_freeze_acc_steps"], auc_models["protbert_freeze_acc_steps"]),
        
        
    }

    colors = {'Normal':'#0C4483', 'Freeze':'#0A6AAE', 'GAS':'#2C8DBE', 'Freeze-GAS':'#50B6D5'}
    #colors = {'ESM2(t6)':'#0C4483', 'ESM2(t12)':'#0A6AAE', 'ESM2(t30)':'#2C8DBE', 'ESM2(t33)':'#50B6D5', 'TAPE':'#7ECFC8', 'ProtBert':'#AADEB5'}

    x = np.arange(len(metrics))  # the label locations
    width = 0.2  # the width of the bars
    multiplier = 0

    fig, ax = plt.subplots(layout='constrained')

    for attribute, measurement in results.items():
        offset = width * multiplier
        rects = ax.bar(x + offset, measurement, width, label=attribute, color=colors[attribute])
        #ax.bar_label(rects, padding=3) # agrega el valor arriba de cada barra
        multiplier += 1

    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel('AUC')
    #ax.set_title('guin attributes by species')
    ax.set_xticks(x + width + 0.1, metrics)
    ax.legend(loc='upper left', ncols=4)
    ax.set_ylim(0.45, 1.05)

    plt.savefig("plots/" + file_name, dpi=300, bbox_inches='tight')
    #plt.show()
    plt.clf()


data = pd.read_csv("predictions/all_preds.csv")

models = ['esm2_t6', 'esm2_t6_freeze', 'esm2_t6_acc_steps', 'esm2_t6_freeze_acc_steps',
          'esm2_t12', 'esm2_t12_freeze', 'esm2_t12_acc_steps', 'esm2_t12_freeze_acc_steps',
          'esm2_t30', 'esm2_t30_freeze', 'esm2_t30_acc_steps', 'esm2_t30_freeze_acc_steps',
          'esm2_t33', 'esm2_t33_freeze', 'esm2_t33_acc_steps', 'esm2_t33_freeze_acc_steps',
          'tape', 'tape_freeze', 'tape_acc_steps', 'tape_freeze_acc_steps',
          'protbert', 'protbert_freeze', 'protbert_acc_steps', 'protbert_freeze_acc_steps']

auc_models = get_auc_models(data, models)
plot_metrics_by_type(auc_models, "metrics_comparion_by_type.png")
plot_metrics_by_model(auc_models, "metrics_comparion_by_model.png")
#print(auc_models)


#plot_metrics({'tape':metrics_tape, 'anthem':metrics_anthem, 'acme':metrics_acme,
#              'netmhcpan4.1':metrics_netmhcpan, 'esm2':metrics_esm2t6}, "metrics_comparison.png")

<Figure size 640x480 with 0 Axes>

<Figure size 640x480 with 0 Axes>

# Getting all metrics

In [22]:
data = pd.read_csv("predictions/all_preds.csv")

models = ['esm2_t6', 'esm2_t6_freeze', 'esm2_t6_acc_steps', 'esm2_t6_freeze_acc_steps',
          'esm2_t12', 'esm2_t12_freeze', 'esm2_t12_acc_steps', 'esm2_t12_freeze_acc_steps',
          'esm2_t30', 'esm2_t30_freeze', 'esm2_t30_acc_steps', 'esm2_t30_freeze_acc_steps',
          'esm2_t33', 'esm2_t33_freeze', 'esm2_t33_acc_steps', 'esm2_t33_freeze_acc_steps',
          'tape', 'tape_freeze', 'tape_acc_steps', 'tape_freeze_acc_steps',
          'protbert', 'protbert_freeze', 'protbert_acc_steps', 'protbert_freeze_acc_steps']
#print(data.head(5))

column_names = {"esm2_t6": "ESM2(t6)", "esm2_t6_freeze": "ESM2(t6)-Freeze", "esm2_t6_acc_steps": "ESM2(t6)-GAS", "esm2_t6_freeze_acc_steps": "ESM2(t6)-Freeze-GAS",
                "esm2_t12": "ESM2(t12)", "esm2_t12_freeze": "ESM2(t12)-Freeze", "esm2_t12_acc_steps": "ESM2(t12)-GAS", "esm2_t12_freeze_acc_steps": "ESM2(t12)-Freeze-GAS",
                "esm2_t30": "ESM2(t30)", "esm2_t30_freeze": "ESM2(t30)-Freeze", "esm2_t30_acc_steps": "ESM2(t30)-GAS", "esm2_t30_freeze_acc_steps": "ESM2(t30)-Freeze-GAS",
                "esm2_t33": "ESM2(t33)", "esm2_t33_freeze": "ESM2(t33)-Freeze", "esm2_t33_acc_steps": "ESM2(t33)-GAS", "esm2_t33_freeze_acc_steps": "ESM2(t33)-Freeze-GAS",
                "tape": "TAPE", "tape_freeze": "TAPE-Freeze", "tape_acc_steps": "TAPE-GAS", "tape_freeze_acc_steps": "TAPE-Freeze-GAS",
                "protbert": "ProtBert", "protbert_freeze": "ProtBert-Freeze", "protbert_acc_steps": "ProtBert-GAS", "protbert_freeze_acc_steps": "ProtBert-Freeze-GAS"}
def get_all_metrics():
    pd_metrics = pd.DataFrame()
    for model_name in models:
        metrics = get_metrics(data["Label"], data[model_name + "_pred"], data[model_name + "_prob"])
        pd_metrics[model_name] = list(metrics.values())
    
    
    pd_metrics = pd_metrics.rename(columns=column_names)
    pd_metrics.index = ['Accuracy', 'Precision', 'Recall', 'F1-score', 'AUC', 'MCC']
    pd_metrics = pd_metrics.T     
    return pd_metrics

metrics_pd = get_all_metrics()
print(metrics_pd.head(24))    
metrics_pd.to_csv("metrics_3_epochs.csv", index=1)

                      Accuracy  Precision    Recall  F1-score       AUC  \
ESM2(t6)              0.934448   0.933390  0.935373  0.934380  0.980460   
ESM2(t6)-Freeze       0.935103   0.925277  0.946359  0.935699  0.981208   
ESM2(t6)-GAS          0.898621   0.896613  0.900674  0.898639  0.960226   
ESM2(t6)-Freeze-GAS   0.886876   0.891343  0.880629  0.885954  0.952008   
ESM2(t12)             0.932675   0.924320  0.942213  0.933181  0.979930   
ESM2(t12)-Freeze      0.934419   0.925051  0.945140  0.934988  0.980781   
ESM2(t12)-GAS         0.901049   0.927921  0.869202  0.897602  0.965547   
ESM2(t12)-Freeze-GAS  0.880496   0.855601  0.914900  0.884257  0.947499   
ESM2(t30)             0.498957   0.498957  1.000000  0.665739  0.499899   
ESM2(t30)-Freeze      0.930258   0.918469  0.944025  0.931072  0.978608   
ESM2(t30)-GAS         0.908970   0.916686  0.899292  0.907906  0.967455   
ESM2(t30)-Freeze-GAS  0.856542   0.815603  0.920625  0.864938  0.931192   
ESM2(t33)             0.4

# AUC por k-mer y HLA

Obtiene el AUC por k-mer y por HLA. Todo se guarda en un diccionario

In [49]:
data = pd.read_csv("predictions/all_preds.csv")

def get_auc(data_by_kmer, models):
    hlas = data_by_kmer['HLA'].unique()

    aucs_models = {}
    for model in models:
        aucs_models[model] = []
            
    for hla in hlas:
        data_by_kmer_hla = data_by_kmer[data_by_kmer['HLA'] == hla]
        
        #metrics = get_metrics(data_by_kmer_hla['Label'], data_by_kmer_hla['esm2_t6_pred'], data_by_kmer_hla['esm2_t6_prob']) 
        #aucs.append(metrics["auc"])
        for model in models:
            metrics = get_metrics(data_by_kmer_hla['Label'], data_by_kmer_hla[model+'_pred'], data_by_kmer_hla[model+'_prob']) 
            aucs_models[model].append(metrics["auc"])

    #print(aucs_models)
    return aucs_models

# k-mer
models = ['esm2_t6', 'esm2_t6_freeze', 'esm2_t6_acc_steps', 'esm2_t6_freeze_acc_steps',
          'esm2_t12', 'esm2_t12_freeze', 'esm2_t12_acc_steps', 'esm2_t12_freeze_acc_steps',
          'esm2_t30', 'esm2_t30_freeze', 'esm2_t30_acc_steps', 'esm2_t30_freeze_acc_steps',
          'esm2_t33', 'esm2_t33_freeze', 'esm2_t33_acc_steps', 'esm2_t33_freeze_acc_steps',
          'tape', 'tape_freeze', 'tape_acc_steps', 'tape_freeze_acc_steps',
          'protbert', 'protbert_freeze', 'protbert_acc_steps', 'protbert_freeze_acc_steps']
total_aucs = {}
for i in range(8,15):
    data_by_kmer = data[data['Length'] == i] 
    total_aucs[i] = get_auc(data_by_kmer, models)



In [83]:
for k in range(8,15):
    # plotting
    aucs_models = total_aucs[k]
    datat6 = [ aucs_models['esm2_t6'], aucs_models['esm2_t6_freeze'], aucs_models['esm2_t6_acc_steps'],aucs_models['esm2_t6_freeze_acc_steps']]
    datat12 = [ aucs_models['esm2_t12'], aucs_models['esm2_t12_freeze'], aucs_models['esm2_t12_acc_steps'],aucs_models['esm2_t12_freeze_acc_steps']]
    datat30 = [ aucs_models['esm2_t30_freeze'], aucs_models['esm2_t30_acc_steps'],aucs_models['esm2_t12_freeze_acc_steps']]
    datat33 = [ aucs_models['esm2_t33_freeze'], aucs_models['esm2_t33_acc_steps'],aucs_models['esm2_t33_freeze_acc_steps']]
       
    #labels = ['Normal', 'Freeze', 'GAS', 'Freeze-GAS']
    labels = ['', '', '', '']
    palette = ['lightgreen', '#3470E0', 'y', '#3DCACE']
    bp = plt.boxplot(datat6, labels=labels, positions=[1, 2, 3, 4], showfliers = False, patch_artist=True)
    for i, box in enumerate(bp['boxes']):
        box.set(color="black")
        box.set(facecolor = palette[i] )

    bp = plt.boxplot(datat12, labels=labels, positions=[6, 7, 8, 9], showfliers = False, patch_artist=True)
    for i, box in enumerate(bp['boxes']):
        box.set(color="black")
        box.set(facecolor = palette[i] )

    bp = plt.boxplot(datat30, labels=labels[1:len(labels)], positions=[11, 12, 13], showfliers = False, patch_artist=True)
    for i, box in enumerate(bp['boxes']):
        box.set(color="black")
        box.set(facecolor = palette[i+1] )

    bp = plt.boxplot(datat33, labels=labels[1:len(labels)], positions=[15, 16, 17], showfliers = False, patch_artist=True)
    for i, box in enumerate(bp['boxes']):
        box.set(color="black")
        box.set(facecolor = palette[i+1] )

    #plt.show()
    #plt.clf()
    plt.savefig("plots/auc_distribution_esm2t6_" + str(k) + "-mer", dpi=300, bbox_inches='tight')
    plt.clf()

<Figure size 640x480 with 0 Axes>