In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime

In [None]:
pd.options.display.float_format = '{:.3f}'.format

In [None]:
RESULTS_DIR = '../../results'
LIST_PROFILERS = ['centrifuge', 'ganon', 'kaiju', 'kmcp', 'kraken2', 'krakenuniq']

# In silico sample analysis

In this notebook we are going to do an analysis on the *in silico* samples, where we are going to study several variables.
First, we are going to study how doing host mapping affects the taxa profiling mapping quality, both to human and nun-human species. 

Additionally, we are also going to study the effect of running 
1) The modes of the profiles: we are going to check how each profiler works in each mode, being more or less sensitive.
2) The effects of the flag system on the improvement of the performance stats.
3) The performance using genus or species

To evaluate the performance of the profilers we are going to use several measures:
1) sensitivity, especificity F1-score of each mode for each profiler (before and after the flagging)
2) AUC using the different modes for each profiler 

## How many reads are incorrectly mapped if we do not perfom a host mapping step?

It has been reported that not mapping to human databases before profiling increases the number of reads assigned to other organisms. 

In this case, we are going to do 3 checks with the *in silico* dataset using pass2 (profiling after 2-time host mapping) and pass0 (direct profiling withou host mapping), and we are going to check the influence in parameter sensitivity:
1) We are going to see what is the total number of reads mapped to the human dataset, and what is the offset left unmapped which should have been mapped to human.
    - We are also going to do the same with the microbial reads, and see if more microbial reads have been assigned to the pass0 dataset.
2) We are going to see the number of species present in total between pass0 and pass2, and their jaccard index.
3) We are going to calculate the ratio between the number of reads in pass0 and pass2.

In [None]:
df_host_map_info = pd.read_csv(f'{RESULTS_DIR}/counts/mapping_counts.txt', sep='\t').set_index('SAMPLE')

In [None]:
artificial_taxid_counts = pd.read_csv('table_artificial_taxid.csv', sep=';', names=['species', 'taxid', 'reads'])
artificial_taxid_counts['reads_true'] = (artificial_taxid_counts['reads'] / 2).astype(int)

n_true_human_reads = int(artificial_taxid_counts['reads_true'].iloc[0])
n_true_human_reads

In [None]:
n_mapped_reads_1and2_maps = df_host_map_info.loc['ARTIFICIAL', '1st_mapped'] + df_host_map_info.loc['ARTIFICIAL', '2nd_mapped']

print(f'There is a total of {n_mapped_reads_1and2_maps} reads mapped to human during the 1st and 2nd map, which represents around {100 * n_mapped_reads_1and2_maps/n_true_human_reads} %.')
print(f'There is a total of {n_true_human_reads - n_mapped_reads_1and2_maps} reads remaining to be mapped.')

In [None]:
df_host_profile_info = pd.read_csv(f'{RESULTS_DIR}/counts/profiling_counts.txt', sep='\t')
df_host_profile_info_artificial = df_host_profile_info[df_host_profile_info['SAMPLE'] == 'ARTIFICIAL']

df_host_profile_info_artificial['mapped_human_1_2_maps'] = 0
df_host_profile_info_artificial.loc[df_host_profile_info_artificial['pass'] == 2, 'mapped_human_1_2_maps'] = n_mapped_reads_1and2_maps

df_host_profile_info_artificial['mapped_human_total'] = df_host_profile_info_artificial['mapped_human_1_2_maps'] + df_host_profile_info_artificial['mapped_human']
df_host_profile_info_artificial['total_reads'] = df_host_profile_info_artificial['mapped_human_total'] + df_host_profile_info_artificial['mapped_others'] + df_host_profile_info_artificial['unmapped']

