In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import os
from datetime import datetime
from matplotlib.colors import Normalize, ListedColormap
from scipy.stats import linregress, pearsonr
from scipy.stats import mannwhitneyu, wilcoxon
from statsmodels.stats.multitest import multipletests

from statannotations.Annotator import Annotator

In [2]:
pd.options.display.float_format = '{:.3f}'.format
pd.set_option('display.max_columns', None)

plt.rcParams['figure.dpi']=170

plt.rc('axes', linewidth=0.65)  # Adjust the line width of plot frames
plt.rc('xtick.major', width=0.0)
plt.rc('ytick.major', width=0.0)


DPI=250

In [3]:
from list_vars import LIST_PROFILERS, DIR_FIGURES, RESULTS_DIR

# In silico sample analysis

In this notebook we are going to do an analysis on the *in silico* samples, where we are going to study several variables.

---

## How many reads are incorrectly mapped if we do not perfom a host mapping step?

It has been reported that not mapping to human databases before profiling increases the number of reads assigned to other organisms. 

In this case, we are going to do 3 checks with the *in silico* dataset using pass2 (profiling after 2-time host mapping) and pass0 (direct profiling withou host mapping), and we are going to check the influence in parameter sensitivity:
-  We are going to see what is the total number of reads mapped to the human dataset, and what is the offset left unmapped which should have been mapped to human.
    - We are also going to do the same with the microbial reads, and see if more microbial reads have been assigned to the pass0 dataset.

Later in the analysis we are going to do two additional analyses:
-  We are going to see the number of species present in total between pass0 and pass2, and their jaccard index.
- We are going to calculate the ratio between the number of reads in pass0 and pass2.

In [4]:
df_host_map_info = pd.read_csv(f'{RESULTS_DIR}/counts/mapping_counts.txt', sep='\t').set_index('SAMPLE')

In [None]:
artificial_taxid_counts = pd.read_csv('table_artificial_taxid.csv', sep=';', names=['species', 'taxid', 'reads'])
artificial_taxid_counts['reads_true'] = (artificial_taxid_counts['reads'] / 2).astype(int)

n_true_human_reads = int(artificial_taxid_counts['reads_true'].iloc[0])
n_true_human_reads

In [None]:
n_mapped_reads_1and2_maps = df_host_map_info.loc['ARTIFICIAL', '1st_mapped'] + df_host_map_info.loc['ARTIFICIAL', '2nd_mapped']

print(f'There is a total of {n_mapped_reads_1and2_maps} reads mapped to human during the 1st and 2nd map, which represents around {100 * n_mapped_reads_1and2_maps/n_true_human_reads} % of the total number of reads ({n_true_human_reads}).')
print(f'There is a total of {n_true_human_reads - n_mapped_reads_1and2_maps} reads remaining to be mapped.')

In [None]:
df_host_profile_info = pd.read_csv(f'{RESULTS_DIR}/counts/profiling_counts_ARTIFICIAL.txt', sep='\t')
df_host_profile_info_artificial = df_host_profile_info[df_host_profile_info['SAMPLE'] == 'ARTIFICIAL']

df_host_profile_info_artificial['mapped_human_1_2_maps'] = 0
df_host_profile_info_artificial.loc[df_host_profile_info_artificial['pass'] == 2, 'mapped_human_1_2_maps'] = n_mapped_reads_1and2_maps

df_host_profile_info_artificial['mapped_human_total'] = df_host_profile_info_artificial['mapped_human_1_2_maps'] + df_host_profile_info_artificial['mapped_human']
df_host_profile_info_artificial['total_reads'] = df_host_profile_info_artificial['mapped_human_total'] + df_host_profile_info_artificial['mapped_others'] + df_host_profile_info_artificial['unmapped']

df_host_profile_info_artificial['observed_human_prop'] = df_host_profile_info_artificial['mapped_human_total'] / df_host_profile_info_artificial['total_reads']
df_host_profile_info_artificial['observed_others_prop'] = df_host_profile_info_artificial['mapped_others'] / df_host_profile_info_artificial['total_reads']
df_host_profile_info_artificial['observed_unmapped_prop'] = df_host_profile_info_artificial['unmapped'] / df_host_profile_info_artificial['total_reads']

df_host_profile_info_artificial['expected_human_prop'] = n_true_human_reads / artificial_taxid_counts['reads_true'].sum() # 0.8
df_host_profile_info_artificial['expected_others_prop'] = 1 - n_true_human_reads / artificial_taxid_counts['reads_true'].sum() # 0.8

df_host_profile_info_artificial['calculated_unmapped_human_prop'] = df_host_profile_info_artificial['expected_human_prop'] - df_host_profile_info_artificial['observed_human_prop']
df_host_profile_info_artificial['calculated_unmapped_others_prop'] = df_host_profile_info_artificial['expected_others_prop'] - df_host_profile_info_artificial['observed_others_prop']

df_host_profile_info_artificial['proportion_mapped_other_reads'] = df_host_profile_info_artificial['observed_others_prop'] /  df_host_profile_info_artificial['expected_others_prop']


for profiler in LIST_PROFILERS:
    display(profiler)
    display(df_host_profile_info_artificial[df_host_profile_info_artificial['profiler'] == profiler])


In [8]:
custom_palette = ["#648FFF", "#785EF0", "#DC267F", "#FE6100", "#FFB000", "#848484"]

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(8, 3))

# 1A) Check if there are differences in human read assignment.

# Step 1: Calculate the differences between pass 2 and pass 0 for each profiler and mode

pass_diff = (
    df_host_profile_info_artificial.pivot_table(
        index=["profiler", "mode"], columns="pass", values="observed_human_prop"
    )
    .reset_index()
)

# Ensure column names are integers
pass_diff.columns.name = None  # Remove the columns' name from pivot_table
pass_diff.columns = ['profiler', 'mode', 0, 2]  # Explicitly rename columns

# Calculate the difference
pass_diff["difference"] = 100 * (pass_diff[2] - pass_diff[0])



# Step 2: Plot the differences using a lineplot
g = sns.lineplot(
    data=pass_diff,
    x="mode",
    y="difference",
    hue="profiler",
    marker="o",
    palette=custom_palette,
    ax=axs[0],
    legend=False
)

axs[0].axhline(0, color="#848484", linestyle="--", linewidth=0.8)
axs[0].set_title("Human assignment", fontsize=14)
axs[0].set_xlabel("Mode", fontsize=12)
axs[0].set_ylabel("Diff pass2 - pass0 (%)", fontsize=12)



# 1B) Check if there are differences in non-human read assignment.

pass_diff = (
    df_host_profile_info_artificial.pivot_table(
        index=["profiler", "mode"], columns="pass", values="observed_others_prop"
    )
    .reset_index()
)

pass_diff["difference"] = 100 * (pass_diff[2] - pass_diff[0])



# Step 2: Plot the differences using a lineplot

h = sns.lineplot(
    data=pass_diff,
    x="mode",
    y="difference",
    hue="profiler",
    marker="o",
    palette=custom_palette,
    ax=axs[1],
)

sns.despine(top=True, right=True)

axs[1].axhline(0, color="gray", linestyle="--", linewidth=0.8)
axs[1].set_title("Non-human assignment", fontsize=12)
axs[1].set_xlabel("Mode", fontsize=12)
axs[1].set_ylabel("", fontsize=12)

for ax in axs:
    ax.set_xticks(range(1, 10))

axs[1].legend(title="Profiler", frameon=False, bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
plt.tight_layout()
plt.rc('axes', linewidth=0.65)
for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figA.{format}', dpi=DPI, )


In [None]:
df_host_profile_info_artificial

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(10, 3))

# 1A) Plot human read assignment for pass 0 and pass 2
for i, pass_num in enumerate([0, 2]):
    pass_data = (
        df_host_profile_info_artificial.pivot_table(
            index=["profiler", "mode"], columns="pass", values="observed_human_prop"
        )
        .reset_index()
    )
    
    pass_data[pass_num] *= 100

    sns.lineplot(
        data=pass_data,
        x="mode",
        y=pass_num,
        hue="profiler",
        marker="o",
        palette=custom_palette,
        legend=False,
        ax=axs[i]
    )

    axs[i].set_title(f"Human assignment\n(pass {pass_num})", fontsize=14)
    axs[i].set_xlabel("Mode", fontsize=12)
    axs[i].set_ylabel("Observed %", fontsize=12)

# 1B) Plot non-human read assignment for pass 0 and pass 2
for i, pass_num in enumerate([0, 2]):
    pass_data = (
        df_host_profile_info_artificial.pivot_table(
            index=["profiler", "mode"], columns="pass", values="observed_others_prop"
        )
        .reset_index()
    )

    pass_data[pass_num] *= 100

    sns.lineplot(
        data=pass_data,
        x="mode",
        y=pass_num,
        hue="profiler",
        marker="o",
        palette=custom_palette,
        legend=False,
        ax=axs[i + 2]
    )

    axs[i + 2].set_title(f"Non-human assignment\n(pass {pass_num})", fontsize=14)
    axs[i + 2].set_xlabel("Mode", fontsize=12)
    axs[i + 2].set_ylabel("Observed %", fontsize=12)

for ax in [axs[0], axs[1]]:
    ax.set_ylim([-1, 85])

for ax in [axs[2], axs[3]]:
    ax.set_ylim([0, 22])

for ax in axs:
    ax.set_xticks(np.arange(1, 10))

for ax in axs[1:]:
    ax.set_ylabel('')

sns.despine(top=True, right=True)
plt.tight_layout()
plt.rc('axes', linewidth=0.65)
for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figA2.{format}', dpi=DPI)
plt.show()


