# Imports

In [None]:
import os
from Bio import SeqIO
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm
from multiprocessing import Pool
import seaborn as sns
from math import sqrt

# Constants

In [None]:
EXPERIMENTS = [
    "order",
    "family",
    "genus",
]
METHODS = [
    "all",
	"centroid",
    "ggrasp",
    "meshclust_0.95",
	"meshclust_0.97",
	"meshclust_0.99",
    "gclust_0.95",
	"gclust_0.97",
	"gclust_0.99",
    "single-linkage_0.95",
	"single-linkage_0.97",
	"single-linkage_0.99",
	"complete-linkage_0.95",
	"complete-linkage_0.97",
	"complete-linkage_0.99",
]
# Labels for plotting
METHOD_LABELS = {
    "all": "All",
    "centroid": "Centroid",
    "ggrasp": "GGRaSP",
    "meshclust_0.95": "MC-0.95",
    "meshclust_0.97": "MC-0.97",
    "meshclust_0.99": "MC-0.99",
    "gclust_0.95": "GC-0.95",
    "gclust_0.97": "GC-0.97",
    "gclust_0.99": "GC-0.99",
    "single-linkage_0.95": "SL-0.95",
    "single-linkage_0.97": "SL-0.97",
    "single-linkage_0.99": "SL-0.99",
    "complete-linkage_0.95": "CL-0.95",
    "complete-linkage_0.97": "CL-0.97",
    "complete-linkage_0.99": "CL-0.99"
}
OUTPUT_DIR = "figures/bacteria"

ALPHA = 0.05 #significance level after multiple testing correction

# All experimental conditions (not all of NCBI)
This will produce the supplementary figures S1-3

## Helper functions

In [None]:
def calculate_CI(refset1, refset2):
    if type(refset1) is np.ndarray:
        intersection = np.sum(np.minimum(refset1, refset2))
        return [intersection/np.sum(refset1), intersection/np.sum(refset2)]
    else:
        intersection = len(refset1.intersection(refset2))
        return [intersection/len(refset1), intersection/len(refset2)]

def determine_nonsingletons(experiment):
    """
    This function determines the species which have more than one genome and should be included
    in the overlap analysis.
    """
    path = f"root/reference_sets/{experiment}_experiments/all.tsv"
    count_per_species = {}
    with open(path, "r") as f_in:
        for line in f_in:
            species, _, _ = line.strip().split("\t")
            if species not in count_per_species:
                count_per_species[species] = 0
            count_per_species[species] += 1
    nonsingletons = [int(species) for species, count in count_per_species.items() if count > 1]
    return sorted(nonsingletons)
    
def determine_content(experiment, method, nonsingletons):
    """
    This function determines the content of a reference set for a given experiment and method.
    NOTE: If a method was used with a threshold (e.g. gclust and 0.95) then the method name
    should be provided as gclust_0.95.
    """
    path = f"root/reference_sets/{experiment}_experiments/{method}.tsv"
    counts = np.zeros(len(nonsingletons), dtype=int)
    items = set()
    with open(path, "r") as f_in:
        for line in f_in:
            species, accession, _ = line.strip().split("\t")
            species = int(species)
            if species in nonsingletons:
                idx = nonsingletons.index(species)
                counts[idx] += 1
                items.add(accession)
    return counts, items

