In [None]:
import os
from Bio import SeqIO
import matplotlib.pyplot as plt
import numpy as np
import math
from tqdm import tqdm
from multiprocessing import Pool
import seaborn as sns

# Code for generating the containment index figures
**NOTE:** paths have been anonymized.

## Code for experimental settings (not all Bacteria in NCBI)

In [None]:
EXPERIMENTS = [
    "order",
    "family",
    "genus",
]
METHODS = [
    "all",
	"centroid",
    "ggrasp",
    "meshclust-0.95",
	"meshclust-0.97",
	"meshclust-0.99",
    "gclust-0.95",
	"gclust-0.97",
	"gclust-0.99",
    "single-linkage-0.95",
	"single-linkage-0.97",
	"single-linkage-0.99",
	"complete-linkage-0.95",
	"complete-linkage-0.97",
	"complete-linkage-0.99",
]
METHOD_LABELS = {
    "all": "All",
    "centroid": "Centroid",
    "ggrasp": "GGRaSP",
    "meshclust-0.95": "MC-0.95",
    "meshclust-0.97": "MC-0.97",
    "meshclust-0.99": "MC-0.99",
    "gclust-0.95": "GC-0.95",
    "gclust-0.97": "GC-0.97",
    "gclust-0.99": "GC-0.99",
    "single-linkage-0.95": "SL-0.95",
    "single-linkage-0.97": "SL-0.97",
    "single-linkage-0.99": "SL-0.99",
    "complete-linkage-0.95": "CL-0.95",
    "complete-linkage-0.97": "CL-0.97",
    "complete-linkage-0.99": "CL-0.99"
}

ALPHA = 0.05

In [None]:
import numpy as np
from math import sqrt

def calculate_CI(refset1, refset2):
    if type(refset1) is np.ndarray:
        intersection = np.sum(np.minimum(refset1, refset2))
        return [intersection/np.sum(refset1), intersection/np.sum(refset2)]
    else:
        intersection = len(refset1.intersection(refset2))
        return [intersection/len(refset1), intersection/len(refset2)]

def calculate_jaccard(refset1, refset2):
    intersection = np.sum(np.minimum(refset1, refset2))
    union = np.sum(np.maximum(refset1, refset2))
    return intersection / union

def determine_nonsingletons(experiment):
    path = f"path_prefix/{experiment}_experiments/reference_sets_full/all_corrected.txt"
    count_per_species = {}
    with open(path, "r") as f_in:
        for line in f_in:
            kingdom, species, accession, _ = line.strip().split("\t")
            if species not in count_per_species and kingdom == "2":
                count_per_species[species] = 0
            if kingdom == "2":
                count_per_species[species] += 1
    nonsingletons = [int(species) for species, count in count_per_species.items() if count > 1]
    return sorted(nonsingletons)
    
def determine_content(experiment, method, nonsingletons):
    path = f"path_prefix/{experiment}_experiments/reference_sets_full/{method}_corrected.txt"
    counts = np.zeros(len(nonsingletons), dtype=int)
    items = set()
    with open(path, "r") as f_in:
        for line in f_in:
            kingdom, species, accession, _ = line.strip().split("\t")
            species = int(species)
            if species in nonsingletons:
                idx = nonsingletons.index(species)
                counts[idx] += 1
                items.add(accession)
    return counts, items

def simulate(N, n1, n2, obs_ci, num_simulations=10_000, seed=None):
    """
    Perform Monte Carlo simulation to estimate the probability of observing a certain overlap between two sets.
    NOTE: We specifically test whether the selection of n2 is independent of n1, i.e. we test the null hypothesis
    that the selection of n2 is random with respect to n1 (thus the n1 draws are fixed). This means that the p-values 
    we obtain are one-sided, i.e. we only test whether the overlap is significantly greater than expected by chance.

    Parameters:
    -----------
    N: numpy.ndarray
        A 1D vector with an entry for every nonsingleton species, indicating the total number of genomes
        for that species that could be selected (i.e. are in the "all" reference set).
    n1: numpy.ndarray
        A 1D vector with an entry for every nonsingleton species, indicating the number of genomes
        selected in the first set.
    n2: numpy.ndarray
        A 1D vector with an entry for every nonsingleton species, indicating the number of genomes
        selected in the second set.
    num_simulations: int
        The number of Monte Carlo simulations to perform.
    seed: int, optional
        Random seed for reproducibility.
    """
    rng = np.random.default_rng(seed)
    N = np.asarray(N, dtype=int)
    n1 = np.asarray(n1, dtype=int)
    n2 = np.asarray(n2, dtype=int)
    G = N.shape[0]

    # simulate overlaps
    sim_totals = np.zeros(num_simulations, dtype=int)
    for Ng, n1g, n2g in zip(N, n1, n2):
        draws = rng.hypergeometric(
            ngood=int(n1g),
            nbad=int(Ng-n1g),
            nsample=int(n2g),
            size=num_simulations
        )
        sim_totals += draws

    sim_ci = sim_totals / np.sum(n2)

    # Estimate p-values
    def pval(sim, obs):
        count = np.sum(sim >= obs)
        p = (1 + count) / (1 + num_simulations)
        se = sqrt(p * (1 - p) / num_simulations)
        return p, se

    return pval(sim_ci, obs_ci)