**What do we see here?**
- The number of reads assigned to humans without host mapping is very variable depending on the profiler. Centrifuge, krakenuniq and ganon and map human reads correctly, whereas kaiju, kraken2 fail to map the reads to human. The differences tend to decrease with the sensitivity mode, that is, paradoxically, a more strict read assignment leads to an improved number of human-mapped reads. However, this makes sense because more reads are assigned in general, and thus both human and non-human reads are mapped.
- However, this difference does not occur in non-human species. In general, non-human species are assigned equally with or without host mapping. This is interesting because we would expect a higher amount of reads assigned to non-human species originating from a false positive assignment of human reads, but seems not to be the case, even in profilers that have a high ammount of unmapped human reads.
    - Still, we have to take into acount that the profiler databases include a host mapping step.


In [None]:
table_artificial_taxcounts = pd.read_csv('../../src/version_2/table_artificial_taxid.csv', sep=';', names=['species', 'taxid', 'count'])
table_artificial_taxcounts = table_artificial_taxcounts[table_artificial_taxcounts['taxid'] != 9606]
table_artificial_taxcounts['abundance'] = 100 * table_artificial_taxcounts['count'] / table_artificial_taxcounts['count'].sum()
table_artificial_taxcounts

# Computing detection stats to aswer the questions

One of the parameters used during profiling is the mode of the profilers. Each profiler has a different set of parametters to include reads as valid or not. This may results in the detection of false positives and negatives. 

Here, we are going to study this effect in *in silico* samples to see if there are major changes. We are can measure the effectivity of several variables: 
- Categorical values: we can use each of the columns in the flag system to check how well were species assigned. We can use the precision (TP/TP + FP), recall (TP/TP+FN) and F1-score (2 x precision x recall / precision + recall) and Cohen's kappa.
$$\kappa = \frac{p_0-p_e}{1-p_e} \quad p_0 = \frac{TP + TN}{TP + FP + FN + TN} \quad p_e=\frac{TP + FP}{TP + FP + FN + TN}\cdot\frac{TP + FN }{TP + FP + FN + TN} + \frac{TN + FP}{TP + FP + FN + TN}\cdot\frac{TN + FN}{TP + FP + FN + TN}$$

- Numerical values: we can use the normalized value and the abundance to see how well are reads classified. For that we can use the expected number of reads and abundance. With that we will calculate the (1) difference between observed and expected categories and (2) the mean error and the (3) mean absolute error:
$$(1) \qquad DIFF_i = 100\cdot\frac{x_{obs,i} - x_{exp,i}}{x_{exp,i}}$$ 
$$(2) \qquad ME = E[DIFF_i] = \frac{100}{N}\sum\frac{x_{obs,i} - x_{exp,i}}{x_{exp,i}}$$ 
$$(3A) \qquad MAE = E[|DIFF_i|] = \frac{100}{N}\sum\frac{|x_{obs,i} - x_{exp,i}|}{x_{exp,i}}$$ 
$$(3B) \qquad MAED = \sigma[|DIFF_i|]$$ 
- For numerical values we are also going to calculate the pearson correlation between the observed and expected values, using a log10(1+x) transform

## Categorical values

In [13]:
def calculate_nominal_metrics(df_tax_ground_truth, df_flags_observed, column):
    list_expected_taxids = list(df_tax_ground_truth['taxid'].astype(int).values)
    list_observed_true_taxids = list(df_flags_observed.loc[df_flags_observed[column] == False, 'taxonomy_id'].astype(int).values)
    list_observed_false_taxids = list(df_flags_observed.loc[df_flags_observed[column] == True, 'taxonomy_id'].astype(int).values)

    TP = len([i for i in list_expected_taxids if i in list_observed_true_taxids])
    FN = len([i for i in list_expected_taxids if i not in list_observed_true_taxids])
    FP = len([i for i in list_observed_true_taxids if i not in list_expected_taxids])
    TN = len([i for i in list_observed_false_taxids if i not in list_expected_taxids])

    assert len(set(list_expected_taxids + list_observed_true_taxids + list_observed_false_taxids)) == TP + FN + FP + TN

    precision = TP / (TP + FP)
    recall = TP / (TP + FN)

    try:
        f1 = (2 * precision * recall) / (precision + recall)
    except:
        f1 = 0
    
    # Create kappa measures
    ALL = TP + FN + FP + TN
    p0 = (TP + TN) / (ALL)
    pe = (TP + FP)/ALL * (TP + FN)/ALL + (TN + FP)/ALL * (TN + FN)/ALL
    kappa = (p0 - pe) / (1 - pe)

    return precision, recall, f1, kappa, TP, FN, FP, TN



In [14]:
columns_selected = ['centrifuge_norm', 'ganon_norm', 'kaiju_norm', 'kmcp_norm', 'kraken2_norm', 'krakenuniq_norm',
                    'centrifuge_relab', 'ganon_relab', 'kaiju_relab', 'kmcp_relab', 'kraken2_relab', 'krakenuniq_relab',
                    'mean_norm', 'CV_norm', 'mean_relab', 'CV_relab']
df_nominal_stats = {'pass': [], 'mode': [], 'S': [], 'column': [], 'precision': [], 'recall': [], 'f1': [], 
                    'kappa': [], 'TP|FN|FP|TN': []}

for passn in [0, 2]:
    for mode in range(1, 10):
        for S in [0, 1, 2, 3, 4, 5, 6, 7, 10, 15]:
            for column in columns_selected: 
                summary_table_flags = pd.read_csv(f'{RESULTS_DIR}/summary/ARTIFICIAL_pass{passn}_mode{mode}_taxspecies_S{S}.flags.tsv', sep='\t')
                try:
                    precision, recall, f1, kappa, TP, FN, FP, TN = calculate_nominal_metrics(table_artificial_taxcounts, summary_table_flags, column)
                except KeyError:
                    continue 

                df_nominal_stats['pass'].append(passn)
                df_nominal_stats['mode'].append(mode)
                df_nominal_stats['S'].append(S)
                df_nominal_stats['column'].append(column)

                df_nominal_stats['precision'].append(precision)
                df_nominal_stats['recall'].append(recall)
                df_nominal_stats['f1'].append(f1)
                df_nominal_stats['kappa'].append(kappa)
                df_nominal_stats['TP|FN|FP|TN'].append((TP, FN, FP, TN))

df_nominal_stats = pd.DataFrame(df_nominal_stats)

In [None]:
df_nominal_stats

## Numerical values

In [16]:
def compute_mad(values):
    median = np.median(values)
    mad = np.median(np.abs(values - median))
    return mad

def calculate_numerical_metrics(df_tax_ground_truth, df_counts_observed, profiler, suffix):
    df_tax_ground_truth = df_tax_ground_truth.copy().set_index('taxid')
    df_counts_observed = df_counts_observed.copy().set_index('taxonomy_id')

    list_expected_taxids = df_tax_ground_truth.index.astype(int).values
    list_observed_true_taxids = df_counts_observed.index.astype(int).values

    combined_taxid = np.intersect1d(list_expected_taxids, list_observed_true_taxids)
    species = df_tax_ground_truth.loc[combined_taxid, 'species'].values
    observed_counts = df_counts_observed.loc[combined_taxid, f'{profiler}_{suffix}'].fillna(0).values

    expected_col = 'count' if suffix == 'norm' else 'abundance'
    expected_counts = df_tax_ground_truth.loc[combined_taxid, expected_col].values

    diff_counts = 100 * (observed_counts - expected_counts) / expected_counts

    MRE_counts = np.mean(diff_counts)
    MRED_counts = np.std(diff_counts)
    MAE_counts = np.mean(np.abs(diff_counts))
    MAED_counts = np.std(np.abs(diff_counts))
    MACV_counts = MAED_counts / MAE_counts

    if len(combined_taxid) > 35:
        x,y = expected_counts, observed_counts 

        slope, _, r_value, _, _ = linregress(x, y)
        r2 = r_value ** 2
    else:
        slope, r2 = np.nan, np.nan

    return diff_counts, MRE_counts, MRED_counts, MAE_counts, MAED_counts, MACV_counts, slope, r2, combined_taxid, species, expected_counts, observed_counts


In [None]:
for passn in [2]:
    for mode in range(5, 6):
            for profiler in ['mean']: 
                try:
                    summary_table_flags = pd.read_csv(f'{RESULTS_DIR}/summary/ARTIFICIAL_pass{passn}_mode{mode}_taxspecies_S{S}.flags.tsv', sep='\t')
                    summary_table_counts = pd.read_csv(f'{RESULTS_DIR}/summary/ARTIFICIAL_pass{passn}_mode{mode}_taxspecies_S{S}.diversity.tsv', sep='\t')
                    diff_counts, MRE_counts, MRED_counts, MAE_counts, MAED_counts, MACV_counts, corr_counts, rmse_counts, taxids_counts, species_counts, expected_counts, observed_counts = \
                        calculate_numerical_metrics(table_artificial_taxcounts, summary_table_counts, \
                                                                        profiler, suffix='norm')
                    diff_abundance, MRE_abundance, MRED_abundance, MAE_abundance, MAED_abundance, MACV_abundance, corr_abundance, rmse_abundance, taxids_abundance, species_abundance, expected_abundance, observed_abundance = \
                        calculate_numerical_metrics(table_artificial_taxcounts, summary_table_counts, \
                                                                        profiler, suffix='relab')
                except KeyError:
                    raise

df = pd.DataFrame({'species': species_counts, 'expected_counts': expected_counts, 'observed_counts': observed_counts, 'diff_abundance': diff_counts})
df


In [18]:
df_numerical_stats = {'pass': [], 'mode': [], 'profiler': [], 
                    'diff_counts': [], 'MRE_counts': [],'MRED_counts': [], 'MAE_counts': [], 'MAED_counts': [], 'MACV_counts': [], 
                    'corr_counts': [], 'R2_counts': [], 'taxid_counts': [], 
                    'species_counts': [], 'expected_counts': [], 'observed_counts': [], 
                    'diff_abundance': [], 'MRE_abundance': [],'MRED_abundance': [], 'MAE_abundance': [], 'MAED_abundance': [], 'MACV_abundance': [], 
                    'corr_abundance': [], 'R2_abundance': [], 'taxid_abundance': [], 
                    'species_abundance': [], 'expected_abundance': [], 'observed_abundance': [],  }