df_host_profile_info_artificial['observed_human_prop'] = df_host_profile_info_artificial['mapped_human_total'] / df_host_profile_info_artificial['total_reads']
df_host_profile_info_artificial['observed_others_prop'] = df_host_profile_info_artificial['mapped_others'] / df_host_profile_info_artificial['total_reads']
df_host_profile_info_artificial['observed_unmapped_prop'] = df_host_profile_info_artificial['unmapped'] / df_host_profile_info_artificial['total_reads']

df_host_profile_info_artificial['expected_human_prop'] = n_true_human_reads / artificial_taxid_counts['reads_true'].sum() # 0.8
df_host_profile_info_artificial['calculated_unmapped_human_prop'] = df_host_profile_info_artificial['expected_human_prop'] - df_host_profile_info_artificial['observed_human_prop']
df_host_profile_info_artificial['calculated_unmapped_others_prop'] = df_host_profile_info_artificial['observed_unmapped_prop'] - df_host_profile_info_artificial['calculated_unmapped_human_prop']

df_host_profile_info_artificial[df_host_profile_info_artificial['profiler'] == 'centrifuge']


In [None]:
# 1A) Check if there are differences in human read assignment.

# Step 1: Calculate the differences between pass 2 and pass 0 for each profiler and mode

pass_diff = (
    df_host_profile_info_artificial.pivot_table(
        index=["profiler", "mode"], columns="pass", values="observed_human_prop"
    )
    .reset_index()
)

pass_diff["difference"] = 100 * (pass_diff[2] - pass_diff[0])



# Step 2: Plot the differences using a lineplot

plt.figure(figsize=(6, 4))
sns.lineplot(
    data=pass_diff,
    x="mode",
    y="difference",
    hue="profiler",
    marker="o",
    palette="tab10",
)

plt.axhline(0, color="gray", linestyle="--", linewidth=0.8)
plt.title("Difference in Observed Human Proportion (Pass 2 - Pass 0)", fontsize=14)
plt.xlabel("Mode", fontsize=12)
plt.ylabel("Diff (%)", fontsize=12)

plt.grid(alpha=0.3)

