In [8]:
from collections import Counter
from ete3 import NCBITaxa
from functools import lru_cache

ncbi = NCBITaxa()

def LCA_to_dict(LCA):
    res = Counter()
    for item in LCA.split(" "):
        tup = item.split(":")
        res[int(tup[0])] += int(tup[1])
    return res

@lru_cache(maxsize=1000)
def get_level(ID, level='genus'):
    """
    Find the genus of a given species ID.
    """
    if ID is None:
        return None
        
    # Get the lineage of the species
    try:
        lineage = ncbi.get_lineage(int(ID))
        ranks = ncbi.get_rank(lineage)
    except:
        return None

    #print(ranks)
    for i in ranks:
        if ranks[i] == level:
            return i
        
    return None


def check_LCA(LCA, ground_truth, level='genus'):
    prediction = LCA_to_dict(LCA)
    acceptable_prediction = 0
    for species in prediction:
        if species != 0:
            if get_level(ground_truth, level) == get_level(species, level):
                acceptable_prediction += prediction[species]
    return acceptable_prediction


def num_unclassified(LCA):
    prediction = LCA_to_dict(LCA)
    return prediction[0]

In [20]:
import pandas as pd

metadata_df = pd.read_csv("../../bac120_metadata.tsv", sep="\t", index_col=0)
metadata_df

Unnamed: 0_level_0,ambiguous_bases,checkm2_completeness,checkm2_contamination,checkm2_model,checkm_completeness,checkm_contamination,checkm_marker_count,checkm_marker_lineage,checkm_marker_set_count,checkm_strain_heterogeneity,...,ssu_silva_blast_align_len,ssu_silva_blast_bitscore,ssu_silva_blast_evalue,ssu_silva_blast_perc_identity,ssu_silva_blast_subject_id,ssu_silva_taxonomy,total_gap_length,trna_aa_count,trna_count,trna_selenocysteine_count
accession,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
RS_GCF_000657795.2,0,100.00,0.14,Specific,99.53,0.00,426,o__Burkholderiales (UID4000),213,0.00,...,1528,2822,0,100,JHEP02000033.784.2325,Bacteria;Proteobacteria;Gammaproteobacteria;Bu...,0,19,55,1
RS_GCF_001072555.1,7,100.00,0.52,Specific,99.81,0.09,773,g__Staphylococcus (UID294),178,0.00,...,896,1655,0,100,CP030246.976624.978176,Bacteria;Firmicutes;Bacilli;Staphylococcales;S...,700,14,36,0
RS_GCF_003050715.1,0,100.00,0.04,Specific,99.60,0.22,769,g__Burkholderia (UID4006),248,0.00,...,1530,2826,0,100,CP012193.2238151.2239686,Bacteria;Proteobacteria;Gammaproteobacteria;Bu...,585,19,52,0
RS_GCF_016772635.1,0,100.00,0.16,Specific,100.00,0.04,1169,f__Enterobacteriaceae (UID5139),340,0.00,...,1538,2841,0,100,CP032396.2469808.2471361,Bacteria;Proteobacteria;Gammaproteobacteria;En...,0,19,86,1
GB_GCA_000615405.1,0,100.00,1.37,Specific,99.62,0.57,471,o__Lactobacillales (UID543),264,0.00,...,1545,2854,0,100,CP020604.1893382.1894929,Bacteria;Firmicutes;Bacilli;Lactobacillales;St...,0,17,45,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GB_GCA_949039885.1,0,98.92,0.35,General,98.79,3.22,420,f__Lachnospiraceae (UID1286),207,11.11,...,none,none,none,none,none,none,200,16,31,0
GB_GCA_905234525.1,14,98.73,1.30,General,90.12,1.09,420,f__Lachnospiraceae (UID1286),207,33.33,...,none,none,none,none,none,none,1001,18,35,0
GB_GCA_948663365.1,8,76.65,0.14,Specific,76.98,0.38,406,o__Bacteroidales (UID2617),265,0.00,...,none,none,none,none,none,none,184,17,26,0
GB_GCA_948940555.1,0,99.99,0.11,Specific,97.55,0.13,406,o__Bacteroidales (UID2617),265,0.00,...,none,none,none,none,none,none,109,18,42,0


In [21]:
# Preprocess metadata

# Add the ncbi_accession column
ncbi_accession = [i[7:] for i in metadata_df.index]
metadata_df["ncbi_accession"] = ncbi_accession

In [12]:
metadata_df

