# Imports

In [None]:
import os
from Bio import SeqIO
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns
import pandas as pd
from scipy.stats import wilcoxon, spearmanr
from matplotlib.ticker import FuncFormatter, ScalarFormatter, LogLocator, LogFormatter
import numpy as np
from sklearn.metrics import f1_score
import statsmodels.formula.api as smf
from statsmodels.stats.multitest import multipletests

# Constants

In [None]:
# Constants
SAMPLES = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
ABUNDANCES = [1, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
METHODS = [
    "all",
	"centroid",
    "ggrasp",
    "vlq",
	"meshclust_0.95",
	"meshclust_0.99",
    "gclust_0.95",
	"gclust_0.99",
	"gclust_0.999",
    "vsearch_0.95",
    "vsearch_0.99",
    "vsearch_0.999",
    "single-linkage_1",
    "single-linkage_5",
    "single-linkage_10",
    "single-linkage_25",
    "single-linkage_50",
    "single-linkage_90",
    "single-linkage_99",
	"complete-linkage_1",
    "complete-linkage_5",
    "complete-linkage_10",
    "complete-linkage_25",
    "complete-linkage_50",
    "complete-linkage_90",
    "complete-linkage_99",

]
METHOD_LABELS = {
    "all": "All",
	"centroid": "Centroid",
    "ggrasp": "GGRaSP",
    "vlq": "VLQ",
	"meshclust_0.95": "MC-0.95",
	"meshclust_0.99": "MC-0.99",
    "gclust_0.95": "GC-0.95",
	"gclust_0.99": "GC-0.99",
	"gclust_0.999": "GC-0.999",
    "vsearch_0.95": "VS-0.95",
    "vsearch_0.99": "VS-0.99",
    "vsearch_0.999": "VS-0.999",
    "single-linkage_1": r"SL-$P_{1}$",
    "single-linkage_5": r"SL-$P_{5}$",
    "single-linkage_10": r"SL-$P_{10}$",
    "single-linkage_25": r"SL-$P_{25}$",
    "single-linkage_50": r"SL-$P_{50}$",
    "single-linkage_90": r"SL-$P_{90}$",
    "single-linkage_99": r"SL-$P_{99}$",
	"complete-linkage_1": r"CL-$P_{1}$",
    "complete-linkage_5": r"CL-$P_{5}$",
    "complete-linkage_10": r"CL-$P_{10}$",
    "complete-linkage_25": r"CL-$P_{25}$",
    "complete-linkage_50": r"CL-$P_{50}$",
    "complete-linkage_90": r"CL-$P_{90}$",
    "complete-linkage_99": r"CL-$P_{99}$",
}
EXPERIMENTS = ["global", "country", "state"]
ALPHA = 0.05
OUTPUT_DIR = "figures/viruses"

# Helper functions

In [None]:
def create_mapping(experiment):
    """
    Determine all lineages that are present in reference and simulation genomes
    """
    idx2lineage = []
    lineage2idx = {}
    # Add lineages from reference
    with open(f"root/reference_sets/{experiment}/all.tsv", "r") as f_in:
        next(f_in) #skip header
        for line in f_in:
            line = line.strip().split("\t")
            lineage = line[0]
            if lineage not in lineage2idx:
                lineage2idx[lineage] = len(idx2lineage)
                idx2lineage.append(lineage)
    print(len(lineage2idx), "unique lineages found in reference")
    # Add lineages from simulations
    with open(f"root/samples/1/metadata.tsv", "r") as f_in:
        next(f_in)
        for line in f_in:
            line = line.strip().split("\t")
            lineage = line[11]
            if lineage not in lineage2idx:
                lineage2idx[lineage] = len(idx2lineage)
                idx2lineage.append(lineage)
    print(len(lineage2idx), "unique lineages found in total")

    return idx2lineage, lineage2idx

def generate_groundtruth(idx2lineage, lineage2idx):
    # Map sequence ids to lineages
    id2lineage = {}
    with open(f"root/samples/1/metadata.tsv", "r") as f_in:
        next(f_in)
        for line in f_in:
            line = line.strip().split("\t")
            id = line[0]
            lineage = line[11]
            id2lineage[id] = lineage

    abundances = np.zeros((len(SAMPLES)*len(ABUNDANCES), len(idx2lineage)), dtype=np.float64)
    for sample in SAMPLES:
        # Map sequence ids to genome length
        id2length = {}
        for record in SeqIO.parse(f"root/samples/{sample}/sequences.tsv", "fasta"):
            id2length[record.id.strip().split("|")[0]] = len(record.seq)
        for record in SeqIO.parse("root/B.1.1.7_sequence.fasta", "fasta"):
            id2length[record.id.strip().split("|")[0]] = len(record.seq)

        for i, abundance in enumerate(ABUNDANCES):
            nucleotides = {id: 0 for id in id2length.keys()}
            linecount = 0
            with open(f"root/samples/{sample}/wwsim_B.1.1.7_sequence_ab{abundance}_1.fastq", "r") as f_in:
                for line in f_in:
                    if linecount % 4 == 0:
                        id = line[1:].strip().split("|")[0]
                    elif linecount % 4 == 1:
                        nucleotides[id] += len(line.strip())             
                    linecount += 1
            linecount = 0
            with open(f"root/samples/{sample}/wwsim_B.1.1.7_sequence_ab{abundance}_2.fastq", "r") as f_in:
                for line in f_in:
                    if linecount % 4 == 0:
                        id = line[1:].strip().split("|")[0]
                    elif linecount % 4 == 1:
                        nucleotides[id] += len(line.strip())             
                    linecount += 1

            for id in nucleotides:
                try:
                    lineage = id2lineage[id]
                except: # not included in sequences file!
                    lineage = "B.1.1.7"
                abundances[sample-1 + i*len(SAMPLES), lineage2idx[lineage]] += nucleotides[id] / id2length[id]
            abundances[sample-1 + i*len(SAMPLES), :] = abundances[sample-1 + i*len(SAMPLES), :] / np.sum(abundances[sample-1 + i*len(SAMPLES), :])

    return abundances

def read_estimations(method, experiment, lineage2idx):
    basepath = f"root/estimations/{experiment}"
    abundances = np.zeros((len(SAMPLES)*len(ABUNDANCES), len(lineage2idx)), dtype=np.float64)
    for i, abundance in enumerate(ABUNDANCES):
        for sample in SAMPLES:
            with open(f"{basepath}/{sample}/{method}_ab{abundance}_predictions.tsv", "r") as f_in:
                next(f_in)
                next(f_in)
                next(f_in) #skip preamble
                for line in f_in:
                    lineage, _, _, cur_abundance = line.strip().split("\t")
                    cur_abundance = float(cur_abundance) / 100
                    if cur_abundance > 0.001:
                        abundances[sample-1 + i*len(SAMPLES), lineage2idx[lineage]] = cur_abundance
            abundances[sample-1 + i*len(SAMPLES), :] = abundances[sample-1 + i*len(SAMPLES), :] / np.sum(abundances[sample-1 + i*len(SAMPLES), :])
    return abundances


# Plotting script

In [None]:
COLOR_PALETTE = {
    "1 - L1/2": sns.color_palette("deep")[0],
    "F1": sns.color_palette("muted")[1]
}

def plot_results(experiment, metrics, results, methods, pvalues_l1, pvalues_f1):
    df = pd.DataFrame(metrics)
    # Keep only rows where 'method' is in the provided methods list (by label)
    method_labels = [METHOD_LABELS[m] for m in methods]
    df = df[df["method"].isin(method_labels)].copy()
    fig = plt.figure()
    g = sns.catplot(
        data=df,
        kind="bar",
        x="method",
        y="value",
        hue="metric",
        estimator="median",
        errorbar=("pi", 80),
        err_kws={"linewidth": 4},
        width=0.9,
        alpha=0.8,
        order=[METHOD_LABELS[m] for m in methods[1:]],
        palette=COLOR_PALETTE,
    )
    g.set_axis_labels("", "Score", fontsize=16)
    g.ax.axhline(y=np.median(results["all"]["L1"]), xmin=0, xmax=1, color=COLOR_PALETTE["1 - L1/2"], linestyle="--", alpha=1, linewidth=2)
    g.ax.axhline(y=np.median(results["all"]["F1"]), xmin=0, xmax=1, color=COLOR_PALETTE["F1"], linestyle="--", alpha=1, linewidth=2)
    g.ax.grid(True, which="major", axis="y", linestyle=":", alpha=0.5)

    title_str = f"{experiment.capitalize()}"
    plt.title(title_str, fontsize=20)

    plt.xticks(fontsize=15, rotation=90, ha="right")
    plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=13)
    plt.ylim([0.0, 1.05])

    methods_labels = [METHOD_LABELS[m] for m in methods[1:]]  # display names

    for ax in g.axes.flat:
        # one BarContainer per metric
        for metric, container in zip(["1 - L1/2", "F1"], ax.containers):
            # one bar per method
            for method_label, bar in zip(methods_labels, container.patches):
                # find matching df subset
                group = df[(df["metric"] == metric) & (df["method"] == method_label)]
                if group.empty:
                    continue

                y_90 = np.percentile(group["value"], 90)
                # choose p-values precomputed as dicts keyed by method_label
                pval_dict = pvalues_f1 if metric == "F1" else pvalues_l1
                if pval_dict.get(method_label, 1.0) < ALPHA:
                    x_center = bar.get_x() + bar.get_width() / 2
                    ax.text(
                        x_center,
                        y_90 + 0.025,     # just above the bar
                        f"\u2217",          # your star or method name
                        ha="center",
                        va="bottom",
                        fontsize=10,
                        fontweight="bold"
                    )
    # Fix legend
    g.ax.legend(title="", fontsize=15, ncol=2)
    
    g.figure.set_size_inches(8, 6)
    g.figure.tight_layout()
    g.savefig(f"${OUTPUT_DIR}/viruses/{experiment}/barplot.svg", dpi=500, bbox_inches="tight")
    plt.close()

