In [35]:
import pandas as pd
import numpy as np
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    matthews_corrcoef,
)


In [36]:
def calc_metrics(targets, outputs):
    y_true = torch.argmax(targets, dim=1) if targets.ndim > 1 else targets
    y_pred = torch.argmax(outputs, dim=1) if outputs.ndim > 1 else outputs
    # Calculate confusion matrix values: TN, FP, FN, TP
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    # Calculate other metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    try:
        auc = roc_auc_score(y_true, y_pred)
    except ValueError:
        auc = "N/A"  # Handle the case when there are no positive labels

    mcc = matthews_corrcoef(y_true, y_pred)
    # Create a pandas dataframe to store the results
    metrics = {
        "TP": [tp],
        "TN": [tn],
        "FP": [fp],
        "FN": [fn],
        "Accuracy": [accuracy],
        "Precision": [precision],
        "Recall": [recall],
        "F1-score": [f1],
        "AUC": [auc],
        "MCC": [mcc],
    }
    return pd.DataFrame(metrics)

In [11]:
!ls '../data/AMP_new/AMP_2024_09_13__test.tsv'

../data/AMP_new/AMP_2024_09_13__test.tsv


In [53]:
test_df_path = '../data/AMP_new/AMP_2024_09_13__test.tsv'
test_df = pd.read_csv(test_df_path, header=None, sep='\t')
test_df.columns = ['Name', 'Length', 'Sequence', 'AMP']
test_df

Unnamed: 0,Name,Length,Sequence,AMP
0,CAMP_CAMPSQ11657,97,MAAPRLTLSVVFVLMLAFITLSEGLRGTGPKKCCFRFHESPVQKER...,1
1,dbAMP_02113,24,FLPLLAGLAANFLPKLFCKITKKC,1
2,sp|B2IUW7|RPOZ_NOSP7,78,MLKRSKFETTQSQIMHRAEELISAASNRYRITVQVANRAKRRRYED...,0
3,DBAASPS_15750,27,EQKIEELLKKAEEQQKKNEEELKKLEK,1
4,DBAASPS_14432,14,LRRHASEGGHGPHW,1
...,...,...,...,...
5516,sp|Q48PD1|MDCC_PSE14,99,METLSFEFPAGQPPKGRALVGVVGSGDLEVLLEPGSPGKLSIQVVT...,0
5517,sp|Q44141|ANRX_NOSS1,76,MTQKRPYEHRKAQKQVKNLESYQCMVCWEVNSKANGHHLIPYSEGG...,0
5518,sp|Q11XD0|RL33_CYTH3,60,MAKKGNRIQVILECTEHKATGLAGTSRHITTKNRKNTPERIELKKY...,0
5519,sp|A1AIK8|YJBT_ECOK1,92,MKRNLIKVVKMKSYFAALMLSVSVLPAYAGPLGTADKADLPQSNVS...,0


In [13]:
# Open the output FASTA file for writing
output_fasta = test_df_path.replace('.tsv', '.fa')
with open(output_fasta, 'w') as fasta_file:
    for _, row in df.iterrows():
        sequence_id = row[0]  # First column
        protein_sequence = row[2]  # Third column
        fasta_file.write(f">{sequence_id}\n{protein_sequence}\n")

In [42]:
# Read Macrel
macrel_df = pd.read_csv('../data/existing_AMP_tools_results/macrel.out.tsv', sep='\t', header=1)
macrel_df.loc[:, 'AMP'] = macrel_df['is_AMP'].apply(lambda x: 1 if x else 0)
macrel_df

Unnamed: 0,Access,Sequence,AMP_family,is_AMP,AMP_probability,Hemolytic,Hemolytic_probability,AMP
0,CAMP_CAMPSQ11657,AAPRLTLSVVFVLMLAFITLSEGLRGTGPKKCCFRFHESPVQKERV...,CDP,False,0.000,NonHemo,0.376,0
1,dbAMP_02113,FLPLLAGLAANFLPKLFCKITKKC,CDP,True,1.000,Hemo,1.000,1
2,sp|B2IUW7|RPOZ_NOSP7,LKRSKFETTQSQIMHRAEELISAASNRYRITVQVANRAKRRRYEDF...,CLP,False,0.000,NonHemo,0.079,0
3,DBAASPS_15750,EQKIEELLKKAEEQQKKNEEELKKLEK,ALP,False,0.337,NonHemo,0.069,0
4,DBAASPS_14432,LRRHASEGGHGPHW,CLP,False,0.406,NonHemo,0.218,0
...,...,...,...,...,...,...,...,...
5516,sp|Q48PD1|MDCC_PSE14,ETLSFEFPAGQPPKGRALVGVVGSGDLEVLLEPGSPGKLSIQVVTS...,ALP,False,0.000,NonHemo,0.109,0
5517,sp|Q44141|ANRX_NOSS1,TQKRPYEHRKAQKQVKNLESYQCMVCWEVNSKANGHHLIPYSEGGS...,CDP,False,0.020,NonHemo,0.257,0
5518,sp|Q11XD0|RL33_CYTH3,AKKGNRIQVILECTEHKATGLAGTSRHITTKNRKNTPERIELKKYN...,CDP,False,0.000,NonHemo,0.386,0
5519,sp|A1AIK8|YJBT_ECOK1,KRNLIKVVKMKSYFAALMLSVSVLPAYAGPLGTADKADLPQSNVSS...,CDP,False,0.050,NonHemo,0.228,0