def simulate(N, n1, n2, obs_ci, num_simulations=10_000, seed=None):
    """
    Perform Monte Carlo simulation to estimate the probability of observing a certain overlap between two sets.
    NOTE: We specifically test whether the selection of n2 is independent of n1, i.e. we test the null hypothesis
    that the selection of n2 is random with respect to n1 (thus the n1 draws are fixed). This means that the p-values 
    we obtain are one-sided, i.e. we only test whether the overlap is significantly greater than expected by chance.

    Parameters:
    -----------
    N: numpy.ndarray
        A 1D vector with an entry for every nonsingleton species, indicating the total number of genomes
        for that species that could be selected (i.e. are in the "all" reference set).
    n1: numpy.ndarray
        A 1D vector with an entry for every nonsingleton species, indicating the number of genomes
        selected in the first set.
    n2: numpy.ndarray
        A 1D vector with an entry for every nonsingleton species, indicating the number of genomes
        selected in the second set.
    num_simulations: int
        The number of Monte Carlo simulations to perform.
    seed: int, optional
        Random seed for reproducibility.
    """
    rng = np.random.default_rng(seed)
    N = np.asarray(N, dtype=int)
    n1 = np.asarray(n1, dtype=int)
    n2 = np.asarray(n2, dtype=int)
    G = N.shape[0]

    # simulate overlaps
    sim_totals = np.zeros(num_simulations, dtype=int)
    for Ng, n1g, n2g in zip(N, n1, n2):
        draws = rng.hypergeometric(
            ngood=int(n1g),
            nbad=int(Ng-n1g),
            nsample=int(n2g),
            size=num_simulations
        )
        sim_totals += draws

    sim_ci = sim_totals / np.sum(n2)

    # Estimate p-values
    def pval(sim, obs):
        count = np.sum(sim >= obs)
        p = (1 + count) / (1 + num_simulations)
        se = sqrt(p * (1 - p) / num_simulations)
        return p, se

    return pval(sim_ci, obs_ci)

def BH_adjust(pvalues, alpha):
    """
    Benjamini-Hochberg procedure to control the false discovery rate (FDR) for multiple hypothesis testing.
    """
    M = pvalues.shape[0]
    
    mask = ~np.eye(M, dtype=bool)
    pvec = pvalues[mask]

    m = len(pvec)

    order = np.argsort(pvec)
    psorted = pvec[order]
    ranks = np.arange(1, m+1)
    adj_sorted = psorted * m / ranks
    adj_sorted = np.minimum.accumulate(adj_sorted[::-1])[::-1]
    adj = np.empty_like(pvec)
    adj[order] = np.minimum(adj_sorted, 1.0)

    adj_pvalues = np.ones_like(pvalues)
    adj_pvalues[mask] = adj
    significant = adj_pvalues < alpha
    return significant, adj_pvalues

## Actual plotting script

In [None]:
for experiment in EXPERIMENTS: # iterate over experiments
    # Determine nonsingletons and reference sets
    ns = determine_nonsingletons(experiment)
    reference_sets = {
        method: determine_content(experiment, method, ns) for method in METHODS
    }
    # Initialize matrices to store containment indices and p-values
    containment_indices = np.ones((len(METHODS), len(METHODS)), dtype=np.float64)
    pvalues = np.ones((len(METHODS), len(METHODS)), dtype=np.float64)
    annot_matrix = np.empty_like(containment_indices, dtype=object) # annotation matrix for plotting
    # Calculate containment indices and p-values for all method pairs
    for i, method1 in tqdm(enumerate(METHODS)):
        for j, method2 in enumerate(METHODS):
            if i != j:
                n1, r1 = reference_sets[method1]
                n2, r2 = reference_sets[method2]
                obs_ci = calculate_CI(r1, r2)[1]
                containment_indices[j, i] = obs_ci
                pvalues[j,i] = simulate(reference_sets["all"][0], n1, n2, obs_ci, seed=i*10000 + j)[0]
    # Correct for multiple testing
    significant_pairs, _ = BH_adjust(pvalues, alpha=ALPHA)
    for i, _ in tqdm(enumerate(METHODS)):
        for j, _ in enumerate(METHODS):
            if significant_pairs[i,j]:
                annot_matrix[i,j] = f"{containment_indices[i,j]:.2f}*"
            else:
                annot_matrix[i,j] = f"{containment_indices[i,j]:.2f}"
    # Start plotting
    method_names = [METHOD_LABELS[method] for method in METHODS]
    fig = plt.figure()
    sns.heatmap(
        containment_indices[:, 1:], 
        annot=annot_matrix[:, 1:], 
        cmap="Greys",
        cbar=True,
        linewidth=0.5,
        xticklabels=method_names[1:],
        yticklabels=method_names,
        fmt="",
        annot_kws={"size": 4},
        )
    plt.title(f"{experiment.capitalize()} experiments", size=20)
    plt.tight_layout()
    plt.savefig(f"{OUTPUT_DIR}/{experiment}/containment_indices.pdf", dpi=500, bbox_inches="tight") #assumes folders etc exist
    plt.close()