for passn in [0, 2]:
    for mode in range(1, 10):
            for profiler in LIST_PROFILERS + ['mean']: 
                summary_table_flags = pd.read_csv(f'{RESULTS_DIR}/summary/ARTIFICIAL_pass{passn}_mode{mode}_taxspecies_S{S}.flags.tsv', sep='\t')
                summary_table_counts = pd.read_csv(f'{RESULTS_DIR}/summary/ARTIFICIAL_pass{passn}_mode{mode}_taxspecies_S{S}.diversity.tsv', sep='\t')
                try:
                    diff_counts, MRE_counts, MRED_counts, MAE_counts, MAED_counts, MACV_counts, corr_counts, rmse_counts, taxids_counts, species_counts, expected_counts, observed_counts = \
                        calculate_numerical_metrics(table_artificial_taxcounts, summary_table_counts, \
                                                                        profiler, suffix='norm')
                    diff_abundance, MRE_abundance, MRED_abundance, MAE_abundance, MAED_abundance, MACV_abundance, corr_abundance, rmse_abundance, taxids_abundance, species_abundance, expected_abundance, observed_abundance = \
                        calculate_numerical_metrics(table_artificial_taxcounts, summary_table_counts, \
                                                                        profiler, suffix='relab')
                except KeyError:
                    raise
                    continue 
                
                df_numerical_stats['pass'].append(passn)
                df_numerical_stats['mode'].append(mode)
                df_numerical_stats['profiler'].append(profiler)

                df_numerical_stats['diff_counts'].append(diff_counts)
                df_numerical_stats['MRE_counts'].append(MRE_counts)          
                df_numerical_stats['MRED_counts'].append(MRED_counts)              
                df_numerical_stats['MAE_counts'].append(MAE_counts)                
                df_numerical_stats['MAED_counts'].append(MAED_counts)                
                df_numerical_stats['MACV_counts'].append(MACV_counts)                
                df_numerical_stats['corr_counts'].append(corr_counts)                
                df_numerical_stats['R2_counts'].append(rmse_counts) 
                df_numerical_stats['taxid_counts'].append(taxids_counts)
                df_numerical_stats['species_counts'].append(species_counts)
                df_numerical_stats['expected_counts'].append(expected_counts)
                df_numerical_stats['observed_counts'].append(observed_counts)

                df_numerical_stats['diff_abundance'].append(diff_abundance)
                df_numerical_stats['MRE_abundance'].append(MRE_abundance)
                df_numerical_stats['MRED_abundance'].append(MRED_abundance)

                df_numerical_stats['MAE_abundance'].append(MAE_abundance)
                df_numerical_stats['MAED_abundance'].append(MAED_abundance)
                df_numerical_stats['MACV_abundance'].append(MACV_abundance)
                df_numerical_stats['corr_abundance'].append(corr_abundance)                
                df_numerical_stats['R2_abundance'].append(rmse_abundance) 
                df_numerical_stats['taxid_abundance'].append(taxids_abundance)
                df_numerical_stats['species_abundance'].append(species_abundance)
                df_numerical_stats['expected_abundance'].append(expected_abundance)
                df_numerical_stats['observed_abundance'].append(observed_abundance)

df_numerical_stats = pd.DataFrame(df_numerical_stats)

In [None]:
df_numerical_stats

## Analysis of kappa/F1

F1-score and $\kappa$ are quite different measures but we observe that they are correlated in this data.

In [None]:
# Compute and print correlation (Pearson by default)
corr = df_nominal_stats['f1'].corr(df_nominal_stats['kappa'])

print("Pearson correlation between F1 and Kappa:", corr)

# Scatter plot with the identity line
plt.figure(figsize=(4, 4))

# Add the identity line (y = x)
plt.plot([0, 1], [0, 1], color='red', linestyle='--', alpha=0.5)

sns.scatterplot(x='f1', y='kappa', data=df_nominal_stats, s=20)

plt.text(0.01, 0.95, f'R$^2$: {corr**2:.5f}')

sns.despine(top=True, right=True)

# Add title and labels
plt.title('F1-score vs Kappa')
plt.xlabel('F1-score')
plt.ylabel('Kappa')
plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/fig0.kappa-f1.{format}', dpi=DPI, bbox_inches='tight' )

plt.show()


In that sense, we can then use one of the measures to explain the results and don't need the second one. We are going to select the F1 score because it has a more clear interpretability and it is related to precision and recall, which are alrady being used.

In [None]:
import random

def cohen_kappa(tp, fp, fn, tn):
    """
    Compute Cohen's kappa given the confusion matrix counts:
      TP = true positives
      FP = false positives
      FN = false negatives
      TN = true negatives
    """
    N = tp + fp + fn + tn  # total
    
    # Observed agreement
    po = (tp + tn) / N
    
    # Expected agreement
    # (actual positives * predicted positives) + (actual negatives * predicted negatives)
    p_a = (tp + fn) / N  # actual positives
    p_p = (tp + fp) / N  # predicted positives
    p_n = (fp + tn) / N  # actual negatives? Actually "actual positives" vs "predicted negatives" can be spelled out:

    p_n_act = (fp + tn) / N   # actual negatives
    p_n_pred = (fn + tn) / N  # predicted negatives
    
    pe = p_a * p_p + p_n_act * p_n_pred
    
    if np.isclose(pe, 1.0):
        # If pe is 1, the denominator (1 - pe) -> 0, can cause issues
        return np.nan
    
    kappa = (po - pe) / (1 - pe)
    return kappa

def f1_score_simple(tp, fp, fn):
    """
    Compute the F1 score from TP, FP, FN:
      F1 = 2 * TP / (2*TP + FP + FN)
    """
    denom = (2 * tp + fp + fn)
    if denom == 0:
        return np.nan
    return 2 * tp / denom

# --- Simulation parameters ---
N = 3000        # total samples per confusion matrix
n_sims = 100000  # number of random simulations

f1_values = []
kappa_values = []
table_values = []

for _ in range(n_sims):
    # We need TP, FP, FN, TN >= 0 and sum = N.
    # One way: sample three random integers and let the fourth be what's left.
    # This ensures the sum is N.
    
    tp = random.randint(0, N)
    fp = random.randint(0, N - tp)
    fn = random.randint(0, N - tp - fp)
    tn = N - tp - fp - fn
    
    # Compute F1
    f1 = f1_score_simple(tp, fp, fn)
    # Compute Kappa
    kappa = cohen_kappa(tp, fp, fn, tn)
    
    # We only keep valid calculations (not NaN)
    if not np.isnan(f1) and not np.isnan(kappa):
        f1_values.append(f1)
        kappa_values.append(kappa)
        table_values.append((tp, fp, fn, tn))

# Convert to numpy arrays for correlation
f1_values = np.array(f1_values)
kappa_values = np.array(kappa_values)

# Compute Pearson correlation
corr, p_value = pearsonr(f1_values, kappa_values)

print(f"Number of valid simulations: {len(f1_values)} / {n_sims}")
print(f"Pearson correlation between F1 and Kappa: {corr:.3f} (p = {p_value:.3e})")

# Plot the scatter of F1 vs Kappa
plt.figure(figsize=(7, 5))
plt.scatter(f1_values, kappa_values, alpha=0.2, s=10)
plt.title("Monte Carlo Simulation (TP, FP, FN, TN random)")
plt.xlabel("F1 Score")
plt.ylabel("Cohen's Kappa")

plt.plot([0, 1], [0, 1])

plt.grid(True)
plt.show()


In [None]:
N = 0.8

df = pd.DataFrame({'f1': f1_values, 'kappa': kappa_values, 'table': table_values})

df[(df['f1'] > N - 0.03) & (df['f1'] < N + 0.03) & (df['kappa'] < N + 0.03) & (df['kappa'] > N- 0.03)]

In [None]:
N = 1

df[(df['f1'] > N - 0.03) & (df['f1'] < N + 0.03) & (df['kappa'] < N + 0.03) & (df['kappa'] > N- 0.03)]

In [None]:
N = 0

df[(df['f1'] > N - 0.03) & (df['f1'] < N + 0.03) & (df['kappa'] < N + 0.03) & (df['kappa'] > N- 0.03)]

# How does the S parametter used during curve fitting affect?

The S parametter is useful to tweak the detection results, so that we can include more or less species during the flagging step. Since it is a structural parametter, we want to fit it first so that we can answer several other comparisons.

To do this we are going to use the nominal variables and their derived statistics.


## Checking recall/precision/F1-score for inclusion/exclusion of species

In [25]:
cols = [f'{i}_norm' for i in LIST_PROFILERS] + ['mean_norm']
modes = range(2, 9)
passn = [2]
S_values = df_nominal_stats['S'].unique()

In [26]:
subset_df = df_nominal_stats[(df_nominal_stats['pass'].isin(passn)) & \
                             (df_nominal_stats['column'].isin(cols)) & \
                              (df_nominal_stats['mode'].isin(modes))]

In [None]:
melted_df = pd.melt(
    subset_df,
    id_vars=['mode', 'S', 'column'],
    value_vars=['recall', 'precision', 'f1'],
    var_name='metric',
    value_name='score'
)

# Create a colormap for 'mode'
norm = Normalize(vmin=melted_df['mode'].min(), vmax=melted_df['mode'].max())
cmap = plt.cm.viridis  # Choose a colormap (e.g., 'viridis', 'plasma', 'cividis')

# Create a FacetGrid: 6x3 grid (row for each profiler, column for each metric)
g = sns.FacetGrid(
    melted_df, 
    col='column', 
    row='metric', 
    height=2, 
    sharey=True, 
    sharex=True
)

# Map the lineplot to the grid
def lineplot_with_cmap(data, **kwargs):
    for mode in sorted(data['mode'].unique()):
        subset = data[data['mode'] == mode]
        plt.plot(subset['S'], subset['score'], label=f"Mode {mode}",
                 color=cmap(norm(mode)), marker='o')

g.map_dataframe(lineplot_with_cmap)

