In [27]:
import mdsine2 as md2
from mdsine2.names import STRNAMES
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
import os
import pandas as pd
from pathlib import Path
import matplotlib.cm as cm
import seaborn as sns
from sklearn.metrics import mean_squared_error

md2.visualization.set_perturbation_color('gray')

In [199]:
# Define useful lookup variables
cols = ['#016bff',
        '#b91f1c',
        '#308937',
        '#ff8137',
        '#6f4fc7',
        '#d4b300',
        '#893a2b',
        '#ff6c79',
        '#16c4ff',
        '#766f41',
        '#00c800',
        '#af3261']

pretty_names_dir = {'B_caccae': '$B. caccae$',
                'B_cellulosilyticus_WH2': '$B. cellulosilyticus$',
                'B_ovatus': '$B. ovatus$',
                'B_thetaiotaomicron': '$B. thetaiotaomicron$',
                'B_uniformis': '$B. uniformis$',
                'B_vulgatus': '$B. vulgatus$',
                'C_aerofaciens': '$C. aerofaciens$',
                'C_scindens': '$C. scindens$',
                'C_spiroforme': '$C. spiroforme$',
                'D_longicatena': '$D. longicatena$',
                'P_distasonis': '$P. distasonis$',
                'R_obeum': '$R. obeum$'}

order = ['B_cellulosilyticus_WH2', 
        'B_caccae', 
        'B_vulgatus', 
        'B_thetaiotaomicron', 
        'B_ovatus', 
        'R_obeum', 
        'B_uniformis', 
        'P_distasonis', 
        'C_scindens', 
        'C_aerofaciens', 
        'C_spiroforme', 
        'D_longicatena']

taxa_color = {order: cols[i] for i, order in enumerate(order)}

lf0 = ['1', '2', '3', '4', '5', '6', '7']
hf0 = ['8', '9', '10', '11', '12', '13', '14', '15']