def plot_genomes(experiment, results):
    median_outcomes_l1 = np.array([np.median(results[method]["L1"]) for method in METHODS], dtype=np.float64)
    median_outcomes_f1 = np.array([np.median(results[method]["F1"]) for method in METHODS], dtype=np.float64)
    genome_counts = np.array([results[method]["genomes"] for method in METHODS], dtype=np.int64)

    rho_l1, _ = spearmanr(genome_counts, median_outcomes_l1, nan_policy="omit")
    rho_f1, _ = spearmanr(genome_counts, median_outcomes_f1, nan_policy="omit")

    fig = plt.figure(figsize=(8,6))
    ax = sns.scatterplot(x=genome_counts, y=median_outcomes_l1, s=100, color=sns.color_palette("deep")[0], label="1 - L1/2", edgecolor="black", alpha=0.5)
    sns.scatterplot(x=genome_counts, y=median_outcomes_f1, s=100, color=sns.color_palette("muted")[1], label="F1", edgecolor="black", alpha=0.5, ax=ax)

    plt.xlabel("Number of genomes", fontsize=16)
    plt.ylabel("Score", fontsize=16)
    ax.grid(True, which="major", axis="y", linestyle=":", alpha=0.5)
    plt.title(rf"$\rho(1-L1/2)={rho_l1:.3f}$, $\rho(F1)={rho_f1:.3f}$", fontsize=16)
    plt.ylim([0, 1.05])
    # Account for scales
    if experiment == "global":
        plt.xlim([0, max(genome_counts)+3000])
    elif experiment == "country":
        plt.xlim([0, max(genome_counts)+500])
    else:
        plt.xlim([0, max(genome_counts)+100])
    plt.xticks(fontsize=13)
    plt.yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=13)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.legend(title="", fontsize=15, ncol=2)
    plt.tight_layout()
    #ax.set_xscale("log", base=10)
    fig.savefig(f"${OUTPUT_DIR}/viruses/{experiment}/l1_f1_genomes.svg", dpi=500, bbox_inches="tight")