# Create a legend for the discrete modes
handles = [
    plt.Line2D([0], [0], color=cmap(norm(mode)), marker='o', linestyle='', label=f"Mode {mode}")
    for mode in sorted(melted_df['mode'].unique())
]
plt.legend(
    handles=handles, 
    title="", 
    bbox_to_anchor=(1.05, 3), 
    loc='center left', 
    frameon=False
)

# Set x-axis ticks (if you have specific S values)
g.set(xticks=subset_df['S'].unique())

for ax in g.axes.ravel():
    ax.set_title('')

# Add axis labels and titles
for ax, profiler in zip(g.axes[0, :], melted_df['column'].unique()):
    ax.set_title(profiler.replace('_norm', ''))

for ax, score in zip(g.axes[:, 0], ['recall', 'precision', 'F1-score']):
    ax.set_ylabel(score)

plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figE.{format}', dpi=DPI, bbox_inches='tight')


plt.show()

In [None]:
g = sns.FacetGrid(
    melted_df, 
    col='column', 
    row='metric', 
    height=3, 
    sharey=True, 
    sharex=True
)

g.map(sns.boxplot, 'S', 'score')

# Set x-axis ticks (if you have specific S values)
for ax in g.axes.ravel():
    ax.set_title('')

# Add axis labels and titles
for ax, profiler in zip(g.axes[0, :], melted_df['column'].unique()):
    ax.set_title(profiler.replace('_norm', ''))

for ax, score in zip(g.axes[:, 0], ['recall', 'precision', 'F1-score']):
    ax.set_ylabel(score)

plt.subplots_adjust(top=0.9)
g.fig.suptitle("Metrics by Profiler and Mode", fontsize=16)

plt.show()

The aim of this part of the analysis was to select the "optimal" `S` to then make other comparisons and extract proper conclusions. 
If we look at individual profilers, the aim is not the select the `S` with best F1 score, but to select the smallest `S` that provides a sufficiently high recall, ensuring that we don't lose TP species. This threshold depends on the profiler. For CEN it is 6-7, GAN is 7-10, KAI is 1-2, KR2 is 4-5, KRU is 5-6. We see that at these values the precision drops (expectedly), but it remains stable afterwards for most profilers. Therefore, a value of `S=2` should be sufficient to ensure that the results are correct. 

**IMPORTANT**: conceptually, if we are choosng the number of the species based on the mean number of reads, using a low `S` is still good, because the number of reads reported by profilers that are not flagged are still considered. That is, the # of reads assigned to one profiler is the same regardless of S. 

The advantage of using the mean value instead of the individual profilers is that it tends to retrieve a better stability on the precision throughout the S values and modes. In fact, each profiler has an individual dinamic, and the mean value averages them all. Therefore, using the averaged value for `S` allows us to choose the parameter with better predictability, ensuring that we don't choose a very high value, while at the same time keeping the correct number of reads.

Therefore, we are going to choose `S=2` and `S=7` for comparisons of robustness with biological samples. 

#  Does pass0/pass2 (no host pre-mapping vs host pre-mapping) affect the detection of the species?

For this part we are going to run run analyses:
- Retrieve the raw detection of species with the passes, and calculate their jaccard index.
- Calculate $\chi^2$ for each case and see if there are significative differences.
- Calculate the Pearson correlation + RMSE for several mode values.


## Total number of species and Jaccard index

In [29]:
# Plot Jaccard index between the different species

"""
For this part we are going to read all the .standardised.species reports and simply read the number of species and compute a table with the total number of species and the jaccard index
"""

def get_n_species(sample, mode, profiler, S=15):
    pass0_df = pd.read_csv(f'{RESULTS_DIR}/summary/{sample}_pass0_mode{mode}_taxspecies_S{S}.diversity.tsv', sep='\t').set_index('taxonomy_id')
    pass2_df = pd.read_csv(f'{RESULTS_DIR}/summary/{sample}_pass2_mode{mode}_taxspecies_S{S}.diversity.tsv', sep='\t').set_index('taxonomy_id')

    pass0_profiler = pass0_df[profiler].dropna()
    pass2_profiler = pass2_df[profiler].dropna()

    pass0_taxids = pass0_profiler.index.values
    pass2_taxids = pass2_profiler.index.values

    jaccard_index = len(np.intersect1d(pass0_taxids, pass2_taxids)) / len(np.union1d(pass0_taxids, pass2_taxids))

    return len(pass0_taxids), len(pass2_taxids), jaccard_index


dict_n_species = {'profiler': [],
                  'mode': [],
                  'pass 0 species': [],
                  'pass 2 species': [],
                  'jaccard': []}

for profiler in LIST_PROFILERS + ['mean']:
    for mode in range(1,10):
        try:
            pass0_n_species, pass2_n_species, jaccard_index = get_n_species('ARTIFICIAL', mode, profiler + '_norm')
            dict_n_species['profiler'].append(profiler)
            dict_n_species['mode'].append(mode)
            dict_n_species['pass 0 species'].append(pass0_n_species)
            dict_n_species['pass 2 species'].append(pass2_n_species)
            dict_n_species['jaccard'].append(jaccard_index)
        except:
            print(f'No entry added for profiler {profiler} and mode {mode}')

df_n_species = pd.DataFrame(dict_n_species)


In [None]:
plt.rc('axes', linewidth=0.65)

# Initialize a 3x2 grid for subplots
fig, axes = plt.subplots(2, 3, figsize=(11, 6))
axes = axes.flatten()

# Set Seaborn style
sns.set_theme(style="white")

# Define custom labels for the shared legend
legend_elements = [
    plt.Line2D([0], [0], marker='o', color='#648FFF', label='Pass 0', markersize=8, linestyle='None'),
    plt.Line2D([0], [0], marker='o', color='#785EF0', label='Pass 2', markersize=8, linestyle='None'),
    plt.Line2D([0], [0], marker='o', color='#DC267F', label='Jaccard Index', markersize=8, linestyle='-')
]

# Plot for each profiler
for i, profiler in enumerate(LIST_PROFILERS + ['mean']):
    df_profiler = df_n_species[df_n_species['profiler'] == profiler]
    ax1 = axes[i]
    
    # Scatter plots for `pass 0 species` and `pass 2 species`
    sns.scatterplot(
        data=df_profiler, x='mode', y='pass 0 species', ax=ax1, color='#648FFF', s=50
    )
    sns.scatterplot(
        data=df_profiler, x='mode', y='pass 2 species', ax=ax1, color='#785EF0', s=50
    )
    
    sns.despine(top=True, right=True)
    
    sns.set_theme(style="white")
    
    # Configure the primary y-axis
    ax1.set_xticks(range(1, 10))
    ax1.set_xlabel('Mode')
    ax1.set_ylabel('Species Count', color='black')
    #ax1.set_yscale('log')  # Optional: Log scale
    ax1.set_title(f'{profiler.capitalize()}')
    
    # Create a secondary y-axis for Jaccard index
    ax2 = ax1.twinx()
    sns.lineplot(
        data=df_profiler, x='mode', y='jaccard', ax=ax2, color='#DC267F', marker='o'
    )
    ax2.set_ylabel('Jaccard Index', color='#DC267F')

    sns.despine(top=True, right=True)
# Hide unused axes
for j in range(len(LIST_PROFILERS) + 1, len(axes)):
    axes[j].axis('off')

# Add a shared legend
fig.legend(
    handles=legend_elements, loc='center right', frameon=False, title="", bbox_to_anchor=(1.15, 0.5)
)

plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figB.{format}', dpi=DPI, bbox_inches='tight' )

In [None]:
sns.lineplot(df_n_species, x='mode', y='jaccard', hue='profiler', marker='o', palette=custom_palette)
plt.legend(
    loc='center right', frameon=False, bbox_to_anchor=(1.35, 0.5),
)

The number of detected species tends to go up with the mode, which is expected (with the exception of krakenuniq). For most profilers the increase is linear, but for centrifuge the increase is exponential, although it may reach the top of detected species at some point.

We see that in general Jaccard indexes are very high. We should consider that this plot is not "strictly" relevant because many, many species are false positive, so we don't really care about the number of falsely assigned species. Still, it is important to note that pass does not have, grosso modo, a relevant impact. It is interesting to find that for kaiju and centrifuge the mode has an impact. For Kaiju it decreases; this it might be because the number of spurious species has increased; while for centrifuge the jaccard index may increase because their number of species is so high (12000!!) that we may be reaching the maximum of discoverable amount of species, and therefore it is obvious that the Jaccard index will increase in that case. For the rest of profilers the mode does not seem to affect the Jaccard index, regardless of the total number of species detected. 

## Checking statistical differences in the truth table.

We are going to use the truth tables to compute statistically differential capture of species. We are going to use `S=2` and `S=7` throughout the different modes and profilers.

In [32]:
from scipy.stats import chi2_contingency

In [33]:
SMALL_VAL = 0.1

df_sub = df_nominal_stats[(df_nominal_stats['column'].isin([f'{i}_norm' for i in LIST_PROFILERS] + ['mean_norm'])) & 
                          (df_nominal_stats['S'].isin([0, 1, 2, 3, 4, 5, 6, 7, 10, 15]))].copy()

# Initialize a list to store chi-squared results
chi2_results = []

