# Dataset processing

In this notebook we are going to process each of the taxpasta files using the proposed methodology:
1) Raw read counts is normalized against the rest of samples. In this scenario we can do two normalizations:
    - Merged normalization of POOL and CONTROL samples (`NORM+`). 
    - Individual normalization of POOL samples on one hand and CONTROL samples on the other (`NORM|`).

2) Select species based on the flagging system. The following criteria are used for flagging:
    - Number of profilers where the species is detected.
    - Minimum number of reads across profilers.
    - Mean number of reads across profilers.
    - CV across profilers.

    Based on this criteria, we are going to apply a cutoff on the number of flags.

The files will be saved as `SAMPLE_PASS{}_MODE{}_NORM{}.summary`


In [None]:
!pip install kneed

In [None]:
from datetime import datetime
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
from kneed import KneeLocator


In [None]:
pd.options.display.float_format = '{:.3f}'.format
pd.set_option('display.max_columns', None)

In [None]:
from list_vars import LIST_PROFILERS, DIR_FIGURES, DIR_SUMMARY_OUTPUT, DIR_PROFILING

## Variable setting

In [None]:
os.makedirs(f'{DIR_SUMMARY_OUTPUT}', exist_ok=True)
os.makedirs(f'{DIR_FIGURES}', exist_ok=True)

## Functions

In [None]:
# MAD (Median Absolute Deviation)
# MAD = median(|X - median(X)|)
def mad(series):
    med = np.median(series)
    return np.median(np.abs(series - med))

def process_df(df):
    # Reverses the phylogenetic order of the taxa and removes "root" and "cellular organisms" labels
    lineage_vals = df.lineage.values
    new_lineage_vals = []
    for lineage in lineage_vals:
        val = ';'.join(lineage.split(';')[::-1]).replace('root;', '').replace('cellular organisms;', '')
        new_lineage_vals.append(val)
    df.lineage = new_lineage_vals
    return df


def load_and_process_tables(sample, profilers, taxonomic_level, mode, passn, remove_human, verbose=False):
    """
    Load and process profiling tables for the given sample and profilers.
    
    Args:
        sample (str): The sample name.
        profilers (list): List of profiler names.
        taxonomic_level (str): Either 'species' or 'genus'.
        mode (int): The processing mode to identify file paths.
    
    Returns:
        list: A list of processed DataFrames for each profiler.
    """

    list_df_samples = []
    list_available_profilers = []

    for profiler in profilers:
        # Construct file pattern and find the file
        file_pattern = f'{DIR_PROFILING}/{profiler}/pass{passn}/{sample}_mode{mode}/{sample}_mode{mode}.*.standardised.{taxonomic_level}'
        matching_files = glob.glob(file_pattern)
        
        if not matching_files:
            print(f"No matching file found for {profiler} with pattern {file_pattern}")
            continue

        file_path = matching_files[0]
        if verbose:
            print(f"Loading file: {file_path}")
        df_sample_method = process_df(pd.read_csv(file_path, sep='\t'))

        df_sample_method['taxonomy_id'] = df_sample_method['taxonomy_id'].astype(int)
        df_sample_method.rename(columns={'count': profiler}, inplace=True)
        
        if remove_human:
            df_sample_method = df_sample_method[df_sample_method['taxonomy_id'] != 9606]

        list_df_samples.append(df_sample_method)
        list_available_profilers.append(profiler)


    return list_df_samples, list_available_profilers


def merge_tables(list_df_samples, profilers, verbose=False):
    """
    Merge profiling DataFrames into a single table.

    Args:
        list_df_samples (list): List of profiling DataFrames.
        profilers (list): List of profiler names.
    
    Returns:
        pd.DataFrame: Merged DataFrame containing all profiling data.
    """
    # Initialize an empty DataFrame
    merged_df = None

    for df_sample, profiler in zip(list_df_samples, profilers):
        if merged_df is None:
            # Start with the first DataFrame
            merged_df = df_sample[['name', 'lineage', 'taxonomy_id']].copy()
            merged_df[profiler] = df_sample[profiler]
        else:
            # Merge subsequent DataFrames on the index
            merged_df = merged_df.merge(
                df_sample[['name', 'lineage', 'taxonomy_id', profiler]],
                how='outer',
            )
    if verbose:
        print("Merged table dimensions:", merged_df.shape)
    return merged_df