# Fetch results

In [None]:
results_per_experiment = {}
metrics_per_experiment = {}
num_genomes_per_experiment = {}
for experiment in ["state", "country", "global"]:
    idx2lineage, lineage2idx = create_mapping(experiment)
    groundtruth = generate_groundtruth(idx2lineage, lineage2idx)

    results_per_experiment[experiment] = {}
    metrics_per_experiment[experiment] = []

    for method in METHODS:
        cur_results = read_estimations(method, experiment, lineage2idx)

        num_genomes = 0
        filepath = f"root/reference_sets/{experiment}/{method}.tsv"
        with open(filepath, "r") as f_in:
            next(f_in) #skip header
            for line in f_in:
                num_genomes += 1

        results_per_experiment[experiment][method] = {
            "L1": [1.0 - np.sum( np.abs(cur_results[x, :] - groundtruth[x, :])) / 2.0 for x in range(len(SAMPLES)*len(ABUNDANCES))],
            "F1": [f1_score(groundtruth[x, :] > 0, cur_results[x, :] > 0, zero_division=1) for x in range(len(SAMPLES)*len(ABUNDANCES))],
            "genomes": num_genomes,
        }

        for x in range(len(SAMPLES)*len(ABUNDANCES)):
            metrics_per_experiment[experiment].append(
                {
                    "method": METHOD_LABELS[method],
                    "metric": "1 - L1/2",
                    "value": results_per_experiment[experiment][method]["L1"][x],
                    "genomes": num_genomes
                }
            )
            metrics_per_experiment[experiment].append(
                {
                    "method": METHOD_LABELS[method],
                    "metric": "F1",
                    "value": results_per_experiment[experiment][method]["F1"][x],
                    "genomes": num_genomes
                }
            )

# Run plotting
**NOTE**: We aggregate results over all abundances!

