### import libs
+ Note: sometimes when you update uitl functions, re-run import may not load your change
+ Try restart kernal.

In [None]:
import pandas as pd
import sys
import nest_asyncio
import itertools
nest_asyncio.apply() # for fetch_ec_improved to run in jupyter notebook

from EC40_loader import EC40_loader

from fetch_ec_improved import fetch_ec_async
from evaluate_ec import evaluate_ec

### File Paths

In [7]:
ec40_file_pkl = "../dataset/ec40/ec40.pkl"
ec40_train_file = "../dataset/ec40/ec40_train.csv"
ec40_valid_file = "../dataset/ec40/ec40_valid.csv"
ec40_test_file = "../dataset/ec40/ec40_test.csv"
test_seq_file = "../dataset/diamond_results/test_sequences.fasta"
ec_results_file = "../dataset/diamond_results/test_sequences_ec_results.csv"
diamond_output_file = "../dataset/diamond_results/test_sequences_results.m8"
filtered_output_file = "../dataset/diamond_results/filtered_results.m8"
ec_result_path = "../dataset/diamond_results/test_sequences_ec_results.csv"
evaluate_file = "../dataset/diamond_results/evaluation_results.csv"
metrics_file = "../metrics/metrics.csv"

### Load EC40 Dataset

In [8]:
ec40_loader = EC40_loader(ec40_file_pkl, ec40_train_file, ec40_valid_file, ec40_test_file)
ec40_loader.load_and_split()
ec40_test = ec40_loader.load_test()
ec40_test

Traintest distribution (raw counts):
traintest
0    74399
1     8261
Name: count, dtype: int64

Traintest distribution (proportions):
traintest
0    0.90006
1    0.09994
Name: proportion, dtype: float64
Training Set saved in ../dataset/ec40/ec40_train.csv
Validation Set saved in ../dataset/ec40/ec40_valid.csv
Test Set saved in ../dataset/ec40/ec40_test.csv


Unnamed: 0.1,Unnamed: 0,accession,sequence,ec,traintest,negative_for,mainclass_set,sprot_version,len
0,4130,Q9D975,MGLRAGGALRRAGAGPGAPEGQGPGGAQGGSIHSGCIATVHNVPIA...,['1.8.98.2'],1,Transferases,Transferases,2017_03,1
1,4131,O66651,MEEKKVDLKDTLNLPRTEFPMKANLPQREPQILEKWKGLYEKIQKE...,['6.1.1.5'],1,Transferases,Transferases,2017_03,1
2,4132,Q0U8V9,MANDYPSSDEEIMEAQTGSRKRRKTSSDSESDTAPRAPTATSISRV...,['3.6.4.13'],1,Transferases,Transferases,2017_03,1
3,4133,Q978Z5,MLIYNTLTRRLQEFNEMHRGRVNLFVCGPTVQDHFHIGHARTYIFF...,['6.1.1.16'],1,,Ligases,2017_03,1
4,4134,Q9HKR6,MPQVKVTASAPCSSANLGSGFDTLAIALDAFHDRVTISDHDGFKLT...,['2.7.1.39'],1,,Transferases,2017_03,1
...,...,...,...,...,...,...,...,...,...
4126,8256,A4YD89,MNPVNDIIDSYSAIVYTHKTVGVDKLASHYLGWNEIKELSKYYDGE...,['1.2.1.70'],1,Hydrolases,Hydrolases,2017_03,1
4127,8257,P65179,MPTGSVTVRVPGKVNLYLAVGDRREDGYHELTTVFHAVSLVDEVTV...,['2.7.1.148'],1,,Transferases,2017_03,1
4128,8258,P13079,MAALLKRILRRRMAEKRSGRGRMAAARTTGAQSRKTAQRSGRSEAD...,['2.1.1.-'],1,,Transferases,2017_03,1
4129,8259,B5YHP2,MKIVIASRNRKKIEELKRILQGLEITILSVNDFPELEEVKEDGLTF...,['3.6.1.9'],1,,Hydrolases,2017_03,1


### Prepare test sequence

In [10]:
# Filter for test sequences
# (Adjust the filtering criteria if your CSV uses a different convention;
#  here we assume '0' indicates test sequences.)
test_df = ec40_test

