# Imports

In [None]:
import os
from Bio import SeqIO
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import pandas as pd
from scipy.stats import wilcoxon, spearmanr
from matplotlib.ticker import FuncFormatter, ScalarFormatter, LogLocator, LogFormatter
import numpy as np
from sklearn.metrics import f1_score
import statsmodels.formula.api as smf
from statsmodels.stats.multitest import multipletests

# Constants

In [None]:
# Constants
EXPERIMENTS = [
    "order_experiments",
    "family_experiments",
    "genus_experiments"
]
SAMPLES = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
METHODS = [
    "all",
	"centroid",
    "ggrasp",
    "meshclust_0.95",
	"meshclust_0.97",
	"meshclust_0.99",
    "gclust_0.95",
	"gclust_0.97",
	"gclust_0.99",
    "single-linkage_0.95",
	"single-linkage_0.97",
	"single-linkage_0.99",
	"complete-linkage_0.95",
	"complete-linkage_0.97",
	"complete-linkage_0.99",
]
METHOD_LABELS = {
    "all": "All",
    "centroid": "Centroid",
    "ggrasp": "GGRaSP",
    "meshclust_0.95": "MC-0.95",
    "meshclust_0.97": "MC-0.97",
    "meshclust_0.99": "MC-0.99",
    "gclust_0.95": "GC-0.95",
    "gclust_0.97": "GC-0.97",
    "gclust_0.99": "GC-0.99",
    "single-linkage_0.95": "SL-0.95",
    "single-linkage_0.97": "SL-0.97",
    "single-linkage_0.99": "SL-0.99",
    "complete-linkage_0.95": "CL-0.95",
    "complete-linkage_0.97": "CL-0.97",
    "complete-linkage_0.99": "CL-0.99"
}
ALPHA = 0.05
OUTPUT_DIR = "figures/bacteria"

# Fetch data and helper functions

In [None]:
data_per_experiment = {
    "order_experiments": {},
    "family_experiments": {},
    "genus_experiments": {}
}

def map_ids(experiment):
    """
    This function maps sequence IDs to accession numbers, species taxids, and lengths.
    It also creates mappings for species to indices and vice versa.
    """
    folder = f"root/simulation_genomes/{experiment}"

    id2species = {}
    species2idx = {}
    idx2species = []
    id2length = {}
    id2accession = {}
    accessions = set()

    # Iterate through genome folders
    for species in os.listdir(f"{folder}"):
        for accession in os.listdir(f"{folder}/{species}"):
            if accession.endswith(".fna"):
                taxid = int(species)
                length = 0
                accessions.add(accession)
                with open(f"{folder}/{species}/{accession}", "r") as f_in:
                    for record in SeqIO.parse(f_in, "fasta"):
                        length += len(record.seq)
                        id2species[record.id] = taxid
                        id2length[record.id] = len(record.seq)
                        id2accession[record.id] = accession
                        if taxid not in species2idx:
                            species2idx[taxid] = len(idx2species)
                            idx2species.append(taxid)
                id2length[accession] = length
                id2species[accession] = taxid
    
    return accessions, {
        "id2species": id2species,
        "species2idx": species2idx,
        "idx2species": idx2species,
        "id2length": id2length,
        "id2accession": id2accession
    }

def index_species(experiment):
    accession2species = {}
    species2idx = {}
    idx2species = []
    folder = f"root/reference_sets/{experiment}/all.tsv"
    with open(folder, "r") as f_in:
        for line in f_in:
            line = line.strip().split("\t")
            taxid = int(line[1])
            accession = line[2]
            accession2species[accession] = taxid

            if taxid not in species2idx:
                species2idx[taxid] = len(idx2species)
                idx2species.append(taxid)

    return accession2species, species2idx, idx2species

# This fetches all data and mappings for each experiment
data_per_experiment = {}
accessions_per_experiment = {}
for experiment in EXPERIMENTS:
    accessions_per_experiment[experiment], data_per_experiment[experiment] = map_ids(experiment)
    a2s, s2i, i2s = index_species(experiment)
    data_per_experiment[experiment]["accession2species"] = a2s
    data_per_experiment[experiment]["species2idx"] = s2i
    data_per_experiment[experiment]["idx2species"] = i2s

# Generate groundtruths