# Iterate over unique combinations of mode, S, and column
for (mode, S, column), group in df_sub.groupby(['mode', 'S', 'column']):
    try:
        # Extract contingency tables for full (TP, FN, FP, TN) and partial (TP, FP, FN)
        full_contingency_table = []
        partial_contingency_table = []
        precision_diff, recall_diff, f1_diff = None, None, None  # Initialize differences
        
        for p in [0, 2]:
            data = group[group['pass'] == p]
            if len(data):
                (tp, fn, fp, tn) = data.iloc[0]['TP|FN|FP|TN']
                full_contingency_table.append([tp, fn, fp, tn])  # Full table
                partial_contingency_table.append([tp, fp, fn])   # Partial table

                # Calculate differences for precision, recall, and f1
                if p == 0:
                    precision_0 = data.iloc[0]['precision']
                    recall_0 = data.iloc[0]['recall']
                    f1_0 = data.iloc[0]['f1']
                elif p == 2:
                    precision_2 = data.iloc[0]['precision']
                    recall_2 = data.iloc[0]['recall']
                    f1_2 = data.iloc[0]['f1']

        # Compute differences
        if 'precision_0' in locals() and 'precision_2' in locals():
            precision_diff = precision_2 - precision_0
            recall_diff = recall_2 - recall_0
            f1_diff = f1_2 - f1_0

        # Perform chi-squared test for full contingency table
        if len(full_contingency_table) == 2:
            chi2_full, p_value_full, _, _ = chi2_contingency(np.array(full_contingency_table) + SMALL_VAL)
        
        # Perform chi-squared test for partial contingency table
        if len(partial_contingency_table) == 2:
            chi2_partial, p_value_partial, _, _ = chi2_contingency(np.array(partial_contingency_table) + SMALL_VAL)
        
        # Store results
        chi2_results.append({
            'mode': mode,
            'S': S,
            'column': column,
            'chi2_full': chi2_full if len(full_contingency_table) == 2 else None,
            'p_value_full': p_value_full if len(full_contingency_table) == 2 else None,
            'chi2_partial': chi2_partial if len(partial_contingency_table) == 2 else None,
            'p_value_partial': p_value_partial if len(partial_contingency_table) == 2 else None,
            'precision_diff': precision_diff,
            'recall_diff': recall_diff,
            'f1_diff': f1_diff,
            'stats_pass0': full_contingency_table[0],
            'stats_pass2': full_contingency_table[1],
        })
    except Exception as e:
        print(f"Error processing {mode}, {S}, {column}: {e}")

# Convert results to a DataFrame
df_pass_chi2_stats = pd.DataFrame(chi2_results)


In [None]:
# Add significance categories based on p-value thresholds
alpha = 0.1 

df_pass_chi2_stats['significance'] = df_pass_chi2_stats.apply(
    lambda row: (
        'Both Significant' if row['p_value_full'] < alpha and row['p_value_partial'] < alpha else
        'Only Full Significant' if row['p_value_full'] < alpha else
        'Only Partial Significant' if row['p_value_partial'] < alpha else
        'Neither Significant'
    ),
    axis=1
)

# Count points in each quadrant
quadrant_counts = {
    'Both Significant': len(df_pass_chi2_stats[(df_pass_chi2_stats['p_value_full'] < alpha) & (df_pass_chi2_stats['p_value_partial'] < alpha)]),
    'Only Full Significant': len(df_pass_chi2_stats[(df_pass_chi2_stats['p_value_full'] < alpha) & (df_pass_chi2_stats['p_value_partial'] >= alpha)]),
    'Only Partial Significant': len(df_pass_chi2_stats[(df_pass_chi2_stats['p_value_full'] >= alpha) & (df_pass_chi2_stats['p_value_partial'] < alpha)]),
    'Neither Significant': len(df_pass_chi2_stats[(df_pass_chi2_stats['p_value_full'] >= alpha) & (df_pass_chi2_stats['p_value_partial'] >= alpha)])
}

# Set up the plots
fig, axs = plt.subplots(1, 2, figsize=(12, 6))

# First plot: chi2_full vs chi2_partial
sns.scatterplot(data=df_pass_chi2_stats, x='chi2_full', y='chi2_partial', ax=axs[0])
axs[0].set_title('Chi2 Full vs Partial')
axs[0].set_xlabel('Chi2 Full')
axs[0].set_ylabel('Chi2 Partial')

# Second plot: p_value_full vs p_value_partial with significance categories
sns.scatterplot(
    data=df_pass_chi2_stats,
    x='p_value_full',
    y='p_value_partial',
    hue='significance',
    palette={
        'Both Significant': '#bc0000',
        'Only Full Significant': '#0000bc',
        'Only Partial Significant': '#00bc00',
        'Neither Significant': '#aaaaaa'
    },
    ax=axs[1]
)

# Add thresholds
axs[1].axhline(alpha, color='black', linestyle='--', linewidth=1)
axs[1].axvline(alpha, color='black', linestyle='--', linewidth=1)

# Annotate quadrant counts
axs[1].text(0.25, 0.75, f"Both Significant\n{quadrant_counts['Both Significant']}", 
            ha='center', color='#bc0000', fontsize=12, bbox=dict(facecolor='white', edgecolor='none'))
axs[1].text(0.75, 0.75, f"Only Full Significant\n{quadrant_counts['Only Full Significant']}", 
            ha='center', color='#0000bc', fontsize=12, bbox=dict(facecolor='white', edgecolor='none'))
axs[1].text(0.25, 0.25, f"Only Partial Significant\n{quadrant_counts['Only Partial Significant']}", 
            ha='center', color='#00bc00', fontsize=12, bbox=dict(facecolor='white', edgecolor='none'))
axs[1].text(0.75, 0.25, f"Neither Significant\n{quadrant_counts['Neither Significant']}", 
            ha='center', color='#aaaaaa', fontsize=12, bbox=dict(facecolor='white', edgecolor='none'))

# Customize the plot
axs[1].set_title('P-Value Comparison with Significance')
axs[1].set_xlabel('P-Value Full')
axs[1].set_ylabel('P-Value Partial')
axs[1].set_xlim(0, 1)
axs[1].set_ylim(0, 1)

plt.legend(loc='center right', frameon=False, bbox_to_anchor=(1.65, 0.5))
# Adjust layout
plt.tight_layout()

# Show the plots
plt.show()

In [None]:
df_pass_chi2_stats[df_pass_chi2_stats['significance'] != 'Neither Significant'].sort_values(by=['mode', 'S'])

### Only using the full table (for publication)

In [None]:
plt.rc('axes', linewidth=0.65)


df_pass_chi2_stats['are_diff'] = [True if i < alpha else False for i in df_pass_chi2_stats['p_value_full']]


print((df_pass_chi2_stats['are_diff'] == True).sum(), (df_pass_chi2_stats['are_diff'] == False).sum())

display(df_pass_chi2_stats)
display(df_pass_chi2_stats[df_pass_chi2_stats['are_diff'] == True])



# Create a grid of three axes for the plots
fig, axs = plt.subplots(1, 3, figsize=(7.5, 2.7), sharey=True)

# Define the metrics to plot
metrics = ['f1_diff', 'precision_diff', 'recall_diff']
titles = ['F1 Score', 'Precision', 'Recall']


# Iterate through the metrics and create a boxplot for each
for ax, metric, title in zip(axs, metrics, titles):
    sns.boxplot(
        data=df_pass_chi2_stats, 
        x="are_diff", 
        y=metric, 
        palette=sns.color_palette(['#4E79A7', '#A0CBE8']),
        ax=ax
    )
    sns.despine(top=True, right=True, ax=ax)

    ax.set_title(title, fontsize=14, pad=40)
    ax.set_xlabel("Group", fontsize=12)
    if metric == 'f1_diff':  # Label y-axis only for the first plot
        ax.set_ylabel("Difference (pass2 - pass0)", fontsize=12)
    else:
        ax.set_ylabel("")

    # Highlight medians
    medians = df_pass_chi2_stats.groupby("are_diff")[metric].median()
    for i, median in enumerate(medians):
        ax.text(i, median, f"{median:.2f}", ha='center', va='bottom', fontsize=10)
    
    pairs=[(False, True)]

    annotator = Annotator(ax, pairs, data=df_pass_chi2_stats, x='are_diff', y=metric)
    annotator.configure(test='Mann-Whitney', text_format='simple', loc='outside', line_width=1)
    annotator.apply_and_annotate()


    ax.set_ylim([-0.8, 1])
    ax.set_yticks([-0.5, 0, 0.5])


# Adjust layout and save the plot
plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figC.{format}', dpi=DPI)

plt.show()



In [None]:
df_pass_chi2_stats[df_pass_chi2_stats['are_diff'] == True].groupby('S').count()

In [None]:
df_pass_chi2_stats[(df_pass_chi2_stats['are_diff'] == True) & (df_pass_chi2_stats['S'] == 0)][['precision_diff',	'recall_diff',	'f1_diff']].median()

In [None]:
df_pass_chi2_stats[(df_pass_chi2_stats['are_diff'] == True) & (df_pass_chi2_stats['S'] == 3)][['precision_diff',	'recall_diff',	'f1_diff']].median()

In [None]:
df_pass_chi2_stats[df_pass_chi2_stats['are_diff'] == True].groupby('column').count()

In [None]:
df_pass_chi2_stats[(df_pass_chi2_stats['are_diff'] == True) & (df_pass_chi2_stats['column'] == 'ganon_norm')][['precision_diff',	'recall_diff',	'f1_diff']].median()

In [None]:
df_pass_chi2_stats[(df_pass_chi2_stats['are_diff'] == True) & (df_pass_chi2_stats['column'] == 'mean_norm')][['precision_diff',	'recall_diff',	'f1_diff']].median()

In [None]:
df_pass_chi2_stats[df_pass_chi2_stats['are_diff'] == True].groupby('mode').count()

As we see, using a second pass generally does not significantly change the assignment of TP/FP/FN | TN.

In the cases where it changes, the f1-score generally improves, either by increasing the number of TP (by reducing the amount of FN) or reduces the amount of false positives.

Therefore, looking at nominal info **choosing pass2 is the best option**.

## Checking at the correlation between the read counts

Now we are going to check the proportion of reads that are correctly assigned to gold truth species. For that we are goin two show: 
- A correlation plot with pass0 and pass2 reads. We are also going to calculate the correlation between pass0 / pass2 and the expected counts.

In [None]:
df_numerical_stats

In [45]:
# Filter for pass=0 and pass=2
df_pass_0 = df_numerical_stats[df_numerical_stats['pass'] == 0].set_index(['mode', 'profiler'])
df_pass_2 = df_numerical_stats[df_numerical_stats['pass'] == 2].set_index(['mode', 'profiler'])

# Align the two DataFrames to ensure consistent indexing
aligned_df = pd.concat([df_pass_0, df_pass_2], axis=1, keys=['pass0', 'pass2'])