In [176]:
def rsme_fwrsim(study_path, dataset, fwsim_path, seed):
    if dataset == 'LF0':
        subjs = lf0
    elif dataset == 'HF0':
        subjs = hf0
    else:
        raise ValueError(f'Unknown dataset: {dataset}')
    study = md2.Study.load(f'{study_path}/{dataset}/mcnulty_{dataset}.pkl')
    true_abundances = [study[subj].matrix()['abs'] for subj in subjs]
    times = [study[subj].times for subj in subjs]
    taxa = [list(study[subj].taxa.names.keys()) for subj in subjs]
    fwsims = [np.mean(np.load(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/forward-simulate/Subject_{subj}/fwsim.npy'), axis=0) for subj in subjs]
    start = 0
    step = 0.01
    end = 40
    sim_times = np.arange(start, end+step+step, step)
    if len(true_abundances) != len(fwsims):
        raise ValueError('The number of true abundances and forward simulations are not the same')
    rmse_inds = {}
    for subj in range(len(true_abundances)):
        ext_indx_sim = [np.where(sim_times == obs_times)[0][0] for obs_times in times[subj]]
        rsme_taxa = {}
        for taxon in range(len(taxa[subj])):
            rsme_taxa[taxa[subj][taxon]] = np.sqrt(mean_squared_error(true_abundances[subj][taxon], fwsims[subj][taxon][ext_indx_sim]))
        rmse_inds[subjs[subj]] = rsme_taxa
    return rmse_inds

def rsme_fwrsim_seed(study_path, dataset, fwsim_path, seeds=[0, 3, 4, 23, 127], intra='sum', inter='sum', save=True):
    rsme_seeds = {}
    for seed in seeds:
        rmse_inds = rsme_fwrsim(study_path, dataset, fwsim_path, seed)
        rsmes_subj_agg = []
        if intra == 'sum':
            for subj in rmse_inds.keys():
                rsmes_subj_agg.append(np.sum(list(rmse_inds[subj].values())))
        elif intra == 'mean':
            for subj in rmse_inds.keys():
                rsmes_subj_agg.append(np.mean(list(rmse_inds[subj].values())))
        else:
            raise ValueError(f'Unknown intra: {intra}')
        if inter == 'sum':
            rmse_inds_agg = np.sum(rsmes_subj_agg)
        elif inter == 'mean':
            rmse_inds_agg = np.mean(rsmes_subj_agg)
        else:
            raise ValueError(f'Unknown inter: {inter}')
        rsme_seeds[seed] = rmse_inds_agg
    rsme_table = pd.DataFrame.from_dict(rsme_seeds, orient='index', columns=['RMSE'])
    rsme_table.index.name = 'Seed'
    if save:
        rsme_table.to_table(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/rsme_fwrsim.tsv', sep='\t')
    return rsme_table

In [None]:
def rsme_filtering(study_path, dataset, fwsim_path, seed):
    if dataset == 'LF0':
        subjs = lf0
    elif dataset == 'HF0':
        subjs = hf0
    else:
        raise ValueError(f'Unknown dataset: {dataset}')
    study = md2.Study.load(f'{study_path}/{dataset}/mcnulty_{dataset}.pkl')
    true_abundances = [study[subj].matrix()['abs'] for subj in subjs]
    times = [study[subj].times for subj in subjs]
    taxa = [list(study[subj].taxa.names.keys()) for subj in subjs]
    filt = [np.load(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/posteriors/filetring/Subject_{subj}/median.tsv') for subj in subjs]
    filt_times = [filt[subj].T.index.values for subj in subjs]
    if len(true_abundances) != len(filt):
        raise ValueError('The number of true abundances and filtering are not the same')
    rmse_inds = {}
    for subj in range(len(true_abundances)):
        if len(times[subj]) != len(filt_times[subj]):
            raise ValueError('The number of true times and filtering times are not the same')
        rsme_taxa = {}
        for taxon in range(len(taxa[subj])):
            rsme_taxa[taxa[subj][taxon]] = np.sqrt(mean_squared_error(true_abundances[subj][taxon], filt[subj].loc[taxon,:].values))
        rmse_inds[subjs[subj]] = rsme_taxa
    return rmse_inds

def rsme_filtering_seed(study_path, dataset, fwsim_path, seeds=[0, 3, 4, 23, 127], intra='sum', inter='sum', save=True):
    rsme_seeds = {}
    for seed in seeds:
        rmse_inds = rsme_filtering(study_path, dataset, fwsim_path, seed)
        rsmes_subj_agg = []
        if intra == 'sum':
            for subj in rmse_inds.keys():
                rsmes_subj_agg.append(np.sum(list(rmse_inds[subj].values())))
        elif intra == 'mean':
            for subj in rmse_inds.keys():
                rsmes_subj_agg.append(np.mean(list(rmse_inds[subj].values())))
        else:
            raise ValueError(f'Unknown intra: {intra}')
        if inter == 'sum':
            rmse_inds_agg = np.sum(rsmes_subj_agg)
        elif inter == 'mean':
            rmse_inds_agg = np.mean(rsmes_subj_agg)
        else:
            raise ValueError(f'Unknown inter: {inter}')
        rsme_seeds[seed] = rmse_inds_agg
    rsme_table = pd.DataFrame.from_dict(rsme_seeds, orient='index', columns=['RMSE'])
    rsme_table.index.name = 'Seed'
    if save:
        rsme_table.to_table(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/rmse_filtering.tsv', sep='\t')
    return rsme_table

In [None]:
def plot_data_fit_pred_by_subj(study_path, dataset, fwsim_path, seed=0, subj2plot='all', taxa2plot='all', fwsim=True, filt=True, subplots=True, save=True):
    if dataset == 'LF0':
        subjs = lf0
    elif dataset == 'HF0':
        subjs = hf0
    else:
        raise ValueError(f'Unknown dataset: {dataset}')
    study = md2.Study.load(f'{study_path}/{dataset}/mcnulty_{dataset}.pkl')
    true_abundances = [study[subj].matrix()['abs'] for subj in subjs]
    times = [study[subj].times for subj in subjs]
    taxa = [list(study[subj].taxa.names.keys()) for subj in subjs]
    if subj2plot != 'all':
        subjs = subj2plot
    if taxa2plot != 'all':
        taxa = taxa2plot
    if fwsim:
        fwsim = [np.load(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/posteriors/forward_simulation/Subject_{subj}/median.tsv') for subj in subjs]
        start = 0
        step = 0.01
        end = 40
        sim_times = np.arange(start, end+step+step, step)
    if filt:
        filt = [np.load(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/posteriors/filetring/Subject_{subj}/median.tsv') for subj in subjs]
        filt_times = [filt[subj].T.index.values for subj in subjs]
        if len(times) != len(filt_times):
            raise ValueError('The number of true times and filtering times are not the same')
    if subplots:
        fig, axs = plt.subplots(5, 3, figsize=(20, 20))
        row = 0
        col = 0
        for subj in range(len(subjs)):
            for taxon in range(len(taxa[subj])):
                if fwsim:
                    fwsim_mean = fwsim[subj].mean(axis=0)[taxon]
                    fwsim_std = fwsim[subj].mean(axis=0)[taxon]
                    axs[row, col].plot(sim_times, fwsim_mean, color=taxa_color[taxa[subj][taxon]], linewidth=2)
                    axs[row, col].fill_between(sim_times, fwsim_mean-fwsim_std, fwsim_mean+fwsim_std, color=taxa_color[taxa[subj][taxon]], alpha=0.2)
                if filt:
                    axs[row, col].plot(filt_times[subj], filt[subj].loc[taxon,:].values, marker='.', linestyle='--', linewidth=0.8, color=taxa_color[taxa[subj][taxon]])
                axs[row, col].plot(times[subj], true_abundances[subj][taxon], label=taxa[subj][taxon])
                if dataset == 'LF0':
                    axs.fill_between(x=[14,26], y1=1, y2=1e12, color='grey', alpha=0.2)
                elif dataset == 'HF0':
                    axs.fill_between(x=[0,12], y1=1, y2=1e12, color='grey', alpha=0.2)
                    axs.fill_between(x=[28,40], y1=1, y2=1e12, color='grey', alpha=0.2)
                axs[row, col].set_title(f'Sunject {subjs[subj]}')
                axs.ylim(1e4, 1e10)
                axs[row, col].set_yscale('log')
                col += 1
                if col == 3:
                    col = 0
                    row += 1
        handles, labels = axs[0,0].get_legend_handles_labels()
        fig.legend(handles, labels, loc='lower center', ncol=4)
        fig.supxlabel('Time (Days)', y=0.08)
        fig.supylabel('$Log_{10}$ Abundance', x=0.08)
        fig.suptitle(f'{dataset} abundance as a function of time', y=0.92)
        if save:
            plt.savefig(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/posteriors/Sims_subplots.pdf', bbox_inches='tight')
            plt.close()
        else:
            plt.show()
    for subj in range(len(subjs)):
        for taxon in range(len(taxa[subj])):
            if fwsim:
                fwsim_mean = fwsim[subj].mean(axis=0)[taxon]
                fwsim_std = fwsim[subj].mean(axis=0)[taxon]
                plt.plot(sim_times, fwsim_mean, color=taxa_color[taxa[subj][taxon]], linewidth=2)
                plt.fill_between(sim_times, fwsim_mean-fwsim_std, fwsim_mean+fwsim_std, color=taxa_color[taxa[subj][taxon]], alpha=0.2)
            if filt:
                plt.plot(filt_times[subj], filt[subj].loc[taxon,:].values, marker='.', linestyle='--', linewidth=0.8, color=taxa_color[taxa[subj][taxon]])
            plt.plot(times[subj], true_abundances[subj][taxon], label=taxa[subj][taxon])
            if dataset == 'LF0':
                plt.fill_between(x=[14,26], y1=1, y2=1e12, color='grey', alpha=0.2)
            elif dataset == 'HF0':
                plt.fill_between(x=[0,12], y1=1, y2=1e12, color='grey', alpha=0.2)
                plt.fill_between(x=[28,40], y1=1, y2=1e12, color='grey', alpha=0.2)
            plt.title(f'{dataset} abundance as a function of time')
            plt.yscale('log')
            plt.ylim(1e4, 1e10)
            plt.legend()
            plt.xlabel('Time (Days)')
            plt.ylabel('$Log_{10}$ Abundance')
            if save:
                plt.savefig(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/posteriors/Sims_subject{subj}.pdf', bbox_inches='tight')
                plt.close()
            else:
                plt.show()

In [None]:
def plot_data_pred_mean(study_path, dataset, fwsim_path, seed=0, taxa2plot='all', points=True, save=True):
    if dataset == 'LF0':
        subjs = lf0
    elif dataset == 'HF0':
        subjs = hf0
    else:
        raise ValueError(f'Unknown dataset: {dataset}')
    study = md2.Study.load(f'{study_path}/{dataset}/mcnulty_{dataset}.pkl')
    true_abundances = [study[subj].matrix()['abs'] for subj in subjs]
    times = [study[subj].times for subj in subjs]
    taxa = [list(study[subj].taxa.names.keys()) for subj in subjs]
    if taxa2plot != 'all':
        taxa = taxa2plot
    if fwsim:
        fwsim = [np.load(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/posteriors/forward_simulation/Subject_{subj}/median.tsv') for subj in subjs]
        start = 0
        step = 0.01
        end = 40
        sim_times = np.arange(start, end+step+step, step)
        fwsim_mean = np.mean([fwsim[subj].mean(axis=0) for subj in subjs], axis=0)
        fwsim_std = np.std([fwsim[subj].mean(axis=0) for subj in subjs], axis=0)
    for taxon in range(len(taxa[subjs[0]])):
        fwsim_mean = fwsim[taxon]
        fwsim_std = fwsim[taxon]
        plt.plot(sim_times, fwsim_mean, color=taxa_color[taxa[0][taxon]], linewidth=2, legend_label=taxa[0][taxon])
        plt.fill_between(sim_times, fwsim_mean-fwsim_std, fwsim_mean+fwsim_std, color=taxa_color[taxa[0][taxon]], alpha=0.2)
        if points:
            for subj in range(len(subjs)):
                plt.plot(times[subj], true_abundances[subj][taxon], marker='.', linestyle='none', color=taxa_color[taxa[subj][taxon]])
        if dataset == 'LF0':
            plt.fill_between(x=[14,26], y1=1, y2=1e12, color='grey', alpha=0.2)
        elif dataset == 'HF0':
            plt.fill_between(x=[0,12], y1=1, y2=1e12, color='grey', alpha=0.2)
            plt.fill_between(x=[28,40], y1=1, y2=1e12, color='grey', alpha=0.2)
        plt.title(f'{dataset} abundance as a function of time')
        plt.yscale('log')
        plt.ylim(1e4, 1e10)
        plt.legend()
        plt.xlabel('Time (Days)')
        plt.ylabel('$Log_{10}$ Abundance')
    if save:
        plt.savefig(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/posteriors/Sims.pdf', bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
def plot_data_pred_mean_area(study_path, dataset, fwsim_path, seed=0, save=True):
    if dataset == 'LF0':
        subjs = lf0
    elif dataset == 'HF0':
        subjs = hf0
    else:
        raise ValueError(f'Unknown dataset: {dataset}')
    study = md2.Study.load(f'{study_path}/{dataset}/mcnulty_{dataset}.pkl')
    taxa = [list(study[subj].taxa.names.keys()) for subj in subjs]
    fwsim = [np.load(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/posteriors/forward_simulation/Subject_{subj}/median.tsv') for subj in subjs]
    fwsim_mean = np.mean([fwsim[subj].mean(axis=0) for subj in subjs], axis=0)
    area = pd.DataFrame(fwsim_mean.T, columns=taxa[0], index=np.arange(fwsim_mean.shape[1]))
    area.index = area.index*0.01
    area.plot.area(stacked=True, 
        color=cols, 
        rot=0, 
        linewidth=0.3,
        ylim=(1e5, 5.2e8))
    if dataset == 'LF0':
        plt.fill_between(x=[14,26], y1=1, y2=1e12, color='grey', alpha=0.2)
    elif dataset == 'HF0':
        plt.fill_between(x=[0,12], y1=1, y2=1e12, color='grey', alpha=0.2)
        plt.fill_between(x=[28,40], y1=1, y2=1e12, color='grey', alpha=0.2)
    plt.title(f'{dataset} abundance as a function of time')
    plt.yscale('log')
    plt.ylim(1e4, 1e10)
    plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left')
    plt.xlabel('Time (Days)')
    plt.ylabel('$Log_{10}$ Abundance')
    if save:
        plt.savefig(f'{fwsim_path}/mcnulty_{dataset}_seed{seed}/posteriors/Sims.pdf', bbox_inches='tight')
        plt.close()
    else:
        plt.show()