In [None]:
# Generate groundtruths (this takes ~9 minutes to run)
groundtruths_genomes = {} #this is corrected for the length of genomes -> DUDes
groundtruths_reads = {} #this only considers reads assigned to species -> Bracken & Centrifuge
for experiment in EXPERIMENTS:
    groundtruths_genomes[experiment] = np.zeros((len(SAMPLES), len(data_per_experiment[experiment]["idx2species"])), dtype=np.float64)
    groundtruths_reads[experiment] = np.zeros((len(SAMPLES), len(data_per_experiment[experiment]["idx2species"])), dtype=np.float64)
    cur_path = f"root/samples/{experiment}/"
    for sample in SAMPLES:
        line_count = 0
        nucleotide_counts = {species: {} for species in data_per_experiment[experiment]["idx2species"]}
        read_counts = {species: 0 for species in data_per_experiment[experiment]["idx2species"]}
        total_reads = 0
        with open(f"{cur_path}/sample_{sample}/sample_1.fq", "r") as f_in: # calculate number of reads and nucleotides per species
            for line in f_in:
                if line_count % 4 == 0:
                    cur_id = line[1:].strip().split("-")[0]
                    cur_accession = data_per_experiment[experiment]["id2accession"][cur_id]
                    read_counts[data_per_experiment[experiment]["id2species"][cur_id]] += 1
                    total_reads += 1
                elif line_count % 4 == 1:
                    cur_read = line.strip()
                    if cur_accession not in nucleotide_counts[data_per_experiment[experiment]["id2species"][cur_accession]]:
                        nucleotide_counts[data_per_experiment[experiment]["id2species"][cur_id]][cur_accession] = 0
                    nucleotide_counts[data_per_experiment[experiment]["id2species"][cur_id]][cur_accession] += len(cur_read)
                line_count += 1

        # Account for genome lengths
        for species in nucleotide_counts:
            species_idx = data_per_experiment[experiment]["species2idx"][species]
            for accession in nucleotide_counts[species]:
                groundtruths_genomes[experiment][sample-1, species_idx] += nucleotide_counts[species][accession] / data_per_experiment[experiment]["id2length"][accession]
        # Only consider number of reads assigned
        for species in read_counts:
            species_idx = data_per_experiment[experiment]["species2idx"][species]
            groundtruths_reads[experiment][sample-1, species_idx] = read_counts[species] / total_reads
            groundtruths_genomes[experiment][sample-1, :] = groundtruths_genomes[experiment][sample-1, :] / np.sum(groundtruths_genomes[experiment][sample-1, :])


# Helper functions for fetching abundance estimates per profiler

In [None]:
def read_bracken(experiment, method, data_per_experiment):
    readcounts = np.zeros((len(SAMPLES), len(data_per_experiment[experiment]["idx2species"])), dtype=np.int64)
    abundances = np.zeros((len(SAMPLES), len(data_per_experiment[experiment]["idx2species"])), dtype=np.float64)
    for sample in SAMPLES:
        total_readcount = 0
        with open(f"root/estimations/{experiment}/bracken/sample_${sample}/${method}.bracken", "r") as f_in:
            next(f_in) #skip header
            for line in f_in:
                line = line.strip().split("\t")
                taxid = int(line[1])
                taxid_idx = data_per_experiment[experiment]["species2idx"][taxid]
                readcount = int(line[5])
                readcounts[sample-1, taxid_idx] += readcount
                total_readcount += readcount
        for taxid_idx in range(len(data_per_experiment[experiment]["idx2species"])):
            if readcounts[sample-1, taxid_idx] / total_readcount > 0.001:
                abundances[sample-1, taxid_idx] = readcounts[sample-1, taxid_idx] / total_readcount
        abundances[sample-1, :] = abundances[sample-1, :] / np.sum(abundances[sample-1, :])
    return abundances

def read_centrifuge(experiment, method, data_per_experiment):
    readcounts = np.zeros((len(SAMPLES), len(data_per_experiment[experiment]["idx2species"])), dtype=np.int64)
    abundances = np.zeros((len(SAMPLES), len(data_per_experiment[experiment]["idx2species"])), dtype=np.float64)
    for sample in SAMPLES:
        total_readcount = 0
        with open(f"root/estimations/{experiment}/centrifuge/sample_{sample}/{method}.report", "r") as f_in:
            next(f_in) #skip header
            for line in f_in:
                line = line.strip().split("\t")
                taxid = int(line[1])
                level = line[2]
                if level == "species":
                    taxid_idx = data_per_experiment[experiment]["species2idx"][taxid]
                    readcount = int(line[5]) #line[5] is the number of reads uniquely assigned to this taxon
                    readcounts[sample-1, taxid_idx] += readcount
                    total_readcount += int(line[5])
        for taxid_idx in range(len(data_per_experiment[experiment]["idx2species"])):
            if readcounts[sample-1, taxid_idx] / total_readcount > 0.001:
                abundances[sample-1, taxid_idx] = readcounts[sample-1, taxid_idx] / total_readcount
        abundances[sample-1, :] = abundances[sample-1, :] / np.sum(abundances[sample-1, :])
    return abundances