# Extract relevant columns and observed counts for pass=0 and pass=2
comparison_df = aligned_df[['pass0', 'pass2']].apply(
    lambda x: pd.Series({
        'passn': 'comparison',  # Mark as comparison for clarity
        'mode': x.name[0],
        'profiler': x.name[1],
        'taxid_counts': x['pass0']['taxid_counts'],  # Keep consistent metadata
        'expected_counts': x['pass0']['expected_counts'],
        'pass0_counts': x['pass0']['observed_counts'],
        'pass2_counts': x['pass2']['observed_counts'],
        'pass0_MAE': x['pass0']['MAE_counts'],
        'pass0_diff': x['pass0']['diff_counts'],
        'pass2_MAE': x['pass2']['MAE_counts'],
        'pass2_diff': x['pass2']['diff_counts'],
    }),
    axis=1
)

# Reset index for the final DataFrame
comparison_df = comparison_df.reset_index(drop=True)

In [None]:
plt.rc('axes', linewidth=0.65)

# Filter the DataFrame for the specific mode (5 in this case)
mode_df = comparison_df[comparison_df['mode'] == 5]

# Get unique profilers for layout
profilers = mode_df['profiler'].unique()
num_profilers = len(profilers)

# Set up the figure with GridSpec for custom layout
fig1, axs1 = plt.subplots(1, num_profilers, figsize=(2 * num_profilers, 2.5), sharex=True, sharey=True, layout='compressed')
fig2, axs2 = plt.subplots(1, num_profilers, figsize=(2 * num_profilers, 2.5), sharex=True, sharey=True, layout='compressed')
fig3, axs3 = plt.subplots(1, num_profilers, figsize=(2 * num_profilers, 2.5), sharex=True, sharey=True, layout='compressed')

# Define plotting functions with regression line
def plot_pass0_vs_pass2(ax, data, profiler, i):
    x = data['pass0_counts'].values[0]
    y = data['pass2_counts'].values[0]
    
    # Scatter plot
    ax.scatter(x, y, alpha=0.7)
    
    # Plot y=x line
    min_val, max_val = min(x.min(), y.min()), max(x.max(), y.max())
    ax.plot([0, 350000], [0, 350000], '--', color='#DC267F')
    
    # Pearson correlation and R^2
    slope, intercept, r_value, _, _ = linregress(x, y)
    r2 = r_value ** 2
    ax.annotate(f"Slope: {slope:.2f}\nR$^2$: {r2:.2f}", xy=(0.05, 0.95),
                xycoords='axes fraction', fontsize=12, verticalalignment='top')
    if i == 0:
        ax.set_ylabel("Pass0 vs Pass2", fontsize=12)
    # Set profiler title on top
    ax.set_title(profiler, fontsize=14)

    sns.despine(top=True, right=True, ax=ax)

    
def plot_with_regression(ax, x, y, xlabel, i):
    # Scatter plot
    ax.scatter(x, y, alpha=0.7)

    # Regression line
    slope, intercept, r_value, _, _ = linregress(x, y)
    ax.plot(x, slope * x + intercept, '-', color='#DC267F')

    # Pearson correlation and R^2
    r2 = r_value ** 2
    ax.annotate(f"Slope: {slope:.2f}\nR$^2$: {r2:.2f}", xy=(0.05, 0.95),
                xycoords='axes fraction', fontsize=12, verticalalignment='top')
    if i == 0:
        ax.set_ylabel(xlabel, fontsize=12)
    sns.despine(top=True, right=True, ax=ax)

# Iterate through profilers and create subplots for each
for i, profiler in enumerate(profilers):
    profiler_data = mode_df[mode_df['profiler'] == profiler]
    x_pass0 = profiler_data['pass0_counts'].values[0]
    x_expected = profiler_data['expected_counts'].values[0]
    y_pass2 = profiler_data['pass2_counts'].values[0]
    
    

    # Pass0 vs Pass2 (Row 1)
    ax1 = fig1.add_subplot(axs1[i])
    plot_pass0_vs_pass2(ax1, profiler_data, profiler, i)
    
    # Pass2 vs Expected (Row 2)
    ax2 = fig2.add_subplot(axs2[i])
    plot_with_regression(ax2, x_expected, y_pass2, "Pass2 vs Expected", i)
    
    # Pass0 vs Expected (Row 3)
    ax3 = fig3.add_subplot(axs3[i])
    plot_with_regression(ax3, x_expected, x_pass0, "Pass0 vs Expected", i)

    for ax in [ax1, ax2, ax3]:
        ax.set_xticks([0, 300000])
        ax.set_yticks([0, 300000])



# Aesthetic adjustments
for axs in [axs1, axs2, axs3]:
    for ax in axs:
        ax.set_xlabel("Counts", fontsize=12)  # Only bottom row has x-labels


for i, fig in enumerate([fig1, fig2, fig3]):

    for format in ['png', 'tiff']: 
        plt.savefig(f'{RESULTS_DIR}/figures/paper/figD{i+1}.{format}', dpi=DPI)


plt.show()



We see that the correlation in number of counts for the ground truth species is almost one. Additionally, if there are differences, they improve the correlation to the expected amount of counts, so it has a positive impact.

## Checking at MAE values

We are going to see if the MAE values are equal or higher in pass2 vs pass0. 

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt


comparison_df_exploded = comparison_df.melt(
    id_vars=['mode', 'profiler'],
    value_vars=['pass0_diff', 'pass2_diff'],
    var_name='pass_type',
    value_name='diff_value'
)

# Convert `diff_value` from a list to separate rows
comparison_df_exploded = comparison_df_exploded.explode('diff_value')

# Ensure `diff_value` is numeric after exploding
comparison_df_exploded['diff_value'] = pd.to_numeric(comparison_df_exploded['diff_value'])




# Get unique profilers
profilers = comparison_df_exploded['profiler'].unique()
num_profilers = len(profilers)

# Set up the figure
fig, axs = plt.subplots(1, num_profilers, figsize=(6 * num_profilers, 6), sharey=True)

# Iterate through profilers and plot the data
for i, profiler in enumerate(profilers):
    ax = axs[i]
    profiler_data = comparison_df_exploded[comparison_df_exploded['profiler'] == profiler]

    # Create the plot
    sns.boxplot(
        data=profiler_data,
        x='mode',
        y='diff_value',
        hue='pass_type',
        ax=ax,
        palette='viridis'
    )
    
    # Set plot title and labels
    ax.set_title(f'Profiler: {profiler}', fontsize=14)
    ax.set_xlabel('Mode', fontsize=12)
    if i == 0:
        ax.set_ylabel('Diff Values', fontsize=12)
    else:
        ax.set_ylabel('')

# Adjust layout and display
plt.tight_layout()
plt.show()


We see that, again, the differences in MAE are not due to the pass, with the exception of ganon (and mean by extension), where running the data with pass2 improves the error rate. 

Thus, **both based on nominal and numerical criteria, running with pass2 is the best option**.

# Which mode is best for the data?

We have used several modes to study how it affects the detection of species, and the number of reads it detects. We are going to use the same methods as before to check for the answer.
- We are going to check which mode (using S=2 and S=7) retains the best F1 scores.
- We are going to check which mode keeps the best correlation / MAE with expected counts.

## Checking F1-scores

In [48]:
df_nominal_stats_sub = df_nominal_stats[(df_nominal_stats['pass'] == 2) & (df_nominal_stats['S'].isin([0, 2, 7, 15])) & (df_nominal_stats['column'].isin([f'{i}_norm' for i in LIST_PROFILERS + ['mean']]))]

In [None]:
df_nominal_stats_sub

In [None]:
plt.rc('axes', linewidth=0.65)

melted_df = pd.melt(
    df_nominal_stats_sub,
    id_vars=['mode', 'S', 'column'],
    value_vars=['recall', 'precision', 'f1'],
    var_name='metric',
    value_name='score'
)

# Create a colormap for 'mode'
norm = Normalize(vmin=melted_df['mode'].min(), vmax=melted_df['mode'].max())
cmap = plt.cm.viridis  # Choose a colormap (e.g., 'viridis', 'plasma', 'cividis')

# Create a FacetGrid: 6x3 grid (row for each profiler, column for each metric)
g = sns.FacetGrid(
    melted_df, 
    col='column', 
    row='metric', 
    height=2, 
    sharey=True, 
    sharex=True
)

# Map the lineplot to the grid
def lineplot_with_cmap(data, **kwargs):
    for S in sorted(data['S'].unique()):
        subset = data[data['S'] == S]
        plt.plot(subset['mode'], subset['score'], label=f"Mode {mode}",
                 color=cmap(norm(S)), marker='o')

g.map_dataframe(lineplot_with_cmap)

# Create a legend for the discrete modes
handles = [
    plt.Line2D([0], [0], color=cmap(norm(mode)), marker='o', linestyle='', label=f"S {mode}")
    for mode in sorted(melted_df['S'].unique())
]
plt.legend(
    handles=handles, 
    title="", 
    bbox_to_anchor=(1.05, 3), 
    loc='center left', 
    frameon=False
)

# Set x-axis ticks (if you have specific S values)
g.set(xticks=df_nominal_stats_sub['mode'].unique())

for ax in g.axes.ravel():
    ax.set_title('')

# Add axis labels and titles
for ax, profiler in zip(g.axes[0, :], melted_df['column'].unique()):
    ax.set_title(profiler.replace('_norm', ''))

for ax, score in zip(g.axes[:, 0], ['recall', 'precision', 'F1-score']):
    ax.set_ylabel(score)

plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figF.{format}', dpi=DPI, bbox_inches='tight')

plt.show()

Based on the table of truth, the precision, recall and f1-score are more S-dependent than mode-dependent; and their patterns are more profiler dependent than anything else. Therefore, the mode is not completely relevant for a proper species detection and, if so, since higher mode values lead to a lower precision (more FP), we could consider using a lower mode if we want to minimize the FPs values. 