# NCBI Bacteria

## Helper functions

In [None]:
def determine_nonsingletons_global():
    path = f"root/reference_sets/all.tsv"
    count_per_species = {}
    with open(path, "r") as f_in:
        for line in f_in:
            species, _, _ = line.strip().split("\t")
            if species not in count_per_species:
                count_per_species[species] = 0
            count_per_species[species] += 1
    nonsingletons = [int(species) for species, count in count_per_species.items() if count > 1]
    return sorted(nonsingletons)
    
def determine_content_global(method, nonsingletons):
    path = f"root/reference_sets/{method}.tsv"
    counts = np.zeros(len(nonsingletons), dtype=int)
    items = set()
    with open(path, "r") as f_in:
        for line in f_in:
            species, accession, _ = line.strip().split("\t")
            species = int(species)
            if species in nonsingletons:
                idx = nonsingletons.index(species)
                counts[idx] += 1
                items.add(accession)
    return counts, items

## Actual plotting

In [None]:
ns = determine_nonsingletons_global()
reference_sets = {
    method: determine_content_global(method, ns) for method in METHODS
}
containment_indices = np.ones((len(METHODS), len(METHODS)), dtype=np.float64)
pvalues = np.ones((len(METHODS), len(METHODS)), dtype=np.float64)
annot_matrix = np.empty_like(containment_indices, dtype=object)
for i, method1 in enumerate(METHODS):
    for j, method2 in enumerate(METHODS):
        if i != j:
            n1, r1 = reference_sets[method1]
            n2, r2 = reference_sets[method2]
            obs_ci = calculate_CI(r1, r2)[1]
            containment_indices[j, i] = obs_ci
            pvalues[j,i] = simulate(reference_sets["all"][0], n1, n2, obs_ci, seed=i*10000 + j)[0]
# Correct for multiple testing
significant_pairs, _ = BH_adjust(pvalues, alpha=ALPHA)
for i, _ in enumerate(METHODS):
    for j, _ in enumerate(METHODS):
        if significant_pairs[i,j]:
            annot_matrix[i,j] = f"{containment_indices[i,j]:.2f}*"
        else:
            annot_matrix[i,j] = f"{containment_indices[i,j]:.2f}"
method_names = [METHOD_LABELS[method] for method in METHODS]
sns.heatmap(
    containment_indices[:, 1:], 
    annot=annot_matrix[:, 1:], 
    cmap="Greys",
    cbar=True,
    linewidth=0.5,
    xticklabels=method_names[1:],
    yticklabels=method_names,
    fmt="",
    annot_kws={"size": 4},
    )
plt.title(f"NCBI Bacteria", size=20)
plt.tight_layout()
plt.savefig(f"{OUTPUT_DIR}/containment_indices.pdf", dpi=500, bbox_inches="tight", format="pdf")            

# Core genomes figure
This will produce the figures in supplementary figure S7

## All experimental conditions (not all of NCBI)

### Helper function

In [None]:
def determine_content_coreset(experiment, method):
    path = f"root/reference_sets/{experiment}_experiments/{method}.tsv"
    genomes_per_species = {}
    succesfully_selected = set()
    with open(path, "r") as f_in:
        for line in f_in:
            species, accession, selected = line.strip().split("\t")
            if species not in genomes_per_species:
                genomes_per_species[species] = set()
            genomes_per_species[species].add(accession)
            if selected == "+": # account for species where selection method failed as these were randomly selected
                succesfully_selected.add(species)
    return genomes_per_species, succesfully_selected

### Actual plotting

