In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
from datetime import datetime


In [None]:
pd.options.display.float_format = '{:.3f}'.format

In [None]:
os.makedirs('../results_diversity', exist_ok=True)

today = datetime.today().strftime('%Y-%m-%d')
os.makedirs(f'../results_diversity/{today}', exist_ok=True)

In [None]:
# MINOR PROCESSING FUNCTIONS
def process_df(df):
    lineage_vals = df.lineage.values
    new_lineage_vals = []
    for lineage in lineage_vals:
        val = ';'.join(lineage.split(';')[::-1]).replace('root;', '').replace('cellular organisms;', '')
        new_lineage_vals.append(val)
    df.lineage = new_lineage_vals
    return df

def get_FASTQ_len(POOL_list):
    dict_FASTQ_len = {}
    
    for POOL in POOL_list:
        n_counts_fastq = !zcat ../results_rnaseq/star_salmon/unmapped/{POOL}.unmapped_1.fastq.gz | wc -l
        dict_FASTQ_len[POOL] = int(n_counts_fastq[0])
    
    return dict_FASTQ_len


In [None]:
# MAJOR PROCESSING FUNCTIONS

def create_POOL_table(POOL, cols_POOL, cutoff_NA, cutoff_CV, cutoff_min_reads):
    index_df_POOL = [] # index of taxonomy members to create a merged table with all methods
    list_df_pools = []

    # Loading the tables. We (1) reverse the phylogenetic lineage to start by root and and by species, (2) sort index by taxonomy and (3) rename count column to the method
    for tax_method in cols_POOL:
        df_POOL_method = process_df(pd.read_csv(f'../results_profiling/{tax_method}/{POOL}.report.standardised', sep='\t', index_col='taxonomy_id')).rename(columns={'count': tax_method})
        list_df_pools.append(df_POOL_method)
        index_df_POOL += df_POOL_method.index.tolist()
    
    index_df_POOL = list(set(index_df_POOL))
    
    # Creating the table and merging count columns
    df_POOL = pd.DataFrame(index=index_df_POOL, columns= ['name', 'lineage'] + cols_POOL)

    for df_POOL_x, name_x in zip(list_df_pools, cols_POOL):
        df_POOL.loc[df_POOL_x.index, ['name', 'lineage', name_x]] = df_POOL_x.loc[df_POOL_x.index, ['name', 'lineage', name_x]]

    # Apply normalisation based on other pools - this step takes the number of reads of the FASTQs and corrects the counts, so that a a FASTQ with more total reads has fewer normalised counts
    cols_POOL_norm = [f'{i}_norm' for i in cols_POOL]
    mean_FASTQ_len = np.mean(np.array(list(dict_FASTQ_len.values())))
    correction_factor = mean_FASTQ_len / dict_FASTQ_len[POOL]

    for col_POOL in cols_POOL:
        df_POOL[f'{col_POOL}_norm'] = df_POOL[col_POOL] * correction_factor

    # Calculate simple stats
    df_POOL.loc[:, 'mean'] = np.mean(df_POOL.loc[:, cols_POOL_norm], axis=1).astype(float)
    df_POOL.loc[:, 'std'] = np.std(df_POOL.loc[:, cols_POOL_norm], axis=1)
    df_POOL.loc[:, 'CV'] = df_POOL.loc[:, 'std'] / df_POOL.loc[:, 'mean']

    # Calculate the number of reads in the fastq to obtain the relative abundance of the counts
    df_POOL.loc[:, 'mean (%)'] = 100 * df_POOL.loc[:, 'mean'] / mean_FASTQ_len


    df_POOL = df_POOL.sort_values(by='mean', ascending=False)


    
    # We use some quality metrics to flag and remove "bad quality" samples:
    #   cutoff_CV to flag species that have very variable counts across profilers
    #   cutoff_min_reads and cutoff_min_sum_reads to flag species that have a low count in one and in all profiling counts
    #   cutoff_NA to remove species that are only present in 2 of fewer samples
    cutoff_sum_reads = int(cutoff_min_reads * len(cols_POOL) * 0.7)
    df_POOL[['quality_CV', 'quality_min_reads', 'quality_sum_reads', 'quality_NA']] = 0

    df_POOL.loc[df_POOL.loc[:, 'CV'] > cutoff_CV, 'quality_CV'] += 1
    df_POOL.loc[df_POOL.loc[:, cols_POOL].min(1) < cutoff_min_reads, 'quality_min_reads'] +=  1
    df_POOL.loc[df_POOL.loc[:, cols_POOL].sum(1) < cutoff_sum_reads, 'quality_sum_reads'] += 1
    df_POOL.loc[df_POOL.loc[:, cols_POOL].isna().sum(1) > cutoff_NA, 'quality_NA'] += 1
    df_POOL['quality'] = df_POOL.loc[:, ['quality_CV', 'quality_min_reads', 'quality_sum_reads', 'quality_NA']].sum(1)
    df_POOL = df_POOL[~ np.isnan(df_POOL['CV'])]

    # We select species with 0 or 1 flag. We are restrictive in that sense to avoid flagging "low quality" species
    df_POOL_cutoff = df_POOL[df_POOL['quality'] < 2] 


    df_POOL.to_csv(f'../results_diversity/{today}/{POOL}.diversity_raw.tsv', sep='\t')
    df_POOL_cutoff .to_csv(f'../results_diversity/{today}/{POOL}.diversity_cutoff.tsv', sep='\t')

    return df_POOL, df_POOL_cutoff



