# Basic sample processing and analysis
In this notebook we will analyse the outputs from the profiles after being run by taxpasta. 

First we perform some preprocessing to the data, such as simplifying the naming of the taxon hierarchy, or normalising the number of counts. 
We normalise the counts by dividing the number of reads in the FASTQ by the mean number of reads across all FASTQs. 
This correction helps us remove the bias of FASTQs with more reads. 
We use this, and not the number of mapped reads because we want a profiler-independent normalisation value; 
and we mapped to the reads that have NOT been filtered (before nf-core/rnaseq and Bowtie2 mapping) because the relevant number of reads is the one provided directly by the sequencer.


Then, we will perform some QC cuts in the datasets. The cuts are in the following sense
* Species that have a high CV across different profilers are flagged.
* Species that have a low median number / sum of counts across profiles are flagged.
* Species that have only counts in one or two profilers are removed. If they only have a record in one profiler, they are directly removed.

With that, species that are flagged for 2 or more flags are removed, and the resulting table is created.

Lastly, correlation across profilers are computed. For that, the Pearson correlation using normalised counts is computed, and then the mean value across all samples is computed.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime


In [None]:
pd.options.display.float_format = '{:.3f}'.format

In [None]:
# VARS
dir_diversity_output = '../results_diversity'
dir_reads_fastq = '../data/EM_EVPools/control_sample/'
fastq_basename = 'POOL.fastq.gz'
dir_results_profiling = '../results_profiling/EM_EVPools'
pools_file="../data/EM_EVPools/samples_profiling.txt"

In [None]:
today = datetime.today().strftime('%Y-%m-%d')
os.makedirs(f'{dir_diversity_output}/{today}', exist_ok=True)

In [None]:
# GENERAL VARIABLES
POOL_list = !cat {pools_file}
POOL_list_control = ['ACIDOLA', 'BLACTIS']
profilers = ['kaiju', 'kraken_2', 'krakenuniq', 'centrifuge']

In [None]:
# MINOR PROCESSING FUNCTIONS
def process_df(df):
    # Reverses the philogenetic order of the taxa, and removes "root" and "cellular organisms" labels
    lineage_vals = df.lineage.values
    new_lineage_vals = []
    for lineage in lineage_vals:
        val = ';'.join(lineage.split(';')[::-1]).replace('root;', '').replace('cellular organisms;', '')
        new_lineage_vals.append(val)
    df.lineage = new_lineage_vals
    return df

def get_FASTQ_len(POOL_list):
    # Obtains a list of the size of the FASTQs before profiling.
    dict_FASTQ_len = {}
    
    for POOL in POOL_list:
        n_counts_fastq = !gzip -dc {dir_reads_fastq}/{fastq_basename.replace('POOL', POOL)} | wc -l
        dict_FASTQ_len[POOL] = int(n_counts_fastq[0])
    
    return dict_FASTQ_len


In [None]:
# MAJOR PROCESSING FUNCTIONS