def read_dudes(experiment, method, data_per_experiment):
    abundances = np.zeros((len(SAMPLES), len(data_per_experiment[experiment]["idx2species"])), dtype=np.float64)
    for sample in SAMPLES:
        total_abundance = 0
        with open(f"root/estimations/{experiment}/dudes/sample_{sample}/{method}_dudes.out", "r") as f_in:
            for _ in range(6): #skip preamble
                next(f_in)
            for line in f_in:
                line = line.strip().split("\t")
                taxid = int(line[0])
                level = line[1]
                if level == "species":
                    taxid_idx = data_per_experiment[experiment]["species2idx"][taxid]
                    abundance = float(line[-1]) / 100 #scale to [0, 1]
                    if abundance > 0.001:
                        abundances[sample-1, taxid_idx] += abundance
                        total_abundance += abundance
        if total_abundance == 0:
            print(experiment, method, sample)
            abundances[sample-1, :] = np.nan
        else:
            abundances[sample-1, :] = abundances[sample-1, :] / np.sum(abundances[sample-1, :])
    return abundances

        

# Plotting functions

In [None]:
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import seaborn as sns
import pandas as pd
from scipy.stats import wilcoxon, spearmanr
import pandas as pd
from matplotlib.ticker import FuncFormatter, ScalarFormatter, LogLocator, LogFormatter


COLOR_PALETTE = {
    "1 - L1/2": sns.color_palette("deep")[0],
    "F1": sns.color_palette("muted")[1]
}

def plot_accuracy(experiment, metrics, results, pvalues_l1, pvalues_f1):
    """
    This function plots the accuracy barplots for each profiler and experiment (Figure 3 in manuscript).
    """
    for profiler in ["bracken", "centrifuge", "dudes"]:
        df = pd.DataFrame(metrics[profiler])
        g = sns.catplot(
            data=df,
            kind="bar",
            y="method",
            x="value",
            hue="metric",
            estimator="median",
            errorbar=("pi", 80),
            err_kws={"linewidth": 4},
            width=0.9,
            alpha=0.8,
            order=[METHOD_LABELS[m] for m in METHODS[1:]],
            palette=COLOR_PALETTE,
        )
        g.set_axis_labels("Score", "", fontsize=16)
        g.ax.axvline(x=np.median(results[profiler]["all"]["L1"]), ymin=0, ymax=1, color=COLOR_PALETTE["1 - L1/2"], linestyle="--", alpha=1, linewidth=2)
        g.ax.axvline(x=np.median(results[profiler]["all"]["F1"]), ymin=0, ymax=1, color=COLOR_PALETTE["F1"], linestyle="--", alpha=1, linewidth=2)
        g.ax.grid(True, which="major", axis="x", linestyle=":", alpha=0.5)

        if profiler == "bracken":
            title_str = f"Bracken - {experiment.split('_')[0].capitalize()}"
            plt.title(f"Bracken - {experiment.split('_')[0].capitalize()}", fontsize=20)
        elif profiler == "centrifuge":
            title_str = f"Centrifuge - {experiment.split('_')[0].capitalize()}"
            plt.title(f"Centrifuge - {experiment.split('_')[0].capitalize()}", fontsize=20)
        else:
            title_str = f"DUDes - {experiment.split('_')[0].capitalize()}"
            plt.title(f"DUDes - {experiment.split('_')[0].capitalize()}", fontsize=20)

        plt.yticks(fontsize=15)
        plt.xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=13)
        plt.xlim([0.0, 1.05])
        # To account for the fact that DUDes didn't provide results for the genus experiments with the centroid selection
        if not (experiment == "genus_experiments" and profiler == "dudes"):
            methods_labels = [METHOD_LABELS[m] for m in METHODS[1:]]
        else:
            methods_labels = [METHOD_LABELS[m] for m in METHODS[2:]]
        
        cur_pvalues_l1 = pvalues_l1[profiler]
        cur_pvalues_f1 = pvalues_f1[profiler]
        for ax in g.axes.flat:
            for metric, container in zip(["1 - L1/2", "F1"], ax.containers):
                for method_label, bar in zip(methods_labels, container.patches):
                    group = df[(df["method"] == method_label) & (df["metric"] == metric)]
                    if group.empty:
                        continue

                    x_90 = np.percentile(group["value"], 90)
                    # Choose significant p-values
                    pval_dict = cur_pvalues_f1 if metric == "F1" else cur_pvalues_l1
                    if pval_dict.get(method_label, 1.0) < ALPHA:
                        y_center = bar.get_y() + bar.get_height() / 2
                        ax.text(
                            x_90 + 0.025,
                            y_center,
                            f"\u2217",
                            #f"{method_label}",
                            ha="left",
                            va="center",
                            fontsize=10,
                            fontweight="bold"
                        )

        # Fix legend
        g.ax.legend(title="", fontsize=15, ncol=2)

        g.figure.set_size_inches(8, 6)
        g.figure.tight_layout()
        g.savefig(f"{OUTPUT_DIR}/{experiment}/{profiler}_barplot.svg", dpi=500, bbox_inches="tight")
        plt.close()