plt.legend(title="Profiler", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
plt.tight_layout()

plt.show()

In [None]:
# 1B) Check if there are differences in human read assignment.

pass_diff = (
    df_host_profile_info_artificial.pivot_table(
        index=["profiler", "mode"], columns="pass", values="observed_others_prop"
    )
    .reset_index()
)

pass_diff["difference"] = 100 * (pass_diff[2] - pass_diff[0])



# Step 2: Plot the differences using a lineplot

plt.figure(figsize=(6, 4))
sns.lineplot(
    data=pass_diff,
    x="mode",
    y="difference",
    hue="profiler",
    marker="o",
    palette="tab10",
)

plt.axhline(0, color="gray", linestyle="--", linewidth=0.8)
plt.title("Difference in Observed Human Proportion (Pass 2 - Pass 0)", fontsize=14)
plt.xlabel("Mode", fontsize=12)
plt.ylabel("Diff (%)", fontsize=12)

plt.grid(alpha=0.3)

plt.legend(title="Profiler", bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
plt.tight_layout()

plt.show()

In [None]:
# 2 Plot Jaccard index between the different species

"""
For this part we are going to read all the .standardised.species reports and simply read the number of species and compute a table with the total number of species and the jaccard index
"""

def get_n_species(sample, mode, profiler):
    if profiler in ['kaiju']:
        prefix = 'results'
    elif profiler in ['krakenuniq', 'kraken2', 'ganon']:
        prefix = 'report'
    elif profiler in ['centrifuge']:
        prefix = 'kreport'
    elif profiler in ['kmcp']:
        prefix = 'profile'


    pass0_df = pd.read_csv(f'{RESULTS_DIR}/profiling/{profiler}/pass0/{sample}_mode{mode}/{sample}_mode{mode}.{prefix}.standardised.species', sep='\t')
    pass2_df = pd.read_csv(f'{RESULTS_DIR}/profiling/{profiler}/pass2/{sample}_mode{mode}/{sample}_mode{mode}.{prefix}.standardised.species', sep='\t')

    pass0_n_species = len(pass0_df)
    pass2_n_species = len(pass2_df)
    jaccard_index = len(np.intersect1d(pass0_df['taxonomy_id'].values, pass2_df['taxonomy_id'].values)) / len(np.union1d(pass0_df['taxonomy_id'].values, pass2_df['taxonomy_id'].values))

    return pass0_n_species, pass2_n_species, jaccard_index


dict_n_species = {'profiler': [],
                  'mode': [],
                  'pass 0 species': [],
                  'pass 2 species': [],
                  'jaccard': []}

for profiler in LIST_PROFILERS:
    for mode in range(1,10):
        try:
            pass0_n_species, pass2_n_species, jaccard_index = get_n_species('ARTIFICIAL', mode, profiler)
            dict_n_species['profiler'].append(profiler)
            dict_n_species['mode'].append(mode)
            dict_n_species['pass 0 species'].append(pass0_n_species)
            dict_n_species['pass 2 species'].append(pass2_n_species)
            dict_n_species['jaccard'].append(jaccard_index)
        except:
            print(f'No entry added for profiler {profiler} and mode {mode}')

df_n_species = pd.DataFrame(dict_n_species)


In [None]:
# Initialize a 3x2 grid for subplots
fig, axes = plt.subplots(2, 3, figsize=(18, 10))
axes = axes.flatten()

# Set Seaborn style
sns.set_theme(style="white")

# Define custom labels for the shared legend
legend_elements = [
    plt.Line2D([0], [0], marker='o', color='blue', label='Pass 0 Species', markersize=8, linestyle='None'),
    plt.Line2D([0], [0], marker='o', color='green', label='Pass 2 Species', markersize=8, linestyle='None'),
    plt.Line2D([0], [0], marker='o', color='red', label='Jaccard Index', markersize=8, linestyle='-')
]

# Plot for each profiler
for i, profiler in enumerate(LIST_PROFILERS):
    df_profiler = df_n_species[df_n_species['profiler'] == profiler]
    ax1 = axes[i]
    
    # Scatter plots for `pass 0 species` and `pass 2 species`
    sns.scatterplot(
        data=df_profiler, x='mode', y='pass 0 species', ax=ax1, color='blue', s=50
    )
    sns.scatterplot(
        data=df_profiler, x='mode', y='pass 2 species', ax=ax1, color='green', s=50
    )
    
    sns.set_theme(style="white")
    
    # Configure the primary y-axis
    ax1.set_xlabel('Mode')
    ax1.set_ylabel('Species Count', color='black')
    #ax1.set_yscale('log')  # Optional: Log scale
    ax1.set_title(f'{profiler.capitalize()}')
    
    # Create a secondary y-axis for Jaccard index
    ax2 = ax1.twinx()
    sns.lineplot(
        data=df_profiler, x='mode', y='jaccard', ax=ax2, color='red', marker='o'
    )
    ax2.set_ylabel('Jaccard Index', color='red')

# Hide unused axes
for j in range(len(LIST_PROFILERS), len(axes)):
    axes[j].axis('off')

# Add a shared legend
fig.legend(
    handles=legend_elements, loc='center right', frameon=True, title="Legend", bbox_to_anchor=(1.15, 0.5)
)

plt.tight_layout()
plt.show()

In [None]:
# 3 Plot the ratio of abundances


def get_species_details(sample, mode, profiler):
    # Determine the prefix based on the profiler
    if profiler in ['kaiju']:
        prefix = 'results'
    elif profiler in ['krakenuniq', 'kraken2', 'ganon']:
        prefix = 'report'
    elif profiler in ['centrifuge']:
        prefix = 'kreport'
    elif profiler in ['kmcp']:
        prefix = 'profile'
    
    # Read pass0 and pass2 dataframes
    pass0_df = pd.read_csv(
        f'{RESULTS_DIR}/profiling/{profiler}/pass0/{sample}_mode{mode}/{sample}_mode{mode}.{prefix}.standardised.species', 
        sep='\t'
    )
    pass2_df = pd.read_csv(
        f'{RESULTS_DIR}/profiling/{profiler}/pass2/{sample}_mode{mode}/{sample}_mode{mode}.{prefix}.standardised.species', 
        sep='\t'
    )

    # Merge the data on taxonomy_id to align species
    merged_df = pd.merge(
        pass0_df[['name', 'count']].rename(columns={'count': 'pass0'}),
        pass2_df[['name', 'count']].rename(columns={'count': 'pass2'}),
        on='name',
        how='inner'
    )
    
    # Remove human
    merged_df = merged_df[merged_df['name'] != 'Homo sapiens']

    # Calculate ratio and mean
    merged_df['ratio'] = merged_df['pass0'] / merged_df['pass2']
    merged_df['mean'] = merged_df[['pass0', 'pass2']].mean(axis=1)
    
    # Add profiler and mode columns for context
    merged_df['profiler'] = profiler
    merged_df['mode'] = mode
    
    return merged_df[['profiler', 'mode', 'name', 'pass0', 'pass2', 'ratio', 'mean']]

# Initialize a list to hold all data
detailed_data = []

# Loop through each profiler and mode
for profiler in LIST_PROFILERS:
    for mode in range(1, 10):
        try:
            # Get the detailed DataFrame
            df_details = get_species_details('ARTIFICIAL', mode, profiler)
            detailed_data.append(df_details)
        except Exception as e:
            print(f'No entry for profiler {profiler} and mode {mode}: {e}')

# Combine all data into a single DataFrame
df_detailed = pd.concat(detailed_data, ignore_index=True)

In [None]:
# Initialize a 6x3 grid for subplots (6 rows, 3 columns)
fig, axes = plt.subplots(6, 3, figsize=(9, 18))
axes = axes.flatten()

# Set Seaborn style
sns.set_theme(style="white")

list_modes = [1, 5, 9]

# Plot for each profiler
for i, profiler in enumerate(LIST_PROFILERS):
    # Filter data for the current profiler and the specified modes 
    df_profiler = df_detailed[(df_detailed['profiler'] == profiler) & (df_detailed['mode'].isin(list_modes))]
    
    # Loop through the modes 
    for j, mode in enumerate(list_modes):
        ax = axes[i * 3 + j]  # Determine the correct axis based on the row and column

        # Filter data for the current mode
        df_mode = df_profiler[df_profiler['mode'] == mode]
        df_mode['logmean'] = np.log10(df_mode['mean'])
        df_mode['logratio'] = np.log10(df_mode['ratio'])
        
        try:
            # Scatter plot for the current mode
            sns.scatterplot(
                data=df_mode, x='logmean', y='logratio', ax=ax, s=50, palette='deep', edgecolor='black'
            )
        except:
            pass
        
        # Configure axis labels and title
        ax.set_xlabel('log$_{10}$ mean')
        ax.set_ylabel('log$_{10}$ ratio')
        ax.set_title(f'{profiler.capitalize()} - Mode {mode}')

# Hide unused axes (if any)
for j in range(len(LIST_PROFILERS) * 3, len(axes)):
    axes[j].axis('off')

# Adjust layout and show the plot
plt.tight_layout()
plt.show()

In [None]:
for profiler in LIST_PROFILERS:
    for mode in (1, 5, 8):
        print(f'PROFILER {profiler} | MODE {mode}')
        df_sub = df_detailed[(df_detailed['profiler'] == profiler) & (df_detailed['mode'] == mode)].copy()
        val = np.log10(df_sub['ratio'].values) ** 2 * np.log10(df_sub['mean'].values)
        df_sub.loc[:, 'xoxo'] = val
        display(df_sub.sort_values(by='xoxo', ascending=False).head(10))