In [None]:
# Determine, per method group, the best overal performance based on average median L1 and F1 score across experiments
method_groups = [
    ["all"],
	["centroid"],
    ["ggrasp"],
    ["vlq"],
	["meshclust_0.95", "meshclust_0.99"],
    ["gclust_0.95", "gclust_0.99", "gclust_0.999"],
    ["vsearch_0.95", "vsearch_0.99", "vsearch_0.999"],
    ["single-linkage_1",
    "single-linkage_5",
    "single-linkage_10",
    "single-linkage_25",
    "single-linkage_50",
    "single-linkage_90",
    "single-linkage_99"],
	["complete-linkage_1",
    "complete-linkage_5",
    "complete-linkage_10",
    "complete-linkage_25",
    "complete-linkage_50",
    "complete-linkage_90",
    "complete-linkage_99"],
]
# Find best method per group
final_methods = []
for method_group in method_groups:
    if len(method_group) == 1:
        final_methods.append(method_group[0])
    else:
        best_method = None
        best_score = -1
        for method in method_group:
            avg_median_l1 = np.mean([np.median(results_per_experiment[exp][method]["L1"]) for exp in ["state", "country", "global"]])
            avg_median_f1 = np.mean([np.median(results_per_experiment[exp][method]["F1"]) for exp in ["state", "country", "global"]])
            score = avg_median_l1 + avg_median_f1
            if score > best_score:
                best_score = score
                best_method = method
        final_methods.append(best_method)
final_methods = [
    "all",
	"centroid",
    "ggrasp",
    "vlq",
	"meshclust_0.95",
	"meshclust_0.99",
    "gclust_0.95",
	"gclust_0.99",
	"gclust_0.999",
    "vsearch_0.95",
    "vsearch_0.99",
    "vsearch_0.999",
    "single-linkage_1",
    "single-linkage_25",
    "single-linkage_99",
	"complete-linkage_1",
    "complete-linkage_25",
    "complete-linkage_99",
]
final_methods = METHODS # override for final plots
# Determine p-values and significance
pvalues_l1_per_experiment = {}
pvalues_l1_corrected_per_experiment = {}
pvalues_f1_per_experiment = {}
pvalues_f1_corrected_per_experiment = {}
for experiment in ["state", "country", "global"]:
    results = results_per_experiment[experiment]
    metrics = metrics_per_experiment[experiment]

    pvalues_l1 = np.ones((len(METHODS)-1,), dtype=np.float64)
    pvalues_f1 = np.ones((len(METHODS)-1,), dtype=np.float64)
    for i, method in enumerate(METHODS[1:]):
        # L1
        try:
            _, pval = wilcoxon(
                results[method]["L1"],
                results["all"]["L1"],
                alternative="greater"
            )
        except:
            pval = 1.0
        pvalues_l1[i] = pval
        # F1
        try:
            _, pval = wilcoxon(
                results[method]["F1"],
                results["all"]["F1"],
                alternative="greater"
            )
        except:
            pval = 1.0
        pvalues_f1[i] = pval
    # Multiple testing correction (Benjamini-Hochberg)
    methods_for_pvalues = [METHOD_LABELS[m] for m in METHODS[1:]]
    pvalues_l1_per_experiment[experiment] = dict(zip(methods_for_pvalues, pvalues_l1))
    pvalues_l1_corrected_per_experiment[experiment] = dict(zip(methods_for_pvalues, multipletests(pvalues_l1, method="fdr_bh")[1]))
    pvalues_f1_per_experiment[experiment] = dict(zip(methods_for_pvalues, pvalues_f1))
    pvalues_f1_corrected_per_experiment[experiment] = dict(zip(methods_for_pvalues, multipletests(pvalues_f1, method="fdr_bh")[1]))
    print(f"Experiment: {experiment}")
    print(multipletests(pvalues_l1, method="fdr_bh")[1])

for experiment in ["global", "country", "state"]:
    plot_results(experiment, metrics_per_experiment[experiment], results_per_experiment[experiment], final_methods, pvalues_l1_corrected_per_experiment[experiment], pvalues_f1_corrected_per_experiment[experiment])
    plot_genomes(experiment, results_per_experiment[experiment])

# Calculate correlation over all reference sets (across experiments)

In [None]:
median_l1 = []
median_f1 = []
genomes = []
for experiment in ["state", "country", "global"]:
    for method in METHODS:
        median_l1.append(np.median(results_per_experiment[experiment][method]["L1"]))
        median_f1.append(np.median(results_per_experiment[experiment][method]["F1"]))
        genomes.append(results_per_experiment[experiment][method]["genomes"])

rho_l1, _ = spearmanr(genomes, median_l1, nan_policy="omit")
rho_f1, _ = spearmanr(genomes, median_f1, nan_policy="omit")

print(f"Overall correlation genomes vs L1: {rho_l1:.3f}")
print(f"Overall correlation genomes vs F1: {rho_f1:.3f}")