print(f"Found {len(test_df)} test sequences.")

# Write the test sequences to a FASTA file.
# We will use the 'accession' column as the FASTA header and 'sequence' as the sequence.
with open(test_seq_file, "w") as fout:
    for index, row in test_df.iterrows():
        accession = row["accession"]
        sequence = row["sequence"]
        fout.write(f">{accession}\n{sequence}\n")

Found 4131 test sequences.


### Dimond Query

Download DIMOND
+ if not downloaded, uncomment below (linux version)

In [None]:
# linux
# !wget http://github.com/bbuchfink/diamond/releases/download/v2.0.4/diamond-linux64.tar.gz
# !tar xzf diamond-linux64.tar.gz

Prepare DIMOND Database Folder
+ if not downloaded, uncomment below

In [None]:
# !mkdir ../dataset/dimond_db/
# %cd ../dataset/dimond_db/

Download UniRef90
+ if not downloaded, uncomment below

In [None]:
# !wget ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/uniref/uniref90/uniref90.fasta.gz

In [None]:
# %cd ../../experiments

Generate Database
+ it takes roughly 45 mins for 24 cpu threads
+ if not generated uncomment below

In [None]:
# !diamond makedb --in ../dataset/dimond_db/uniref90.fasta.gz -d ../dataset/dimond_db/uniref90.dmnd

Run DIAMOND search 
+ 84 mins for all test_sequences
+ DIAMOND Output format

            qseqid means Query Seq-id
           sseqid means Subject Seq-id
           pident means Percentage of identical matches
           length means Alignment length
         mismatch means Number of mismatches
          gapopen means Number of gap openings
           qstart means Start of alignment in query
             qend means End of alignment in query
           sstart means Start of alignment in subject
             send means End of alignment in subject
           evalue means Expect value
         bitscore means Bit score

In [11]:
!diamond blastp --db ../dataset/dimond_db/uniref90.dmnd \
                --query $test_seq_file \
                --out $diamond_output_file \
                --quiet

diamond v2.0.4.142 (C) Max Planck Society for the Advancement of Science
Documentation, support and updates available at http://www.diamondsearch.org