def create_POOL_table(POOL, profilers, dict_FASTQ_len, cutoff_NA, cutoff_CV, cutoff_min_reads):
    index_df_POOL = [] # index of taxonomy members to create a merged table with all methods
    list_df_pools = []

    # Loading the tables. We (1) reverse the phylogenetic lineage to start by root and and by species, (2) sort index by taxonomy and (3) rename count column to the method
    for tax_method in profilers:
        df_POOL_method = process_df(pd.read_csv(f'{dir_results_profiling}/{tax_method}/{POOL}.report.standardised', sep='\t', index_col='taxonomy_id')).rename(columns={'count': tax_method})
        list_df_pools.append(df_POOL_method)
        index_df_POOL += df_POOL_method.index.tolist()
    
    index_df_POOL = list(set(index_df_POOL))
    
    # Creating the table and merging count columns
    df_POOL = pd.DataFrame(index=index_df_POOL, columns= ['name', 'lineage'] + profilers)

    for df_POOL_x, name_x in zip(list_df_pools, profilers):
        df_POOL.loc[df_POOL_x.index, ['name', 'lineage', name_x]] = df_POOL_x.loc[df_POOL_x.index, ['name', 'lineage', name_x]]

    # Apply normalisation based on other pools - this step takes the number of reads of the FASTQs and corrects the counts by dividing it by the mean number of counts, so that a a FASTQ with more total reads has fewer normalised counts
    cols_POOL_norm = [f'{i}_norm' for i in profilers]
    mean_FASTQ_len = np.mean(np.array(list(dict_FASTQ_len.values())))
    correction_factor = mean_FASTQ_len / dict_FASTQ_len[POOL]

    for col_POOL in profilers:
        df_POOL[f'{col_POOL}_norm'] = df_POOL[col_POOL] * correction_factor

    # Calculate simple stats
    df_POOL.loc[:, 'mean'] = np.mean(df_POOL.loc[:, cols_POOL_norm], axis=1).astype(float)
    df_POOL.loc[:, 'std'] = np.std(df_POOL.loc[:, cols_POOL_norm], axis=1)

    df_POOL = df_POOL[df_POOL['mean'] > 0]  # Filter step because in some cases it was 0 and CV y¡would yield NAN

    df_POOL.loc[:, 'CV'] = df_POOL.loc[:, 'std'] / df_POOL.loc[:, 'mean']


    # Calculate the number of reads in the fastq to obtain the relative abundance of the counts
    df_POOL.loc[:, 'mean (%)'] = 100 * df_POOL.loc[:, 'mean'] / mean_FASTQ_len


    df_POOL = df_POOL.sort_values(by='mean', ascending=False)


    
    # We use some quality metrics to flag and remove "bad quality" samples:
    #   cutoff_CV to flag species that have very variable counts across profilers
    #   cutoff_min_reads and cutoff_min_sum_reads to flag species that have a low count in one and in all profiling counts.
    #   cutoff_NA to remove species that are only present in 2 or fewer samples
    cutoff_median_reads = int(cutoff_min_reads * 2.5)
    df_POOL[['quality_CV', 'quality_min_reads', 'quality_sum_reads', 'quality_NA']] = 0

    df_POOL.loc[df_POOL.loc[:, 'CV'] > cutoff_CV, 'quality_CV'] += 1
    df_POOL.loc[df_POOL.loc[:, cols_POOL_norm].min(1, skipna=True) < cutoff_min_reads, 'quality_min_reads'] +=  1
    df_POOL.loc[df_POOL.loc[:, cols_POOL_norm].median(1, skipna=True) < cutoff_median_reads, 'quality_min_reads'] +=  1
    df_POOL.loc[df_POOL.loc[:, cols_POOL_norm].isna().sum(1) > cutoff_NA, 'quality_NA'] += 1
    df_POOL.loc[df_POOL.loc[:, cols_POOL_norm].isna().sum(1) == len(profilers) - 1, 'quality_NA'] += 1  # If only one profiler shows the information, we remove it
    df_POOL['quality'] = df_POOL.loc[:, ['quality_CV', 'quality_min_reads', 'quality_sum_reads', 'quality_NA']].sum(1)
    # df_POOL = df_POOL[~ np.isnan(df_POOL['CV'])]

    # We select species with 0 or 1 flag. We are restrictive in that sense to avoid flagging "low quality" species
    df_POOL_cutoff = df_POOL[df_POOL['quality'] < 2] 


    df_POOL.to_csv(f'../results_diversity/{today}/{POOL}.diversity_raw.tsv', sep='\t')
    df_POOL_cutoff .to_csv(f'../results_diversity/{today}/{POOL}.diversity_cutoff.tsv', sep='\t')

    return df_POOL, df_POOL_cutoff



In [None]:
# PLOTTING FUNCTIONS