## Checking correlation and MAE values

In [51]:
df_numerical_stats_sub = df_numerical_stats[(df_numerical_stats['pass'] == 2) & (df_numerical_stats['mode'].isin([1, 3, 5, 7, 9]))]

In [None]:
# Get unique profilers for layout
profilers = df_numerical_stats_sub['profiler'].unique()
modes = df_numerical_stats_sub['mode'].unique()


# Set up the figure with GridSpec for custom layout
fig, axs = plt.subplots(len(profilers), len(modes), figsize=(2 * len(modes), 2 * len(profilers)), sharex=True, sharey=True)

# Define plotting functions with regression line
def plot_regression(ax, data, profiler, i):
    x = data['expected_counts'].values[0]
    y = data['observed_counts'].values[0]
    
    # Scatter plot
    ax.scatter(x, y, alpha=0.7)
    
    # Plot y=x line
    min_val, max_val = min(x.min(), y.min()), max(x.max(), y.max())
    ax.plot([0, 350000], [0, 350000], 'r--')
    
    # Pearson correlation and R^2
    slope, intercept, r_value, _, _ = linregress(x, y)
    r2 = r_value ** 2
    ax.annotate(f"Slope: {slope:.2f}\nR2: {r2:.2f}", xy=(0.05, 0.95),
                xycoords='axes fraction', fontsize=10, verticalalignment='top')
    

# Iterate through profilers and create subplots for each
for i, profiler in enumerate(profilers):
    for j, mode in enumerate(modes):
        profiler_data = df_numerical_stats_sub[(df_numerical_stats_sub['profiler'] == profiler) & (df_numerical_stats_sub['mode'] == mode)]
        
        # Pass0 vs Pass2 (Row 1)
        ax = fig.add_subplot(axs[i, j])
        plot_regression(ax, profiler_data, profiler, i)

        if j == 0:
            ax.set_ylabel(profiler, fontsize=12)
            # Set profiler title on top
        if i == 0:
            ax.set_title(f'Mode {mode}', fontsize=14)

# Aesthetic adjustments
for ax in axs[-1, :]:
    ax.set_xlabel("Counts", fontsize=12)  # Only bottom row has x-labels

plt.tight_layout()
plt.show()



In [None]:
# Prepare the data for plotting
expanded_data = []
for _, row in df_numerical_stats_sub.iterrows():
    profiler = row['profiler']
    mode = row['mode']
    for value in row['diff_counts']:
        expanded_data.append({'profiler': profiler, 'mode': mode, 'diff_counts': value})

# Convert to a new DataFrame
plot_data = pd.DataFrame(expanded_data)

# Initialize the grid of plots with seaborn
g = sns.FacetGrid(plot_data, col="profiler", height=4, aspect=1)

# Plot the data on each facet
g.map_dataframe(
    sns.boxplot,
    x="mode",
    y="diff_counts",
    palette="viridis"
)

# Adjust the legend and titles
g.set_axis_labels("Mode", "Difference Counts")
g.set_titles(col_template="{col_name} Profiler")

plt.tight_layout()
plt.show()

In [None]:
from matplotlib.ticker import MaxNLocator

plt.rc('axes', linewidth=0.65)

# Initialize the grid of plots with seaborn
g = sns.FacetGrid(df_numerical_stats_sub, col="profiler", height=2.3, aspect=0.85, sharex=False, sharey=False)

# Plot the data on each facet
g.map_dataframe(
    sns.scatterplot,
    x="corr_counts",
    y="R2_counts",
    hue="mode",
    palette="viridis"
)

# Adjust the legend and titles
g.add_legend(title="Mode", bbox_to_anchor=(1.05, 0.5))

g.set_axis_labels("Correlation Counts", "R² Counts")
g.set_titles(col_template="{col_name}")

for ax in g.axes.flat:
    ax.xaxis.set_major_locator(MaxNLocator(nbins=2)) 


plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figGA.{format}', dpi=DPI, bbox_inches='tight')

plt.show()

In [None]:
g = sns.FacetGrid(df_numerical_stats_sub, col="profiler", height=2.3, aspect=0.85, sharex=False, sharey=False)

# Plot the data on each facet
g.map_dataframe(
    sns.scatterplot,
    x="MAE_counts",
    y="MAED_counts",
    hue="mode",
    palette="viridis"
)

# Adjust the legend and titles
g.add_legend(title="Mode", bbox_to_anchor=(1.05, 0.5))

g.set_axis_labels("MAE", "MAED")
g.set_titles(col_template="{col_name}")

plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figGB.{format}', dpi=DPI, bbox_inches='tight')


plt.show()

In [None]:
g = sns.FacetGrid(df_numerical_stats_sub, col="profiler", height=2.3, aspect=0.85, sharex=False, sharey=False)

# Plot the data on each facet
g.map_dataframe(
    sns.scatterplot,
    x="MRE_counts",
    y="MRED_counts",
    hue="mode",
    palette="viridis"
)

# Adjust the legend and titles
g.add_legend(title="Mode", bbox_to_anchor=(1.05, 0.5))

g.set_axis_labels("MRE", "MRED")
g.set_titles(col_template="{col_name}")

plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figGC.{format}', dpi=DPI, bbox_inches='tight')


plt.show()

In [None]:
df

In [None]:
# Create the DataFrame
for stat in ['MRE', 'MRED']:
    df = df_numerical_stats[df_numerical_stats['pass'] == 2].copy()



    # Calculate the differences from the mean for each mode
    df_mean = df[df['profiler'] == 'mean'][['mode', f'{stat}_counts']].rename(columns={f'{stat}_counts': f'mean_{stat}'})
    df = df.merge(df_mean, on='mode')
    df['difference'] = df[f'{stat}_counts'] - df[f'mean_{stat}']


    # Filter out the mean profiler itself
    df_filtered = df[df['profiler'] != 'mean'].sort_values('profiler')


    # Perform Mann-Whitney U tests
    profiler_groups = df_filtered['profiler'].unique()
    p_values = []
    results = []


    for profiler in profiler_groups:
        group_data = df_filtered[df_filtered['profiler'] == profiler]
        mean_data = group_data[f'{stat}_counts'].values
        prof_data = group_data[f'mean_{stat}'].values
        statistic, p = wilcoxon(mean_data, prof_data, alternative='greater')
        p_values.append(p)
        results.append({'profiler': profiler, 'statistic': statistic, 'p_value': p})

    # Correct p-values for multiple testing
    p_values_corrected = multipletests(p_values, method='fdr_bh')[1]

    for i, result in enumerate(results):
        result['p_value_corrected'] = p_values_corrected[i]


    # Convert results to DataFrame
    results_df = pd.DataFrame(results)

    display(results_df)

    # Plot the differences
    fig, ax = plt.subplots(1, 1, figsize=(5, 3))

    sns.boxplot(data=df_filtered, x='profiler', y='difference', palette=custom_palette, ax=ax)
    # sns.stripplot(data=df_filtered, x='profiler', y='difference', color='black', alpha=0.5, jitter=True)
    medians = df_filtered.groupby("profiler")['difference'].median()
    for i, median in enumerate(medians):
            ax.text(i, median , f"{median:.2f}", ha='center', va='bottom', fontsize=8)


    plt.axhline(0, linestyle='--', color='#848484', linewidth=0.5)
    plt.ylabel(f'{stat} difference from mean')
    plt.xlabel('')

    sns.despine(top=True, right=True)

    plt.tight_layout()

    for format in ['png', 'tiff']: 
        plt.savefig(f'{RESULTS_DIR}/figures/paper/figJ_{stat}.{format}', dpi=DPI, bbox_inches='tight')


    plt.show()

In [None]:
# Create the DataFrame
df = df_numerical_stats[df_numerical_stats['pass'] == 2].copy()



# Calculate the differences from the mean for each mode
df_mean = df[df['profiler'] == 'mean'][['mode', 'MAED_counts']].rename(columns={'MAED_counts': 'mean_MAED'})
df = df.merge(df_mean, on='mode')
df['difference'] = df['MAED_counts'] - df['mean_MAED']


# Filter out the mean profiler itself
df_filtered = df[df['profiler'] != 'mean'].sort_values('profiler')


# Perform Mann-Whitney U tests
profiler_groups = df_filtered['profiler'].unique()
p_values = []
results = []


for profiler in profiler_groups:
    group_data = df_filtered[df_filtered['profiler'] == profiler]
    mean_data = group_data['MAED_counts'].values
    prof_data = group_data['mean_MAED'].values
    stat, p = wilcoxon(mean_data, prof_data, alternative='greater')
    p_values.append(p)
    results.append({'profiler': profiler, 'statistic': stat, 'p_value': p})

# Correct p-values for multiple testing
p_values_corrected = multipletests(p_values, method='fdr_bh')[1]

for i, result in enumerate(results):
    result['p_value_corrected'] = p_values_corrected[i]


# Convert results to DataFrame
results_df = pd.DataFrame(results)

display(results_df)

# Plot the differences
fig, ax = plt.subplots(1, 1, figsize=(5, 3))

sns.boxplot(data=df_filtered, x='profiler', y='difference', palette=custom_palette, ax=ax)
# sns.stripplot(data=df_filtered, x='profiler', y='difference', color='black', alpha=0.5, jitter=True)
medians = df_filtered.groupby("profiler")['difference'].median()
for i, median in enumerate(medians):
        ax.text(i, median , f"{median:.2f}", ha='center', va='bottom', fontsize=8)



plt.axhline(0, linestyle='--', color='#848484', linewidth=0.5)
plt.ylabel('MAED difference from mean')
plt.xlabel('')

sns.despine(top=True, right=True)

plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figJ_MAED.{format}', dpi=DPI, bbox_inches='tight')


plt.show()

In [None]:
df_numerical_stats_sub[df_numerical_stats_sub['mode'].isin([3, 7])]