In [47]:
# Read AMPScaner
amp_scaner_df = pd.read_csv('../data/existing_AMP_tools_results/AMP_scaner_v2.csv')
amp_scaner_df.loc[:, 'AMP'] = amp_scaner_df['Prediction_Class'].apply(lambda x: 1 if x=='AMP' else 0)
amp_scaner_df

Unnamed: 0,SeqID,Prediction_Class,Prediction_Probability,Sequence,AMP
0,CAMP_CAMPSQ11657,AMP,0.9981,MAAPRLTLSVVFVLMLAFITLSEGLRGTGPKKCCFRFHESPVQKER...,1
1,dbAMP_02113,AMP,1.0000,FLPLLAGLAANFLPKLFCKITKKC,1
2,sp|B2IUW7|RPOZ_NOSP7,Non-AMP,0.0003,MLKRSKFETTQSQIMHRAEELISAASNRYRITVQVANRAKRRRYED...,0
3,DBAASPS_15750,Non-AMP,0.0006,EQKIEELLKKAEEQQKKNEEELKKLEK,0
4,DBAASPS_14432,Non-AMP,0.0038,LRRHASEGGHGPHW,0
...,...,...,...,...,...
5516,sp|Q48PD1|MDCC_PSE14,Non-AMP,0.0011,METLSFEFPAGQPPKGRALVGVVGSGDLEVLLEPGSPGKLSIQVVT...,0
5517,sp|Q44141|ANRX_NOSS1,AMP,0.5057,MTQKRPYEHRKAQKQVKNLESYQCMVCWEVNSKANGHHLIPYSEGG...,1
5518,sp|Q11XD0|RL33_CYTH3,Non-AMP,0.0891,MAKKGNRIQVILECTEHKATGLAGTSRHITTKNRKNTPERIELKKY...,0
5519,sp|A1AIK8|YJBT_ECOK1,Non-AMP,0.0730,MKRNLIKVVKMKSYFAALMLSVSVLPAYAGPLGTADKADLPQSNVS...,0


In [52]:
ds_dict = {'Macrel': macrel_df, 'AMPScaner': amp_scaner_df}

In [62]:
res_df = pd.DataFrame()
for tool, df in ds_dict.items():
    out_df = calc_metrics(targets=test_df['AMP'].values, outputs=df['AMP'].values)
    out_df.index = [tool]
    res_df = pd.concat([res_df, out_df])
res_df

Unnamed: 0,TP,TN,FP,FN,Accuracy,Precision,Recall,F1-score,AUC,MCC
Macrel,1461,2996,25,1039,0.807281,0.983176,0.5844,0.733066,0.788062,0.646592
AMPScaner,2006,1908,1113,494,0.70893,0.643155,0.8024,0.714006,0.716989,0.435732


In [92]:
res_df = pd.DataFrame()
for tool, df in ds_dict.items():
    out_df = calc_metrics(targets=test_df['AMP'].values, outputs=df['AMP'].values).T
    out_df.columns = ['score']
    # out_df.index.name = ['metric']
    out_df['tool'] = tool
    out_df.reset_index(inplace=True)
    out_df['dataset'] = 'AMP global'
    out_df = out_df[['tool', 'score', 'dataset', 'index']]
    out_df.loc[4:, 'score'] = out_df.loc[4:, 'score'].apply(lambda x: np.round(x * 100, 2))
    res_df = pd.concat([res_df, out_df])

In [94]:
res_df

Unnamed: 0,tool,score,dataset,index
0,Macrel,1461.0,AMP global,TP
1,Macrel,2996.0,AMP global,TN
2,Macrel,25.0,AMP global,FP
3,Macrel,1039.0,AMP global,FN
4,Macrel,80.73,AMP global,Accuracy
5,Macrel,98.32,AMP global,Precision
6,Macrel,58.44,AMP global,Recall
7,Macrel,73.31,AMP global,F1-score
8,Macrel,78.81,AMP global,AUC
9,Macrel,64.66,AMP global,MCC