In [None]:
"""
The idea here is to generate a figure (one per experiment) that shows how many species (percentually) have a genome in the core set. The x-axis shows the
cut-off (i.e. the minimum number of genomes in "all") and the y-axis shows the percentage of species with at least one genome in the core set.
This requires:
    - Determining the which genomes are selected for which species for all methods
    - For each cut-off, determine the number of species that have at least one genome in the core set
        + this can be achieved using set intersections
"""
for experiment in EXPERIMENTS:
    reference_sets = {
        method: determine_content_coreset(experiment, method) for method in METHODS
    } #maps method to (genomes per species, successful species)
    max_species = 0 #max for the x-axis -> max number of genomes in "all" for any species
    # Coresets per species
    coresets_per_species = {}
    coresets_per_species_centroid = {}
    for s in reference_sets["all"][0]:
        max_species = max(max_species, len(reference_sets["all"][0][s]))
        coresets_per_species[s] = reference_sets["all"][0][s].copy()
        coresets_per_species_centroid[s] = reference_sets["all"][0][s].copy()
        coresets_per_species_centroid[s] = coresets_per_species_centroid[s].intersection(reference_sets["centroid"][0][s])
        for method in METHODS[2:]: #skip 'all' and 'centroid'
            if s in reference_sets[method][1]: #only consider species where method succeeded
                coresets_per_species[s] = coresets_per_species[s].intersection(reference_sets[method][0][s]) #contains the genomes shared by every method for every species
                coresets_per_species_centroid[s] = coresets_per_species_centroid[s].intersection(reference_sets[method][0][s])
    total_per_cutoff = []
    total_genomes_per_cutoff = []
    num_per_cutoff = []
    num_per_cutoff_centroid = []
    coreset_per_cutoff = []
    coreset_per_cutoff_centroid = []

    lb = 1
    for cutoff in range(lb, max_species+1):
        cur_species = set()
        cur_total = 0
        for s in reference_sets["all"][0]:
            if len(reference_sets["all"][0][s]) >= cutoff:
                cur_species.add(s)
                cur_total += len(reference_sets["all"][0][s])
        total_per_cutoff.append(len(cur_species))
        total_genomes_per_cutoff.append(cur_total)
        num_species_in_coreset = 0
        num_species_in_coreset_centroid = 0
        num_genomes_in_coreset = 0
        num_genomes_in_coreset_centroid = 0
        for species in cur_species:
            if len(coresets_per_species[species]) > 0:
                num_species_in_coreset += 1
                num_genomes_in_coreset += len(coresets_per_species[species])
            if len(coresets_per_species_centroid[species]) > 0:
                num_species_in_coreset_centroid += 1
                num_genomes_in_coreset_centroid += len(coresets_per_species_centroid[species])
        num_per_cutoff.append(num_species_in_coreset)
        num_per_cutoff_centroid.append(num_species_in_coreset_centroid)
        coreset_per_cutoff.append(num_genomes_in_coreset)
        coreset_per_cutoff_centroid.append(num_genomes_in_coreset_centroid)

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(range(lb, max_species+1), total_per_cutoff, color="green", linewidth=1, label="Total number of species")
    ax.plot(range(lb, max_species+1), num_per_cutoff, color="blue", alpha=0.5, label="Number of species in coreset (excl. centroid)")
    ax.plot(range(lb, max_species+1), num_per_cutoff_centroid, color="orange", alpha=0.5, label="Number of species in coreset (incl. centroid)")
    ax.set_xlabel("Minimal number of genomes per species")
    ax.set_ylabel("Number of species")
    ax.set_xlim(lb, max_species)
    ax.set_xscale("log", base=10)
    ax.set_ylim(bottom=0)
    ax.set_title(f"Experiment: {experiment.capitalize()}")
    ax.legend()
    # Add annotation for sets with >= 2 genomes
    ax.vlines(2, ymin=0, ymax=num_per_cutoff_centroid[1], color="orange", linestyles=":", linewidth=1)
    ax.vlines(2, ymin=num_per_cutoff_centroid[1], ymax=num_per_cutoff[1], color="blue", linestyles=":", linewidth=1)
    ax.vlines(2, ymin=num_per_cutoff[1], ymax=total_per_cutoff[1], color="green", linestyles=":", linewidth=1)
    ax.hlines(num_per_cutoff_centroid[1], xmin=lb, xmax=2, color="orange", linestyles=":", linewidth=1)
    ax.hlines(num_per_cutoff[1], xmin=lb, xmax=2, color="blue", linestyles=":", linewidth=1)
    ax.hlines(total_per_cutoff[1], xmin=lb, xmax=2, color="green", linestyles=":", linewidth=1)
    xticks = list(ax.get_xticks()) + [2]
    yticks = list(ax.get_yticks()) + [num_per_cutoff_centroid[1]] +  [num_per_cutoff[1]] + [total_per_cutoff[1]]
    ax.set_yticks(sorted(set(yticks)))
    # Color
    for label in ax.get_yticklabels():
        if label.get_text() == f"{num_per_cutoff_centroid[1]}":
            label.set_color("orange")
            label.set_x(label.get_position()[0] - 0.05)
        elif label.get_text() == f"{num_per_cutoff[1]}":
            label.set_color("blue")
            label.set_x(label.get_position()[0] - 0.05)
        elif label.get_text() == f"{total_per_cutoff[1]}":
            label.set_color("green")
            label.set_x(label.get_position()[0] - 0.05)
    fig.savefig(f"{OUTPUT_DIR}/{experiment}/coreset_species.pdf", dpi=500, bbox_inches="tight")

    fig, ax = plt.subplots(figsize=(8, 6))
    #ax.plot(range(lb, max_species+1), total_genomes_per_cutoff, color="black", linewidth=1, label="Total number of genomes")
    ax.plot(range(lb, max_species+1), coreset_per_cutoff, color="blue", alpha=0.5, label="Number of genomes in coreset (excl. centroid)")
    ax.plot(range(lb, max_species+1), coreset_per_cutoff_centroid, color="orange", alpha=0.5, label="Number of genomes in coreset (incl. centroid)")
    ax.set_xlabel("Minimal number of genomes per species")
    ax.set_ylabel("Number of genomes")
    ax.set_xlim(lb, max_species)
    ax.set_xscale("log", base=10)
    ax.set_ylim(bottom=0)
    ax.set_title(f"Experiment: {experiment.capitalize()}")
    ax.legend()
    # Add annotation for sets with >= 2 genomes
    ax.vlines(2, ymin=0, ymax=coreset_per_cutoff_centroid[1], color="orange", linestyles=":", linewidth=1)
    ax.vlines(2, ymin=coreset_per_cutoff_centroid[1], ymax=coreset_per_cutoff[1], color="blue", linestyles=":", linewidth=1)
    ax.hlines(coreset_per_cutoff_centroid[1], xmin=lb, xmax=2, color="orange", linestyles=":", linewidth=1)
    ax.hlines(coreset_per_cutoff[1], xmin=lb, xmax=2, color="blue", linestyles=":", linewidth=1)
    xticks = list(ax.get_xticks()) + [2]
    yticks = list(ax.get_yticks()) + [coreset_per_cutoff_centroid[1]] +  [coreset_per_cutoff[1]]
    ax.set_yticks(sorted(set(yticks)))
    # Color
    for label in ax.get_yticklabels():
        if label.get_text() == f"{coreset_per_cutoff_centroid[1]}":
            label.set_color("orange")
            label.set_x(label.get_position()[0] - 0.05)
        elif label.get_text() == f"{coreset_per_cutoff[1]}":
            label.set_color("blue")
            label.set_x(label.get_position()[0] - 0.05)
    fig.savefig(f"{OUTPUT_DIR}/{experiment}/coreset_genomes.pdf", dpi=500, bbox_inches="tight")