def BH_adjust(pvalues, alpha):
    """
    This does the same as statsmodels.stats.multitest.multipletests with method='fdr_bh'.
    """
    M = pvalues.shape[0]
    
    mask = ~np.eye(M, dtype=bool)
    pvec = pvalues[mask]

    m = len(pvec)

    order = np.argsort(pvec)
    psorted = pvec[order]
    ranks = np.arange(1, m+1)
    adj_sorted = psorted * m / ranks
    adj_sorted = np.minimum.accumulate(adj_sorted[::-1])[::-1]
    adj = np.empty_like(pvec)
    adj[order] = np.minimum(adj_sorted, 1.0)

    adj_pvalues = np.ones_like(pvalues)
    adj_pvalues[mask] = adj
    significant = adj_pvalues < alpha
    return significant, adj_pvalues

In [None]:
for experiment in EXPERIMENTS:
    ns = determine_nonsingletons(experiment)
    reference_sets = {
        method: determine_content(experiment, method, ns) for method in METHODS
    }
    containment_indices = np.ones((len(METHODS), len(METHODS)), dtype=np.float64)
    pvalues = np.ones((len(METHODS), len(METHODS)), dtype=np.float64)
    annot_matrix = np.empty_like(containment_indices, dtype=object)
    for i, method1 in enumerate(METHODS):
        for j, method2 in enumerate(METHODS):
            if i != j:
                n1, r1 = reference_sets[method1]
                n2, r2 = reference_sets[method2]
                obs_ci = calculate_CI(r1, r2)[1]
                containment_indices[j, i] = obs_ci
                pvalues[j,i] = simulate(reference_sets["all"][0], n1, n2, obs_ci, seed=i*10000 + j)[0]
    # Correct for multiple testing
    significant_pairs, _ = BH_adjust(pvalues, alpha=ALPHA)
    for i, _ in enumerate(METHODS):
        for j, _ in enumerate(METHODS):
            if significant_pairs[i,j]:
                annot_matrix[i,j] = f"{containment_indices[i,j]:.2f}*"
            else:
                annot_matrix[i,j] = f"{containment_indices[i,j]:.2f}"

    method_names = [METHOD_LABELS[method] for method in METHODS]
    fig = plt.figure()
    sns.heatmap(
        containment_indices[:, 1:], 
        annot=annot_matrix[:, 1:], 
        cmap="Greys",
        cbar=True,
        linewidth=0.5,
        xticklabels=method_names[1:],
        yticklabels=method_names,
        fmt="",
        annot_kws={"size": 4},
        )
    plt.title(f"{experiment.capitalize()} experiments", size=20)
    plt.tight_layout()
    plt.savefig(f"path_prefix/{experiment}/{experiment}_containment_indices.pdf", dpi=500, bbox_inches="tight")


## Code for all bacterial species in NCBI

In [None]:
import numpy as np
from math import sqrt


def determine_nonsingletons_global():
    path = f"path_prefix/all_corrected.txt"
    count_per_species = {}
    with open(path, "r") as f_in:
        for line in f_in:
            kingdom, species, accession, _ = line.strip().split("\t")
            if species not in count_per_species and kingdom == "2":
                count_per_species[species] = 0
            if kingdom == "2":
                count_per_species[species] += 1
    nonsingletons = [int(species) for species, count in count_per_species.items() if count > 1]
    return sorted(nonsingletons)
    
def determine_content_global(method, nonsingletons):
    path = f"path_prefix/{method}_corrected.txt"
    counts = np.zeros(len(nonsingletons), dtype=int)
    items = set()
    with open(path, "r") as f_in:
        for line in f_in:
            kingdom, species, accession, success = line.strip().split("\t")
            species = int(species)
            if species in nonsingletons:
                idx = nonsingletons.index(species)
                counts[idx] += 1
                items.add(accession)
    return counts, items


In [None]:
ns = determine_nonsingletons_global()
reference_sets = {
    method: determine_content_global(method, ns) for method in METHODS
}
containment_indices = np.ones((len(METHODS), len(METHODS)), dtype=np.float64)
pvalues = np.ones((len(METHODS), len(METHODS)), dtype=np.float64)
annot_matrix = np.empty_like(containment_indices, dtype=object)
for i, method1 in enumerate(METHODS):
    for j, method2 in enumerate(METHODS):
        if i != j:
            n1, r1 = reference_sets[method1]
            n2, r2 = reference_sets[method2]
            obs_ci = calculate_CI(r1, r2)[1]
            containment_indices[j, i] = obs_ci
            pvalues[j,i] = simulate(reference_sets["all"][0], n1, n2, obs_ci, seed=i*10000 + j)[0]
            if pvalues[j,i] < 0.01:
                annot_matrix[j,i] = f"{containment_indices[j,i]:.2f}*"
            else:
                annot_matrix[j,i] = f"{containment_indices[j,i]:.2f}"
        else:
            annot_matrix[j,i] = "1.00"
# Correct for multiple testing
significant_pairs, _ = BH_adjust(pvalues, alpha=ALPHA)
for i, _ in enumerate(METHODS):
    for j, _ in enumerate(METHODS):
        if significant_pairs[i,j]:
            annot_matrix[i,j] = f"{containment_indices[i,j]:.2f}*"
        else:
            annot_matrix[i,j] = f"{containment_indices[i,j]:.2f}"
method_names = [METHOD_LABELS[method] for method in METHODS]
sns.heatmap(
    containment_indices[:, 1:], 
    annot=annot_matrix[:, 1:], 
    cmap="Greys",
    cbar=True,
    linewidth=0.5,
    xticklabels=method_names[1:],
    yticklabels=method_names,
    fmt="",
    annot_kws={"size": 4},
    )
plt.title(f"NCBI Bacteria", size=20)
plt.tight_layout()
plt.savefig(f"path_prefix/NCBI_containment_indices.pdf", dpi=500, bbox_inches="tight", format="pdf")            