def plot_genomes(experiment, results):
    for profiler in ["bracken", "centrifuge", "dudes"]:
        median_outcomes_l1 = np.array([np.median(results[profiler][method]["L1"]) for method in METHODS], dtype=np.float64)
        median_outcomes_f1 = np.array([np.median(results[profiler][method]["F1"]) for method in METHODS], dtype=np.float64)
        genome_counts = np.array([results[profiler][method]["genomes"] for method in METHODS], dtype=np.int64)

        rho_l1, _ = spearmanr(genome_counts, median_outcomes_l1, nan_policy="omit")
        rho_f1, _ = spearmanr(genome_counts, median_outcomes_f1, nan_policy="omit")

        fig = plt.figure(figsize=(8,6))
        ax = sns.scatterplot(x=genome_counts, y=median_outcomes_l1, s=100, color=sns.color_palette("deep")[0], label="1 - L1/2", edgecolor="black", alpha=0.5)
        sns.scatterplot(x=genome_counts, y=median_outcomes_f1, s=100, color=sns.color_palette("muted")[1], label="F1", edgecolor="black", alpha=0.5)
        
        plt.xlabel("Number of genomes", fontsize=16)
        plt.ylabel("Score", fontsize=16)
        ax.grid(True, which="major", axis="y", linestyle=":", alpha=0.5)
        plt.title(rf"$\rho(1-L1/2)={rho_l1:.3f}$, $\rho(F1)={rho_f1:.3f}$", fontsize=16)
        plt.ylim([0, 1.05])
        plt.xlim([0, max(genome_counts)+300])
        plt.xticks(fontsize=13)
        plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=13)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.legend(title="", fontsize=15, ncol=2)
        plt.tight_layout()
        fig.savefig(f"{OUTPUT_DIR}/{experiment}/{profiler}_l1_f1_genomes.svg", dpi=500, bbox_inches="tight")

# Determine L1 and F1 for all methods and plot
**NOTE**: In this script we hardcode the experiments. This should be changed accordingly!!!

In [None]:
experiment = "genus_experiments"  # Change this to run for different experiments!!