## NCBI Bacteria

### Helper function

In [None]:
def determine_content_coreset_global(method):
    path = f"root/reference_sets/{method}.tsv"
    genomes_per_species = {}
    succesfully_selected = set()
    with open(path, "r") as f_in:
        for line in f_in:
            species, accession, selected = line.strip().split("\t")
            if species not in genomes_per_species:
                genomes_per_species[species] = set()
            genomes_per_species[species].add(accession)
            if selected == "+": # account for species where selection method failed as these were randomly selected
                succesfully_selected.add(species)
    return genomes_per_species, succesfully_selected

### Actual plotting

In [None]:
"""
The idea here is to generate a figure (one per experiment) that shows how many species (percentually) have a genome in the core set. The x-axis shows the
cut-off (i.e. the minimum number of genomes in "all") and the y-axis shows the percentage of species with at least one genome in the core set.
This requires:
    - Determining the which genomes are in which species for all methods
    - For each cut-off, determine the number of species that have at least one genome in the core set
        + this can be achieved using set intersections
"""
reference_sets = {
    method: determine_content_coreset_global(method) for method in METHODS
} #maps method to (genomes per species, successful species)
max_species = 0 #max for the x-axis -> max number of genomes in "all" for any species
# Coresets per species
coresets_per_species = {}
coresets_per_species_centroid = {}
for s in reference_sets["all"][0]:
    max_species = max(max_species, len(reference_sets["all"][0][s]))
    coresets_per_species[s] = reference_sets["all"][0][s].copy()
    coresets_per_species_centroid[s] = reference_sets["all"][0][s].copy()
    coresets_per_species_centroid[s] = coresets_per_species_centroid[s].intersection(reference_sets["centroid"][0][s])
    for method in METHODS[2:]: #skip 'all' and 'centroid'
        if s in reference_sets[method][1]: #only consider species where method succeeded
            coresets_per_species[s] = coresets_per_species[s].intersection(reference_sets[method][0][s]) #contains the genomes shared by every method for every species
            coresets_per_species_centroid[s] = coresets_per_species_centroid[s].intersection(reference_sets[method][0][s])