def normalize_counts(df_sample, profilers, sample, dict_FASTQ_len, verbose):
    median_FASTQ_len = np.median(np.array(list(dict_FASTQ_len.values())))
    correction_factor = median_FASTQ_len / dict_FASTQ_len[sample]

    if verbose:
        print(f'The correction factor is {correction_factor}')

    # Apply normalization
    for col_profiler in profilers:
        df_sample[f'{col_profiler}_norm'] = df_sample[col_profiler] * correction_factor

    # Print the total normalized reads per profiler if verbose
    if verbose:
        for col_profiler in profilers:
            total_norm = df_sample[f'{col_profiler}_norm'].sum()
            print(f'{col_profiler}: Total normalized reads = {total_norm}')

    # Calculate relative abundance as a percentage of each profiler's total normalized reads
    for col_profiler in profilers:
        total_norm = df_sample[f'{col_profiler}_norm'].sum()
        # Avoid division by zero if no reads present
        if total_norm > 0:
            df_sample[f'{col_profiler}_relab'] = (df_sample[f'{col_profiler}_norm'] / total_norm) * 100
        else:
            df_sample[f'{col_profiler}_relab'] = 0.0

    return df_sample



def calculate_stats(df_sample, profilers, verbose=False):
    # Identify columns
    cols_sample_norm = [f'{i}_norm' for i in profilers]
    cols_sample_relab = [f'{i}_relab' for i in profilers]

    df_sample['n_profilers'] = (df_sample[profilers].fillna(0) > 0).sum(axis=1)
    
    # Calculate median and MAD for raw counts, normalized counts, and relative abundance
    # Median
    df_sample['median_raw'] = df_sample[profilers].median(axis=1)
    df_sample['MAD_raw'] = df_sample[profilers].apply(lambda x: mad(x), axis=1)
    df_sample['CV_raw_median_MAD'] = df_sample.apply(
        lambda row: row['MAD_raw'] / row['median_raw'] if row['MAD_raw'] != 0 else np.nan,
        axis=1
    )

    df_sample['median_norm'] = df_sample[cols_sample_norm].median(axis=1)
    df_sample['MAD_norm'] = df_sample[cols_sample_norm].apply(lambda x: mad(x), axis=1)
    df_sample['CV_norm_median_MAD'] = df_sample.apply(
        lambda row: row['MAD_norm'] / row['median_norm'] if row['median_norm'] != 0 else np.nan,
        axis=1
    )

    df_sample['median_relab'] = df_sample[cols_sample_relab].median(axis=1)
    df_sample['MAD_relab'] = df_sample[cols_sample_relab].apply(lambda x: mad(x), axis=1)
    df_sample['CV_relab_median_MAD'] = df_sample.apply(
        lambda row: row['MAD_relab'] / row['median_relab'] if row['median_relab'] != 0 else np.nan,
        axis=1
    )


    # # Calculate mean FASTQ length for original percentage calculation
    # mean_FASTQ_len = np.median(np.array(list(dict_FASTQ_len.values())))
    # df_sample['mean (%)'] = 100 * df_sample['mean'] / mean_FASTQ_len

    # Sort by median of relative abundance (descending)
    df_sample = df_sample.sort_values(by='median_relab', ascending=False)

    # if verbose:
    #     print("Calculated additional statistics (median, MAD, CV based on median/MAD).")
    #     print("Sorted species by median relative abundance.")

    return df_sample