results = {
    "bracken": {},
    "centrifuge": {},
    "dudes": {}
}
metrics = {
    "bracken": [],
    "centrifuge": [],
    "dudes": []
}
for method in METHODS:
    bracken = read_bracken(experiment, method, data_per_experiment)
    centrifuge = read_centrifuge(experiment, method, data_per_experiment)
    dudes = read_dudes(experiment, method, data_per_experiment)

    num_genomes = 0
    with open(f"root/reference_sets/{experiment}/{method}.tsv", "r") as f_in:
        for line in f_in:
            num_genomes += 1

    results["bracken"][method] = {
        "L1": [1.0 - np.sum( np.abs(bracken[sample, :] - groundtruths_reads[experiment][sample, :]) ) / 2.0 for sample in range(len(SAMPLES))],
        "F1": [f1_score(groundtruths_reads[experiment][sample, :] > 0, bracken[sample, :] > 0, zero_division=np.nan) for sample in range(len(SAMPLES))],
        "genomes": num_genomes,
    }
    results["centrifuge"][method] = {
        "L1": [1.0 - np.sum( np.abs(centrifuge[sample, :] - groundtruths_reads[experiment][sample, :]) ) / 2.0 for sample in range(len(SAMPLES))],
        "F1": [f1_score(groundtruths_reads[experiment][sample, :] > 0, centrifuge[sample, :] > 0, zero_division=np.nan) for sample in range(len(SAMPLES))],
        "genomes": num_genomes,
    }
    results["dudes"][method] = {
        "L1": [1.0 - np.sum( np.abs(dudes[sample, :] - groundtruths_genomes[experiment][sample, :]) ) / 2.0 for sample in range(len(SAMPLES))],
        "F1": [f1_score(groundtruths_genomes[experiment][sample, :] > 0, dudes[sample, :] > 0, zero_division=np.nan) for sample in range(len(SAMPLES))],
        "genomes": num_genomes,
    }

    for sample in range(len(SAMPLES)):
        # Bracken
        metrics["bracken"].append(
            {
                "method": METHOD_LABELS[method],
                "metric": "1 - L1/2",
                "value": 1.0 - np.sum( np.abs(bracken[sample, :] - groundtruths_reads[experiment][sample, :]) ) / 2.0,
                "genomes": num_genomes
            }
        )
        metrics["bracken"].append(
            {
                "method": METHOD_LABELS[method],
                "metric": "F1",
                "value": f1_score(groundtruths_reads[experiment][sample, :] > 0, bracken[sample, :] > 0, zero_division=np.nan),
                "genomes": num_genomes
            }
        )
        # Centrifuge
        metrics["centrifuge"].append(
            {
                "method": METHOD_LABELS[method],
                "metric": "1 - L1/2",
                "value": 1.0 - np.sum( np.abs(centrifuge[sample, :] - groundtruths_reads[experiment][sample, :]) ) / 2.0,
                "genomes": num_genomes
            }
        )
        metrics["centrifuge"].append(
            {
                "method": METHOD_LABELS[method],
                "metric": "F1",
                "value": f1_score(groundtruths_reads[experiment][sample, :] > 0, centrifuge[sample, :] > 0, zero_division=np.nan),
                "genomes": num_genomes
            }
        )
        # DUDes
        metrics["dudes"].append(
            {
                "method": METHOD_LABELS[method],
                "metric": "1 - L1/2",
                "value": 1.0 - np.sum( np.abs(dudes[sample, :] - groundtruths_genomes[experiment][sample, :]) ) / 2.0,
                "genomes": num_genomes
            }
        )
        if method == "centroid" and experiment == "genus_experiments":
            metrics["dudes"].append(
                {
                    "method": METHOD_LABELS[method],
                    "metric": "F1",
                    "value": np.nan,
                    "genomes": num_genomes
                }
            )
        else:
            metrics["dudes"].append(
                {
                    "method": METHOD_LABELS[method],
                    "metric": "F1",
                    "value": f1_score(groundtruths_genomes[experiment][sample, :] > 0, dudes[sample, :] > 0, zero_division=np.nan),
                    "genomes": num_genomes
                }
            )
# Determine pvalues and significance
pvalues_l1 = {}
pvalues_corrected_l1 = {}
pvalues_f1 = {}
pvalues_corrected_f1 = {}
for profiler in ["bracken", "centrifuge", "dudes"]:
    cur_results = results[profiler]
    cur_metrics = metrics[profiler]

    cur_pvalues_l1 = np.ones(len(METHODS)-1, dtype=np.float64)
    cur_pvalues_f1 = np.ones(len(METHODS)-1, dtype=np.float64)
    for i, method in enumerate(METHODS[1:]):
        if method == "centroid" and experiment == "genus_experiments" and profiler == "dudes":
            continue
        else:
            # L1
            try:
                _, pval = wilcoxon(
                    cur_results[method]["L1"],
                    cur_results["all"]["L1"],
                    alternative="greater"
                )
            except:
                pval = 1.0
            cur_pvalues_l1[i] = pval
            #F1
            try:
                _, pval = wilcoxon(
                    cur_results[method]["F1"],
                    cur_results["all"]["F1"],
                    alternative="greater"
                )
            except:
                pval = 1.0
            cur_pvalues_f1[i] = pval
    # Multiple testing correction (Benjamini-Hochberg)
    methods_for_pvalues = [METHOD_LABELS[m] for m in METHODS[1:]]
    pvalues_l1[profiler] = cur_pvalues_l1
    pvalues_corrected_l1[profiler] = dict(zip(methods_for_pvalues, multipletests(cur_pvalues_l1, method="fdr_bh")[1]))
    pvalues_f1[profiler] = cur_pvalues_f1
    pvalues_corrected_f1[profiler] = dict(zip(methods_for_pvalues, multipletests(cur_pvalues_f1, method="fdr_bh")[1]))


plot_accuracy(experiment, metrics, results, pvalues_corrected_l1, pvalues_corrected_f1)
plot_genomes(experiment, results)