Unnamed: 0,ambiguous_bases,checkm2_completeness,checkm2_contamination,checkm2_model,checkm_completeness,checkm_contamination,checkm_marker_count,checkm_marker_lineage,checkm_marker_set_count,checkm_strain_heterogeneity,...,ssu_silva_blast_bitscore,ssu_silva_blast_evalue,ssu_silva_blast_perc_identity,ssu_silva_blast_subject_id,ssu_silva_taxonomy,total_gap_length,trna_aa_count,trna_count,trna_selenocysteine_count,ncbi_accession
GCF_000657795.2,0,100.00,0.14,Specific,99.53,0.00,426,o__Burkholderiales (UID4000),213,0.00,...,2822,0,100,JHEP02000033.784.2325,Bacteria;Proteobacteria;Gammaproteobacteria;Bu...,0,19,55,1,GCF_000657795.2
GCF_001072555.1,7,100.00,0.52,Specific,99.81,0.09,773,g__Staphylococcus (UID294),178,0.00,...,1655,0,100,CP030246.976624.978176,Bacteria;Firmicutes;Bacilli;Staphylococcales;S...,700,14,36,0,GCF_001072555.1
GCF_003050715.1,0,100.00,0.04,Specific,99.60,0.22,769,g__Burkholderia (UID4006),248,0.00,...,2826,0,100,CP012193.2238151.2239686,Bacteria;Proteobacteria;Gammaproteobacteria;Bu...,585,19,52,0,GCF_003050715.1
GCF_016772635.1,0,100.00,0.16,Specific,100.00,0.04,1169,f__Enterobacteriaceae (UID5139),340,0.00,...,2841,0,100,CP032396.2469808.2471361,Bacteria;Proteobacteria;Gammaproteobacteria;En...,0,19,86,1,GCF_016772635.1
GCA_000615405.1,0,100.00,1.37,Specific,99.62,0.57,471,o__Lactobacillales (UID543),264,0.00,...,2854,0,100,CP020604.1893382.1894929,Bacteria;Firmicutes;Bacilli;Lactobacillales;St...,0,17,45,0,GCA_000615405.1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
GCA_949039885.1,0,98.92,0.35,General,98.79,3.22,420,f__Lachnospiraceae (UID1286),207,11.11,...,none,none,none,none,none,200,16,31,0,GCA_949039885.1
GCA_905234525.1,14,98.73,1.30,General,90.12,1.09,420,f__Lachnospiraceae (UID1286),207,33.33,...,none,none,none,none,none,1001,18,35,0,GCA_905234525.1
GCA_948663365.1,8,76.65,0.14,Specific,76.98,0.38,406,o__Bacteroidales (UID2617),265,0.00,...,none,none,none,none,none,184,17,26,0,GCA_948663365.1
GCA_948940555.1,0,99.99,0.11,Specific,97.55,0.13,406,o__Bacteroidales (UID2617),265,0.00,...,none,none,none,none,none,109,18,42,0,GCA_948940555.1


In [28]:
metadata_df["ncbi_taxid"]

accession
RS_GCF_000657795.2    1331258
RS_GCF_001072555.1       1282
RS_GCF_003050715.1    2135698
RS_GCF_016772635.1      90371
GB_GCA_000615405.1    1236944
                       ...   
GB_GCA_949039885.1     297314
GB_GCA_905234525.1     297314
GB_GCA_948663365.1    2301481
GB_GCA_948940555.1    2301481
GB_GCA_000753355.2    1499689
Name: ncbi_taxid, Length: 584382, dtype: int64

In [29]:
from functools import lru_cache

@lru_cache(maxsize=100)
def accession_to_taxid(accession):
    sub_accession = accession[4:]
    try:
        return metadata_df[metadata_df["ncbi_accession"] == sub_accession]["ncbi_taxid"][0]
    except:
        return None

In [40]:
import numpy as np

def analyze_kraken_results(kraken2_output_file, ground_truth_file):
    # Read the kraken2 output file
    kraken2_predictions = pd.read_csv(kraken2_output_file, sep='\t', header=None)
    kraken2_predictions.rename(columns={0: 'classified', 1: 'read_id', 2: 'classification_result', 3: 'read_length', 4: 'LCA'}, inplace=True)
    
    # Read the ground truth file
    ground_truth = []
    with open(ground_truth_file, 'r') as gt:
        for line in gt:
            ground_truth.append(accession_to_taxid(line.strip()))

    # Benchmark the accuracy
    kraken2_predictions["ground_truth"] = ground_truth

    species_correct = np.full(len(ground_truth), False, dtype=bool)
    family_correct = np.full(len(ground_truth), False, dtype=bool)
    genus_correct = np.full(len(ground_truth), False, dtype=bool)
    prediction_at_species_level = np.full(len(ground_truth), False, dtype=bool)


    for i in range(len(ground_truth)):
        if get_level(ground_truth[i], 'species') == get_level(kraken2_predictions["classification_result"][i], 'species'):
            species_correct[i] = True
        if get_level(ground_truth[i], 'genus') == get_level(kraken2_predictions["classification_result"][i], 'genus'):
            genus_correct[i] = True
        if get_level(ground_truth[i], 'family') == get_level(kraken2_predictions["classification_result"][i], 'family'):
            family_correct[i] = True

    return np.sum(species_correct), np.sum(genus_correct), np.sum(family_correct)

In [38]:
# Perform benchmarking
import glob
ground_truth_files = "/home/zhenhao/htc/data/zymo_test_reads/train_reads.label"
kraken2_reports = glob.glob("/mnt/c/Users/zhenh/kraken2_benchmark/db*_query0.output")

In [42]:
for report in kraken2_reports:
    print(analyze_kraken_results(report, ground_truth_files))
#analyze_kraken_results(kraken2_reports[0], ground_truth_files)

(14751, 20070, 20474)
(14788, 20315, 20735)
(14798, 20498, 20914)
(14780, 20459, 20859)
(14249, 18865, 19129)
(1458, 7553, 7645)