def plot_allPOOL_correlations(POOL_list, profilers, corr_method='spearman'):
    ncols = int(len(POOL_list) ** 0.5)
    nrows = int(len(POOL_list) // ncols) + int(len(POOL_list) % ncols != 0)

    

    list_mean_heatmaps = []

    for type_plot_idx, type_plot in enumerate(['raw', 'cutoff']):
        _, axs= plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))
        corr_mat_list = [] # this is to later sum all correlations and do a big plot
        for ax_int, POOL in enumerate(POOL_list):
            df_POOL = pd.read_csv(f'../results_diversity/{today}/{POOL}.diversity_{type_plot}.tsv', sep='\t')
            df_corr = np.log10(df_POOL.loc[:, profilers].astype(float) + 1)
            corr_mat = df_corr.corr(method=corr_method)
            corr_mat_list.append(corr_mat)

            sns.heatmap(corr_mat, cmap='Blues', annot=True, ax=axs.ravel()[ax_int])
            axs.ravel()[ax_int].set_title(POOL)
        
        plt.suptitle(f'Correlation ({corr_method}, {type_plot})')
        plt.tight_layout()

        plt.savefig(f'../results_diversity/{today}/correlation_{corr_method}_{type_plot}.png', dpi=300)

        # Create the mean heatmap

        mean_ht = corr_mat_list[0]
        NaNs_mat = np.isnan(mean_ht).astype(int)

        for corr_mat in corr_mat_list[1:]:
            mean_ht = np.nansum(np.dstack((mean_ht,corr_mat)), 2) # This function is used to sum avoiding NaNs
            NaNs_mat += np.isnan(corr_mat).astype(int) # To calculate the mean properly, we need to divided buy the number of non NaN elements.
        mean_ht /= (len(corr_mat_list) - NaNs_mat)

        list_mean_heatmaps.append(mean_ht)

    _, axs_all = plt.subplots(1, 2, figsize=(4 * 2, 4 * 1))

    for type_plot_idx, type_plot in enumerate(['raw', 'cutoff']):
        sns.heatmap(list_mean_heatmaps[type_plot_idx], cmap='Blues', annot=True, ax=axs_all.ravel()[type_plot_idx])
        axs_all.ravel()[type_plot_idx].set_title(type_plot)

    plt.tight_layout()
    plt.savefig(f'../results_diversity/{today}/correlation_{corr_method}_mean.png', dpi=300)

    


In [None]:
dict_FASTQ_len = get_FASTQ_len(POOL_list)
dict_FASTQ_len

In [None]:
dict_FASTQ_len_ctrl = get_FASTQ_len(POOL_list_control)
dict_FASTQ_len_ctrl

In [None]:
dict_FASTQ_len = {'POOL1': 226194968,
 'POOL2': 168822704,
 'POOL3': 195169004,
 'POOL4': 188583804,
 'POOL5': 178776084,
 'POOL6': 222561592,
 'POOL7': 222791744,
 'POOL8': 191307976,
 'POOL9': 171410164,
 'POOL10': 220270844,
 'POOL11': 166626812,
 'POOL12': 161471504}

dict_FASTQ_len_ctrl = { 'BLACTIS': 8436468,
 'ACIDOLA': 9853568}

In [None]:
dfpool, dfpoolcut = create_POOL_table('ACIDOLA', profilers, dict_FASTQ_len_ctrl, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)
display(dfpool)
display(dfpoolcut)


In [None]:
dfpool, dfpoolcut = create_POOL_table('BLACTIS', profilers, dict_FASTQ_len_ctrl, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)
display(dfpool)
display(dfpoolcut)

In [None]:
dfpool, dfpoolcut = create_POOL_table('POOL1', profilers,  dict_FASTQ_len, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)
display(dfpool)
display(dfpoolcut)

In [None]:
dfpool, dfpoolcut = create_POOL_table('POOL3', profilers,  dict_FASTQ_len, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)
display(dfpool)
display(dfpoolcut)

In [None]:
dfpool, dfpoolcut = create_POOL_table('POOL6', profilers, dict_FASTQ_len, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)
display(dfpool)
display(dfpoolcut)

In [None]:
for pool in POOL_list:
    dfpool, dfpoolcut = create_POOL_table(pool, profilers, dict_FASTQ_len | dict_FASTQ_len_ctrl, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)
    display(pool)
    display(dfpoolcut[dfpoolcut['name'] == 'Saccharomycodes'])

for pool in ['ACIDOLA', 'BLACTIS']:
    dfpool, dfpoolcut = create_POOL_table(pool, profilers, dict_FASTQ_len | dict_FASTQ_len_ctrl, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)
    display(pool)
    display(dfpoolcut[dfpoolcut['name'] == 'Saccharomycodes'])

In [None]:
for pool in POOL_list:
    _, _ = create_POOL_table(pool, profilers, dict_FASTQ_len, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)

for pool in ['ACIDOLA', 'BLACTIS']:
    _, _ = create_POOL_table(pool, profilers, dict_FASTQ_len_ctrl, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)

In [None]:
plot_allPOOL_correlations(POOL_list, profilers, corr_method='spearman')