total_per_cutoff = []
total_genomes_per_cutoff = []
num_per_cutoff = []
num_per_cutoff_centroid = []
coreset_per_cutoff = []
coreset_per_cutoff_centroid = []    

lb = 1
for cutoff in range(lb, max_species+1): 
    cur_species = set()
    cur_total = 0
    for s in reference_sets["all"][0]:
        if len(reference_sets["all"][0][s]) >= cutoff:
            cur_species.add(s)
            cur_total += len(reference_sets["all"][0][s])
    total_per_cutoff.append(len(cur_species))
    total_genomes_per_cutoff.append(cur_total)
    num_species_in_coreset = 0
    num_species_in_coreset_centroid = 0
    num_genomes_in_coreset = 0
    num_genomes_in_coreset_centroid = 0
    for species in cur_species:
        if len(coresets_per_species[species]) > 0:
            num_species_in_coreset += 1
            num_genomes_in_coreset += len(coresets_per_species[species])
        if len(coresets_per_species_centroid[species]) > 0:
            num_species_in_coreset_centroid += 1
            num_genomes_in_coreset_centroid += len(coresets_per_species_centroid[species])
    num_per_cutoff.append(num_species_in_coreset)
    num_per_cutoff_centroid.append(num_species_in_coreset_centroid)
    coreset_per_cutoff.append(num_genomes_in_coreset)
    coreset_per_cutoff_centroid.append(num_genomes_in_coreset_centroid)

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(range(lb, max_species+1), total_per_cutoff, color="green", linewidth=1, label="Total number of species")
ax.plot(range(lb, max_species+1), num_per_cutoff, color="blue", alpha=0.5, label="Number of species in coreset (excl. centroid)")
ax.plot(range(lb, max_species+1), num_per_cutoff_centroid, color="orange", alpha=0.5, label="Number of species in coreset (incl. centroid)")
ax.set_xlabel("Minimal number of genomes per species")
ax.set_ylabel("Number of species")
ax.set_xlim(lb, max_species)
ax.set_xscale("log", base=10)
ax.set_ylim(bottom=0)
ax.set_title(f"NCBI Bacteria")
ax.legend()
# Add annotation for sets with >= 2 genomes
ax.vlines(2, ymin=0, ymax=num_per_cutoff_centroid[1], color="orange", linestyles=":", linewidth=1)
ax.vlines(2, ymin=num_per_cutoff_centroid[1], ymax=num_per_cutoff[1], color="blue", linestyles=":", linewidth=1)
ax.vlines(2, ymin=num_per_cutoff[1], ymax=total_per_cutoff[1], color="green", linestyles=":", linewidth=1)
ax.hlines(num_per_cutoff_centroid[1], xmin=lb, xmax=2, color="orange", linestyles=":", linewidth=1)
ax.hlines(num_per_cutoff[1], xmin=lb, xmax=2, color="blue", linestyles=":", linewidth=1)
ax.hlines(total_per_cutoff[1], xmin=lb, xmax=2, color="green", linestyles=":", linewidth=1)
xticks = list(ax.get_xticks()) + [2]
yticks = list(ax.get_yticks()) + [num_per_cutoff_centroid[1]] +  [num_per_cutoff[1]] + [total_per_cutoff[1]]
ax.set_yticks(sorted(set(yticks)))
# Color
for label in ax.get_yticklabels():
    if label.get_text() == f"{num_per_cutoff_centroid[1]}":
        label.set_color("orange")
        label.set_x(label.get_position()[0] - 0.05)
    elif label.get_text() == f"{num_per_cutoff[1]}":
        label.set_color("blue") 
        label.set_x(label.get_position()[0] - 0.05)
    elif label.get_text() == f"{total_per_cutoff[1]}":
        label.set_color("green")
        label.set_x(label.get_position()[0] - 0.05)