In [61]:
def plot_prof_couns(df_combo):
    # Assuming your data is in a DataFrame called `df`
    # Expand the diff_counts column into individual rows for plotting
    df_expanded = df_combo.explode('diff_counts')
    df_expanded['diff_counts'] = pd.to_numeric(df_expanded['diff_counts'])

    # Create the plot
    plt.figure(figsize=(10, 2.5))
    

    for profiler, countsdiff in zip(df_combo['profiler'], df_combo['diff_counts']):
        me = np.median(countsdiff)
        med = np.std(countsdiff)
        plt.scatter(me, profiler, color='#DC267F', label='_nolegend_', s=100, zorder=3, marker = '|')
        plt.plot([me-med, me+med], [profiler, profiler], color='#DC267F', label='_nolegend_',)

    sns.stripplot(
        data=df_expanded,
        x='diff_counts',
        y='profiler',
        jitter=True,  # Adds jitter for better visibility of points
        size=5,  # Adjust point size
        alpha=0.7,  # Slight transparency
        color="#648FFF"
    )

    # Add a vertical line at x=0
    plt.axvline(0, color='#848484', linestyle='--', linewidth=1)

    sns.despine(top=True, right=True)

    # Add labels and title
    # plt.title('Diff Counts Across Profilers', fontsize=14)
    plt.xlabel('PRE', fontsize=12)
    plt.ylabel('', fontsize=12)
    plt.grid(True, axis='x', linestyle='--', alpha=0.6)
    plt.tight_layout()

    for format in ['png', 'tiff']: 
        plt.savefig(f'{RESULTS_DIR}/figures/paper/figH{mode}.{format}', dpi=DPI, bbox_inches='tight')

    plt.show()

In [None]:
for mode in [3,5,7]:
    print(mode)
    df_combo = df_numerical_stats[(df_numerical_stats['pass'] == 2) & \
                              (df_numerical_stats['mode'] == mode)] # We choose a large number because S in not relevant here (but with small S we may select few datasets)

    plot_prof_couns(df_combo)


Based on these results, we see that a higher mode increases the correlation (slope) and R2; and decreases the MAE (although it increases the MAED). However, again, we see that this effect is very profiler-dependent. It is interesting to note, however, that there are some species that have a very bad detection rate throughout the profiler/mode/S values.

Based on that, considering that higher modes (1) lower the MAE / increase the correlation but (2) decrease the precision; but considering that the effect on (1) is more pronounced that in (2), we are going to choose a high mode value, but not extreme. For instance, `mode = 7`

# Which species have the worst assignment?

In [63]:
species_to_kingdom = {
    # Bacteria
    "Cutibacterium acnes": "Bacteria",
    "Lactobacillus acidophilus": "Bacteria",
    "Bifidobacterium bifidum": "Bacteria",
    "Akkermansia muciniphila": "Bacteria",
    "Blautia coccoides": "Bacteria",
    "Blautia luti": "Bacteria",
    "Bacteroides ovatus": "Bacteria",
    "Bacteroides intestinalis": "Bacteria",
    "Bacteroides fragilis": "Bacteria",
    "Escherichia coli": "Bacteria",
    "Dietzia lutea": "Bacteria",
    "Ruthenibacterium lactatiformans": "Bacteria",
    "Faecalibacterium prausnitzii": "Bacteria",
    "Parabacteroides distasonis": "Bacteria",
    "Parabacteroides merdae": "Bacteria",
    "Fusicatenibacter saccharivorans": "Bacteria",
    "Erysipelatoclostridium ramosum": "Bacteria",
    "Streptococcus salivarius": "Bacteria",
    "Hungatella hathewayi": "Bacteria",
    "Eisenbergiella porci": "Bacteria",
    "Butyricimonas faecalis": "Bacteria",
    "Alistipes indistinctus": "Bacteria",
    "Alistipes finegoldii": "Bacteria",
    "Eubacterium callanderi": "Bacteria",
    "Acidaminococcus intestini": "Bacteria",
    
    # Fungi
    "Aspergillus chevalieri": "Fungi",
    "Aspergillus flavus": "Fungi",
    "Saccharomyces cerevisiae": "Fungi",
    "Saccharomyces kudriavzevii": "Fungi",
    "Saccharomyces mikatae": "Fungi",
    "Candida albicans": "Fungi",
    "Candida dubliniensis": "Fungi",
    "Candida orthopsilosis": "Fungi",
    "Malassezia restricta": "Fungi",
    "Alternaria dauci": "Fungi",
    "Kazachstania africana": "Fungi",
    "Penicillium digitatum": "Fungi",
    "Pichia kudriavzevii": "Fungi",
    "Trichoderma asperellum": "Fungi",
    "Akanthomyces muscarius": "Fungi",
    "Fusarium falciforme": "Fungi",
    "Eremothecium sinecaudum": "Fungi",
    "Cryptococcus decagattii": "Fungi",
    "Kwoniella shandongensis": "Fungi",
    "Puccinia triticina": "Fungi",
    
    # Virus
    "Tobacco mosaic virus": "Virus",
    "Rotavirus A": "Virus",
    "Rotavirus B": "Virus",
    "Rotavirus C": "Virus",
    "Bacteriophage P2": "Virus",
    "Escherichia phage T4": "Virus",
    "Human immunodeficiency virus 1": "Virus",
    "Human adenovirus 7": "Virus",
    "Hepatitis C virus": "Virus",
    "Bovine alphaherpesvirus 2": "Virus",
    "Human herpesvirus 4 type 2": "Virus",
    "Mimivirus terra2": "Virus",
    "Dengue virus": "Virus",
    "Norovirus GI": "Virus",
    "Zaire ebolavirus": "Virus"
}


In [64]:
mode = 5

In [65]:
df_combo = df_numerical_stats[(df_numerical_stats['pass'] == 2) & \
                              (df_numerical_stats['mode'] == mode)].set_index('profiler')

In [None]:
df_combo

In [67]:
# RUN FOR MEAN IN MODE 1 3 5 7 9
# RUN FOR MODE 5 AND THE REST OF PROFILERS

In [None]:
prof = 'mean'

df_diffab = pd.DataFrame({'species': df_combo.loc[prof, 'species_counts'],  
                          'expected_counts': df_combo.loc[prof, 'expected_counts'], \
                          'observed_counts': df_combo.loc[prof, 'observed_counts'], \
                          'diff_abundance': df_combo.loc[prof, 'diff_counts']})
df_diffab['kingdom'] = [species_to_kingdom[i] for i in df_diffab['species']]
df_diffab['isfamily'] = False

family_species = ['Blautia coccoides', 'Blautia luti', 'Bacteroides ovatus', 'Bacteroides intestinalis','Bacteroides fragilis', 
                  'Parabacteroides distasonis', 'Parabacteroides merdae', 'Alistipes indistinctus', 'Alistipes finegoldii', 'Aspergillus chevalieri', 
                  'Aspergillus flavus', 'Saccharomyces cerevisiae', 'Saccharomyces kudriavzevii', 'Saccharomyces mikatae', 'Candida albicans', 'Candida dubliniensis', 
                  'Candida orthopsilosis', 'Rotavirus A', 'Rotavirus B', 'Rotavirus C']
df_diffab.loc[df_diffab['species'].isin(family_species), 'isfamily'] = True
df_diffab.sort_values(by=['expected_counts', 'diff_abundance'], ascending=False)

In [None]:
# Create a grid of three axes for the plots
fig, axs = plt.subplots(1, 3, figsize=(9, 3.5), sharey=True)

pairs=[(True, False)]


ax1 = sns.boxplot(df_diffab, x='isfamily', y='diff_abundance', ax=axs[0], palette=sns.color_palette(['#4E79A7', '#A0CBE8']))
medians = df_diffab.groupby("isfamily")['diff_abundance'].median()
for i, median in enumerate(medians):
        ax1.text(i, median - 8, f"{median:.2f}", ha='center', va='bottom', fontsize=8)

axs[0].set_ylabel("PRE", fontsize=12)
annotator = Annotator(ax1, pairs, data=df_diffab, x='isfamily', y='diff_abundance')
annotator.configure(test='Mann-Whitney', text_format='simple', loc='outside', line_width=1)
annotator.apply_and_annotate()
axs[0].set_xlabel("Species belongs to\na shared genus", fontsize=12)



ax2 = sns.boxplot(df_diffab, x='kingdom', y='diff_abundance', ax=axs[1], palette=sns.color_palette(['#4E79A7', '#77a2c8', '#A0CBE8']))
pairs=[("Virus", 'Bacteria'), ('Fungi', 'Bacteria'), ('Fungi', 'Virus')]
annotator = Annotator(ax2, pairs, data=df_diffab, x='kingdom', y='diff_abundance')
annotator.configure(test='Mann-Whitney', text_format='simple', loc='outside', line_width=1)
annotator.apply_and_annotate()
medians = df_diffab.groupby("kingdom")['diff_abundance'].median()
for i, median in enumerate(medians):
        ax2.text(i, median - 8, f"{median:.2f}", ha='center', va='bottom', fontsize=8)




df_bac = df_diffab[df_diffab['kingdom'] == 'Bacteria']

ax3 = sns.boxplot(df_bac, x='isfamily', y='diff_abundance', ax=axs[2], palette=sns.color_palette(['#4E79A7', '#A0CBE8']))
pairs=[(True, False)]
annotator = Annotator(ax3, pairs, data=df_bac, x='isfamily', y='diff_abundance')
annotator.configure(test='Mann-Whitney', text_format='simple', loc='outside', line_width=1)
annotator.apply_and_annotate()
axs[2].set_xlabel("Bacterial species belongs\nto a shared genus", fontsize=12)
medians = df_bac.groupby("isfamily")['diff_abundance'].median()
for i, median in enumerate(medians):
        ax3.text(i, median - 8, f"{median:.2f}", ha='center', va='bottom', fontsize=8)


for ax in axs:
    sns.despine(top=True, right=True, ax=ax)


plt.tight_layout()

for format in ['png', 'tiff']: 
    plt.savefig(f'{RESULTS_DIR}/figures/paper/figI-{mode}-{prof}.{format}', dpi=DPI, bbox_inches='tight')

plt.show()