#CPU threads: 24
Scoring parameters: (Matrix=BLOSUM62 Lambda=0.267 K=0.041 Penalties=11/1)
Temporary directory: ../dataset/diamond_results
Opening the database...  [0.857s]
#Target sequences to report alignments for: 25
Reference = ../dataset/dimond_db/uniref90.dmnd
Sequences = 204806910
Letters = 70410540112
Block size = 2000000000
Opening the input file...  [0.016s]
Opening the output file...  [0.086s]
Loading query sequences...  [0.002s]
Masking queries...  [0.006s]
Building query seed set...  [0.006s]
[1;33mThe host system is detected to have 134 GB of RAM. It is recommended to increase the block size for better performance using these parameters : -b5 -c1
[0;39mAlgorithm: Query-indexed
Building query histograms...  [0s]
Allocating buffers...  [0s]
Loading reference sequences...  [46.251s]
Masking reference...  [4.136s]
Initializin

In [12]:
!head $diamond_output_file

Q9D975	UniRef90_Q9D975	100.0	136	0	0	1	136	1	136	1.9e-68	268.1
Q9D975	UniRef90_UPI002452E298	93.4	136	9	0	1	136	20	155	3.1e-63	250.8
Q9D975	UniRef90_A0A8J6KV60	93.4	136	9	0	1	136	1	136	4.1e-63	250.4
Q9D975	UniRef90_G3I8Z5	92.6	136	10	0	1	136	20	155	1.2e-62	248.8
Q9D975	UniRef90_A0A6I9LVF8	91.2	136	12	0	1	136	20	155	5.9e-62	246.5
Q9D975	UniRef90_A0AAW0HZB7	91.9	136	11	0	1	136	20	155	1.0e-61	245.7
Q9D975	UniRef90_Q9BYN0	89.8	137	13	1	1	136	1	137	8.0e-59	236.1
Q9D975	UniRef90_A0A8C2NLF8	89.1	137	14	1	1	136	188	324	1.0e-58	235.7
Q9D975	UniRef90_A0A8D2AFX9	88.3	137	15	1	1	136	1	137	3.0e-58	234.2
Q9D975	UniRef90_A0A6J0ANF4	87.8	139	14	2	1	136	1	139	6.8e-58	233.0


In [18]:
def filter_diamond_output(diamond_output_file, filtered_output_file, original_query_file):
    """
    Process the Diamond output file to filter records with identity <= 40.
    Additionally, for queries that are in the original input file but do not appear
    in the Diamond output, write a dummy record (all fields set to "-1").

    :param diamond_output_file: Path to the Diamond output file.
    :param filtered_output_file: Path where the filtered output should be written.
    :param original_query_file: Path to the original file containing all query IDs.
                                This file should have one query per line, with the query ID
                                in the first column.
    """
    # Read the original queries into a set.
    original_queries = set()
    with open(original_query_file, "r") as f:
        for line in f:
            # Assumes query ID is the first whitespace-separated token.
            parts = line.strip().split()
            if parts:
                original_queries.add(parts[0])
    
    # Keep track of queries found in the Diamond output.
    found_queries = set()
    
    # Process the Diamond output file.
    with open(diamond_output_file, "r") as infile, open(filtered_output_file, "w") as outfile:
        # Group lines by query ID (assumed to be the first column, tab-delimited)
        for query, group in itertools.groupby(infile, key=lambda line: line.split("\t")[0]):
            found_queries.add(query)
            lines = list(group)
            # Filter for lines where the identity (column 3) is <= 40.
            low_lines = [line for line in lines if float(line.strip().split("\t")[2]) <= 40]
            if low_lines:
                for line in low_lines:
                    outfile.write(line)
            else:
                # No line meets the criterion; write a dummy record for this query.
                dummy_record = [query] + ["-1"] * 11
                outfile.write("\t".join(dummy_record) + "\n")
        
        # Determine which original queries were not found in the Diamond output.
        missing_queries = original_queries - found_queries
        # Write a dummy record for each missing query.
        for query in missing_queries:
            dummy_record = [query] + ["-1"] * 11
            outfile.write("\t".join(dummy_record) + "\n")

In [None]:
filter_diamond_output(diamond_output_file, filtered_output_file, ec40_test_file)

### Fetch EC number from UniProt API

In [14]:
fetch_ec_async(filtered_output_file, ec_result_path)

Parsing DIAMOND output and fetching EC numbers concurrently...


Processing queries: 100%|██████████| 3328/3328 [00:03<00:00, 919.42it/s]  

Saving results to CSV...
Results saved to '../dataset/diamond_results/test_sequences_ec_results.csv'





### Evaluate EC Result

In [None]:
# !mkdir ../metrics

In [16]:
evaluate_ec(ec_results_file, metrics_file, evaluate_file, ec40_test_file)

Loading data...
Matching predictions...


Matching Predictions: 100%|██████████| 4131/4131 [00:00<00:00, 4384.99it/s]


Computing evaluation metrics...
Exact Match Accuracy: 4.36%
First Number Match Accuracy: 4.96%
Per First-Number Accuracy:
  EC 7: 0.00%
  EC 2: 96.05%
  EC 3: 92.86%
  EC 6: 100.00%
  EC 5: 96.67%
  EC 4: 100.00%
  EC 1: 90.91%
Exact Precision: 1.00, Recall: 0.04, F1-Score: 0.08
First Number Precision: 1.00, Recall: 0.05, F1-Score: 0.09
Saved results to ../metrics/metrics.csv and ../dataset/diamond_results/evaluation_results.csv


### Visualize Metrics

In [17]:
metrics = pd.read_csv("../metrics/metrics.csv")
metrics

Unnamed: 0,Method,Exact Match Accuracy,First Number Match Accuracy,Exact Precision,Exact Recall,Exact F1-Score,First Number Precision,First Number Recall,First Number F1-Score,No EC number found,No Prediction,EC 1,EC 2,EC 3,EC 4,EC 5,EC 6,EC 7
0,DIMOND Benchmark,4.357298,4.962479,1.0,0.043573,0.083507,1.0,0.049625,0.094557,0.94626,0.0,90.909091,96.052632,92.857143,100.0,96.666667,100.0,0.0