fig.savefig(f"{OUTPUT_DIR}/coreset_species.pdf", dpi=500, bbox_inches="tight")

fig, ax = plt.subplots(figsize=(8, 6))
#ax.plot(range(lb, max_species+1), total_genomes_per_cutoff, color="black", linewidth=1, label="Total number of genomes")
ax.plot(range(lb, max_species+1), coreset_per_cutoff, color="blue", alpha=0.5, label="Number of genomes in coreset (excl. centroid)")
ax.plot(range(lb, max_species+1), coreset_per_cutoff_centroid, color="orange", alpha=0.5, label="Number of genomes in coreset (incl. centroid)")
ax.set_xlabel("Minimal number of genomes per species")
ax.set_ylabel("Number of genomes")
ax.set_xlim(lb, max_species)
ax.set_xscale("log", base=10)
ax.set_ylim(bottom=0)
ax.set_title(f"NCBI Bacteria")
ax.legend()
# Add annotation for sets with >= 2 genomes
ax.vlines(2, ymin=0, ymax=coreset_per_cutoff_centroid[1], color="orange", linestyles=":", linewidth=1)
ax.vlines(2, ymin=coreset_per_cutoff_centroid[1], ymax=coreset_per_cutoff[1], color="blue", linestyles=":", linewidth=1)
ax.hlines(coreset_per_cutoff_centroid[1], xmin=lb, xmax=2, color="orange", linestyles=":", linewidth=1)
ax.hlines(coreset_per_cutoff[1], xmin=lb, xmax=2, color="blue", linestyles=":", linewidth=1)
xticks = list(ax.get_xticks()) + [2]
yticks = list(ax.get_yticks()) + [coreset_per_cutoff_centroid[1]] +  [coreset_per_cutoff[1]]
ax.set_yticks(sorted(set(yticks)))
# Color
for label in ax.get_yticklabels():
    if label.get_text() == f"{coreset_per_cutoff_centroid[1]}":
        label.set_color("orange")
        label.set_x(label.get_position()[0] - 0.05)
    elif label.get_text() == f"{coreset_per_cutoff[1]}":
        label.set_color("blue")
        label.set_x(label.get_position()[0] - 0.05)
fig.savefig(f"{OUTPUT_DIR}/{experiment}/coreset_genomes.pdf", dpi=500, bbox_inches="tight")