def apply_flagging_system(df_sample, sample, mode, passn, verbose, norm_string, taxonomic_level, S=1, save_knee_plots=True):
    # We assume the first three columns are metadata (e.g. name, lineage, taxonomy_id)
    columns_flagging = df_sample.columns[3:]  # exclude name, lineage, and taxonomy
    df_flags = df_sample.copy()
    df_flags.iloc[:, 3:] = 0  # initialize flag columns to 0

    # Create a directory for knee plots if saving is enabled
    if save_knee_plots:
        plot_dir = f"{DIR_FIGURES}/knee_plot/{sample}"
        os.makedirs(plot_dir, exist_ok=True)

    for col in columns_flagging:
        # Extract values and sort them in descending order to form a curve
        col_values = df_sample[col].dropna().values
        if len(col_values) == 0:
            # If no values, just continue
            continue

        # Sort values in descending order
        col_values_sorted = np.sort(col_values)[::-1]
        x = np.arange(len(col_values_sorted))

        # Apply KneeLocator
        kneedle = KneeLocator(x, col_values_sorted, curve='convex', direction='decreasing', S=S)

        if kneedle.knee_y is not None:
            knee_y = kneedle.knee_y
        else:
            # If no knee found, fallback to median
            knee_y = np.median(col_values_sorted)

        # Assign flags: 1 if value < knee_y, otherwise 0
        # We compare the original (unsorted) values from df_sample
        df_flags.loc[df_sample[col] >= knee_y, col] = 1
        df_flags.loc[df_sample[col] < knee_y, col] = 0

        # For CV_related 

        if save_knee_plots:
            # Plot the curve and knee point
            plt.figure()
            plt.plot(x, col_values_sorted, label=col)
            if kneedle.knee_y is not None:
                plt.scatter(kneedle.knee, kneedle.knee_y, color='red', zorder=3, label='Knee Point')
                plt.annotate(
                    f"Knee: ({kneedle.knee}, {kneedle.knee_y:.2f})", 
                    xy=(kneedle.knee, kneedle.knee_y),
                    xytext=(kneedle.knee + len(col_values_sorted)*0.1, kneedle.knee_y),
                    arrowprops=dict(arrowstyle='->', color='black')
                )
            
            plt.title(f'Knee Detection for {col}')
            plt.xlabel('Index')
            plt.ylabel('Value')
            plt.grid(True)
            plt.legend()

            # Save the plot
            plt.savefig(f'{plot_dir}/{sample}_pass{passn}_mode{mode}_tax{taxonomic_level}_S{S}{norm_string}_{col}')
            plt.close()

        if verbose:
            print(f"Column {col}: Knee = {(kneedle.knee, kneedle.knee_y)}")

    df_flags[df_flags.columns[3:]] = df_flags[df_flags.columns[3:]].astype(bool)

    return df_flags


In [None]:
def create_sample_table(sample, profilers, taxonomic_level, mode, passn, dict_FASTQ_len, S, norm_string, verbose=False, remove_human=True):
    # Step 1: Load and process tables
    if verbose:
        print(f">>> Loading and processing tables for sample: {sample}, taxonomic level: {taxonomic_level}")
    list_df_samples, available_profilers = load_and_process_tables(sample, profilers, taxonomic_level, mode, passn, remove_human, verbose)

    if verbose:
        print(f"the list of available profilers is {available_profilers}")
    # Step 2: Merge tables
    if verbose:
        print(">>> Merging tables from all profilers")
    df_sample = merge_tables(list_df_samples, available_profilers, verbose)
    
    # Step 3: Normalize counts
    if verbose:
        print(">>> Normalizing counts")
    df_sample = normalize_counts(df_sample, available_profilers, sample, dict_FASTQ_len, verbose)
    
    # Step 4: Calculate statistics
    if verbose:
        print(">>> Calculating statistics")
    df_sample = calculate_stats(df_sample, available_profilers, verbose)
    
    # Step 5: Apply flagging system
    if verbose:
        print(">>> Applying flagging system")
    df_flags = apply_flagging_system(df_sample, sample, mode, passn, verbose, norm_string, taxonomic_level, S, save_knee_plots=True)

    
    # Save results
    if verbose:
        print(f">>> Saving results for sample: {sample}")
    df_sample.to_csv(f'{DIR_SUMMARY_OUTPUT}/{sample}_pass{passn}_mode{mode}_tax{taxonomic_level}_S{S}{norm_string}.diversity.tsv', sep='\t')
    df_flags.to_csv(f'{DIR_SUMMARY_OUTPUT}/{sample}_pass{passn}_mode{mode}_tax{taxonomic_level}_S{S}{norm_string}.flags.tsv', sep='\t')

    return df_sample, df_flags

In [None]:
for mode in range(1, 10):
    for passn in [0, 2]:
        df_sample, df_flags = create_sample_table(
        sample='ARTIFICIAL',
        profilers=LIST_PROFILERS,
        taxonomic_level='species',
        mode=mode,
        passn=passn,
        dict_FASTQ_len={'ARTIFICIAL': 50000000},
        norm_string='', # type of normalization (0: no normalization, +: normalize using POOLS + CONTROLS, -: normalize using POOLS or CONTROLS)
        S=5,
        verbose=False,
        remove_human=True
)