In [None]:
# PLOTTING FUNCTIONS

def plot_allPOOL_correlations(POOL_list, cols_POOL, corr_method='spearman'):
    ncols = int(len(POOL_list) ** 0.5)
    nrows = int(len(POOL_list) // ncols) + int(len(POOL_list) % ncols != 0)

    

    list_mean_heatmaps = []

    for type_plot_idx, type_plot in enumerate(['raw', 'cutoff']):
        _, axs= plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))
        corr_mat_list = [] # this is to later sum all correlations and do a big plot
        for ax_int, POOL in enumerate(POOL_list):
            df_POOL = pd.read_csv(f'../results_diversity/{today}/{POOL}.diversity_{type_plot}.tsv', sep='\t')
            df_corr = np.log10(df_POOL.loc[:, cols_POOL].astype(float) + 1)
            corr_mat = df_corr.corr(method=corr_method)
            corr_mat_list.append(corr_mat)

            sns.heatmap(corr_mat, cmap='Blues', annot=True, ax=axs.ravel()[ax_int])
            axs.ravel()[ax_int].set_title(POOL)
        
        plt.suptitle(f'Correlation ({corr_method}, {type_plot})')
        plt.tight_layout()

        plt.savefig(f'../results_diversity/{today}/correlation_{corr_method}_{type_plot}.png', dpi=300)

        # Create the mean heatmap
        mean_ht = corr_mat_list[0]
        for corr_mat in corr_mat_list[1:]:
            mean_ht += corr_mat
        mean_ht /= len(corr_mat_list)

        list_mean_heatmaps.append(mean_ht)


    _, axs_all = plt.subplots(1, 2, figsize=(4 * 2, 4 * 1))

    for type_plot_idx, type_plot in enumerate(['raw', 'cutoff']):
        sns.heatmap(list_mean_heatmaps[type_plot_idx], cmap='Blues', annot=True, ax=axs_all.ravel()[type_plot_idx])
        axs_all.ravel()[type_plot_idx].set_title(type_plot)

    plt.tight_layout()
    plt.savefig(f'../results_diversity/{today}/correlation_{corr_method}_mean.png', dpi=300)

    


In [None]:
# GENERAL VARIABLES
POOL_list = [f'POOL{i}' for i in range(1, 13)]
dict_FASTQ_len = get_FASTQ_len(POOL_list)
dict_FASTQ_len

cols_POOL = ['kaiju', 'kraken_2', 'krakenuniq', 'centrifuge']

In [None]:
for pool in POOL_list:
    _, _ = create_POOL_table(pool, cols_POOL, cutoff_NA=1, cutoff_CV=1.0, cutoff_min_reads=100)


In [None]:
plot_allPOOL_correlations(POOL_list, cols_POOL, corr_method='spearman')

# FALLABA AQUI Y NO SE POR QUE